schedule the loadops like everything else (#1964)

* schedule the loadops like everything else

* unify loadops with other things we schedule

* delete all the ops

* fix symbolic jit
This commit is contained in:
George Hotz
2023-10-04 02:36:04 -07:00
committed by GitHub
parent fb4d830a2a
commit 0945848b5f
4 changed files with 47 additions and 40 deletions

View File

@@ -135,7 +135,7 @@ assert len(lazyop.src) == 2
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
assert lazyop.src[0].op.op == LoadOps.FROM
assert lazyop.src[0].op.src[0].device == "CPU"
assert lazyop.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
# now we realize the LazyBuffer

View File

@@ -166,7 +166,6 @@ class LazyBuffer:
if self.optype is MovementOps: return self.base.schedule(seen)
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
if op.op in LoadOps: return [(self.op, self, ())]
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
elif self.optype is ReduceOps: op = _ast_reduceops(op)
@@ -177,21 +176,27 @@ class LazyBuffer:
else: op = LazyOp(UnaryOps.CAST, (op,), (dtypes.float32, False))
self.dtype = dtypes.float32
# contiguous can be a copy. must do this after the image hack
if self.op.op == LoadOps.CONTIGUOUS:
src = cast(LazyBuffer, self.op.src[0])
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
return src.schedule(seen) + [(self.op, self, ())]
# realize the past and exec the AST
ret = []
for x in op.buffers: ret += x.schedule(seen)
# TODO: this belongs in the schedule in some way
self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
self.var_vals = dict(sorted(merge_dicts([self.var_vals] + [buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
# contiguous can be a copy. must do this after the image hack
if self.op.op == LoadOps.CONTIGUOUS:
src = cast(LazyBuffer, self.op.src[0])
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
return ret + [(self.op, self, (src,))]
# run the ast and log the op
op, base_bufs = _replace_bufferops(op)
# confirm the LoadOps are contiguous and in order
if op.op in LoadOps:
for i,s in enumerate(op.src):
assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
return ret + [(op, self, tuple(base_bufs))]
def realize(self:LazyBuffer) -> LazyBuffer:

View File

@@ -19,38 +19,16 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]
from extra.utils import print_tree # type: ignore
print_tree(op)
if op.op in LoadOps:
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out)
# TODO: why can't we delete these ops?
# NOTE: load op buffers are promised to be in order by the scheduler
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out, *buffers)
else:
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args())
del out.op
for v in out.views: del v.op
del out.op
for v in out.views: del v.op
assert out.realized and isinstance(out.realized, Device[out.device].buffer), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
def _realize_contiguous(buffer: LazyBuffer) -> None:
# this is just a copy now, if it's not a copy schedule will handle it
src = cast(LazyBuffer, buffer.op.src[0])
buffer.realized = src.realized
assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}"
def _realize_custom(buffer: LazyBuffer) -> None:
# this needs to immediately realize
buffer.realized = buffer.op.arg(buffer, *[x.realize() for x in buffer.op.src])
def _realize_from(buffer: LazyBuffer) -> None:
rawbuf = cast(LazyBuffer, buffer.op.src[0]).contiguous().realize()
assert rawbuf.realized, "realize failed?"
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {rawbuf.device} size {rawbuf.realized.size} dtype {rawbuf.realized.dtype}")
# TODO: make this generic
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
assert all_int(buffer.shape), "does not support symbolic shape"
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
rawbuf.prepare_transfer().readinto(cast(RawBufferMapped, buffer.realized)._buffer())
elif isinstance(rawbuf.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1:
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(rawbuf.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
else:
buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args())
# *** zero op LoadOps ***
def _realize_empty(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "does not support symbolic shape"
@@ -61,10 +39,34 @@ def _realize_rand(buffer: LazyBuffer) -> None:
rng = np.random.default_rng(buffer.op.arg)
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
# *** one op LoadOps ***
def _realize_contiguous(buffer: LazyBuffer, src: LazyBuffer) -> None:
# this is just a copy now, if it's not a copy schedule will handle it
buffer.realized = src.realized
assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}"
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size} dtype {src.realized.dtype}")
# TODO: make this generic
if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
assert all_int(buffer.shape), "does not support symbolic shape"
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
src.prepare_transfer().readinto(cast(RawBufferMapped, buffer.realized)._buffer())
elif isinstance(src.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1:
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(src.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
else:
buffer.realized = Device[buffer.device].buffer.fromCPU(src.toCPU(), **buffer._device_extra_args())
# *** n op LoadOps ***
def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
buffer.realized = buffer.op.arg(buffer, *inputs)
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.CONTIGUOUS: _realize_contiguous,
LoadOps.CUSTOM: _realize_custom,
LoadOps.FROM: _realize_from,
LoadOps.EMPTY: _realize_empty,
LoadOps.RAND: _realize_rand,
LoadOps.CONTIGUOUS: _realize_contiguous,
LoadOps.FROM: _realize_from,
LoadOps.CUSTOM: _realize_custom,
}

View File

@@ -70,7 +70,7 @@ class Tensor:
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
else: raise RuntimeError(f"can't create Tensor from {data}")
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data.contiguous())
def __repr__(self):
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"