mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
improve multi tests + add support for fixedvars [pr] (#10281)
* improve multi tests + add support for fixedvars [pr] * add support for fixedvars
This commit is contained in:
@@ -358,23 +358,11 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for p in get_parameters(m): p.shard_(devices_2).realize()
|
||||
GlobalCounters.reset()
|
||||
shard_output = m(fake_image_sharded).log_softmax().realize()
|
||||
assert shard_output.lazydata.src[0].shape == (1, 1000)
|
||||
assert shard_output.lazydata.src[1].shape == (1, 1000)
|
||||
shard_output_np = shard_output.numpy()
|
||||
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM", "CPU"), "slow, and flaky on LLVM/CPU")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
|
||||
def test_data_parallel_resnet_train_step(self):
|
||||
from extra.models.resnet import ResNet18
|
||||
def _test_model_train_step(self, m, fake_image, labels):
|
||||
from tinygrad.nn.optim import LARS
|
||||
|
||||
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
|
||||
fake_image_sharded = fake_image.shard(devices_2, axis=0)
|
||||
labels = Tensor.randint(2, low=0, high=1000)
|
||||
labels_sharded = labels.shard(devices_2, axis=0)
|
||||
|
||||
m = ResNet18()
|
||||
optimizer = LARS(get_parameters(m), 0.1) # set requires_grad for all params
|
||||
|
||||
optimizer.zero_grad()
|
||||
@@ -383,6 +371,8 @@ class TestMultiTensor(unittest.TestCase):
|
||||
output.backward()
|
||||
grad = m.conv1.weight.grad.numpy()
|
||||
|
||||
fake_image_sharded = fake_image.shard(devices_2, axis=0)
|
||||
labels_sharded = labels.shard(devices_2, axis=0)
|
||||
for p in get_parameters(m): p.shard_(devices_2).realize()
|
||||
GlobalCounters.reset()
|
||||
optimizer.zero_grad()
|
||||
@@ -392,6 +382,26 @@ class TestMultiTensor(unittest.TestCase):
|
||||
# sometimes there is zeros in these grads... why?
|
||||
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM", "CPU"), "slow, and flaky on LLVM/CPU")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
|
||||
def test_data_parallel_resnet_train_step(self):
|
||||
from extra.models.resnet import ResNet18
|
||||
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
|
||||
labels = Tensor.randint(2, low=0, high=1000)
|
||||
m = ResNet18()
|
||||
self._test_model_train_step(m, fake_image, labels)
|
||||
|
||||
def test_data_parallel_simple_train_step(self):
|
||||
class Model:
|
||||
def __init__(self): self.conv1 = nn.Linear(128,128)
|
||||
def __call__(self, x): return self.conv1(x)
|
||||
def load_from_pretrained(self): pass
|
||||
|
||||
fake_image = Tensor.rand((128,))
|
||||
labels = Tensor.randint(2, low=0, high=127)
|
||||
m = Model()
|
||||
self._test_model_train_step(m, fake_image, labels)
|
||||
|
||||
def test_assign_kv_cache_multi(self):
|
||||
bsz, max_context = 2, 8
|
||||
|
||||
@@ -833,7 +843,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
a = t.shrink(((0, 2), (0, 8)))
|
||||
a.schedule()
|
||||
assert a.shape == (2, 8)
|
||||
assert a.lazydata.real == (True, False, False, False)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# cannot pad sharded and non-sharded axis at the same time
|
||||
@@ -848,7 +857,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
p = a.pad(((0, 6), (0, 0)))
|
||||
p.schedule()
|
||||
assert p.shape == (8, 8)
|
||||
assert p.lazydata.real == (True, True, True, True)
|
||||
|
||||
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
|
||||
def test_ops(self, dtype):
|
||||
@@ -861,7 +869,6 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
b = Tensor(t.numpy()[0+2*i:2+2*i])
|
||||
assert a.shape == b.shape == (2, 8)
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
assert a.lazydata.real == tuple(i==j for j in range(4))
|
||||
# cast
|
||||
np.testing.assert_allclose(a.float().numpy(), b.float().numpy())
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
return ret
|
||||
|
||||
def add(self, ei:ExecItem):
|
||||
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
|
||||
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars))
|
||||
|
||||
def reset(self):
|
||||
assert self.fxn is not None, "can't reset without function"
|
||||
@@ -310,7 +310,8 @@ class TinyJit(Generic[ReturnType]):
|
||||
# Exclude buffers involved in transfer ops to preserve parallelism.
|
||||
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
|
||||
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
||||
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
|
||||
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
|
||||
item.metadata, item.fixedvars) for item in jit_cache]
|
||||
|
||||
input_replace = get_input_replace(jit_cache, input_buffers)
|
||||
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, cast, Generator
|
||||
import time, pprint
|
||||
from dataclasses import dataclass, replace
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
|
||||
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU
|
||||
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
||||
@@ -120,8 +120,9 @@ class ExecItem:
|
||||
prg: Runner
|
||||
bufs: list[Optional[Buffer]]
|
||||
metadata: Optional[tuple[Metadata, ...]] = None
|
||||
fixedvars: dict[Variable, int] = field(default_factory=dict)
|
||||
def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
||||
var_vals = {} if _var_vals is None else _var_vals
|
||||
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
|
||||
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
||||
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
@@ -147,7 +148,8 @@ si_lowerer = PatternMatcher([
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
|
||||
])
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars)
|
||||
|
||||
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
|
||||
while len(schedule):
|
||||
@@ -177,7 +179,7 @@ def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, i
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
|
||||
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque, defaultdict
|
||||
from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
|
||||
from tinygrad.device import Buffer
|
||||
@@ -11,6 +11,7 @@ class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[Variable, int] = field(default_factory=dict)
|
||||
|
||||
# **** unbind Variables
|
||||
|
||||
|
||||
Reference in New Issue
Block a user