mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
move upcast before reduce (#11250)
* move upcast before reduce upcast goes to end of global+local+upcast * r_196_32_4_24_8
This commit is contained in:
@@ -1356,9 +1356,9 @@ class TestKernelOpts(unittest.TestCase):
|
||||
[Opt(OptOps.PADTO, 2, 8)],
|
||||
])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 1, 8)]])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 2, 8)]])
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 1, 8)]])
|
||||
with self.assertRaises(KernelOptError):
|
||||
helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
|
||||
|
||||
|
||||
@@ -307,7 +307,7 @@ typedef signed char signed_char128 __attribute__((aligned(128),vector_size(128))
|
||||
typedef unsigned char unsigned_char8 __attribute__((aligned(8),vector_size(8)));
|
||||
typedef unsigned char unsigned_char4 __attribute__((aligned(4),vector_size(4)));
|
||||
typedef unsigned char unsigned_char128 __attribute__((aligned(128),vector_size(128)));
|
||||
__attribute__((noinline)) void r_196_24_8_32_4(unsigned char* restrict __attribute__((align_value(128))) data0, unsigned char* restrict __attribute__((align_value(128))) data1, signed char* restrict __attribute__((align_value(
|
||||
__attribute__((noinline)) void r_196_32_4_24_8(unsigned char* restrict __attribute__((align_value(128))) data0, unsigned char* restrict __attribute__((align_value(128))) data1, signed char* restrict __attribute__((align_value(
|
||||
128))) data2, int* restrict __attribute__((align_value(128))) data3) {
|
||||
int32 cast0 = (int32){0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
int32 val0 = *((int32*)((data3+0)));
|
||||
|
||||
@@ -303,7 +303,7 @@ class Kernel:
|
||||
# NOTE: assume the first get_local_axes() LOCAL are for TC
|
||||
check(not (self.tensor_core and axis in self.axes_of(AxisType.LOCAL)[:len(self.tensor_core.get_local_axes())]), "can't upcast TC locals")
|
||||
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
|
||||
self.shift_to(axis, amt, AxisType.UPCAST, insert_at=None)
|
||||
self.shift_to(axis, amt, AxisType.UPCAST, insert_at=max(self.axes_of(AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP, AxisType.UPCAST))+1)
|
||||
elif opt.op is OptOps.NOLOCALS:
|
||||
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
||||
check(AxisType.LOCAL not in self.axis_types and self.group_for_reduces == 0, "can't have no locals with locals")
|
||||
@@ -489,9 +489,14 @@ class Kernel:
|
||||
ret = ret.replace(arg = (op.arg[0], axes))
|
||||
if self.group_for_reduces and grouped_axes:
|
||||
local_axes = tuple([i for i,t in enumerate(self.axis_types) if t in (AxisType.LOCAL, AxisType.UPCAST) or i in grouped_axes])
|
||||
slocal, supcast, sgroup = sorted(self.axes_of(AxisType.LOCAL)), sorted(self.axes_of(AxisType.UPCAST)), sorted(grouped_axes)
|
||||
# NOTE: start with UPCAST at the end so it has stride 1 and can merge
|
||||
base_shape = tuple([s for i,s in enumerate(self.full_shape) if i in slocal] + [s for i,s in enumerate(self.full_shape) if i in sgroup] + \
|
||||
[s for i,s in enumerate(self.full_shape) if i in supcast])
|
||||
permute_axes = tuple([local_axes.index(i) for i in slocal+sgroup+supcast])
|
||||
local_shape = tuple([s if i in local_axes else 1 for i,s in enumerate(self.full_shape)])
|
||||
local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
|
||||
st = ShapeTracker.from_shape(local_shape).expand(local_src_shape)
|
||||
st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape)
|
||||
local_size = st.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
||||
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
|
||||
|
||||
Reference in New Issue
Block a user