File size: 2,128 Bytes
d353777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn.functional as F


glu_fwd_codestring = """
template <typename T> T glu_fwd(T x, T y) {
    return float(y) / (1.0f + ::exp(-float(x)));
}
"""
glu_bwd_codestring = """
template <typename T> T glu_bwd(T x, T y, T g, T& dx, T& dy) {
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
    dx = x_sigmoid * (1.0f - x_sigmoid) * float(g) * float(y);
    dy = x_sigmoid * float(g);
}
"""

glu_bwd_with_output_codestring = """
template <typename T> T glu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) {
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
    dx = x_sigmoid * (1.0f - x_sigmoid) * float(g) * float(y);
    dy = x_sigmoid * float(g);
    z = x_sigmoid * float(y);
}
"""

glu_fwd = torch.cuda.jiterator._create_jit_fn(glu_fwd_codestring)
glu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(glu_bwd_codestring, num_outputs=2)
glu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(glu_bwd_with_output_codestring, num_outputs=3)


class GLULinearFunction(torch.autograd.Function):
    r"""
    Gated Linear Unit (GLU) function followed by a linear transformation.

    .. math::
        \text{GLULinear}(x, y, W, b) = (sh(x) * y) W + b

    This simple wrap discards the intermediate results of GLU(x, y) to save memory.
    """

    @staticmethod
    def forward(ctx, x, y, weight, bias):
        z = glu_fwd(x, y)
        out = F.linear(z.to(weight.dtype), weight, bias)
        # We don't store z, will be recomputed in the backward pass to save memory
        ctx.save_for_backward(x, y, weight)
        ctx.linear_bias_is_none = bias is None
        return out

    @staticmethod
    def backward(ctx, dout, *args):
        x, y, weight = ctx.saved_tensors
        dout = dout.reshape(-1, dout.shape[-1])
        dz = F.linear(dout, weight.t()).view_as(x)
        dx, dy, z = glu_bwd_with_output(x, y, dz)
        dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1]))
        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
        return dx, dy, dlinear_weight, dlinear_bias

glu_linear = GLULinearFunction.apply