Move var_vals from ShapeTracker to LazyBuffer (#1819)

This commit is contained in:
chenyu
2023-09-08 09:25:10 -07:00
committed by GitHub
parent 7ac65a93b4
commit ebcda8a714
11 changed files with 63 additions and 61 deletions

View File

@@ -76,8 +76,8 @@ class TransformerBlock:
cache_k = cache_k.reshape(cache_k.shape[0], start_pos_var, cache_k.shape[2], cache_k.shape[3])
cache_v = cache_v.reshape(cache_v.shape[0], start_pos_var, cache_v.shape[2], cache_v.shape[3])
# need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache
cache_k.lazydata.st.var_vals[start_pos_var] = start_pos
cache_v.lazydata.st.var_vals[start_pos_var] = start_pos
cache_k.lazydata.var_vals[start_pos_var] = start_pos
cache_v.lazydata.var_vals[start_pos_var] = start_pos
output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask, jit_ctx=jit_ctx)
h = x + output
@@ -113,7 +113,7 @@ class Transformer:
if seqlen == 1 and start_pos > 0 and getenv("JIT"):
start_pos_var = Variable("start_pos", 1, MAX_CONTEXT)
pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var+seqlen)))
pos.lazydata.st.var_vals[start_pos_var] = start_pos
pos.lazydata.var_vals[start_pos_var] = start_pos
h = self.embed_jitted(tokens, pos)
for i, (hi, (cache_k, cache_v)) in enumerate(zip(self.h_jitted, self.kv_caches)):
h, cache_k, cache_v = hi(h, cache_k, cache_v, start_pos=start_pos, mask=None, jit_ctx={start_pos_var: start_pos})

View File

@@ -82,7 +82,7 @@ class Attention:
keys, values = xk, xv
else:
assert cache_k is not None and cache_v is not None, "no cache"
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals)})"
assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals)})"
assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?"
keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1)
@@ -121,12 +121,12 @@ class TransformerBlock:
cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3])
cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3])
# need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache
cache_k.lazydata.st.var_vals[pos] = start_pos
cache_v.lazydata.st.var_vals[pos] = start_pos
cache_k.lazydata.var_vals[pos] = start_pos
cache_v.lazydata.var_vals[pos] = start_pos
# get only the part of freqs_cis that we are using.
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (pos, pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
freqs_cis.lazydata.st.var_vals[pos] = start_pos
freqs_cis.lazydata.var_vals[pos] = start_pos
else:
freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (start_pos, start_pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4])))
@@ -158,7 +158,7 @@ class Transformer:
if seqlen == 1 and JIT:
pos = Variable("pos", 1, 1024)
freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4])))
freqs_cis.lazydata.st.var_vals[pos] = start_pos
freqs_cis.lazydata.var_vals[pos] = start_pos
h = self.tok_embeddings_jitted(tokens)
for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)):
h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=None, jit_ctx={pos: start_pos})

View File

@@ -43,7 +43,7 @@ class ATan2(Function):
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
self.a, self.b = a, b
ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device])
return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype))
return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {})
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b))
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \

View File

@@ -171,7 +171,7 @@ class TestSymbolicJit(unittest.TestCase):
for i in range(1, 5):
a = Tensor.rand(7, 11)
symbolic = a.shrink(((3,5),(vi,vi+2)))
symbolic.lazydata.st.var_vals[vi] = i
symbolic.lazydata.var_vals[vi] = i
symbolic = jf(symbolic).numpy()
expected = f(a.shrink(((3,5),(i,i+2)))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

View File

@@ -123,7 +123,7 @@ class TestSymbolicOps(unittest.TestCase):
for i in range(1, 5):
a = Tensor.rand(7, 11)
symbolic = a.shrink(((3,5),(vi,vi+2)))
symbolic.lazydata.st.var_vals[vi] = i
symbolic.lazydata.var_vals[vi] = i
symbolic = symbolic.numpy()
expected = a.shrink(((3,5),(i,i+2))).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

View File

@@ -45,28 +45,28 @@ class TestSymbolicReshape(unittest.TestCase):
for i in range(1, 6):
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
assert t.lazydata.st.var_vals[vi] == i
assert t.lazydata.var_vals[vi] == i
t = Tensor.rand(i, 6).reshape(vi, 2, 3)
assert t.shape == (vi, 2, 3)
assert t.lazydata.st.var_vals[vi] == i
assert t.lazydata.var_vals[vi] == i
def test_reshape_symbols_reshape_ints(self):
vi = Variable("i", 1, 5)
for i in range(1, 6):
t = Tensor.rand(i, 4).reshape(vi, 4)
assert t.shape == (vi, 4)
assert t.lazydata.st.var_vals == {vi: i}
assert t.lazydata.var_vals == {vi: i}
t = t.reshape(i, 4)
assert t.shape == (i, 4)
assert t.lazydata.st.var_vals == {}
assert t.lazydata.var_vals == {vi: i}
def test_reshape_reuse_var_same_value_ok(self):
vi = Variable("i", 1, 5)
for i in range(1, 6):
a = Tensor.rand(i, 4).reshape(vi, 4)
b = Tensor.rand(i, 3).reshape(vi, 3)
assert a.lazydata.st.var_vals[vi] == i
assert b.lazydata.st.var_vals[vi] == i
assert a.lazydata.var_vals[vi] == i
assert b.lazydata.var_vals[vi] == i
def test_reshape_reuse_var_different_value_ok(self):
vi = Variable("i", 1, 10)
@@ -74,8 +74,8 @@ class TestSymbolicReshape(unittest.TestCase):
a = Tensor.rand(i, 4).reshape(vi, 2)
b = Tensor.rand(i, 3).reshape(vi, 3)
# a and b have different values of vi
assert a.lazydata.st.var_vals[vi] == 2 * i
assert b.lazydata.st.var_vals[vi] == i
assert a.lazydata.var_vals[vi] == 2 * i
assert b.lazydata.var_vals[vi] == i
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10)
@@ -115,10 +115,10 @@ class TestSymbolicExpand(unittest.TestCase):
vj = Variable("j", 1, 5)
a = Tensor([[1], [2], [3]]).expand((3, vi))
assert a.shape == (3, vi)
assert a.lazydata.st.var_vals == {}
assert a.lazydata.var_vals == {}
a = a.reshape(3, vi, 1).expand((3, vi, vj))
assert a.shape == (3, vi, vj)
assert a.lazydata.st.var_vals == {}
assert a.lazydata.var_vals == {}
def test_plus_expands_constant(self):
vi = Variable("i", 1, 5)
@@ -152,18 +152,18 @@ class TestShapeTrackerVarVals(unittest.TestCase):
vi = Variable("i", 1, 5)
vj = Variable("j", 1, 5)
t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj)
assert t.lazydata.st.var_vals == {vi: 4, vj: 3}
assert t.lazydata.var_vals == {vi: 4, vj: 3}
def test_lazy_check_var_vals(self):
vi = Variable("i", 1, 5)
a = Tensor.rand(3, 4).reshape(3, vi)
b = Tensor.rand(5, 6).reshape(vi, 6)
assert a.lazydata.st.var_vals == {vi: 4}
assert b.lazydata.st.var_vals == {vi: 5}
assert a.lazydata.var_vals == {vi: 4}
assert b.lazydata.var_vals == {vi: 5}
c = a@b
# shapetracker works with symbolic shape and doesn't check / propagate the underlying variable values
# shapetracker works with symbolic shape and doesn't check the underlying variable values
assert c.shape == (3, 6)
assert c.lazydata.st.var_vals == {}
assert c.lazydata.var_vals == {vi: 4}
if __name__ == '__main__':
unittest.main()

View File

@@ -197,7 +197,7 @@ class Linearizer(OptimizedKernel):
for i,b in enumerate(self.bufs):
if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized]
# add variables from symbolic shapes
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
for var in sorted(set(v for buf in self.ast.buffers for v in buf.var_vals), key=lambda k: k.key):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
# define local buffers

View File

@@ -30,7 +30,7 @@ class TinyJit:
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
if self.cnt >= 2:
try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"]
except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor])
except KeyError: var_vals = merge_dicts([arg.lazydata.var_vals for arg in args if arg.__class__ is Tensor])
if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key))
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"

View File

@@ -5,10 +5,10 @@ from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.graph import log_op
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp
from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction
from tinygrad.shape.symbolic import Node
from tinygrad.shape.symbolic import Node, Variable
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
@@ -96,25 +96,27 @@ def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: ret
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType):
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int]):
# fromcpu aren't cached
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype)
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, var_vals)
# wop is the deduping key. i feel this used to compare more deeply
wop = (device, dtype, optype, ref(op))
wop = (device, dtype, optype, ref(op), tuple(sorted(var_vals.keys())))
if wop in lazycache:
for x in op.buffers: x.children.add(lazycache[wop])
return lazycache[wop]
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, var_vals)
return ret
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None):
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None):
self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker
self.var_vals: Dict[Variable, int] = var_vals
self.var_vals_key: Tuple[Variable, ...] = tuple(sorted(self.var_vals.keys()))
self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype
self.realized: Optional[RawBuffer] = src
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
@@ -132,8 +134,8 @@ class LazyBuffer:
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if not self.realized else self.realized} st={self.st}>"
@property
def key(self):
if self.realized: return (self.dtype, self.realized.key, self.st.key)
return (self.dtype, self.op.op, self.st.key)
if self.realized: return (self.dtype, self.realized.key, self.st.key, self.var_vals_key)
return (self.dtype, self.op.op, self.st.key, self.var_vals_key)
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
@@ -174,7 +176,7 @@ class LazyBuffer:
@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {})
# create a constant with the shape and dtype of self
def const(self, val:Union[float, int]) -> LazyBuffer:
@@ -183,11 +185,11 @@ class LazyBuffer:
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype)
return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals)
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker(x.shape, [View(x.shape, tuple(st//x.itemsize for st in x.strides))]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
return LazyBuffer("CPU", ShapeTracker(x.shape, [View(x.shape, tuple(st//x.itemsize for st in x.strides))]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
def toCPU(self) -> np.ndarray:
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
@@ -220,7 +222,7 @@ class LazyBuffer:
# remove the buffers from any (childless) BinaryOps that feed into this
srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
@@ -230,12 +232,12 @@ class LazyBuffer:
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals)
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals)
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
@@ -246,8 +248,19 @@ class LazyBuffer:
def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer:
if self.shape == arg: return self
new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int))
if new_nodes and all(isinstance(s, int) for s in self.shape):
# reshape from all int shape into shape with a variable, update the variable value
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
if new_var not in self.var_vals:
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
self.var_vals[new_var] = new_val
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
if not self.realized and self.op.op == MovementOps.RESHAPE:
assert isinstance(self.op.src[0], LazyBuffer)
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
self.op.src[0].var_vals = self.var_vals
return self.op.src[0].reshape(arg)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)

View File

@@ -198,7 +198,7 @@ class Compiled:
from tinygrad.jit import CacheCollector
CacheCollector._mark_output_buffer(output.output_buffer)
# update the output var_vals from src
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
output.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, output, self.linearizer_opts)
@@ -218,5 +218,5 @@ class Compiled:
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
prg.exec(k.bufs, var_vals=output.st.var_vals)
prg.exec(k.bufs, var_vals=output.var_vals)
return output.realized

View File

@@ -1,8 +1,8 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools
from typing import Dict, Tuple, Union, List, Optional, NamedTuple
from tinygrad.helpers import prod, DEBUG, partition
from typing import Tuple, Union, List, Optional, NamedTuple
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int
@functools.lru_cache(maxsize=None)
@@ -127,11 +127,10 @@ def get_unsafe_resize_offset(strides, arg):
return sum([s * x[0] for s, x in zip(strides,arg)])
class ShapeTracker:
__slots__ = "views", "var_vals"
__slots__ = "views"
def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None):
self.views: List[View] = views if views is not None else [*shape.views] if isinstance(shape, ShapeTracker) else [View(shape)]
self.var_vals: Dict[Variable, int] = shape.var_vals if isinstance(shape, ShapeTracker) else {}
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views}, var_vals={self.var_vals})"
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
@property
@@ -141,7 +140,7 @@ class ShapeTracker:
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape)
@property
def key(self) -> Tuple[Tuple[View, ...], Tuple[Variable, ...]]: return tuple(self.views), tuple(sorted(self.var_vals.keys()))
def key(self) -> Tuple[View, ...]: return tuple(self.views)
# this is the real size (ish)
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
@@ -233,16 +232,6 @@ class ShapeTracker:
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
if self.views[-1].shape == new_shape: return self
new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int))
if new_nodes and all(isinstance(s, int) for s in self.shape):
# reshape from all int shape into shape with a variable, update the variable value
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
if new_var not in self.var_vals:
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
self.var_vals[new_var] = new_val
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
elif not new_nodes: self.var_vals = {}
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):