mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
sum hook override
This commit is contained in:
@@ -13,7 +13,15 @@ import io
|
||||
from extra.utils import fetch
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from models.transformer import TransformerBlock, layernorm
|
||||
|
||||
def layernorm(x, sz, eps=1e-5):
|
||||
in_shape = x.shape
|
||||
x = x.reshape(shape=(-1, sz))
|
||||
layer_mean = x.mean(axis=(-1,)).reshape(shape=[-1, 1])
|
||||
y = (x - layer_mean)
|
||||
layer_var = (y*y).mean(axis=(-1,))
|
||||
ret = y.div(layer_var.add(eps).reshape(shape=[-1, 1]).sqrt())
|
||||
return ret.reshape(shape=in_shape)
|
||||
|
||||
class ViTBlock:
|
||||
def __init__(self, embed_dim, num_heads, ff_dim):
|
||||
@@ -67,33 +75,28 @@ class ViTBlock:
|
||||
return x.reshape(shape=(bs, -1, embed_dim))
|
||||
|
||||
class ViT:
|
||||
def __init__(self):
|
||||
self.conv_weight = Tensor.uniform(192, 3, 16, 16)
|
||||
self.conv_bias = Tensor.zeros(192)
|
||||
self.cls_token = Tensor.ones(1, 1, 192)
|
||||
self.tbs = [ViTBlock(embed_dim=192, num_heads=3, ff_dim=768) for i in range(12)]
|
||||
self.pos_embed = Tensor.ones(1, 197, 192)
|
||||
self.head = (Tensor.uniform(192, 1000), Tensor.zeros(1000))
|
||||
self.norm = (Tensor.uniform(192), Tensor.zeros(192))
|
||||
def __init__(self, embed_dim=192):
|
||||
self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16)
|
||||
self.conv_bias = Tensor.zeros(embed_dim)
|
||||
self.cls_token = Tensor.ones(1, 1, embed_dim)
|
||||
self.tbs = [ViTBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768) for i in range(12)]
|
||||
self.pos_embed = Tensor.ones(1, 197, embed_dim)
|
||||
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
|
||||
self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim))
|
||||
|
||||
def patch_embed(self, x):
|
||||
x = x.conv2d(self.conv_weight, stride=16)
|
||||
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
|
||||
x = x.reshape(shape=(x.shape[0], 192, -1)).transpose(order=(0,2,1))
|
||||
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
pe = self.patch_embed(x)
|
||||
print(x.shape)
|
||||
# TODO: expand cls_token for batch
|
||||
x = self.cls_token.cat(pe, dim=1) + self.pos_embed
|
||||
print(x.shape)
|
||||
print(x.mean())
|
||||
for l in self.tbs:
|
||||
x = l(x)
|
||||
print(x.mean())
|
||||
print(x.shape)
|
||||
x = layernorm(x, 192).affine(self.norm)
|
||||
x = layernorm(x, x.shape[-1]).affine(self.norm)
|
||||
return x[:, 0].affine(self.head)
|
||||
|
||||
Tensor.training = False
|
||||
@@ -101,8 +104,8 @@ m = ViT()
|
||||
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
dat = np.load(io.BytesIO(fetch("https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz")))
|
||||
for x in dat.keys():
|
||||
print(x, dat[x].shape, dat[x].dtype)
|
||||
#for x in dat.keys():
|
||||
# print(x, dat[x].shape, dat[x].dtype)
|
||||
|
||||
m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
|
||||
m.conv_bias.assign(dat['embedding/bias'])
|
||||
@@ -179,8 +182,6 @@ print(x.shape, xp.shape)
|
||||
print(np.max(x.data), np.max(xp.detach().numpy()))
|
||||
print(np.max(np.abs(x.data - xp.detach().numpy())))
|
||||
|
||||
|
||||
|
||||
exit(0)
|
||||
"""
|
||||
|
||||
@@ -195,9 +196,7 @@ exit(0)
|
||||
out = m.forward(Tensor(img))
|
||||
outnp = out.cpu().data.ravel()
|
||||
choice = outnp.argmax()
|
||||
print(out.shape, choice, outnp[choice])
|
||||
|
||||
print(lbls[choice])
|
||||
print(out.shape, choice, outnp[choice], lbls[choice])
|
||||
|
||||
#lookup = dict([x.split(" ") for x in open("cache/classids.txt").read().strip().split("\n")])
|
||||
#cls = open("cache/imagenet21k_wordnet_ids.txt").read().strip().split("\n")
|
||||
|
||||
@@ -206,6 +206,10 @@ class Tensor:
|
||||
def dot(self, w):
|
||||
return self.matmul(w)
|
||||
|
||||
# override for sum to support keepdim
|
||||
def sum(self, axis=None):
|
||||
return self._sum(axis=axis)
|
||||
|
||||
def mean(self, axis=None):
|
||||
out = self.sum(axis=axis)
|
||||
return out * (np.prod(out.shape)/np.prod(self.shape))
|
||||
@@ -333,7 +337,10 @@ def register(name, fxn, device=Device.CPU):
|
||||
#f.cl_ctx, f.cl_queue, f.device = cl_ctx, cl_queue, tt.device
|
||||
f.device = tt.device
|
||||
return f.apply(f, *x, **kwargs)
|
||||
setattr(Tensor, name, dispatch)
|
||||
if getattr(Tensor, name, None) is not None:
|
||||
setattr(Tensor, "_"+name, dispatch)
|
||||
else:
|
||||
setattr(Tensor, name, dispatch)
|
||||
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
|
||||
setattr(Tensor, f"__{name}__", dispatch)
|
||||
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))
|
||||
|
||||
Reference in New Issue
Block a user