diff --git a/examples/transformer.py b/examples/transformer.py index 27dbb784be..c9e03b5ece 100755 --- a/examples/transformer.py +++ b/examples/transformer.py @@ -29,7 +29,7 @@ def layernorm(x, sz, eps=1e-5): layer_mean = x.mean(axis=(1,)) y = (x - layer_mean.reshape(shape=[-1, 1])) layer_var = (y*y).mean(axis=(1,)) - ret = y.div(layer_var.add(eps).reshape(shape=[-1, 1])) + ret = y.div(layer_var.add(eps).reshape(shape=[-1, 1]).sqrt()) return ret.reshape(shape=in_shape) class TransformerBlock: @@ -52,7 +52,8 @@ class TransformerBlock: def __call__(self, x): # bs x T x embed_dim bs = x.shape[0] - inputs = x.reshape(shape=(-1, self.num_heads * self.head_size)) + embed_dim = self.num_heads * self.head_size + inputs = x.reshape(shape=(-1, embed_dim)) # run multi head attention (bs, T, num_heads, head_size) query, key, value = [inputs.dot(y) \ @@ -66,13 +67,15 @@ class TransformerBlock: score = query.dot(key) * (1 / np.sqrt(self.head_size)) weights = score.softmax() # (bs, num_heads, T, T) attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size) - x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final) - x = layernorm(x, self.num_heads * self.head_size) + + x = inputs + attention.reshape(shape=(-1, embed_dim)).dot(self.final) + x = layernorm(x, embed_dim) x = x + x.dot(self.ff1).relu().dot(self.ff2) - x = layernorm(x, self.num_heads * self.head_size) - return x.reshape(shape=(bs, -1, self.num_heads * self.head_size)) + x = layernorm(x, embed_dim) + return x.reshape(shape=(bs, -1, embed_dim)) class Transformer: + # L = cnt, H = embed_dim, A = num_heads def __init__(self, syms, maxlen, cnt, embed_dim, num_heads): self.maxlen, self.syms = maxlen, syms self.embed = Tensor.uniform(maxlen+syms, embed_dim, requires_grad=False)