sum hook override

This commit is contained in:
George Hotz
2021-11-29 17:14:24 -05:00
parent 8097b8f7d6
commit 70544e7e9f
2 changed files with 30 additions and 24 deletions

View File

@@ -13,7 +13,15 @@ import io
from extra.utils import fetch
from tinygrad.tensor import Tensor
from models.transformer import TransformerBlock, layernorm
def layernorm(x, sz, eps=1e-5):
in_shape = x.shape
x = x.reshape(shape=(-1, sz))
layer_mean = x.mean(axis=(-1,)).reshape(shape=[-1, 1])
y = (x - layer_mean)
layer_var = (y*y).mean(axis=(-1,))
ret = y.div(layer_var.add(eps).reshape(shape=[-1, 1]).sqrt())
return ret.reshape(shape=in_shape)
class ViTBlock:
def __init__(self, embed_dim, num_heads, ff_dim):
@@ -67,33 +75,28 @@ class ViTBlock:
return x.reshape(shape=(bs, -1, embed_dim))
class ViT:
def __init__(self):
self.conv_weight = Tensor.uniform(192, 3, 16, 16)
self.conv_bias = Tensor.zeros(192)
self.cls_token = Tensor.ones(1, 1, 192)
self.tbs = [ViTBlock(embed_dim=192, num_heads=3, ff_dim=768) for i in range(12)]
self.pos_embed = Tensor.ones(1, 197, 192)
self.head = (Tensor.uniform(192, 1000), Tensor.zeros(1000))
self.norm = (Tensor.uniform(192), Tensor.zeros(192))
def __init__(self, embed_dim=192):
self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16)
self.conv_bias = Tensor.zeros(embed_dim)
self.cls_token = Tensor.ones(1, 1, embed_dim)
self.tbs = [ViTBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768) for i in range(12)]
self.pos_embed = Tensor.ones(1, 197, embed_dim)
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim))
def patch_embed(self, x):
x = x.conv2d(self.conv_weight, stride=16)
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
x = x.reshape(shape=(x.shape[0], 192, -1)).transpose(order=(0,2,1))
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
return x
def forward(self, x):
pe = self.patch_embed(x)
print(x.shape)
# TODO: expand cls_token for batch
x = self.cls_token.cat(pe, dim=1) + self.pos_embed
print(x.shape)
print(x.mean())
for l in self.tbs:
x = l(x)
print(x.mean())
print(x.shape)
x = layernorm(x, 192).affine(self.norm)
x = layernorm(x, x.shape[-1]).affine(self.norm)
return x[:, 0].affine(self.head)
Tensor.training = False
@@ -101,8 +104,8 @@ m = ViT()
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
dat = np.load(io.BytesIO(fetch("https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz")))
for x in dat.keys():
print(x, dat[x].shape, dat[x].dtype)
#for x in dat.keys():
# print(x, dat[x].shape, dat[x].dtype)
m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
m.conv_bias.assign(dat['embedding/bias'])
@@ -179,8 +182,6 @@ print(x.shape, xp.shape)
print(np.max(x.data), np.max(xp.detach().numpy()))
print(np.max(np.abs(x.data - xp.detach().numpy())))
exit(0)
"""
@@ -195,9 +196,7 @@ exit(0)
out = m.forward(Tensor(img))
outnp = out.cpu().data.ravel()
choice = outnp.argmax()
print(out.shape, choice, outnp[choice])
print(lbls[choice])
print(out.shape, choice, outnp[choice], lbls[choice])
#lookup = dict([x.split(" ") for x in open("cache/classids.txt").read().strip().split("\n")])
#cls = open("cache/imagenet21k_wordnet_ids.txt").read().strip().split("\n")

View File

@@ -206,6 +206,10 @@ class Tensor:
def dot(self, w):
return self.matmul(w)
# override for sum to support keepdim
def sum(self, axis=None):
return self._sum(axis=axis)
def mean(self, axis=None):
out = self.sum(axis=axis)
return out * (np.prod(out.shape)/np.prod(self.shape))
@@ -333,7 +337,10 @@ def register(name, fxn, device=Device.CPU):
#f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device
f.device = tt.device
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)
if getattr(Tensor, name, None) is not None:
setattr(Tensor, "_"+name, dispatch)
else:
setattr(Tensor, name, dispatch)
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
setattr(Tensor, f"__{name}__", dispatch)
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))