From dedbd970aaf6770edf6060cefec7314584ce4036 Mon Sep 17 00:00:00 2001 From: Rayan Hatout Date: Mon, 26 Jun 2023 21:55:42 +0100 Subject: [PATCH] Optimizations in lazy.py (#987) * optimizations in lazy.py * make mypy happy with stubs and fix the graph import hack * merge conflict in helpers.py --- test/test_custom_function.py | 3 +- test/test_tensor.py | 3 +- tinygrad/codegen/linearizer.py | 33 ++- tinygrad/graph.py | 12 +- tinygrad/helpers.py | 63 +++++- tinygrad/lazy.py | 367 ++++++++++++++++++--------------- tinygrad/mlops.py | 55 +++-- tinygrad/ops.py | 91 +++++--- tinygrad/runtime/lib.py | 4 + 9 files changed, 387 insertions(+), 244 deletions(-) diff --git a/test/test_custom_function.py b/test/test_custom_function.py index b3cc69c7bf..2583ff79c8 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -10,6 +10,7 @@ from tinygrad.helpers import prod, dtypes # `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers from tinygrad.lazy import LazyBuffer, create_lazybuffer, Device from tinygrad.ops import ASTRunner +from tinygrad.shape.shapetracker import ShapeTracker # we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer): @@ -39,7 +40,7 @@ class ATan2(Function): assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch" self.a, self.b = a, b ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device]) - return create_lazybuffer(a.device, a.shape, LoadOps, ast, max(a.dtype, b.dtype)) + return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype)) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b)) return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ diff --git a/test/test_tensor.py b/test/test_tensor.py index 7e9e997348..5ab3186b65 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -186,8 +186,7 @@ class TestTinygrad(unittest.TestCase): # assert Tensor.randn(1,0,2,5) == 0 # TODO: fix empty tensors def test_element_size(self): - for f in dataclasses.fields(dtypes): - dtype = f.default + for _, dtype in dtypes.fields().items(): assert dtype.itemsize == Tensor.randn(3, dtype=dtype).element_size(), f"Tensor.element_size() not matching Tensor.dtype.itemsize for {dtype}" if __name__ == '__main__': diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index fbec0029c3..0283cecc40 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -4,13 +4,12 @@ from collections import defaultdict from enum import Enum, auto from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same -from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps +from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps from tinygrad.lazy import LazyBuffer from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape from tinygrad.shape.symbolic import Variable -# bottom ones are asm only class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); \ SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702 @@ -94,27 +93,27 @@ class Linearizer: self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast # get the output buffers - self.bufs = [output_buffer] + dedup(get_buffers(ast)) + self.bufs = [output_buffer] + dedup(ast.buffers) # key for lookup in cache (can change, str might not be right) # bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels. # mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?) - self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}" + self.key = (ast.map_buffers({x:i for i,x in enumerate(self.bufs)}).key, tuple([x.key for x in self.bufs])) def process(self) -> None: if hasattr(self, "sts"): return # already processed # fetch lazyop info - self.info: FlopCounter = get_lazyop_info(self.ast) + self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast)) self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None) # there's only allowed to be one reduceop - reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps] + reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps] assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast" self.reduceop = reduceops[0] if reduceops else None # get earlybufs, before the one reduce op - self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else [] + self.earlybufs = dedup(self.reduceop.buffers) if self.reduceop else [] # create new shapetrackers inside this kernel, we will permute them self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs] @@ -178,7 +177,7 @@ class Linearizer: for k,out_tokens in self._group_float4(i, load_offset).items(): idxs = [x[2]-out_tokens[0][2] for x in out_tokens] valids_okay = all_same([x[3] for x in out_tokens]) or (all_same([x[3]//4 for x in out_tokens]) and (out_tokens[0][3]//4)*4 == out_tokens[0][3]) - if any(idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay: + if any([idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))]) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay: # idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float for x in out_tokens: load_offset_new[x[1]] = x else: @@ -306,13 +305,13 @@ class Linearizer: loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs) # there's no AST here (and there's no shape for the reduce LazyOp) - self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) + self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True) # type: ignore # end the late reduce loop self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce")) # load latebufs - loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)}) + loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer}) # run late AST val = self.ast_parse(self.ast, acc, loaded_buffers, ssa) @@ -334,17 +333,17 @@ class Linearizer: return out def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]: - if not isinstance(x, LazyOp): return loaded_buffers[x] + if x.__class__ is not LazyOp: return loaded_buffers[x] if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op if x.op in ReduceOps and not do_reduce: return acc # MULACC fusion. TODO: this is copied from Interpreted - if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == BinaryOps.MUL: + if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == BinaryOps.MUL: x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg) - if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == UnaryOps.CAST and isinstance(x.src[0].src[0], LazyOp) and x.src[0].src[0].op == BinaryOps.MUL: + if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL: x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg) values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src] # TODO: fold float4 into a single uop when possible. - if isinstance(x.op, (ReduceOps, FusedOps)): + if x.op.__class__ in {ReduceOps, FusedOps}: ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)] else: ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)] @@ -431,7 +430,7 @@ class Linearizer: # remove places where the shape is all ones # TODO: this should be factored in to multi shape stride if self.shape_len == 0: return - all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)] + all_ones = [all([st.shape[i]==1 for st in self.sts]) for i in range(self.shape_len)] # keep at least 1 one if all(all_ones): all_ones[-1] = False self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) @@ -456,14 +455,14 @@ class Linearizer: else: rets[j].append((shapes[j][i], strides[j][i])) # do the reshapes - for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x)) + for i,x in enumerate(rets): self.sts[i].reshape(tuple([y[0] for y in x])) # ******************** GPU simplifiers ******************** def required_optimizations(self, early_only=False): for buf_index,buf in enumerate(self.bufs): unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0] - if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType): + if (not early_only or buf in self.earlybufs) and self.bufs[buf_index].dtype.__class__ is ImageDType: assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: self.shift_to(unit_stride_axes_mul_4[0], 4) diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 91f670bbda..37d24e633e 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -5,13 +5,11 @@ except ImportError: nx = None # graph won't work from collections import defaultdict from typing import Dict, List, Optional -from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp, get_buffers, get_lazyops -from tinygrad.lazy import LazyBuffer -from tinygrad.helpers import getenv, DEBUG, GlobalCounters +from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps, Op, OpType, LazyOp +from tinygrad.tensor import LazyBuffer +from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters from tinygrad.runtime.lib import RawConst -GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net") - # **** debugging and graphing **** G = nx.DiGraph() if nx is not None else None @@ -52,8 +50,8 @@ def str_dtype(dtyp): def log_op(ret: LazyBuffer, ast: LazyOp, show_graph: Optional[bool] = None, phantom=False): if show_graph is None: show_graph = bool(GRAPH) if not DEBUG and not show_graph: return - op: List[Op] = [x.op for x in get_lazyops(ast)] - inp: List[LazyBuffer] = [x for x in get_buffers(ast) if not isinstance(x.realized, RawConst) or GRAPH > 1] + op: List[Op] = [x.op for x in ast.get_lazyops()] + inp: List[LazyBuffer] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1] oporder = [LoadOps, FusedOps, ReduceOps, BinaryOps, UnaryOps, MovementOps] optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0]) cnts[optype] += 1 diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 409ca72f4c..7fca7a83dc 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,18 +1,23 @@ from __future__ import annotations -import platform -from dataclasses import dataclass, asdict -import os, math, functools, time, re +import os, functools, platform, time, re +from weakref import KeyedRef, ref +from _weakref import _remove_dead_weakref # type: ignore import numpy as np -from typing import Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any +from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any +from math import prod # noqa: F401 # pylint:disable=unused-import + ShapeType = Tuple[int, ...] # NOTE: helpers is not allowed to import from anything else in tinygrad OSX = platform.system() == "Darwin" def dedup(x): return list(dict.fromkeys(x)) # retains list order -def prod(x:Union[List[int], Tuple[int, ...]]) -> int: return math.prod(x) -def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], (tuple, list)) else tuple(x) +def argfix(*x): + if x[0].__class__ in {tuple, list}: + try: return tuple(x[0]) + except IndexError: return tuple() + return tuple(x) def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python -def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True +def all_same(items): return all([x == items[0] for x in items]) if len(items) > 1 else True def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s)) def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)] @@ -43,6 +48,7 @@ class ContextVar: def value(self): return ContextVar.ctx_stack[-1][self.key] if self.key in ContextVar.ctx_stack[-1] else self.initial_value DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0) +GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net") class Timing(object): def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled @@ -60,6 +66,8 @@ class DType(NamedTuple): np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project sz: int = 1 def __repr__(self): return f"dtypes.{self.name}" + @property + def key(self): return (self.name) # dependent typing? class ImageDType(DType): @@ -70,7 +78,6 @@ class ImageDType(DType): super().__init__() def __repr__(self): return f"dtypes.{self.name}({self.shape})" -@dataclass class dtypes: @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64) @@ -79,7 +86,9 @@ class dtypes: @staticmethod def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64) @staticmethod - def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name] + def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name] + @staticmethod + def fields() -> Dict[str, DType]: return DTYPES_DICT bool: Final[DType] = DType(0, 1, "bool", bool) float16: Final[DType] = DType(0, 2, "half", np.float16) half = float16 @@ -97,6 +106,9 @@ class dtypes: _half4: Final[DType] = DType(0, 2*4, "half4", None, 4) _float4: Final[DType] = DType(4, 4*4, "float4", None, 4) +# HACK: staticmethods are not callable in 3.8 so we have to compare the class +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod} + class GlobalCounters: global_ops: ClassVar[int] = 0 global_mem: ClassVar[int] = 0 @@ -106,3 +118,36 @@ class GlobalCounters: cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None @staticmethod def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None + +# Stripped down version of a WeakSet +class LightWeakSet: + __slots__ = 'data', '_remove', '__weakref__' + def __init__(self): + self.data = set() + def _remove(item, selfref=ref(self)): + self = selfref() + if self: self.data.discard(item) + self._remove = _remove + + def __len__(self): return len(self.data) + def add(self, item): self.data.add(ref(item, self._remove)) + def discard(self, item): self.data.discard(ref(item)) + +# Stripped down version of a WeakValueDictionary +class LightWeakValueDictionary: + __slots__ = 'data', '_remove', '__weakref__' + def __init__(self): + def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref): + self = selfref() + if self: _atomic_removal(self.data, wr.key) + self._remove = remove + self.data = {} + + def __getitem__(self, key): + o = self.data[key]() + if o is None: raise KeyError(key) + else: return o + + def __setitem__(self, key, value): self.data[key] = KeyedRef(value, self._remove, key) + + def __contains__(self, key): return key in self.data \ No newline at end of file diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 359ed7ec06..86de3bca77 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import Optional, Tuple, Union, List, Dict, Any, cast +import operator +from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast import sys, importlib, inspect, functools, pathlib +from weakref import ref + import numpy as np -from weakref import WeakValueDictionary, ref, WeakSet -from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, DEBUG -from tinygrad.shape.shapetracker import ShapeTracker, get_contraction -from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers -from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped +from tinygrad.helpers import GRAPH, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary from tinygrad.runtime.ops_cpu import RawNumpyBuffer from tinygrad.runtime.ops_disk import RawDiskBuffer +from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, get_contraction +from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, LoadOps, OpType, LazyOp +from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer # lazy can recurse a lot sys.setrecursionlimit(10000) @@ -25,23 +27,23 @@ PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3 def _ast_reduceops(self:LazyBuffer) -> LazyOp: # TODO: this can also corealize a binary op after the reduce, not just before src = self.op.src[0] - if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1: - src = src.op - return LazyOp(self.op.op, (src,), self.op.arg) + if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype == BinaryOps and len(src.children) <= 1: + src = src.op # type: ignore + return LazyOp(self.op.op, (src,), self.op.arg, src.get_buffers()) # this supports late merging an upstream Reduce op and even an Elementwise op above that def _ast_binaryops(self:LazyBuffer) -> LazyOp: - real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)} + real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in self.op.buffers} # NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how # TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd - psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1] + psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1] intermediate_shape: Tuple[int, ...] = self.shape - if len(psrcs) >= 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE: + if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and len(psrcs) >= 1: psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop if psrc[1].optype == ReduceOps: top = _ast_reduceops(psrc[1]) real_srcs[psrc[0]] = top - real_srcs.update({x:x for x in get_buffers(top)}) # the reduce op buffers are not modified + real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified # if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs if psrc[0].shape != psrc[1].shape: @@ -51,124 +53,91 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp: # reshape all the late ops into the output shape # NOTE: these RESHAPEs will return self if they don't change the shape for x in real_srcs.keys(): - if real_srcs[x] is None: real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape) - ast = map_buffers(real_srcs, self.op) - return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast + if not real_srcs[x]: real_srcs[x] = x.reshape_op(intermediate_shape) + ast = self.op.map_buffers(real_srcs) + return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape, ast.buffers) if intermediate_shape != self.shape else ast # **** lazy operations **** -def get_weakop(op:LazyOp) -> LazyOp: return LazyOp(op.op, tuple([get_weakop(x) if x.__class__ is LazyOp else ref(x) for x in op.src]), op.arg) -def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.op.src[0]) if getattr(root, 'op', None) and len(root.op.src) == 1 else root -def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(root.op.src[0], allow_contiguous) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root -def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(x.op.src[0]) if x.realized is None and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) -def replace_with_movement_ops(y:Union[LazyOp, LazyBuffer], ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> LazyBuffer: - if isinstance(y, LazyBuffer): - for op, arg in ops: y = y.movement_op(op, arg) - return y - assert y.op in BinaryOps or y.op in UnaryOps - return elementwise_op(y.op, *[replace_with_movement_ops(z, ops) for z in y.src], arg=y.arg) # type: ignore +def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 else root +def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root +def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) -lazycache: WeakValueDictionary[Tuple[str, DType, OpType, LazyOp], LazyBuffer] = WeakValueDictionary() -def create_lazybuffer(device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp, dtype:DType): - st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape)) +lazycache: LightWeakValueDictionary = LightWeakValueDictionary() +def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType): # fromcpu aren't cached - if optype == LoadOps and op.op in [LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST]: return LazyBuffer(device, st, optype, op, dtype) + if optype == LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}: return LazyBuffer(device, st, optype, op, dtype) #print("create_lazybuffer", device, shape, optype, op, dtype) # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker # get_weakop makes all the LazyBuffers in the op have a weakref - wop = (device, dtype, optype, get_weakop(op)) + wop = (device, dtype, optype, ref(op)) - if wop not in lazycache: lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype) - else: ret = lazycache[wop] + if wop in lazycache: return lazycache[wop] + lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype) return ret class LazyBuffer: + __slots__ = 'st', 'device', 'shape', 'optype', 'dtype', 'op', 'realized', 'output_buffer', 'children', 'node_id', '__weakref__' __deletable__ = ('op',) - def __init__(self, device:str, st:ShapeTracker, optype:OpType, src:Union[LazyOp, RawBuffer], dtype:DType): + def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None): self.st = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype - self.realized: Optional[RawBuffer] = src if isinstance(src, RawBuffer) else None + self.realized: Optional[RawBuffer] = src self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized # TODO: does children have to be a ref count instead of a set? can a Buffer be a double child? - self.children: WeakSet[LazyBuffer] = WeakSet() + self.children: LightWeakSet = LightWeakSet() # NOTE: op should be read only after construction of LazyBuffer - if isinstance(src, LazyOp): - self.op: LazyOp = src - for x in get_buffers(self.op): x.children.add(self) + if op: + self.op: LazyOp = op + for x in op.buffers: x.children.add(self) if not LAZY: self.realize() # log phantom ops to the graph - from tinygrad.graph import log_op, GRAPH - if GRAPH >= 3: log_op(self, self.op, phantom=True) + if GRAPH >= 3: + from tinygrad.graph import log_op + log_op(self, self.op, phantom=True) + + def __repr__(self): return f"" + @property + def key(self): + if self.realized: return (self.dtype.key, self.realized.key, self.st.key) + return (self.dtype.key, self.op.op, self.st.key) - def __repr__(self): return f"" def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {} def realize(self:LazyBuffer) -> LazyBuffer: - if self.realized is None: + if not self.realized: # get real ops first - if self.op.op == LoadOps.CONTIGUOUS: - realized = self.op.src[0].realize().realized - if self.op.src[0].st.contiguous and not isinstance(realized, RawConst) and realized.size == prod(self.shape): - # no need to run an AST, this is already contiguous - self.realized = realized - else: - # TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though - self.op = LazyOp(UnaryOps.NOOP, self.op.src) - elif self.op.op == LoadOps.CUSTOM: - # this needs to immediately realize - self.realized = self.op.arg(self, *[x.realize() for x in self.op.src]) - elif self.op.op == LoadOps.FROM: - rawbuf = self.op.src[0].realize() - # TODO: make this generic - if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[self.device].buffer, RawBufferMapped): - self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args()) - rawbuf.realized.readinto(cast(RawBufferMapped, self.realized)._buffer()) - else: - self.realized = Device[self.device].buffer.fromCPU(rawbuf.toCPU(), **self._device_extra_args()) - elif self.optype == LoadOps: - if DEBUG >= 4: print(f"{self.op.op} {self.shape} {self.dtype} {self.op.arg}") - if self.op.op == LoadOps.EMPTY: - self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args()) - elif self.op.op == LoadOps.RAND: - rng = np.random.default_rng(self.op.arg) - assert self.dtype.np is not None, "internal dtypes don't work with LoadOps.RAND" - self.realized = Device[self.device].buffer.fromCPU(rng.random(size=self.shape, dtype=self.dtype.np), **self._device_extra_args()) - elif self.op.op == LoadOps.CONST: - if hasattr(Device[self.device].codegen, 'supports_constant_folding'): - self.realized = RawConst(1, self.dtype, float(self.op.arg)) - else: - self.realized = Device[self.device].buffer.fromCPU(np.array(self.op.arg, dtype=self.dtype.np), **self._device_extra_args()) - # these can be late folded and change the op to go further back in the graph - elif self.optype == ReduceOps: self.op = _ast_reduceops(self) - elif self.optype == BinaryOps: self.op = _ast_binaryops(self) # ISSUE: this can include a reshape - + if self.optype in REALIZE_DISPATCHER: + self.op = REALIZE_DISPATCHER[self.optype](self) + elif self.op.op in REALIZE_DISPATCHER: + REALIZE_DISPATCHER[self.op.op](self) # run the ast if we still have to, and log the op - if self.realized is None: - for x in get_buffers(self.op): x.realize() + if not self.realized: + for x in self.op.buffers: x.realize() # HACK: image shape can be wrong, hot cast it back to a normal float - if self.optype != MovementOps and isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): + if self.optype != MovementOps and self.dtype.__class__ is ImageDType and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])): if self.op.op == MovementOps.RESHAPE: # put CAST before the final RESHAPE self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg) else: self.op = LazyOp(UnaryOps.CAST, (self.op,), dtypes.float32) self.dtype = dtypes.float32 - self.realized = Device[self.device].exec_ast(self.op, output=self, **self._device_extra_args()) - assert isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}" + assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}" # HACK: allow hot casting of images assert self.realized.dtype == self.dtype or self.dtype.name.startswith("image"), f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}" self.dtype = self.realized.dtype # log to the graph - from tinygrad.graph import log_op, GRAPH - if not isinstance(self.realized, RawConst) or GRAPH >= 2: log_op(self, self.op) + if self.realized.__class__ is not RawConst or GRAPH >= 2: + from tinygrad.graph import log_op + log_op(self, self.op) # no need to keep the op after realization del self.op @@ -176,17 +145,17 @@ class LazyBuffer: @staticmethod def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: - return create_lazybuffer(device, shape, LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype) + return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype) @staticmethod def fromCPU(x: np.ndarray) -> LazyBuffer: - return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, RawNumpyBuffer.fromCPU(x), dtypes.from_np(x.dtype)) + return LazyBuffer("CPU", ShapeTracker(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None, ()), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x)) # create a constant with the shape and dtype of self def const_like(self, val) -> LazyBuffer: # NOTE: dtypes.from_np(self.dtype.np) to deal with image types return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val) \ - .movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape) + .reshape_op((1,)*len(self.shape)).expand_op(self.shape) # NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this? def toCPU(self): @@ -198,89 +167,99 @@ class LazyBuffer: def unary_op(self:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, self) def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y) def contiguous(self:LazyBuffer) -> LazyBuffer: - if self.realized is None and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one - return create_lazybuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype) - - def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer: - if self.shape == tuple(new_shape): return self - srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) - return create_lazybuffer(self.device, new_shape, ReduceOps, LazyOp(op, tuple(srcs), new_shape), self.dtype) - - # shrink -> stride -> permute -> reshape -> pad -> expand - def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer: - # very instant nop - if op == MovementOps.RESHAPE and self.shape == arg: return self - - # TODO: look into why that copy is needed - local_st = ShapeTracker(self.shape).movement_op(op, arg) - - # instant nops - if local_st.contiguous and self.shape == local_st.shape: return self - - # two ops in a row is one op. merge them if unresolved - if self.realized is None and self.op.op == op: - # TODO: why is deleting self from children needed? shouldn't GC do it? - self.op.src[0].children.discard(self) - if op in [MovementOps.RESHAPE, MovementOps.EXPAND]: return self.op.src[0].movement_op(op, arg) - if op == MovementOps.SHRINK: return self.op.src[0].movement_op(op, tuple((b1+b2, b1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg))) - if op == MovementOps.PERMUTE: return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg)) - if op == MovementOps.PAD: return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg))) - if op == MovementOps.STRIDE: return self.op.src[0].movement_op(op, tuple(i*j for i,j in zip(arg, self.op.arg))) - - # push permutes before reduce ops - if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.optype == ReduceOps: - # reduceops have one buffer input, permute it - narg = tuple(self.op.arg[arg[i]] for i in range(len(arg))) - src, rop = self.op.src[0], self.op.op - src.children.discard(self) - del self # TODO: why doesn't this delete remove it from the children - return src.movement_op(op, arg).reduce_op(rop, narg) - - # some permutes are actually just reshapes - if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg)) - - # move permutes before expands (always, this is safe) - if op == MovementOps.PERMUTE and self.realized is None and self.op.op == MovementOps.EXPAND: - self.op.src[0].children.discard(self) - return self.op.src[0].movement_op(MovementOps.PERMUTE, arg).movement_op(MovementOps.EXPAND, tuple(self.op.arg[a] for a in arg)) - - # move permutes before reshapes if we can - if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer): - if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape): - self.op.src[0].children.discard(self) # this changes nothing? - return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(flatten(shape_idx_groups[i] for i in arg))) \ - .movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape) - - # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead. NOTE: UnaryOps is never an OpType - if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and (op in [MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE] or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0: # and op != MovementOps.EXPAND and (op != MovementOps.PAD or (SHUFFLE_PAD_OPS and all(x.op != BinaryOps.DIV for x in get_lazyops(self.op)))): - return replace_with_movement_ops(self.op, [(op, arg)]) - - # create the buffer - ret = create_lazybuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg), self.dtype) - - # if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match) - # NOTE: if ret is in the cache, it can already be realized - if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous: + if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one + return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None, (self,)), self.dtype) + + def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer: + if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0: + return self.op.replace_with_movement_ops([(op, arg)]) + ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg, (self,)), self.dtype) + if REMOVE_MOVEMENT_NOPS and not self.realized and not ret.realized and ret.st.contiguous: # MovementOps aren't stacked any more, they each have one parent, find the root root = get_movementroot(self) if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape): - return root.movement_op(MovementOps.RESHAPE, ret.st.shape) - + return root.reshape_op(ret.st.shape) return ret + + def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer: + if self.shape == tuple(new_shape): return self + srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) + return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype) + def reshape_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + if self.shape == arg: return self + if not self.realized and self.op.op == MovementOps.RESHAPE: + self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? + return self.op.src[0].reshape_op(arg) + return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg) + + def pad_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: + if all([b == 0 and e == 0 for b,e in arg]): return self + if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad_op(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)])) + return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg) + + def expand_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + if self.shape == arg: return self + if not self.realized and self.op.op == MovementOps.EXPAND: + return self.op.src[0].expand_op(arg) + return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg) + + def permute_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + if arg == tuple(range(len(self.shape))): return self + if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute_op(tuple([self.op.arg[i] for i in arg])) + if not self.realized: + if PUSH_PERMUTES and self.optype == ReduceOps: + # reduceops have one buffer input, permute it + narg = tuple([self.op.arg[arg[i]] for i in range(len(arg))]) + src, rop = self.op.src[0], self.op.op + src.children.discard(self) + del self # TODO: why doesn't this delete remove it from the children + return src.permute_op(arg).reduce_op(cast(ReduceOps, rop), narg) + + # move permutes before expands (always, this is safe) + if self.op.op == MovementOps.EXPAND: + return self.op.src[0].permute_op(arg).expand_op(tuple([self.op.arg[a] for a in arg])) + + # move permutes before reshapes if we can + if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer: + if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape): + self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? + return self.op.src[0].permute_op(tuple(flatten(shape_idx_groups[i] for i in arg))) \ + .reshape_op(ShapeTracker(self.st).permute(arg).shape) + return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg) + + def shrink_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: + if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self + if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink_op(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)])) + return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg) + + def stride_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + local_st = ShapeTracker(self.shape).stride(arg) + if self.shape == local_st.shape and local_st.contiguous: return self + if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride_op(tuple(map(operator.mul, arg, self.op.arg))) + return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg) + + def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self) + def get_buffers(self) -> Tuple[LazyBuffer, ...]: return (self,) + def get_lazyops(self) -> List[Any]: return [] + def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer: + y = self + for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg) + return y + def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: new_srcs = [] for x in srcs: - mops: List[Tuple[MovementOps, Tuple[Any, ...]]] = [] + mops: List[Tuple[MovementOps, Any]] = [] bx = x # backwalk all the movement ops. don't push PAD or EXPAND - while bx.realized is None and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (bx.op.op != MovementOps.PAD or SHUFFLE_PAD_OPS) and len(bx.children) <= 1: + while not bx.realized and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op != MovementOps.PAD) and len(bx.children) <= 1: assert isinstance(bx.op.op, MovementOps) mops.append((bx.op.op, bx.op.arg)) - bx = bx.op.src[0] + bx = cast(LazyBuffer, bx.op.src[0]) # NOTE: can't push pads with a div - if bx.realized is None and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in get_lazyops(bx.op))): - new_srcs.append(replace_with_movement_ops(bx.op, mops[::-1])) + if not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all([x[0] != MovementOps.PAD for x in mops]) or all([x.op != BinaryOps.DIV for x in bx.op.get_lazyops()])): + new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1])) else: new_srcs.append(x) return tuple(new_srcs) @@ -290,29 +269,30 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs) # get outputs now - out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg) + out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(DType, arg) # push all contiguous to the end of BinaryOps. kernels 198 -> 196 - if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs): - new_srcs = [] + if PUSH_CONTIGUOUS and any([not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs]): + new_srcs: List[LazyBuffer] = [] for x in srcs: - if x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1: + if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1: x.op.src[0].children.discard(x) - new_srcs.append(x.op.src[0]) + new_srcs.append(cast(LazyBuffer, x.op.src[0])) else: new_srcs.append(x) return elementwise_op(op, *new_srcs, arg=arg).contiguous() if MERGE_ELEMENTWISE_OPS: # remove the buffers from any (childless) BinaryOps that feed into this - srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore + srcs = tuple([x.op if x.optype == BinaryOps and len(x.children) == 0 and not x.realized else x for x in srcs]) # type: ignore - return create_lazybuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs, arg), out_dtype) + return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype) class _Device: def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device() + @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: return self._get_device(x.split(":")[0].upper()) @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none @@ -324,3 +304,58 @@ class _Device: except Exception: pass return "CPU" Device = _Device() + +def _realize_contiguous(buffer: LazyBuffer) -> None: + realized = buffer.op.src[0].realize().realized + if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and cast(RawBuffer, realized).size == prod(buffer.shape): + # no need to run an AST, this is already contiguous + buffer.realized = realized + else: + # TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though + buffer.op = LazyOp(UnaryOps.NOOP, buffer.op.src) + +def _realize_custom(buffer: LazyBuffer) -> None: + # this needs to immediately realize + buffer.realized = buffer.op.arg(buffer, *[x.realize() for x in buffer.op.src]) + +def _realize_from(buffer: LazyBuffer) -> None: + rawbuf = buffer.op.src[0].realize() + # TODO: make this generic + if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped): + buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) + rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer()) + else: + buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args()) + +def _realize_empty(buffer: LazyBuffer) -> None: + buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) + +def _realize_rand(buffer: LazyBuffer) -> None: + rng = np.random.default_rng(buffer.op.arg) + buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=buffer.dtype.np), **buffer._device_extra_args()) # type: ignore + +def _realize_const(buffer: LazyBuffer) -> None: + if hasattr(Device[buffer.device].codegen, 'supports_constant_folding'): + buffer.realized = RawConst(1, buffer.dtype, float(buffer.op.arg)) + else: + buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args()) + +REALIZE_DISPATCHER: Dict[Any, Callable] = { + LoadOps.CONTIGUOUS: _realize_contiguous, + LoadOps.CUSTOM: _realize_custom, + LoadOps.FROM: _realize_from, + LoadOps.EMPTY: _realize_empty, + LoadOps.RAND: _realize_rand, + LoadOps.CONST: _realize_const, + ReduceOps: _ast_reduceops, + BinaryOps: _ast_binaryops, +} + +MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = { + MovementOps.RESHAPE: LazyBuffer.reshape_op, + MovementOps.EXPAND: LazyBuffer.expand_op, + MovementOps.SHRINK: LazyBuffer.shrink_op, + MovementOps.PERMUTE: LazyBuffer.permute_op, + MovementOps.PAD: LazyBuffer.pad_op, + MovementOps.STRIDE: LazyBuffer.stride_op, +} \ No newline at end of file diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index bdb91f3cbc..fd3a4fa425 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -1,6 +1,6 @@ from typing import Tuple, Optional from tinygrad.helpers import argsort, ShapeType -from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps +from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps from tinygrad.tensor import Function from tinygrad.lazy import LazyBuffer import math @@ -10,6 +10,7 @@ class Contiguous(Function): def backward(self, grad_output): return grad_output class Cast(Function): + __slots__ = "input_dtype" def forward(self, x, dtype): self.input_dtype = x.dtype return x.cast(dtype) @@ -19,6 +20,7 @@ class Cast(Function): # ************* unary ops ************* class Sin(Function): + __slots__ = "x" def forward(self, x: LazyBuffer) -> LazyBuffer: self.x = x return x.unary_op(UnaryOps.SIN) @@ -26,6 +28,7 @@ class Sin(Function): return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad) # NOTE: maximum(x, 0) behaves differently where x=0 class Relu(Function): + __slots__ = "ret" def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0)) return self.ret @@ -35,6 +38,7 @@ class Relu(Function): return mask.binary_op(BinaryOps.MUL, grad_output) class Log(Function): + __slots__ = "x" def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)/math.log(math.e))) @@ -43,6 +47,7 @@ class Log(Function): return grad_output.binary_op(BinaryOps.DIV, self.x) class Exp(Function): + __slots__ = "ret" def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.binary_op(BinaryOps.MUL, x.const_like(math.log(math.e)/math.log(2))).unary_op(UnaryOps.EXP2) return self.ret @@ -53,27 +58,29 @@ class Exp(Function): # ************* reduce ops ************* class Sum(Function): + __slots__ = "input_shape" def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer: self.input_shape = x.shape return x.reduce_op(ReduceOps.SUM, new_shape) def backward(self, grad_output): - return grad_output.movement_op(MovementOps.EXPAND, self.input_shape) + return grad_output.expand_op(self.input_shape) class Max(Function): + __slots__ = "x", "ret" def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer: self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.movement_op(MovementOps.EXPAND, self.x.shape)) + max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand_op(self.x.shape)) # sum of locations, averaged - div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).movement_op(MovementOps.EXPAND, self.x.shape) + div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand_op(self.x.shape) max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div) - grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, self.x.shape) + grad_output_expanded = grad_output.expand_op(self.x.shape) return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded) # ************* binary ops ************* @@ -83,6 +90,7 @@ class Equal(Function): return x.binary_op(BinaryOps.CMPEQ, y) class Maximum(Function): + __slots__ = "x", "y", "ret" def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: self.x, self.y = x, y self.ret = x.binary_op(BinaryOps.MAX, y) @@ -113,6 +121,7 @@ class Sub(Function): grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None class Mul(Function): + __slots__ = 'x', 'y' def forward(self, x:LazyBuffer, y:LazyBuffer): self.x, self.y = x, y return x.binary_op(BinaryOps.MUL, y) @@ -122,6 +131,7 @@ class Mul(Function): self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None class Pow(Function): + __slots__ = 'x', 'y', 'ret' def forward(self, x:LazyBuffer, y:LazyBuffer): self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y) return self.ret @@ -131,6 +141,7 @@ class Pow(Function): grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2)/math.log(math.e))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None class Div(Function): + __slots__ = 'x', 'y' def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: self.x, self.y = x, y return x.binary_op(BinaryOps.DIV, y) @@ -143,49 +154,55 @@ class Div(Function): # NOTE: this is sum in reverse class Expand(Function): + __slots__ = 'input_shape' def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer: self.input_shape = x.shape - return x.movement_op(MovementOps.EXPAND, shape) + return x.expand_op(shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reduce_op(ReduceOps.SUM, self.input_shape) class Reshape(Function): + __slots__ = 'input_shape' def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer: self.input_shape = x.shape - return x.movement_op(MovementOps.RESHAPE, shape) + return x.reshape_op(shape) def backward(self, grad_output): - return grad_output.movement_op(MovementOps.RESHAPE, self.input_shape) + return grad_output.reshape_op(self.input_shape) class Permute(Function): + __slots__ = 'input_order' def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer: self.input_order = order - return x.movement_op(MovementOps.PERMUTE, order) + return x.permute_op(order) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.movement_op(MovementOps.PERMUTE, argsort(self.input_order)) + return grad_output.permute_op(argsort(self.input_order)) class Pad(Function): + __slots__ = 'narg' def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: - self.narg = tuple((p[0], s+p[0]) for s,p in zip(x.shape, arg)) - return x.movement_op(MovementOps.PAD, arg) + self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) + return x.pad_op(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.movement_op(MovementOps.SHRINK, self.narg) + return grad_output.shrink_op(self.narg) class Shrink(Function): + __slots__ = 'narg' def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: - self.narg = tuple((p[0], s-p[1]) for s,p in zip(x.shape, arg)) - return x.movement_op(MovementOps.SHRINK, arg) + self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) + return x.shrink_op(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.movement_op(MovementOps.PAD, self.narg) + return grad_output.pad_op(self.narg) class Flip(Function): + __slots__ = 'arg' def forward(self, x:LazyBuffer, axis:Tuple[int, ...]): - self.arg = tuple(-1 if i in axis else 1 for i in range(len(x.shape))) - return x.movement_op(MovementOps.STRIDE, self.arg) + self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))]) + return x.stride_op(self.arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.movement_op(MovementOps.STRIDE, self.arg) + return grad_output.stride_op(self.arg) \ No newline at end of file diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d64c471c9e..6a54e85109 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,10 +1,12 @@ from __future__ import annotations -import functools, operator, time +import functools, time from enum import Enum, auto -from typing import Union, Type, NamedTuple, Tuple, Any, List, Optional, Dict, Callable -from tinygrad.helpers import prod, DEBUG, getenv, GlobalCounters, DType, colored, ansilen +from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast +from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored from tinygrad.shape.shapetracker import MovementOps from tinygrad.runtime.lib import RawBuffer, RawConst +if TYPE_CHECKING: + from tinygrad.lazy import LazyBuffer # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly @@ -19,19 +21,62 @@ class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto( Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, FusedOps] OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[FusedOps]] -class LazyOp(NamedTuple): - op: Op - # Any == Union[LazyOp, LazyBuffer, DeviceBuffer] - src: Tuple[Any, ...] # type: ignore - arg: Any = None +class LazyOp: # TODO: add dest to support multiple outputs. on second thought, multiple outputs will have multiple LazyOps. + __slots__ = "op", "src", "arg", "buffers", "__weakref__" + op: Op + src: Tuple[Union[LazyOp, LazyBuffer], ...] + arg: Any + buffers: Tuple[LazyBuffer, ...] -# Any == Union[LazyBuffer, DeviceBuffer] -def get_buffers(op:LazyOp) -> List[Any]: return functools.reduce(operator.add, [get_buffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], []) -def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op]) -def map_buffers(real_srcs:Dict[Any, Any], x:Any) -> LazyOp: - if len(real_srcs) and x in real_srcs: return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x] - return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg) + def __init__(self, op: Op, src: Tuple[Union[LazyOp, LazyBuffer], ...], arg: Any = None, buffers: Optional[Tuple[LazyBuffer, ...]] = None): + self.op = op + self.src = src + self.arg = arg + if not buffers: + buffers = tuple() + for s in src: + try: buffers += s.get_buffers() + except AttributeError: pass + self.buffers = buffers + + def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})" + def __eq__(self, __value: object) -> bool: + if __value.__class__ is not LazyOp: return False + __value = cast(LazyOp, __value) + return self.op == __value.op and self.src == __value.src and self.arg == __value.arg + def __hash__(self) -> int: return hash((self.op, self.src, self.arg)) + @property + def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg)) + + # Any == Union[LazyBuffer, DeviceBuffer] + def map_buffers(self, real_srcs: Dict[Any, Any]): return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg) + + def get_buffers(self) -> Tuple[LazyBuffer, ...]: return self.buffers + def get_lazyops(self) -> List['LazyOp']: return [self] + [item for x in self.src for item in x.get_lazyops()] + + def replace_with_movement_ops(self: LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer': + from tinygrad.lazy import elementwise_op + assert self.op in BinaryOps or self.op in UnaryOps + return elementwise_op(self.op, *[z.replace_with_movement_ops(ops) for z in self.src], arg=self.arg) # type: ignore + + @property + def st(self): raise NotImplementedError + @property + def children(self): raise NotImplementedError + @property + def shape(self): raise NotImplementedError + @property + def realized(self): raise NotImplementedError + @property + def optype(self): raise NotImplementedError + def realize(self): raise NotImplementedError + def reshape_op(self, _): raise NotImplementedError + def pad_op(self, _): raise NotImplementedError + def expand_op(self, _): raise NotImplementedError + def permute_op(self, _): raise NotImplementedError + def shrink_op(self, _): raise NotImplementedError + def stride_op(self, _): raise NotImplementedError # **************** for Interpreted Buffers **************** @@ -46,12 +91,12 @@ class Interpreted: self.codegen = None def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs): - if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: - ast = LazyOp(FusedOps.MULACC, ast.src[0].src, ast.arg) + if FusedOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and ast.src[0].__class__ is LazyOp and ast.src[0].op == BinaryOps.MUL: + ast = LazyOp(FusedOps.MULACC, cast(LazyOp, ast.src[0]).src, ast.arg) created_context = context is None if context is None: context = dict() if not created_context and ast in context: return context[ast] - srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src] + srcs = [self.exec_ast(cast(LazyOp, x), context=context, **kwargs) if x.__class__ is LazyOp else self.from_lazybuffer(x) for x in ast.src] if DEBUG >= 3: st = time.perf_counter() ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else [])))) if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "") @@ -90,7 +135,7 @@ class ASTRunner: return self def exec(self, bufs) -> Optional[float]: - rawbufs = [x.realized for x in bufs if x.realized is not None and not isinstance(x.realized, RawConst)] + rawbufs = [x.realized for x in bufs if x.realized is not None and x.realized.__class__ is not RawConst] if GlobalCounters.cache is not None: GlobalCounters.cache.append((self, rawbufs)) return self(rawbufs) @@ -114,21 +159,21 @@ class Compiled: def exec_ast(self, ast:LazyOp, output, **kwargs): # all movementops do nothing in a Compiled buffer! - if ast.op in MovementOps and not isinstance(ast.src[0], LazyOp) and ast.src[0].realized is not None: return ast.src[0].realized + if ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized: return ast.src[0].realized # check if we can reuse the output buffer # if it's aliased, don't use it # NOTE: this is pretty wrong actually, who knows where else this buffer is used? output.realized = output.output_buffer - if output.realized is not None: - if isinstance(output.realized, RawConst): output.realized = None # can't assign to RawConst - for a in get_buffers(ast): + if output.realized: + if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst + for a in ast.buffers: if a.realized == output.realized and not a.st.contiguous: output.realized = None break # we don't have an output buffer, we have to create it - if output.realized is None: + if not output.realized: output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs) # compilation time diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index e986c5f338..543938badb 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -14,6 +14,8 @@ class RawBuffer: # pylint: disable=abstract-method def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz def __repr__(self): return f"buffer<{self.size}, {self.dtype}>" + @property + def key(self): return (self.size, self.dtype.key) # NOTE: this interface allows for 0 copy @classmethod @@ -50,3 +52,5 @@ class RawBufferCopyInOut(RawBufferCopyIn): class RawConst(RawBuffer): # pylint: disable=abstract-method def __repr__(self): return f"const<{self._buf}, {self.dtype}>" + @property + def key(self): return (str(self._buf), self.dtype.key)