diff --git a/examples/mlperf/model_spec.py b/examples/mlperf/model_spec.py index 7c9b839902..ad1716275c 100644 --- a/examples/mlperf/model_spec.py +++ b/examples/mlperf/model_spec.py @@ -27,6 +27,8 @@ if __name__ == "__main__": from models.unet3d import UNet3D mdl = UNet3D() mdl.load_from_pretrained() + img = Tensor.randn(1, 1, 5, 224, 224) + test_model(mdl, img) # RNNT diff --git a/models/unet3d.py b/models/unet3d.py index ff6935aab6..687e4dc74c 100644 --- a/models/unet3d.py +++ b/models/unet3d.py @@ -1,31 +1,42 @@ # https://github.com/wolny/pytorch-3dunet from pathlib import Path -from extra.utils import download_file, fake_torch_load +from extra.utils import download_file, fake_torch_load, get_child import tinygrad.nn as nn class SingleConv: def __init__(self, in_channels, out_channels): self.groupnorm = nn.GroupNorm(1, in_channels) # 1 group? - self.conv = nn.Conv2d(in_channels, out_channels, (3,3,3), bias=False) + # TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False) def __call__(self, x): return self.conv(self.groupnorm(x)).relu() -def get_basic_module(c0, c1, c2): return {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)} +class BasicModule: + def __init__(self, c0, c1, c2): + self.basic_module = {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)} + def __call__(self, x): + return self.basic_module['SingleConv2'](self.basic_module['SingleConv1'](x)) class UNet3D: def __init__(self): ups = [16,32,64,128,256] - self.encoders = [get_basic_module(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)] - self.decoders = [get_basic_module(ups[-1-i] + ups[-2+i], ups[-2+i], ups[-2+i]) for i in range(3)] + self.encoders = [BasicModule(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)] + self.decoders = [BasicModule(ups[-1-i] + ups[-2-i], ups[-2-i], ups[-2-i]) for i in range(3)] self.final_conv = nn.Conv2d(32, 1, (1,1,1)) def __call__(self, x): - # TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3) - pass + intermediates = [x] + for e in self.encoders: intermediates.append(e(intermediates[-1])) + ret = intermediates[-1] + for d,i in zip(self.decoders, intermediates[:-1][::-1]): ret = d(ret.cat(i, dim=1)) + return ret def load_from_pretrained(self): fn = Path(__file__).parent.parent / "weights/unet-3d.ckpt" download_file("https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch", fn) - state = fake_torch_load(open(fn, "rb").read())['model_state_dict'] - for x in state.keys(): - print(x, state[x].shape) + state_dict = fake_torch_load(open(fn, "rb").read())['model_state_dict'] + for k, v in state_dict.items(): + print(k, v.shape) + obj = get_child(self, k) + assert obj.shape == v.shape, (k, obj.shape, v.shape) + obj.assign(v.numpy()) diff --git a/test/test_ops.py b/test/test_ops.py index ce323aa806..75b043f0f4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3,7 +3,7 @@ import time import numpy as np import unittest from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv +from tinygrad.helpers import getenv, IMAGE from tinygrad.lazy import Device FORWARD_ONLY = getenv("FORWARD_ONLY", 0) @@ -346,6 +346,18 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w).relu(), lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) + @unittest.skipIf(IMAGE>0, "no conv3d on images") + def test_simple_conv3d(self): + helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], + lambda x,w: torch.nn.functional.conv3d(x,w).relu(), + lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) + + @unittest.skipIf(IMAGE>0, "no conv3d on images") + def test_padded_conv3d(self): + helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], + lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(), + lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5) + def test_simple_conv2d_m4(self): helper_test_op([(1,16,18,18), (16,16,3,3)], lambda x,w: torch.nn.functional.conv2d(x,w).relu(), @@ -580,10 +592,10 @@ class TestOps(unittest.TestCase): for dim in range(-1, 3): helper_test_op([(45, 65, 3), (45, 65, 3), (45, 65, 3)], lambda x, y, z: torch.stack((x, y, z), dim=dim), lambda x, y, z: Tensor.stack([x, y, z], dim=dim)) - + with self.assertRaises(IndexError): Tensor.stack([x], dim=77) - + def test_repeat(self): x = Tensor.randn(45, 65, 3) base_repeats = [2, 4, 3] @@ -597,7 +609,7 @@ class TestOps(unittest.TestCase): with self.assertRaises(AssertionError): x.repeat((2, 0, 4)) - + def test_clip(self): helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2)) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 8d0c530e60..2da901f2dc 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -36,9 +36,9 @@ class BatchNorm2d: # TODO: is this good weight init? class Conv2d: def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): - self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else (kernel_size[0], kernel_size[1]) + self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups - self.weight = Tensor.glorot_uniform(out_channels, in_channels//groups, self.kernel_size[0], self.kernel_size[1]) + self.weight = Tensor.glorot_uniform(out_channels, in_channels//groups, *self.kernel_size) self.bias = Tensor.zeros(out_channels) if bias else None def __call__(self, x): @@ -65,7 +65,7 @@ class GroupNorm: if self.weight is None or self.bias is None: return x # elementwise_affine on channels - return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1) + return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) class LayerNorm: def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 64fb836cab..4199c99f11 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -95,7 +95,7 @@ class ASTRunner: if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2)) if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(28-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)")) GlobalCounters.kernel_count += 1 GlobalCounters.global_ops += self.op_estimate diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index c5505b8fa3..e40ed7a9b7 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -184,7 +184,7 @@ class ShapeTracker: def reshape(self, new_shape: Tuple[int, ...]): if self.shape == new_shape: return self - assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}" + assert all(isinstance(x, int) and x > 0 for x in new_shape), f"shape must be ints and can't contain 0 or negative numbers {new_shape}" assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" # check if this is adding or removing 1s (only) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7c2de6f100..3ed17efa6a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -281,7 +281,10 @@ class Tensor: return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) # (padding_left, padding_right, padding_top, padding_bottom) - def pad2d(self, padding:Union[List[int], Tuple[int, ...]]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1]))) + def pad2d(self, padding:Union[List[int], Tuple[int, ...]]): + slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1] + return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc) + @property def T(self) -> Tensor: return self.transpose() def transpose(self, ax1=1, ax2=0) -> Tensor: @@ -360,14 +363,14 @@ class Tensor: def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: - (bs,cin_,_,_), (cout,cin,H,W) = self.shape, weight.shape - assert groups*cin == cin_, f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" - padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]]) + (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] + assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" + padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) >= 4 else [padding[1], padding[1], padding[0], padding[0]]) # conv2d is a pooling op (with padding) - x = self.pad2d(padding_)._pool((H,W), stride, dilation) # (bs, groups*cin, oy, ox, H, W) - rcout, oy, ox = cout//groups, x.shape[2], x.shape[3] - x = x.reshape(bs, groups, cin, 1, oy, ox, H, W).expand(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7) + x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W) + rcout, oyx = cout//groups, x.shape[2:-len(HW)] + x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # expand the channels with the pool # TODO: this reduces the number of kernels, but it's slower! @@ -375,9 +378,9 @@ class Tensor: #rcout, oy, ox = x.shape[2:5] #x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7) - # conv! broadcasted to (bs, groups, rcout, oy, ox, cin, H, W) - ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1), keepdim=True).reshape(bs, cout, oy, ox) - return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) + # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) + ret = (x * weight.reshape(1, groups, rcout, *[1 for _ in range(len(oyx))], cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) + return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))])) def dot(self, w:Tensor) -> Tensor: x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])