diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index a1f309b2e2..8e7eb7746f 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase): loss.backward() optimizer.step() - helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63) + helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 104) @unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow") def test_train_cifar(self): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6bcdf13a62..f53f15626a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -212,7 +212,7 @@ class UOpMetaClass(type): return created # some uops map to other stuff -buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers +buffers:Dict[UOp, weakref.ReferenceType[Buffer]] = {} # this maps BUFFER uops to their device Buffers realized:weakref.WeakKeyDictionary[UOp, UOp] = weakref.WeakKeyDictionary() # this maps realized ops to a BUFFER uop forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet() @@ -223,9 +223,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): dtype:DType = dtypes.void src:Tuple[UOp, ...] = tuple() arg:Any = None - def __del__(self): - if self.op is Ops.BUFFER: self.buffer.ref(-1) - del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)] + def __del__(self): del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)] def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg) def replace(self, **kwargs) -> UOp: new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), kwargs.pop("arg", self.arg)) @@ -440,14 +438,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def buffer(self) -> Buffer: if self.base.realized is not None: return self.base.realized - if (ret:=buffers.get(self)) is not None: return ret + if (wret:=buffers.get(self)) is not None and (ret:=wret()) is not None: return ret if self.op is Ops.VIEW: assert unwrap(self.st).contiguous, "VIEW only works here if it's contiguous" return self.src[0].buffer assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" from tinygrad.device import Buffer - buffers[self] = ret = Buffer(*self.arg[1]) - return ret + buffers[self] = weakref.ref(created:=Buffer(*self.arg[1])) + return created @property def realized(self) -> Optional[Buffer]: return real_buf_uop.buffer if (real_buf_uop:=realized.get(self)) is not None else None @property