mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
schedule the loadops like everything else (#1964)
* schedule the loadops like everything else * unify loadops with other things we schedule * delete all the ops * fix symbolic jit
This commit is contained in:
@@ -135,7 +135,7 @@ assert len(lazyop.src) == 2
|
||||
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
|
||||
assert lazyop.src[0].op.op == LoadOps.FROM
|
||||
assert lazyop.src[0].op.src[0].device == "CPU"
|
||||
assert lazyop.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
|
||||
assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
|
||||
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
||||
|
||||
# now we realize the LazyBuffer
|
||||
|
||||
@@ -166,7 +166,6 @@ class LazyBuffer:
|
||||
if self.optype is MovementOps: return self.base.schedule(seen)
|
||||
|
||||
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
|
||||
if op.op in LoadOps: return [(self.op, self, ())]
|
||||
|
||||
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
|
||||
elif self.optype is ReduceOps: op = _ast_reduceops(op)
|
||||
@@ -177,21 +176,27 @@ class LazyBuffer:
|
||||
else: op = LazyOp(UnaryOps.CAST, (op,), (dtypes.float32, False))
|
||||
self.dtype = dtypes.float32
|
||||
|
||||
# contiguous can be a copy. must do this after the image hack
|
||||
if self.op.op == LoadOps.CONTIGUOUS:
|
||||
src = cast(LazyBuffer, self.op.src[0])
|
||||
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
|
||||
return src.schedule(seen) + [(self.op, self, ())]
|
||||
|
||||
# realize the past and exec the AST
|
||||
ret = []
|
||||
for x in op.buffers: ret += x.schedule(seen)
|
||||
|
||||
# TODO: this belongs in the schedule in some way
|
||||
self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
self.var_vals = dict(sorted(merge_dicts([self.var_vals] + [buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
# contiguous can be a copy. must do this after the image hack
|
||||
if self.op.op == LoadOps.CONTIGUOUS:
|
||||
src = cast(LazyBuffer, self.op.src[0])
|
||||
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
|
||||
return ret + [(self.op, self, (src,))]
|
||||
|
||||
# run the ast and log the op
|
||||
op, base_bufs = _replace_bufferops(op)
|
||||
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
if op.op in LoadOps:
|
||||
for i,s in enumerate(op.src):
|
||||
assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
|
||||
return ret + [(op, self, tuple(base_bufs))]
|
||||
|
||||
def realize(self:LazyBuffer) -> LazyBuffer:
|
||||
|
||||
@@ -19,38 +19,16 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]
|
||||
from extra.utils import print_tree # type: ignore
|
||||
print_tree(op)
|
||||
if op.op in LoadOps:
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out)
|
||||
# TODO: why can't we delete these ops?
|
||||
# NOTE: load op buffers are promised to be in order by the scheduler
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out, *buffers)
|
||||
else:
|
||||
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args())
|
||||
del out.op
|
||||
for v in out.views: del v.op
|
||||
del out.op
|
||||
for v in out.views: del v.op
|
||||
assert out.realized and isinstance(out.realized, Device[out.device].buffer), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
|
||||
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
|
||||
|
||||
def _realize_contiguous(buffer: LazyBuffer) -> None:
|
||||
# this is just a copy now, if it's not a copy schedule will handle it
|
||||
src = cast(LazyBuffer, buffer.op.src[0])
|
||||
buffer.realized = src.realized
|
||||
assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}"
|
||||
|
||||
def _realize_custom(buffer: LazyBuffer) -> None:
|
||||
# this needs to immediately realize
|
||||
buffer.realized = buffer.op.arg(buffer, *[x.realize() for x in buffer.op.src])
|
||||
|
||||
def _realize_from(buffer: LazyBuffer) -> None:
|
||||
rawbuf = cast(LazyBuffer, buffer.op.src[0]).contiguous().realize()
|
||||
assert rawbuf.realized, "realize failed?"
|
||||
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {rawbuf.device} size {rawbuf.realized.size} dtype {rawbuf.realized.dtype}")
|
||||
# TODO: make this generic
|
||||
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
rawbuf.prepare_transfer().readinto(cast(RawBufferMapped, buffer.realized)._buffer())
|
||||
elif isinstance(rawbuf.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1:
|
||||
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(rawbuf.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
|
||||
else:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args())
|
||||
# *** zero op LoadOps ***
|
||||
|
||||
def _realize_empty(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
@@ -61,10 +39,34 @@ def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
rng = np.random.default_rng(buffer.op.arg)
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
|
||||
|
||||
# *** one op LoadOps ***
|
||||
|
||||
def _realize_contiguous(buffer: LazyBuffer, src: LazyBuffer) -> None:
|
||||
# this is just a copy now, if it's not a copy schedule will handle it
|
||||
buffer.realized = src.realized
|
||||
assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}"
|
||||
|
||||
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
|
||||
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size} dtype {src.realized.dtype}")
|
||||
# TODO: make this generic
|
||||
if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
src.prepare_transfer().readinto(cast(RawBufferMapped, buffer.realized)._buffer())
|
||||
elif isinstance(src.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1:
|
||||
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(src.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
|
||||
else:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(src.toCPU(), **buffer._device_extra_args())
|
||||
|
||||
# *** n op LoadOps ***
|
||||
|
||||
def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
|
||||
buffer.realized = buffer.op.arg(buffer, *inputs)
|
||||
|
||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
LoadOps.CONTIGUOUS: _realize_contiguous,
|
||||
LoadOps.CUSTOM: _realize_custom,
|
||||
LoadOps.FROM: _realize_from,
|
||||
LoadOps.EMPTY: _realize_empty,
|
||||
LoadOps.RAND: _realize_rand,
|
||||
LoadOps.CONTIGUOUS: _realize_contiguous,
|
||||
LoadOps.FROM: _realize_from,
|
||||
LoadOps.CUSTOM: _realize_custom,
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ class Tensor:
|
||||
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
else: raise RuntimeError(f"can't create Tensor from {data}")
|
||||
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data.contiguous())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
|
||||
Reference in New Issue
Block a user