mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Merge remote-tracking branch 'origin/no_uop_mutability' into test_rewrite_map
This commit is contained in:
@@ -101,6 +101,10 @@ print(sched[-1].ast)
|
||||
# run that schedule
|
||||
run_schedule(sched)
|
||||
|
||||
# NOTE: UOps are no longer mutable, you have to fetch this from the becomes_map
|
||||
from tinygrad.ops import becomes_map
|
||||
out = becomes_map[out]
|
||||
|
||||
# check the data out
|
||||
assert out.realized is not None and out.realized.as_buffer().cast('I')[0] == 5
|
||||
|
||||
|
||||
@@ -8,8 +8,6 @@ def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pa
|
||||
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.lazydata, pat)
|
||||
|
||||
class TestTensorMutates(unittest.TestCase):
|
||||
# this fails because uops are mutating
|
||||
@unittest.expectedFailure
|
||||
def test_mutate_add(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
|
||||
@@ -231,6 +231,8 @@ buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # t
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
|
||||
forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet()
|
||||
|
||||
becomes_map: weakref.WeakKeyDictionary[UOp, UOp] = weakref.WeakKeyDictionary()
|
||||
|
||||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@@ -471,12 +473,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def forced_realize(self): return self in forced_realize
|
||||
|
||||
# *** danger zone ***
|
||||
# *** less danger zone ***
|
||||
|
||||
# CAUTION: MUTABILITY!
|
||||
def become(self, u:UOp):
|
||||
del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
|
||||
self.op, self.dtype, self.src, self.arg = u.op, u.dtype, u.src, u.arg
|
||||
def become(self, u:UOp): becomes_map[self] = u
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
@@ -1205,7 +1204,9 @@ symbolic_simple = PatternMatcher([
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# ** constant folding **
|
||||
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))), lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False))),
|
||||
# TODO: add const folding for Ops.THREEFRY
|
||||
(UPat(GroupOp.ALU, name="a", src=UPat((Ops.VCONST, Ops.CONST))),
|
||||
lambda a: a.const_like(exec_alu(a.op, a.dtype, [x.arg for x in a.src], False)) if a.op is not Ops.THREEFRY else None),
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
@@ -8,12 +8,17 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element, becomes_map
|
||||
from tinygrad.device import Device, Buffer, BufferSpec
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
|
||||
# *** all in scope Tensors are here. this is the only way to get children ***
|
||||
# TODO: different "universes" for disconnected Tensors
|
||||
|
||||
all_tensors: weakref.WeakSet[Tensor] = weakref.WeakSet()
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Function:
|
||||
@@ -121,6 +126,11 @@ class Tensor(SimpleMathTrait):
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
instance = super().__new__(cls)
|
||||
all_tensors.add(instance)
|
||||
return instance
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
@@ -217,6 +227,18 @@ class Tensor(SimpleMathTrait):
|
||||
NOTE: A Tensor can only be scheduled once.
|
||||
"""
|
||||
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
||||
# TODO: becomes_map should be returned from create_schedule_with_vars
|
||||
|
||||
# NOTE: this is potentially a lot of Tensors. see above about the universes
|
||||
fixed_tensors: list[Tensor] = list(all_tensors)
|
||||
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
|
||||
new_sink = sink.substitute(becomes_map)
|
||||
becomes_map.clear()
|
||||
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
|
||||
else: t.lazydata = ns
|
||||
|
||||
return memory_planner(schedule), var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||
|
||||
Reference in New Issue
Block a user