diff --git a/test/test_multitensor.py b/test/test_multitensor.py index eb627c661b..b1af690231 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -358,23 +358,11 @@ class TestMultiTensor(unittest.TestCase): for p in get_parameters(m): p.shard_(devices_2).realize() GlobalCounters.reset() shard_output = m(fake_image_sharded).log_softmax().realize() - assert shard_output.lazydata.src[0].shape == (1, 1000) - assert shard_output.lazydata.src[1].shape == (1, 1000) shard_output_np = shard_output.numpy() np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6) - @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM", "CPU"), "slow, and flaky on LLVM/CPU") - @unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") - def test_data_parallel_resnet_train_step(self): - from extra.models.resnet import ResNet18 + def _test_model_train_step(self, m, fake_image, labels): from tinygrad.nn.optim import LARS - - fake_image = Tensor.rand((2, 3, 224//8, 224//8)) - fake_image_sharded = fake_image.shard(devices_2, axis=0) - labels = Tensor.randint(2, low=0, high=1000) - labels_sharded = labels.shard(devices_2, axis=0) - - m = ResNet18() optimizer = LARS(get_parameters(m), 0.1) # set requires_grad for all params optimizer.zero_grad() @@ -383,6 +371,8 @@ class TestMultiTensor(unittest.TestCase): output.backward() grad = m.conv1.weight.grad.numpy() + fake_image_sharded = fake_image.shard(devices_2, axis=0) + labels_sharded = labels.shard(devices_2, axis=0) for p in get_parameters(m): p.shard_(devices_2).realize() GlobalCounters.reset() optimizer.zero_grad() @@ -392,6 +382,26 @@ class TestMultiTensor(unittest.TestCase): # sometimes there is zeros in these grads... why? np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5) + @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM", "CPU"), "slow, and flaky on LLVM/CPU") + @unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers") + def test_data_parallel_resnet_train_step(self): + from extra.models.resnet import ResNet18 + fake_image = Tensor.rand((2, 3, 224//8, 224//8)) + labels = Tensor.randint(2, low=0, high=1000) + m = ResNet18() + self._test_model_train_step(m, fake_image, labels) + + def test_data_parallel_simple_train_step(self): + class Model: + def __init__(self): self.conv1 = nn.Linear(128,128) + def __call__(self, x): return self.conv1(x) + def load_from_pretrained(self): pass + + fake_image = Tensor.rand((128,)) + labels = Tensor.randint(2, low=0, high=127) + m = Model() + self._test_model_train_step(m, fake_image, labels) + def test_assign_kv_cache_multi(self): bsz, max_context = 2, 8 @@ -833,7 +843,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): a = t.shrink(((0, 2), (0, 8))) a.schedule() assert a.shape == (2, 8) - assert a.lazydata.real == (True, False, False, False) with self.assertRaises(AssertionError): # cannot pad sharded and non-sharded axis at the same time @@ -848,7 +857,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): p = a.pad(((0, 6), (0, 0))) p.schedule() assert p.shape == (8, 8) - assert p.lazydata.real == (True, True, True, True) @given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16])) def test_ops(self, dtype): @@ -861,7 +869,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): b = Tensor(t.numpy()[0+2*i:2+2*i]) assert a.shape == b.shape == (2, 8) np.testing.assert_allclose(a.numpy(), b.numpy()) - assert a.lazydata.real == tuple(i==j for j in range(4)) # cast np.testing.assert_allclose(a.float().numpy(), b.float().numpy()) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index a618fb3462..b775bcbcd5 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -239,7 +239,7 @@ class TinyJit(Generic[ReturnType]): return ret def add(self, ei:ExecItem): - self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None])) + self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars)) def reset(self): assert self.fxn is not None, "can't reset without function" @@ -310,7 +310,8 @@ class TinyJit(Generic[ReturnType]): # Exclude buffers involved in transfer ops to preserve parallelism. noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs} assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ") - jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache] + jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None], + item.metadata, item.fixedvars) for item in jit_cache] input_replace = get_input_replace(jit_cache, input_buffers) if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found") diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index e8e10b19f2..9fdf596c17 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,6 +1,6 @@ from typing import Optional, cast, Generator import time, pprint -from dataclasses import dataclass, replace +from dataclasses import dataclass, replace, field from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer @@ -120,8 +120,9 @@ class ExecItem: prg: Runner bufs: list[Optional[Buffer]] metadata: Optional[tuple[Metadata, ...]] = None + fixedvars: dict[Variable, int] = field(default_factory=dict) def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]: - var_vals = {} if _var_vals is None else _var_vals + var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars) bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs] et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2) if do_update_stats: @@ -147,7 +148,8 @@ si_lowerer = PatternMatcher([ if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \ else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))), ]) -def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata) +def lower_schedule_item(si:ScheduleItem) -> ExecItem: + return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars) def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]: while len(schedule): @@ -177,7 +179,7 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, i ei.run(var_vals, do_update_stats=do_update_stats) # validate the output buffers match (NOTE: this is assuming the output is buffer 0) - lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats) + lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats) import numpy as np np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3) else: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 981a56473f..27bf3dd44d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from collections import deque, defaultdict from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers from tinygrad.device import Buffer @@ -11,6 +11,7 @@ class ScheduleItem: ast: UOp bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] = () + fixedvars: dict[Variable, int] = field(default_factory=dict) # **** unbind Variables