improved caching for pointer arithmetics in ptx (#3922)

* improved caching for pointer arithmetics

* Add test for pointer arithmetics caching

* Refactor test
This commit is contained in:
Szymon Ożóg
2024-04-04 16:33:48 +02:00
committed by GitHub
parent 68fe3527f1
commit ba118abfec
2 changed files with 42 additions and 12 deletions

View File

@@ -2,6 +2,7 @@ from typing import Optional, Tuple, Any, List
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device, CompiledASTRunner
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
@@ -232,5 +233,26 @@ class TestLocalAccess(unittest.TestCase):
sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs))
self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42)
@unittest.skipUnless(Device.DEFAULT in {"CUDA"} and getenv("PTX"), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_pointer_arithmetics_caching(self):
uops = UOpGraph()
u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True))
u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9))
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42)
u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL)
u5 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0)
u6 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=1)
u7 = uops.add(UOps.ALU, dtypes.int, (u4, u5), BinaryOps.ADD)
u8 = uops.add(UOps.ALU, dtypes.int, (u4, u6), BinaryOps.ADD)
u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u7))
u10 = uops.add(UOps.LOAD, dtypes.int, (u1, u8))
_uops_to_prg(uops)
self.assertEqual(u9.vin[0], u10.vin[0])
self.assertEqual(u9.vin[1].uop, UOps.CONST)
self.assertEqual(u9.vin[1].arg, u5.arg*dtypes.float.itemsize)
self.assertEqual(u10.vin[1].uop, UOps.CONST)
self.assertEqual(u10.vin[1].arg, u6.arg*dtypes.float.itemsize)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -13,6 +13,26 @@ def render_val(x, dtype):
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
def ptr_ar(root, uops):
assert root.arg in {'.shared', '.global', None}
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
if root.vin[1].uop is UOps.ALU and root.vin[1].arg in [BinaryOps.ADD, BinaryOps.SUB] and root.vin[1].vin[1].uop is UOps.CONST:
offset = uops.add(UOps.ALU, dtypes.int, (root.vin[1].vin[0], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
offset = uops.add(UOps.CAST, dtypes.uint64, (offset,), insert_before=uops.uops.index(root))
cache = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], offset), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1].vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
if root.vin[1].arg == BinaryOps.SUB: ptr = uops.add(UOps.ALU, dtypes.int, (ptr,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root))
root.vin = (cache, ptr) + root.vin[2:]
else:
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
else:
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
root.vin = (fptr, zero) + root.vin[2:]
class AssemblyLanguage(NamedTuple):
kernel_prefix: str = ""
barrier: str = ""
@@ -43,18 +63,6 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
kernel:List[str] = []
bufs = []
def ptr_ar(root, uops):
assert root.arg in {'.shared', '.global', None}
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
else:
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root))
root.vin = (fptr, zero) + root.vin[2:]
matcher = PatternMatcher([
({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),