mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
improved caching for pointer arithmetics in ptx (#3922)
* improved caching for pointer arithmetics * Add test for pointer arithmetics caching * Refactor test
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Optional, Tuple, Any, List
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.device import Buffer, Device, CompiledASTRunner
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
@@ -232,5 +233,26 @@ class TestLocalAccess(unittest.TestCase):
|
||||
sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs))
|
||||
self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CUDA"} and getenv("PTX"), "This only tests assembly backends")
|
||||
class TestAssembly(unittest.TestCase):
|
||||
def test_pointer_arithmetics_caching(self):
|
||||
uops = UOpGraph()
|
||||
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
|
||||
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
|
||||
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42)
|
||||
u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL)
|
||||
u5 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0)
|
||||
u6 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=1)
|
||||
u7 = uops.add(UOps.ALU, dtypes.int, (u4, u5), BinaryOps.ADD)
|
||||
u8 = uops.add(UOps.ALU, dtypes.int, (u4, u6), BinaryOps.ADD)
|
||||
u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u7))
|
||||
u10 = uops.add(UOps.LOAD, dtypes.int, (u1, u8))
|
||||
_uops_to_prg(uops)
|
||||
self.assertEqual(u9.vin[0], u10.vin[0])
|
||||
self.assertEqual(u9.vin[1].uop, UOps.CONST)
|
||||
self.assertEqual(u9.vin[1].arg, u5.arg*dtypes.float.itemsize)
|
||||
self.assertEqual(u10.vin[1].uop, UOps.CONST)
|
||||
self.assertEqual(u10.vin[1].arg, u6.arg*dtypes.float.itemsize)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -13,6 +13,26 @@ def render_val(x, dtype):
|
||||
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||
|
||||
def ptr_ar(root, uops):
|
||||
assert root.arg in {'.shared', '.global', None}
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
|
||||
if root.vin[1].uop is UOps.ALU and root.vin[1].arg in [BinaryOps.ADD, BinaryOps.SUB] and root.vin[1].vin[1].uop is UOps.CONST:
|
||||
offset = uops.add(UOps.ALU, dtypes.int, (root.vin[1].vin[0], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
|
||||
offset = uops.add(UOps.CAST, dtypes.uint64, (offset,), insert_before=uops.uops.index(root))
|
||||
cache = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], offset), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
||||
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1].vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
|
||||
if root.vin[1].arg == BinaryOps.SUB: ptr = uops.add(UOps.ALU, dtypes.int, (ptr,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root))
|
||||
root.vin = (cache, ptr) + root.vin[2:]
|
||||
else:
|
||||
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
|
||||
if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
|
||||
else:
|
||||
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
|
||||
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
|
||||
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
||||
root.vin = (fptr, zero) + root.vin[2:]
|
||||
|
||||
class AssemblyLanguage(NamedTuple):
|
||||
kernel_prefix: str = ""
|
||||
barrier: str = ""
|
||||
@@ -43,18 +63,6 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
def ptr_ar(root, uops):
|
||||
assert root.arg in {'.shared', '.global', None}
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
|
||||
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
|
||||
if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
|
||||
else:
|
||||
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
|
||||
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
|
||||
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
|
||||
root.vin = (fptr, zero) + root.vin[2:]
|
||||
|
||||
matcher = PatternMatcher([
|
||||
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
|
||||
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
|
||||
|
||||
Reference in New Issue
Block a user