diff --git a/examples/vit.py b/examples/vit.py index 937e7b454a..5a714bfe37 100644 --- a/examples/vit.py +++ b/examples/vit.py @@ -13,7 +13,15 @@ import io from extra.utils import fetch from tinygrad.tensor import Tensor -from models.transformer import TransformerBlock, layernorm + +def layernorm(x, sz, eps=1e-5): + in_shape = x.shape + x = x.reshape(shape=(-1, sz)) + layer_mean = x.mean(axis=(-1,)).reshape(shape=[-1, 1]) + y = (x - layer_mean) + layer_var = (y*y).mean(axis=(-1,)) + ret = y.div(layer_var.add(eps).reshape(shape=[-1, 1]).sqrt()) + return ret.reshape(shape=in_shape) class ViTBlock: def __init__(self, embed_dim, num_heads, ff_dim): @@ -67,33 +75,28 @@ class ViTBlock: return x.reshape(shape=(bs, -1, embed_dim)) class ViT: - def __init__(self): - self.conv_weight = Tensor.uniform(192, 3, 16, 16) - self.conv_bias = Tensor.zeros(192) - self.cls_token = Tensor.ones(1, 1, 192) - self.tbs = [ViTBlock(embed_dim=192, num_heads=3, ff_dim=768) for i in range(12)] - self.pos_embed = Tensor.ones(1, 197, 192) - self.head = (Tensor.uniform(192, 1000), Tensor.zeros(1000)) - self.norm = (Tensor.uniform(192), Tensor.zeros(192)) + def __init__(self, embed_dim=192): + self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16) + self.conv_bias = Tensor.zeros(embed_dim) + self.cls_token = Tensor.ones(1, 1, embed_dim) + self.tbs = [ViTBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768) for i in range(12)] + self.pos_embed = Tensor.ones(1, 197, embed_dim) + self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000)) + self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim)) def patch_embed(self, x): x = x.conv2d(self.conv_weight, stride=16) x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1))) - x = x.reshape(shape=(x.shape[0], 192, -1)).transpose(order=(0,2,1)) + x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1)) return x def forward(self, x): pe = self.patch_embed(x) - print(x.shape) # TODO: expand cls_token for batch x = self.cls_token.cat(pe, dim=1) + self.pos_embed - print(x.shape) - print(x.mean()) for l in self.tbs: x = l(x) - print(x.mean()) - print(x.shape) - x = layernorm(x, 192).affine(self.norm) + x = layernorm(x, x.shape[-1]).affine(self.norm) return x[:, 0].affine(self.head) Tensor.training = False @@ -101,8 +104,8 @@ m = ViT() # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py dat = np.load(io.BytesIO(fetch("https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz"))) -for x in dat.keys(): - print(x, dat[x].shape, dat[x].dtype) +#for x in dat.keys(): +# print(x, dat[x].shape, dat[x].dtype) m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,0,1))) m.conv_bias.assign(dat['embedding/bias']) @@ -179,8 +182,6 @@ print(x.shape, xp.shape) print(np.max(x.data), np.max(xp.detach().numpy())) print(np.max(np.abs(x.data - xp.detach().numpy()))) - - exit(0) """ @@ -195,9 +196,7 @@ exit(0) out = m.forward(Tensor(img)) outnp = out.cpu().data.ravel() choice = outnp.argmax() -print(out.shape, choice, outnp[choice]) - -print(lbls[choice]) +print(out.shape, choice, outnp[choice], lbls[choice]) #lookup = dict([x.split(" ") for x in open("cache/classids.txt").read().strip().split("\n")]) #cls = open("cache/imagenet21k_wordnet_ids.txt").read().strip().split("\n") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 644093a0c3..2f844c6db5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -206,6 +206,10 @@ class Tensor: def dot(self, w): return self.matmul(w) + # override for sum to support keepdim + def sum(self, axis=None): + return self._sum(axis=axis) + def mean(self, axis=None): out = self.sum(axis=axis) return out * (np.prod(out.shape)/np.prod(self.shape)) @@ -333,7 +337,10 @@ def register(name, fxn, device=Device.CPU): #f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device f.device = tt.device return f.apply(f, *x, **kwargs) - setattr(Tensor, name, dispatch) + if getattr(Tensor, name, None) is not None: + setattr(Tensor, "_"+name, dispatch) + else: + setattr(Tensor, name, dispatch) if name in ['add', 'sub', 'mul', 'pow', 'matmul']: setattr(Tensor, f"__{name}__", dispatch) setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))