mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
Move var_vals from ShapeTracker to LazyBuffer (#1819)
This commit is contained in:
@@ -76,8 +76,8 @@ class TransformerBlock:
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], start_pos_var, cache_k.shape[2], cache_k.shape[3])
|
||||
cache_v = cache_v.reshape(cache_v.shape[0], start_pos_var, cache_v.shape[2], cache_v.shape[3])
|
||||
# need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache
|
||||
cache_k.lazydata.st.var_vals[start_pos_var] = start_pos
|
||||
cache_v.lazydata.st.var_vals[start_pos_var] = start_pos
|
||||
cache_k.lazydata.var_vals[start_pos_var] = start_pos
|
||||
cache_v.lazydata.var_vals[start_pos_var] = start_pos
|
||||
|
||||
output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask, jit_ctx=jit_ctx)
|
||||
h = x + output
|
||||
@@ -113,7 +113,7 @@ class Transformer:
|
||||
if seqlen == 1 and start_pos > 0 and getenv("JIT"):
|
||||
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
|
||||
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var+seqlen)))
|
||||
pos.lazydata.st.var_vals[start_pos_var] = start_pos
|
||||
pos.lazydata.var_vals[start_pos_var] = start_pos
|
||||
h = self.embed_jitted(tokens, pos)
|
||||
for i, (hi, (cache_k, cache_v)) in enumerate(zip(self.h_jitted, self.kv_caches)):
|
||||
h, cache_k, cache_v = hi(h, cache_k, cache_v, start_pos=start_pos, mask=None, jit_ctx={start_pos_var: start_pos})
|
||||
|
||||
@@ -82,7 +82,7 @@ class Attention:
|
||||
keys, values = xk, xv
|
||||
else:
|
||||
assert cache_k is not None and cache_v is not None, "no cache"
|
||||
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals)})"
|
||||
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals)})"
|
||||
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
|
||||
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
|
||||
|
||||
@@ -121,12 +121,12 @@ class TransformerBlock:
|
||||
cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3])
|
||||
cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3])
|
||||
# need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache
|
||||
cache_k.lazydata.st.var_vals[pos] = start_pos
|
||||
cache_v.lazydata.st.var_vals[pos] = start_pos
|
||||
cache_k.lazydata.var_vals[pos] = start_pos
|
||||
cache_v.lazydata.var_vals[pos] = start_pos
|
||||
|
||||
# get only the part of freqs_cis that we are using.
|
||||
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (pos, pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
|
||||
freqs_cis.lazydata.st.var_vals[pos] = start_pos
|
||||
freqs_cis.lazydata.var_vals[pos] = start_pos
|
||||
else:
|
||||
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (start_pos, start_pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
|
||||
|
||||
@@ -158,7 +158,7 @@ class Transformer:
|
||||
if seqlen == 1 and JIT:
|
||||
pos = Variable("pos", 1, 1024)
|
||||
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
|
||||
freqs_cis.lazydata.st.var_vals[pos] = start_pos
|
||||
freqs_cis.lazydata.var_vals[pos] = start_pos
|
||||
h = self.tok_embeddings_jitted(tokens)
|
||||
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
|
||||
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=None, jit_ctx={pos: start_pos})
|
||||
|
||||
@@ -43,7 +43,7 @@ class ATan2(Function):
|
||||
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
|
||||
self.a, self.b = a, b
|
||||
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
|
||||
return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype))
|
||||
return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {})
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b))
|
||||
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \
|
||||
|
||||
@@ -171,7 +171,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a.shrink(((3,5),(vi,vi+2)))
|
||||
symbolic.lazydata.st.var_vals[vi] = i
|
||||
symbolic.lazydata.var_vals[vi] = i
|
||||
symbolic = jf(symbolic).numpy()
|
||||
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
for i in range(1, 5):
|
||||
a = Tensor.rand(7, 11)
|
||||
symbolic = a.shrink(((3,5),(vi,vi+2)))
|
||||
symbolic.lazydata.st.var_vals[vi] = i
|
||||
symbolic.lazydata.var_vals[vi] = i
|
||||
symbolic = symbolic.numpy()
|
||||
expected = a.shrink(((3,5),(i,i+2))).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -45,28 +45,28 @@ class TestSymbolicReshape(unittest.TestCase):
|
||||
for i in range(1, 6):
|
||||
t = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
assert t.shape == (vi, 4)
|
||||
assert t.lazydata.st.var_vals[vi] == i
|
||||
assert t.lazydata.var_vals[vi] == i
|
||||
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
|
||||
assert t.shape == (vi, 2, 3)
|
||||
assert t.lazydata.st.var_vals[vi] == i
|
||||
assert t.lazydata.var_vals[vi] == i
|
||||
|
||||
def test_reshape_symbols_reshape_ints(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
t = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
assert t.shape == (vi, 4)
|
||||
assert t.lazydata.st.var_vals == {vi: i}
|
||||
assert t.lazydata.var_vals == {vi: i}
|
||||
t = t.reshape(i, 4)
|
||||
assert t.shape == (i, 4)
|
||||
assert t.lazydata.st.var_vals == {}
|
||||
assert t.lazydata.var_vals == {vi: i}
|
||||
|
||||
def test_reshape_reuse_var_same_value_ok(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
for i in range(1, 6):
|
||||
a = Tensor.rand(i, 4).reshape(vi, 4)
|
||||
b = Tensor.rand(i, 3).reshape(vi, 3)
|
||||
assert a.lazydata.st.var_vals[vi] == i
|
||||
assert b.lazydata.st.var_vals[vi] == i
|
||||
assert a.lazydata.var_vals[vi] == i
|
||||
assert b.lazydata.var_vals[vi] == i
|
||||
|
||||
def test_reshape_reuse_var_different_value_ok(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
@@ -74,8 +74,8 @@ class TestSymbolicReshape(unittest.TestCase):
|
||||
a = Tensor.rand(i, 4).reshape(vi, 2)
|
||||
b = Tensor.rand(i, 3).reshape(vi, 3)
|
||||
# a and b have different values of vi
|
||||
assert a.lazydata.st.var_vals[vi] == 2 * i
|
||||
assert b.lazydata.st.var_vals[vi] == i
|
||||
assert a.lazydata.var_vals[vi] == 2 * i
|
||||
assert b.lazydata.var_vals[vi] == i
|
||||
|
||||
def test_reshape_into_symbols_bad_shape(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
@@ -115,10 +115,10 @@ class TestSymbolicExpand(unittest.TestCase):
|
||||
vj = Variable("j", 1, 5)
|
||||
a = Tensor([[1], [2], [3]]).expand((3, vi))
|
||||
assert a.shape == (3, vi)
|
||||
assert a.lazydata.st.var_vals == {}
|
||||
assert a.lazydata.var_vals == {}
|
||||
a = a.reshape(3, vi, 1).expand((3, vi, vj))
|
||||
assert a.shape == (3, vi, vj)
|
||||
assert a.lazydata.st.var_vals == {}
|
||||
assert a.lazydata.var_vals == {}
|
||||
|
||||
def test_plus_expands_constant(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
@@ -152,18 +152,18 @@ class TestShapeTrackerVarVals(unittest.TestCase):
|
||||
vi = Variable("i", 1, 5)
|
||||
vj = Variable("j", 1, 5)
|
||||
t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj)
|
||||
assert t.lazydata.st.var_vals == {vi: 4, vj: 3}
|
||||
assert t.lazydata.var_vals == {vi: 4, vj: 3}
|
||||
|
||||
def test_lazy_check_var_vals(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
a = Tensor.rand(3, 4).reshape(3, vi)
|
||||
b = Tensor.rand(5, 6).reshape(vi, 6)
|
||||
assert a.lazydata.st.var_vals == {vi: 4}
|
||||
assert b.lazydata.st.var_vals == {vi: 5}
|
||||
assert a.lazydata.var_vals == {vi: 4}
|
||||
assert b.lazydata.var_vals == {vi: 5}
|
||||
c = a@b
|
||||
# shapetracker works with symbolic shape and doesn't check / propagate the underlying variable values
|
||||
# shapetracker works with symbolic shape and doesn't check the underlying variable values
|
||||
assert c.shape == (3, 6)
|
||||
assert c.lazydata.st.var_vals == {}
|
||||
assert c.lazydata.var_vals == {vi: 4}
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -197,7 +197,7 @@ class Linearizer(OptimizedKernel):
|
||||
for i,b in enumerate(self.bufs):
|
||||
if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized]
|
||||
# add variables from symbolic shapes
|
||||
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
|
||||
for var in sorted(set(v for buf in self.ast.buffers for v in buf.var_vals), key=lambda k: k.key):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
|
||||
# define local buffers
|
||||
|
||||
@@ -30,7 +30,7 @@ class TinyJit:
|
||||
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
|
||||
if self.cnt >= 2:
|
||||
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
|
||||
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
|
||||
except KeyError: var_vals = merge_dicts([arg.lazydata.var_vals for arg in args if arg.__class__ is Tensor])
|
||||
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
|
||||
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
|
||||
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
|
||||
|
||||
@@ -5,10 +5,10 @@ from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
|
||||
from tinygrad.shape.symbolic import Node
|
||||
from tinygrad.shape.symbolic import Node, Variable
|
||||
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
@@ -96,25 +96,27 @@ def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: ret
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
lazycache: WeakValueDictionary = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType):
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int]):
|
||||
# fromcpu aren't cached
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype)
|
||||
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, var_vals)
|
||||
|
||||
# wop is the deduping key. i feel this used to compare more deeply
|
||||
wop = (device, dtype, optype, ref(op))
|
||||
wop = (device, dtype, optype, ref(op), tuple(sorted(var_vals.keys())))
|
||||
if wop in lazycache:
|
||||
for x in op.buffers: x.children.add(lazycache[wop])
|
||||
return lazycache[wop]
|
||||
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, var_vals)
|
||||
return ret
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
|
||||
class LazyBuffer:
|
||||
__deletable__ = ('op',)
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None):
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None):
|
||||
self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker
|
||||
self.var_vals: Dict[Variable, int] = var_vals
|
||||
self.var_vals_key: Tuple[Variable, ...] = tuple(sorted(self.var_vals.keys()))
|
||||
self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype
|
||||
self.realized: Optional[RawBuffer] = src
|
||||
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
|
||||
@@ -132,8 +134,8 @@ class LazyBuffer:
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if not self.realized else self.realized} st={self.st}>"
|
||||
@property
|
||||
def key(self):
|
||||
if self.realized: return (self.dtype, self.realized.key, self.st.key)
|
||||
return (self.dtype, self.op.op, self.st.key)
|
||||
if self.realized: return (self.dtype, self.realized.key, self.st.key, self.var_vals_key)
|
||||
return (self.dtype, self.op.op, self.st.key, self.var_vals_key)
|
||||
|
||||
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
|
||||
|
||||
@@ -174,7 +176,7 @@ class LazyBuffer:
|
||||
|
||||
@staticmethod
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
|
||||
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {})
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const(self, val:Union[float, int]) -> LazyBuffer:
|
||||
@@ -183,11 +185,11 @@ class LazyBuffer:
|
||||
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype)
|
||||
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x: np.ndarray) -> LazyBuffer:
|
||||
return LazyBuffer("CPU", ShapeTracker(x.shape, [View(x.shape, tuple(st//x.itemsize for st in x.strides))]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
|
||||
return LazyBuffer("CPU", ShapeTracker(x.shape, [View(x.shape, tuple(st//x.itemsize for st in x.strides))]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
|
||||
|
||||
def toCPU(self) -> np.ndarray:
|
||||
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
|
||||
@@ -220,7 +222,7 @@ class LazyBuffer:
|
||||
# remove the buffers from any (childless) BinaryOps that feed into this
|
||||
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
|
||||
|
||||
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
|
||||
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
|
||||
@@ -230,12 +232,12 @@ class LazyBuffer:
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
|
||||
return root.reshape(st.shape)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals)
|
||||
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
|
||||
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
@@ -246,8 +248,19 @@ class LazyBuffer:
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int))
|
||||
if new_nodes and all(isinstance(s, int) for s in self.shape):
|
||||
# reshape from all int shape into shape with a variable, update the variable value
|
||||
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
|
||||
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
|
||||
if new_var not in self.var_vals:
|
||||
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
|
||||
self.var_vals[new_var] = new_val
|
||||
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
|
||||
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
||||
assert isinstance(self.op.src[0], LazyBuffer)
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
self.op.src[0].var_vals = self.var_vals
|
||||
return self.op.src[0].reshape(arg)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
|
||||
|
||||
|
||||
@@ -198,7 +198,7 @@ class Compiled:
|
||||
from tinygrad.jit import CacheCollector
|
||||
CacheCollector._mark_output_buffer(output.output_buffer)
|
||||
# update the output var_vals from src
|
||||
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
output.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, output, self.linearizer_opts)
|
||||
@@ -218,5 +218,5 @@ class Compiled:
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
prg.exec(k.bufs, var_vals=output.st.var_vals)
|
||||
prg.exec(k.bufs, var_vals=output.var_vals)
|
||||
return output.realized
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import Dict, Tuple, Union, List, Optional, NamedTuple
|
||||
from tinygrad.helpers import prod, DEBUG, partition
|
||||
from typing import Tuple, Union, List, Optional, NamedTuple
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@@ -127,11 +127,10 @@ def get_unsafe_resize_offset(strides, arg):
|
||||
return sum([s * x[0] for s, x in zip(strides,arg)])
|
||||
|
||||
class ShapeTracker:
|
||||
__slots__ = "views", "var_vals"
|
||||
__slots__ = "views"
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None):
|
||||
self.views: List[View] = views if views is not None else [*shape.views] if isinstance(shape, ShapeTracker) else [View(shape)]
|
||||
self.var_vals: Dict[Variable, int] = shape.var_vals if isinstance(shape, ShapeTracker) else {}
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views}, var_vals={self.var_vals})"
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
||||
|
||||
@property
|
||||
@@ -141,7 +140,7 @@ class ShapeTracker:
|
||||
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
|
||||
|
||||
@property
|
||||
def key(self) -> Tuple[Tuple[View, ...], Tuple[Variable, ...]]: return tuple(self.views), tuple(sorted(self.var_vals.keys()))
|
||||
def key(self) -> Tuple[View, ...]: return tuple(self.views)
|
||||
|
||||
# this is the real size (ish)
|
||||
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
||||
@@ -233,16 +232,6 @@ class ShapeTracker:
|
||||
|
||||
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
|
||||
if self.views[-1].shape == new_shape: return self
|
||||
new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int))
|
||||
if new_nodes and all(isinstance(s, int) for s in self.shape):
|
||||
# reshape from all int shape into shape with a variable, update the variable value
|
||||
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
|
||||
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
|
||||
if new_var not in self.var_vals:
|
||||
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
|
||||
self.var_vals[new_var] = new_val
|
||||
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
|
||||
elif not new_nodes: self.var_vals = {}
|
||||
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
||||
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
|
||||
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):
|
||||
|
||||
Reference in New Issue
Block a user