Merge remote-tracking branch 'origin/no_uop_mutability' into test_rewrite_map

This commit is contained in:
qazal
2024-12-30 22:42:09 +02:00
4 changed files with 35 additions and 10 deletions

View File

@@ -101,6 +101,10 @@ print(sched[-1].ast)
# run that schedule
run_schedule(sched)
# NOTE: UOps are no longer mutable, you have to fetch this from the becomes_map
from tinygrad.ops import becomes_map
out = becomes_map[out]
# check the data out
assert out.realized is not None and out.realized.as_buffer().cast('I')[0] == 5

View File

@@ -8,8 +8,6 @@ def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pa
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.lazydata, pat)
class TestTensorMutates(unittest.TestCase):
# this fails because uops are mutating
@unittest.expectedFailure
def test_mutate_add(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])

View File

@@ -231,6 +231,8 @@ buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # t
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet()
becomes_map: weakref.WeakKeyDictionary[UOp, UOp] = weakref.WeakKeyDictionary()
# NOTE: this should be frozen, but frozen is slower
@dataclass(eq=False, slots=True)
class UOp(MathTrait, metaclass=UOpMetaClass):
@@ -471,12 +473,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def forced_realize(self): return self in forced_realize
# *** danger zone ***
# *** less danger zone ***
# CAUTION: MUTABILITY!
def become(self, u:UOp):
del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
self.op, self.dtype, self.src, self.arg = u.op, u.dtype, u.src, u.arg
def become(self, u:UOp): becomes_map[self] = u
# *** uop movement ops ***
@@ -1205,7 +1204,9 @@ symbolic_simple = PatternMatcher([
# NOTE: this can be wrong for loaded NaN
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# ** constant folding **
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))), lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False))),
# TODO: add const folding for Ops.THREEFRY
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),

View File

@@ -1,6 +1,6 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
@@ -8,12 +8,17 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.multi import MultiLazyBuffer
from tinygrad.gradient import compute_gradient
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element, becomes_map
from tinygrad.device import Device, Buffer, BufferSpec
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
# *** all in scope Tensors are here. this is the only way to get children ***
# TODO: different "universes" for disconnected Tensors
all_tensors: weakref.WeakSet[Tensor] = weakref.WeakSet()
# **** start with two base classes, Tensor and Function ****
class Function:
@@ -121,6 +126,11 @@ class Tensor(SimpleMathTrait):
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
all_tensors.add(instance)
return instance
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
@@ -217,6 +227,18 @@ class Tensor(SimpleMathTrait):
NOTE: A Tensor can only be scheduled once.
"""
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
# TODO: becomes_map should be returned from create_schedule_with_vars
# NOTE: this is potentially a lot of Tensors. see above about the universes
fixed_tensors: list[Tensor] = list(all_tensors)
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
new_sink = sink.substitute(becomes_map)
becomes_map.clear()
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
if s is ns: continue
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
else: t.lazydata = ns
return memory_planner(schedule), var_vals
def schedule(self, *lst:Tensor) -> list[ScheduleItem]: