mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 09:33:03 +08:00
use supportsFamily API to detect GPU generation instead of parsing ICB debug description strings. also adds metal4.0 compiler target.
67 lines
2.4 KiB
Python
67 lines
2.4 KiB
Python
import unittest
|
|
from tinygrad.device import CompileError, Device, BufferSpec
|
|
if Device.DEFAULT=="METAL":
|
|
from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
|
|
@unittest.skipIf(Device.DEFAULT!="METAL", "Metal support required")
|
|
class TestMetal(unittest.TestCase):
|
|
def test_alloc_oom(self):
|
|
device = MetalDevice("metal")
|
|
with self.assertRaises(MemoryError):
|
|
device.allocator.alloc(10000000000000000000)
|
|
|
|
def test_compile_error(self):
|
|
compiler = MetalCompiler()
|
|
with self.assertRaises(CompileError):
|
|
compiler.compile("this is not valid metal")
|
|
|
|
def test_compile_success(self):
|
|
compiler = MetalCompiler()
|
|
ret = compiler.compile("""
|
|
#include <metal_stdlib>
|
|
using namespace metal;
|
|
kernel void E_4n1(device int* data0, const device int* data1, const device int* data2,
|
|
uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
|
|
int val0 = *(data1+0);
|
|
int val1 = *(data1+1);
|
|
int val2 = *(data1+2);
|
|
int val3 = *(data1+3);
|
|
int val4 = *(data2+0);
|
|
int val5 = *(data2+1);
|
|
int val6 = *(data2+2);
|
|
int val7 = *(data2+3);
|
|
*(data0+0) = (val0+val4);
|
|
*(data0+1) = (val1+val5);
|
|
*(data0+2) = (val2+val6);
|
|
*(data0+3) = (val3+val7);
|
|
}
|
|
""")
|
|
assert ret is not None
|
|
|
|
def test_failed_newLibraryWithData(self):
|
|
device = MetalDevice("metal")
|
|
compiler = MetalCompiler()
|
|
compiled = compiler.compile("""
|
|
#include <metal_stdlib>
|
|
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){
|
|
data0[0] = 0;
|
|
}
|
|
""")
|
|
with self.assertRaises(RuntimeError):
|
|
compiled = compiled[:40] # corrupt the compiled program
|
|
MetalProgram(device, "r_5", compiled)
|
|
|
|
def test_free(self):
|
|
size = 2**16
|
|
device = Device['METAL']
|
|
before = device.sysdevice.currentAllocatedSize()
|
|
|
|
buf = device.allocator.alloc(size, BufferSpec(nolru=True))
|
|
self.assertEqual(curr:=device.sysdevice.currentAllocatedSize(), before+size, msg=f"{curr=} - {before=}")
|
|
device.allocator.free(buf, buf.size, BufferSpec(nolru=True))
|
|
self.assertEqual(curr:=device.sysdevice.currentAllocatedSize(), before, msg=f"{curr=} - {before=}")
|
|
|
|
def test_gpu_family(self):
|
|
device = Device['METAL']
|
|
self.assertGreater(device.gpu_family, 0)
|
|
self.assertLessEqual(device.gpu_family, 15)
|