mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 17:40:13 +08:00
fix sync of offset buffers in graphs (#4850)
* correctly sync offset buffers * test * style * run less * just use base
This commit is contained in:
@@ -7,6 +7,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
def _simple_test(add, extract=lambda x: x, N=10):
|
||||
for _ in range(5):
|
||||
@@ -304,6 +305,21 @@ class TestJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU/CUDA/METAL in CI, fine to run on AMD/NV")
|
||||
def test_jitted_view(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
||||
def f(a):
|
||||
x1 = a.sum(axis=(1,))
|
||||
x = (x1 + 5).bitcast(dtypes.int32)
|
||||
y = x.to(d1)
|
||||
return y.realize()
|
||||
|
||||
jf = TinyJit(f)
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 1000, device=d0).realize()
|
||||
xc = jf(a)
|
||||
np.testing.assert_allclose((a.numpy().sum(axis=(1,)) + 5).view(np.int32), xc.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skip("Pending multioutput implementation #3607")
|
||||
class TestMultioutputJit(unittest.TestCase):
|
||||
|
||||
@@ -97,12 +97,12 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
|
||||
wait_nodes = []
|
||||
|
||||
for rawbuf in read + write:
|
||||
if id(rawbuf._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf._buf)])
|
||||
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
|
||||
for rawbuf in write:
|
||||
if id(rawbuf._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf._buf)))
|
||||
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
|
||||
|
||||
for rawbuf in read: self.r_dependency_map[id(rawbuf._buf)].append(new_dependency)
|
||||
for rawbuf in write: self.w_dependency_map[id(rawbuf._buf)] = new_dependency
|
||||
for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
|
||||
for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
|
||||
Reference in New Issue
Block a user