mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 09:33:03 +08:00
* jit: graphing as rewrite rule * f * +metal,cuda * x * cl * x * x * simpler * f * m * x * revert? * revert2 * back * back * t * x * m * x * c * x * l * x * comment * smaller * rv * x * x
45 lines
2.0 KiB
Python
45 lines
2.0 KiB
Python
import unittest
|
|
from tinygrad import Device, Tensor
|
|
from tinygrad.engine.jit import TinyJit
|
|
from tinygrad.uop.ops import UOp, Ops
|
|
from tinygrad.dtype import dtypes
|
|
from tinygrad.runtime.graph.hcq import HCQGraph
|
|
from tinygrad.runtime.support.hcq import HCQCompiled
|
|
from tinygrad.runtime.support.usb import USBMMIOInterface
|
|
from test.mockgpu.usb import MockUSB
|
|
|
|
@unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "HCQ device required to run")
|
|
class TestHCQUnit(unittest.TestCase):
|
|
@unittest.skipIf(Device.DEFAULT == "CPU", "requires non-CPU HCQ device")
|
|
def test_supports_exec_item(self):
|
|
d0, cpu_dev = Device[Device.DEFAULT], Device["CPU"]
|
|
|
|
@TinyJit
|
|
def f(inp, inp_cpu):
|
|
return (inp + 1.0).contiguous().realize(), (inp_cpu + 1.0).contiguous().realize()
|
|
inp, inp_cpu = Tensor.randn(10, 10, device=Device.DEFAULT).realize(), Tensor.randn(10, 10, device="CPU").realize()
|
|
for _ in range(5): f(inp, inp_cpu)
|
|
|
|
# construct minimal CALL UOps for supports_exec_item
|
|
gpu_call = UOp(Ops.SINK).call(UOp.new_buffer(Device.DEFAULT, 1, dtypes.float))
|
|
cpu_call = UOp(Ops.SINK).call(UOp.new_buffer("CPU", 1, dtypes.float))
|
|
gpu_devs = [d0]
|
|
|
|
# local MMIO: GPU works alone and with CPU in batch (cpu_support=True)
|
|
assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True
|
|
assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is True
|
|
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is True
|
|
|
|
# USB MMIO: GPU-only still works, but CPU batching must be rejected (cpu_support=False)
|
|
orig_view = d0.timeline_signal.base_buf.view
|
|
try:
|
|
d0.timeline_signal.base_buf.view = USBMMIOInterface(MockUSB(bytearray(256)), 0, 16, fmt='B')
|
|
assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True
|
|
assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is False
|
|
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is False
|
|
finally:
|
|
d0.timeline_signal.base_buf.view = orig_view
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|