|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
glu_fwd_codestring = """ |
|
|
template <typename T> T glu_fwd(T x, T y) { |
|
|
return float(y) / (1.0f + ::exp(-float(x))); |
|
|
} |
|
|
""" |
|
|
glu_bwd_codestring = """ |
|
|
template <typename T> T glu_bwd(T x, T y, T g, T& dx, T& dy) { |
|
|
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); |
|
|
dx = x_sigmoid * (1.0f - x_sigmoid) * float(g) * float(y); |
|
|
dy = x_sigmoid * float(g); |
|
|
} |
|
|
""" |
|
|
|
|
|
glu_bwd_with_output_codestring = """ |
|
|
template <typename T> T glu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) { |
|
|
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); |
|
|
dx = x_sigmoid * (1.0f - x_sigmoid) * float(g) * float(y); |
|
|
dy = x_sigmoid * float(g); |
|
|
z = x_sigmoid * float(y); |
|
|
} |
|
|
""" |
|
|
|
|
|
glu_fwd = torch.cuda.jiterator._create_jit_fn(glu_fwd_codestring) |
|
|
glu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(glu_bwd_codestring, num_outputs=2) |
|
|
glu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(glu_bwd_with_output_codestring, num_outputs=3) |
|
|
|
|
|
|
|
|
class GLULinearFunction(torch.autograd.Function): |
|
|
r""" |
|
|
Gated Linear Unit (GLU) function followed by a linear transformation. |
|
|
|
|
|
.. math:: |
|
|
\text{GLULinear}(x, y, W, b) = (sh(x) * y) W + b |
|
|
|
|
|
This simple wrap discards the intermediate results of GLU(x, y) to save memory. |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x, y, weight, bias): |
|
|
z = glu_fwd(x, y) |
|
|
out = F.linear(z.to(weight.dtype), weight, bias) |
|
|
|
|
|
ctx.save_for_backward(x, y, weight) |
|
|
ctx.linear_bias_is_none = bias is None |
|
|
return out |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, dout, *args): |
|
|
x, y, weight = ctx.saved_tensors |
|
|
dout = dout.reshape(-1, dout.shape[-1]) |
|
|
dz = F.linear(dout, weight.t()).view_as(x) |
|
|
dx, dy, z = glu_bwd_with_output(x, y, dz) |
|
|
dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) |
|
|
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) |
|
|
return dx, dy, dlinear_weight, dlinear_bias |
|
|
|
|
|
glu_linear = GLULinearFunction.apply |
|
|
|