improve multi tests + add support for fixedvars [pr] (#10281)

* improve multi tests + add support for fixedvars [pr]

* add support for fixedvars
This commit is contained in:
George Hotz
2025-05-13 09:27:00 -07:00
committed by GitHub
parent 8a906cb124
commit 5f64bbc63d
4 changed files with 34 additions and 23 deletions

View File

@@ -358,23 +358,11 @@ class TestMultiTensor(unittest.TestCase):
for p in get_parameters(m): p.shard_(devices_2).realize()
GlobalCounters.reset()
shard_output = m(fake_image_sharded).log_softmax().realize()
assert shard_output.lazydata.src[0].shape == (1, 1000)
assert shard_output.lazydata.src[1].shape == (1, 1000)
shard_output_np = shard_output.numpy()
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM", "CPU"), "slow, and flaky on LLVM/CPU")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_data_parallel_resnet_train_step(self):
from extra.models.resnet import ResNet18
def _test_model_train_step(self, m, fake_image, labels):
from tinygrad.nn.optim import LARS
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
fake_image_sharded = fake_image.shard(devices_2, axis=0)
labels = Tensor.randint(2, low=0, high=1000)
labels_sharded = labels.shard(devices_2, axis=0)
m = ResNet18()
optimizer = LARS(get_parameters(m), 0.1) # set requires_grad for all params
optimizer.zero_grad()
@@ -383,6 +371,8 @@ class TestMultiTensor(unittest.TestCase):
output.backward()
grad = m.conv1.weight.grad.numpy()
fake_image_sharded = fake_image.shard(devices_2, axis=0)
labels_sharded = labels.shard(devices_2, axis=0)
for p in get_parameters(m): p.shard_(devices_2).realize()
GlobalCounters.reset()
optimizer.zero_grad()
@@ -392,6 +382,26 @@ class TestMultiTensor(unittest.TestCase):
# sometimes there is zeros in these grads... why?
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM", "CPU"), "slow, and flaky on LLVM/CPU")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
def test_data_parallel_resnet_train_step(self):
from extra.models.resnet import ResNet18
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
labels = Tensor.randint(2, low=0, high=1000)
m = ResNet18()
self._test_model_train_step(m, fake_image, labels)
def test_data_parallel_simple_train_step(self):
class Model:
def __init__(self): self.conv1 = nn.Linear(128,128)
def __call__(self, x): return self.conv1(x)
def load_from_pretrained(self): pass
fake_image = Tensor.rand((128,))
labels = Tensor.randint(2, low=0, high=127)
m = Model()
self._test_model_train_step(m, fake_image, labels)
def test_assign_kv_cache_multi(self):
bsz, max_context = 2, 8
@@ -833,7 +843,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
a = t.shrink(((0, 2), (0, 8)))
a.schedule()
assert a.shape == (2, 8)
assert a.lazydata.real == (True, False, False, False)
with self.assertRaises(AssertionError):
# cannot pad sharded and non-sharded axis at the same time
@@ -848,7 +857,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
p = a.pad(((0, 6), (0, 0)))
p.schedule()
assert p.shape == (8, 8)
assert p.lazydata.real == (True, True, True, True)
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
def test_ops(self, dtype):
@@ -861,7 +869,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
b = Tensor(t.numpy()[0+2*i:2+2*i])
assert a.shape == b.shape == (2, 8)
np.testing.assert_allclose(a.numpy(), b.numpy())
assert a.lazydata.real == tuple(i==j for j in range(4))
# cast
np.testing.assert_allclose(a.float().numpy(), b.float().numpy())

View File

@@ -239,7 +239,7 @@ class TinyJit(Generic[ReturnType]):
return ret
def add(self, ei:ExecItem):
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars))
def reset(self):
assert self.fxn is not None, "can't reset without function"
@@ -310,7 +310,8 @@ class TinyJit(Generic[ReturnType]):
# Exclude buffers involved in transfer ops to preserve parallelism.
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
item.metadata, item.fixedvars) for item in jit_cache]
input_replace = get_input_replace(jit_cache, input_buffers)
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")

View File

@@ -1,6 +1,6 @@
from typing import Optional, cast, Generator
import time, pprint
from dataclasses import dataclass, replace
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
@@ -120,8 +120,9 @@ class ExecItem:
prg: Runner
bufs: list[Optional[Buffer]]
metadata: Optional[tuple[Metadata, ...]] = None
fixedvars: dict[Variable, int] = field(default_factory=dict)
def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
var_vals = {} if _var_vals is None else _var_vals
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
if do_update_stats:
@@ -147,7 +148,8 @@ si_lowerer = PatternMatcher([
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
])
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars)
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
while len(schedule):
@@ -177,7 +179,7 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, i
ei.run(var_vals, do_update_stats=do_update_stats)
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats)
import numpy as np
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
else:

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from collections import deque, defaultdict
from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
from tinygrad.device import Buffer
@@ -11,6 +11,7 @@ class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[Variable, int] = field(default_factory=dict)
# **** unbind Variables