diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index ebbd70dd30..8cd6548dcc 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -83,6 +83,17 @@ def train_cifar(): print(X_train.shape, Y_train.shape) Xt, Yt = fetch_batch(X_test, Y_test, BS=BS) model = SpeedyResNet() + + # init weights with torch + if getenv("TORCHWEIGHTS"): + from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch + torch_model = SpeedyResNetTorch() + model_state_dict = optim.get_state_dict(model) + torch_state_dict = torch_model.state_dict() + for k,v in torch_state_dict.items(): + print(f"initting {k} from torch") + model_state_dict[k].assign(Tensor(v.detach().numpy())).realize() + if getenv("ADAM"): optimizer = optim.Adam(optim.get_parameters(model), lr=Tensor([0.001]).realize()) else: diff --git a/examples/hlb_cifar10_torch.py b/examples/hlb_cifar10_torch.py index 73fca1e51a..9a2a9afafd 100644 --- a/examples/hlb_cifar10_torch.py +++ b/examples/hlb_cifar10_torch.py @@ -30,28 +30,28 @@ class ConvGroup(nn.Module): x = self.norm[2](self.conv[2](x) * mult).relu() return x + residual +class GlobalMaxPool(nn.Module): + def forward(self, x): return torch.amax(x, dim=(2,3)) + class SpeedyResNet(nn.Module): def __init__(self): super().__init__() # TODO: add whitening - self.ic = nn.Conv2d(3, 64, kernel_size=1) - self.ib = nn.BatchNorm2d(64, track_running_stats=False, eps=1e-12, momentum=0.8) self.net = nn.ModuleList([ + nn.Conv2d(3, 64, kernel_size=1), + nn.BatchNorm2d(64, track_running_stats=False, eps=1e-12, momentum=0.8), + nn.ReLU(), ConvGroup(64, 128, short=False), ConvGroup(128, 256, short=True), ConvGroup(256, 512, short=False), + GlobalMaxPool(), + nn.Linear(512, num_classes, bias=False) ]) - self.lin = nn.Linear(512, num_classes, bias=False) # note, pytorch just uses https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html instead of log_softmax def forward(self, x): - x = self.ic(x) - x = self.ib(x) - x = x.relu() for layer in self.net: x = layer(x) - x = torch.amax(x, dim=(2,3)) - x = self.lin(x) return x.log_softmax(-1) def train_step_jitted(model, optimizer, X, Y): diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py new file mode 100644 index 0000000000..6163e9a2a0 --- /dev/null +++ b/tinygrad/runtime/ops_hip.py @@ -0,0 +1,57 @@ +import numpy as np +import ctypes +from pyhip import hip, hiprtc # type: ignore +from tinygrad.helpers import DEBUG +from tinygrad.ops import Compiled +from tinygrad.runtime.lib import RawBufferCopyInOut +from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage + +# The default HIP stream is used for everything. + +class RawHIPBuffer(RawBufferCopyInOut): + def __init__(self, size, dtype): + self.buf_sz = size * dtype.itemsize + super().__init__(size, dtype, hip.hipMalloc(self.buf_sz)) + def _copyin(self, x:np.ndarray): hip.hipMemcpyAsync_htod(self._buf, x.ctypes.data, self.buf_sz, 0) + def _copyout(self, x:np.ndarray): hip.hipMemcpyAsync_dtoh(x.ctypes.data, self._buf, self.buf_sz, 0) + +class HIPProgram: + def __init__(self, name:str, prg:str, binary=False): + try: + if not binary: + prog = hiprtc.hiprtcCreateProgram(prg, name, [], []) + device_properties = hip.hipGetDeviceProperties(0) + hiprtc.hiprtcCompileProgram(prog, [f'--offload-arch={device_properties.gcnArchName}']) + prg = hiprtc.hiprtcGetCode(prog) + except hip.hipError as e: + if DEBUG >= 3: print("FAILED TO BUILD", prg) + raise e + if DEBUG >= 5: print(prg) + module = hip.hipModuleLoadData(prg) + self.prg = hip.hipModuleGetFunction(module, name) + + def __call__(self, global_size, local_size, *args, wait=False): + local_size = (local_size + [1] * (3 - len(local_size))) if local_size is not None else (1,1,1) + global_size = global_size + [1] * (3 - len(global_size)) + assert all(x%y == 0 for x,y in zip(global_size, local_size)), f"local:{local_size} must divide global:{global_size}" + global_size = [x//y for x,y in zip(global_size, local_size)] + if wait: + start, end = hip.hipEventCreate(), hip.hipEventCreate() + hip.hipEventRecord(start) + class PackageStruct(ctypes.Structure): + _fields_ = [(f'field{idx}', ctypes.c_void_p) for idx in range(len(args))] + struct = PackageStruct(*[data._buf for data in args]) + hip.hipModuleLaunchKernel(self.prg, global_size[0], global_size[1], global_size[2], local_size[0], local_size[1], local_size[2], 0, 0, struct) + if wait: + hip.hipEventRecord(end) + hip.hipEventSynchronize(end) + return hip.hipEventElapsedTime(start, end)*1e-3 + +class HIPCodegen(CStyleCodegen): + lang = CStyleLanguage( + kernel_prefix = "#define INFINITY (__builtin_inff())\nextern \"C\" __global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", + half_prekernel = "", + gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)], + lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]) + +HIPBuffer = Compiled(RawHIPBuffer, HIPCodegen, HIPProgram, hip.hipDeviceSynchronize)