From ec3efd2919d2b757b16eea277cc5236caecdd343 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 18 Jul 2025 14:42:15 -0400 Subject: [PATCH] move upcast before reduce (#11250) * move upcast before reduce upcast goes to end of global+local+upcast * r_196_32_4_24_8 --- test/test_linearizer.py | 4 ++-- test/test_quantize_onnx.py | 2 +- tinygrad/opt/kernel.py | 9 +++++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 73c6ca0daa..a1002250bd 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1356,9 +1356,9 @@ class TestKernelOpts(unittest.TestCase): [Opt(OptOps.PADTO, 2, 8)], ]) with self.assertRaises(KernelOptError): - helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 2, 8)]]) + helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 1, 8)]]) with self.assertRaises(KernelOptError): - helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 2, 8)]]) + helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 1, 8)]]) with self.assertRaises(KernelOptError): helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]]) diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index ca9925eaf7..c3d7fa9baf 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -307,7 +307,7 @@ typedef signed char signed_char128 __attribute__((aligned(128),vector_size(128)) typedef unsigned char unsigned_char8 __attribute__((aligned(8),vector_size(8))); typedef unsigned char unsigned_char4 __attribute__((aligned(4),vector_size(4))); typedef unsigned char unsigned_char128 __attribute__((aligned(128),vector_size(128))); -__attribute__((noinline)) void r_196_24_8_32_4(unsigned char* restrict __attribute__((align_value(128))) data0, unsigned char* restrict __attribute__((align_value(128))) data1, signed char* restrict __attribute__((align_value( +__attribute__((noinline)) void r_196_32_4_24_8(unsigned char* restrict __attribute__((align_value(128))) data0, unsigned char* restrict __attribute__((align_value(128))) data1, signed char* restrict __attribute__((align_value( 128))) data2, int* restrict __attribute__((align_value(128))) data3) { int32 cast0 = (int32){0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; int32 val0 = *((int32*)((data3+0))); diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index bfcaf555a4..cde589afca 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -303,7 +303,7 @@ class Kernel: # NOTE: assume the first get_local_axes() LOCAL are for TC check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals") check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16") - self.shift_to(axis, amt, AxisType.UPCAST, insert_at=None) + self.shift_to(axis, amt, AxisType.UPCAST, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST))+1) elif opt.op is OptOps.NOLOCALS: check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals") check(AxisType.LOCAL not in self.axis_types and self.group_for_reduces == 0, "can't have no locals with locals") @@ -489,9 +489,14 @@ class Kernel: ret = ret.replace(arg = (op.arg[0], axes)) if self.group_for_reduces and grouped_axes: local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes]) + slocal, supcast, sgroup = sorted(self.axes_of(AxisType.LOCAL)), sorted(self.axes_of(AxisType.UPCAST)), sorted(grouped_axes) + # NOTE: start with UPCAST at the end so it has stride 1 and can merge + base_shape = tuple([s for i,s in enumerate(self.full_shape) if i in slocal] + [s for i,s in enumerate(self.full_shape) if i in sgroup] + \ + [s for i,s in enumerate(self.full_shape) if i in supcast]) + permute_axes = tuple([local_axes.index(i) for i in slocal+sgroup+supcast]) local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)]) local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)]) - st = ShapeTracker.from_shape(local_shape).expand(local_src_shape) + st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape) local_size = st.real_size() local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}") local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))