mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
remove _arg_int32 internal type (#2767)
in DEFINE_GLOBAL, PtrDtype(int32) is buffer and int32 is int
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Final, Callable, DefaultDict
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
|
||||
from tinygrad.helpers import DType, dtypes, ImageDType, DEBUG, getenv
|
||||
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
|
||||
from tinygrad.codegen.linearizer import UOp, UOps
|
||||
from triton.compiler import compile as triton_compile
|
||||
import linecache
|
||||
@@ -9,7 +9,7 @@ import math
|
||||
import re
|
||||
|
||||
triton_dtypes = {dtypes.double: "tl.float64", dtypes.float32: "tl.float32", dtypes.float16: "tl.float16", dtypes.bool: "tl.int1", dtypes.int8: "tl.int8", dtypes.uint8: "tl.uint8", dtypes.int32: "tl.int32", dtypes.int64: "tl.int64", dtypes.uint32: "tl.uint32", dtypes.uint64: "tl.uint64", dtypes.int16: "tl.int16", dtypes.uint16: "tl.uint16"}
|
||||
signature_dtypes = {dtypes.double: "*fp64",dtypes.float32: "*fp32", dtypes.float16: "*fp16", dtypes.bool: "*i8", dtypes.int8: "*i1", dtypes.uint8: "*u8", dtypes._arg_int32: "i32", dtypes.int32: "*i32", dtypes.int64: "*i64", dtypes.uint32: "*u32", dtypes.uint64: "*u64", dtypes.int16: "*i16", dtypes.uint16: "*u16"}
|
||||
signature_dtypes = {dtypes.double: "fp64",dtypes.float32: "fp32", dtypes.float16: "fp16", dtypes.bool: "i8", dtypes.int8: "i1", dtypes.uint8: "u8", dtypes.int32: "i32", dtypes.int64: "i64", dtypes.uint32: "u32", dtypes.uint64: "u64", dtypes.int16: "i16", dtypes.uint16: "u16"}
|
||||
|
||||
def next_power_of_2(x):
|
||||
return 1 << (x - 1).bit_length()
|
||||
@@ -98,7 +98,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
bufs.append(args)
|
||||
signatures.append(signature_dtypes[args[1]])
|
||||
signatures.append("*" if isinstance(args[1], PtrDType) else "" + signature_dtypes[args[1]])
|
||||
r[u] = args[0]
|
||||
elif uop == UOps.SPECIAL:
|
||||
dims.append(args[1])
|
||||
|
||||
@@ -236,8 +236,7 @@ class TestTypeSpec(unittest.TestCase):
|
||||
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
|
||||
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
|
||||
|
||||
# TODO: better way to write a set of core dtypes?
|
||||
core_types = [d for d in DTYPES_DICT.values() if d not in [dtypes._arg_int32]]
|
||||
core_types = list(DTYPES_DICT.values())
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(st.sampled_from(core_types))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
|
||||
@@ -46,18 +46,18 @@ class TestUOps(unittest.TestCase):
|
||||
def _equal(self, v1, v2):
|
||||
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) if v1.dtype != np.bool_ else self.assertEqual(v1, v2)
|
||||
|
||||
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
def _test_uop_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32)):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 0.0, 1.0]:
|
||||
self._equal(f([a], bop, dt), fxn(a))
|
||||
|
||||
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False):
|
||||
def _test_bop_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32), no_b_zero=False):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 0.0, 1.0]:
|
||||
for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
|
||||
self._equal(f([a,b], bop, dt), fxn(a,b))
|
||||
|
||||
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
def _test_top_fxn(self, bop, fxn, dt=PtrDType(dtypes.float32)):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 0, 1]:
|
||||
for b in [-3.0, 3.0]:
|
||||
@@ -88,15 +88,15 @@ class TestFloatUOps(TestUOps):
|
||||
# TODO: fix this on all the backends
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
|
||||
class TestNonFloatUOps(TestUOps):
|
||||
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, dtypes.int32)
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
|
||||
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), dtypes.int32)
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
|
||||
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
|
||||
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
|
||||
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, PtrDType(dtypes.int32))
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), PtrDType(dtypes.int32))
|
||||
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), PtrDType(dtypes.int32))
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), PtrDType(dtypes.int32))
|
||||
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), PtrDType(dtypes.int32), no_b_zero=True)
|
||||
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], PtrDType(dtypes.int32), no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), PtrDType(dtypes.int32))
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no bool storage buffer on webgpu")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), PtrDType(dtypes.bool))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -122,7 +122,7 @@ class TestSafetensors(unittest.TestCase):
|
||||
|
||||
def test_save_all_dtypes(self):
|
||||
for dtype in dtypes.fields().values():
|
||||
if dtype in [dtypes.bfloat16, dtypes._arg_int32]: continue # not supported in numpy
|
||||
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
|
||||
path = temp("ones.safetensors")
|
||||
ones = Tensor.rand((10,10), dtype=dtype)
|
||||
safe_save(get_state_dict(ones), path)
|
||||
|
||||
@@ -184,11 +184,11 @@ class Linearizer(Kernel):
|
||||
# add global buffers
|
||||
for i,buf in enumerate(self.bufs):
|
||||
if isinstance(buf, MemBuffer):
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) # noqa: E501
|
||||
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, dtype:=PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", dtype)) # noqa: E501
|
||||
# add var vals
|
||||
for var in vars_from_ast(self.ast):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes.int32))
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
|
||||
|
||||
@@ -174,9 +174,6 @@ class dtypes:
|
||||
# it has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
|
||||
bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
|
||||
|
||||
# NOTE: these are internal dtypes, should probably check for that
|
||||
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp, dtypes.float32)
|
||||
|
||||
@@ -3,7 +3,7 @@ import math, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens, getenv
|
||||
from tinygrad.helpers import ImageDType, dtypes, prod, DType, PtrDType, strip_parens, getenv
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
size_prefix: str = "int"
|
||||
@@ -83,8 +83,8 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
|
||||
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
self.arg_int_prefix if dtype == dtypes._arg_int32 else
|
||||
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
|
||||
("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for i,(name,dtype) in enumerate(bufs)]
|
||||
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
@@ -351,7 +351,7 @@ class WGSLLanguage(CStyleLanguage):
|
||||
local_size = local_size[::-1] if local_size else [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
||||
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var<uniform>' if dtype == dtypes._arg_int32 else 'var<storage,read_write>'} {name}: {'i32' if dtype == dtypes._arg_int32 else f'array<{self.type_map[dtype]}>'};" for name,dtype in bufs]) # noqa: E501
|
||||
prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'} {name}: {f'array<{self.type_map[dtype]}>' if isinstance(dtype, PtrDType) else 'i32'};" for name,dtype in bufs]) # noqa: E501
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501
|
||||
return prg
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
|
||||
from llvmlite import ir
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.helpers import DType, dtypes
|
||||
from tinygrad.helpers import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
@@ -33,7 +33,7 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
|
||||
dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(),
|
||||
dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64),
|
||||
dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16),
|
||||
dtypes.int32: ir.IntType(32), dtypes.int16:ir.IntType(16), dtypes.uint16:ir.IntType(16),
|
||||
dtypes.uint32:ir.IntType(32), dtypes.uint64:ir.IntType(64)}
|
||||
|
||||
def cast(bb, val, input_type, output_type, bitcast=False):
|
||||
@@ -84,7 +84,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
||||
|
||||
# create llvm function
|
||||
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if dt!=dtypes._arg_int32 else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
|
||||
@@ -99,7 +99,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
|
||||
lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
|
||||
|
||||
for bufname,dtype in buf_to_dtype.items():
|
||||
if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
|
||||
Reference in New Issue
Block a user