mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 16:37:04 +08:00
47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
import unittest
|
|
from unittest.mock import MagicMock
|
|
from tinygrad import Device
|
|
from tinygrad.uop.ops import Ops, UOp
|
|
from tinygrad.dtype import dtypes
|
|
|
|
@unittest.skipUnless(Device.DEFAULT == "METAL", "Metal device required to run")
|
|
class TestMetalGraph(unittest.TestCase):
|
|
def setUp(self):
|
|
from tinygrad.runtime.graph.metal import MetalGraph
|
|
self.MetalGraph = MetalGraph
|
|
self.dev = Device[Device.DEFAULT]
|
|
|
|
def metal_buf(self, offset):
|
|
buf = MagicMock()
|
|
if offset > 0:
|
|
buf.op = Ops.SLICE
|
|
src = MagicMock()
|
|
src.dtype = dtypes.uint8
|
|
buf.src = (src, UOp.const(dtypes.weakint, offset))
|
|
buf.dtype = dtypes.uint8
|
|
else:
|
|
buf.op = Ops.BUFFER
|
|
buf.device = Device.DEFAULT
|
|
return buf
|
|
|
|
def call(self, *bufs):
|
|
c = MagicMock()
|
|
c.src = (MagicMock(op=Ops.PROGRAM),) + tuple(bufs)
|
|
return c
|
|
|
|
def test_supports_uop_normal_offset(self):
|
|
assert self.MetalGraph.supports_uop([self.dev], self.call(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True
|
|
|
|
def test_supports_uop_overflow_offset(self):
|
|
assert self.MetalGraph.supports_uop([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False
|
|
|
|
def test_supports_uop_nonmetal_buf(self):
|
|
# non-SLICE ops should not be checked for offset
|
|
buf = MagicMock()
|
|
buf.op = Ops.BUFFER
|
|
buf.device = Device.DEFAULT
|
|
self.MetalGraph.supports_uop([self.dev], self.call(buf))
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|