remove _arg_int32 internal type (#2767)

in DEFINE_GLOBAL, PtrDtype(int32) is buffer and int32 is int
This commit is contained in:
chenyu
2023-12-14 14:17:14 -05:00
committed by GitHub
parent 8a2a2257b4
commit 5235cdee3d
8 changed files with 26 additions and 30 deletions

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, Final, Callable, DefaultDict
from collections import defaultdict
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
from tinygrad.helpers import DType, dtypes, ImageDType, DEBUG, getenv
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
from tinygrad.codegen.linearizer import UOp, UOps
from triton.compiler import compile as triton_compile
import linecache
@@ -9,7 +9,7 @@ import math
import re
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"}
signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"}
def next_power_of_2(x):
return 1 << (x - 1).bit_length()
@@ -98,7 +98,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
signatures.append(signature_dtypes[args[1]])
signatures.append("*" if isinstance(args[1], PtrDType) else "" + signature_dtypes[args[1]])
r[u] = args[0]
elif uop == UOps.SPECIAL:
dims.append(args[1])

View File

@@ -236,8 +236,7 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
# TODO: better way to write a set of core dtypes?
core_types = [d for d in DTYPES_DICT.values() if d not in [dtypes._arg_int32]]
core_types = list(DTYPES_DICT.values())
class TestTypePromotion(unittest.TestCase):
@given(st.sampled_from(core_types))
def test_self_promo_to_self(self, dtype):

View File

@@ -46,18 +46,18 @@ class TestUOps(unittest.TestCase):
def _equal(self, v1, v2):
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2)
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
def _test_uop_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32)):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0.0, 1.0]:
self._equal(f([a], bop, dt), fxn(a))
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False):
def _test_bop_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32), no_b_zero=False):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0.0, 1.0]:
for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
self._equal(f([a,b], bop, dt), fxn(a,b))
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
def _test_top_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32)):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 0, 1]:
for b in [-3.0, 3.0]:
@@ -88,15 +88,15 @@ class TestFloatUOps(TestUOps):
# TODO: fix this on all the backends
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
class TestNonFloatUOps(TestUOps):
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, dtypes.int32)
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32)
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, PtrDType(dtypes.int32))
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), PtrDType(dtypes.int32))
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), PtrDType(dtypes.int32))
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), PtrDType(dtypes.int32))
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), PtrDType(dtypes.int32), no_b_zero=True)
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], PtrDType(dtypes.int32), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), PtrDType(dtypes.int32))
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no bool storage buffer on webgpu")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), PtrDType(dtypes.bool))
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -122,7 +122,7 @@ class TestSafetensors(unittest.TestCase):
def test_save_all_dtypes(self):
for dtype in dtypes.fields().values():
if dtype in [dtypes.bfloat16, dtypes._arg_int32]: continue # not supported in numpy
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
path = temp("ones.safetensors")
ones = Tensor.rand((10,10), dtype=dtype)
safe_save(get_state_dict(ones), path)

View File

@@ -184,11 +184,11 @@ class Linearizer(Kernel):
# add global buffers
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) # noqa: E501
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, dtype:=PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", dtype)) # noqa: E501
# add var vals
for var in vars_from_ast(self.ast):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes.int32))
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))

View File

@@ -174,9 +174,6 @@ class dtypes:
# it has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
# NOTE: these are internal dtypes, should probably check for that
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp, dtypes.float32)

View File

@@ -3,7 +3,7 @@ import math, functools
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens, getenv
from tinygrad.helpers import ImageDType, dtypes, prod, DType, PtrDType, strip_parens, getenv
class CStyleLanguage(NamedTuple):
size_prefix: str = "int"
@@ -83,8 +83,8 @@ class CStyleLanguage(NamedTuple):
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
self.arg_int_prefix if dtype == dtypes._arg_int32 else
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
self.arg_int_prefix if dtype == dtypes.int else None) for i,(name,dtype) in enumerate(bufs)]
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
@@ -351,7 +351,7 @@ class WGSLLanguage(CStyleLanguage):
local_size = local_size[::-1] if local_size else [1]
bind_it = iter(range(len(bufs)))
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var<uniform>' if dtype == dtypes._arg_int32 else 'var<storage,read_write>'} {name}: {'i32' if dtype == dtypes._arg_int32 else f'array<{self.type_map[dtype]}>'};" for name,dtype in bufs]) # noqa: E501
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'} {name}: {f'array<{self.type_map[dtype]}>' if isinstance(dtype, PtrDType) else 'i32'};" for name,dtype in bufs]) # noqa: E501
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501
return prg

View File

@@ -1,7 +1,7 @@
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
from llvmlite import ir
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import DType, dtypes
from tinygrad.helpers import DType, PtrDType, dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
@@ -33,7 +33,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(),
dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64),
dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16),
dtypes.int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16),
dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
def cast(bb, val, input_type, output_type, bitcast=False):
@@ -84,7 +84,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
@@ -99,7 +99,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
for bufname,dtype in buf_to_dtype.items():
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg