mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix DEBUG=2 output for copy runners [pr] (#8579)
* fix DEBUG=2 output for copy runners [pr] * itemsize is constant
This commit is contained in:
@@ -7,6 +7,7 @@ from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
from tinygrad.device import Device
|
||||
|
||||
def flops_mem(uops, ignore_indexing=False):
|
||||
est = Estimates.from_uops(uops, ignore_indexing)
|
||||
@@ -64,6 +65,15 @@ class TestMemoryCount(unittest.TestCase):
|
||||
_, mem = get_stats(a.assign(a+a))
|
||||
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "test copy to CLANG from other device")
|
||||
def test_copyout(self):
|
||||
a = Tensor.empty(32, dtype=dtypes.uint8).to("CLANG")
|
||||
_, mem = get_stats(a)
|
||||
self.assertEqual(mem, 32*1)
|
||||
a = Tensor.empty(32, dtype=dtypes.uint32).to("CLANG")
|
||||
_, mem = get_stats(a)
|
||||
self.assertEqual(mem, 32*4)
|
||||
|
||||
# NOTE: this still isn't testing unroll using the acc
|
||||
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
|
||||
class TestUOpsStatsMatmulHalf(unittest.TestCase):
|
||||
|
||||
@@ -141,9 +141,9 @@ class ExecItem:
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(copy.size, ctx[0].device, ctx[1].device) \
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(copy.size, ctx[0].device, ctx[1].device)), list(ctx))),
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
|
||||
])
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user