Files
tinygrad/test/backend/test_linearizer.py
2026-06-07 12:07:00 -07:00

460 lines
24 KiB
Python

import numpy as np
import unittest
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, buffers
from tinygrad.device import Device, Buffer
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_linear
from tinygrad.codegen import to_program
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, DEV
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.renderer.isa import ISARenderer
from test.helpers import replace_opts
MOCKGPU = DEV.interface.startswith("MOCK")
from tinygrad.uop.render import print_uops # noqa: F401 # pylint: disable=unused-import
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, ISARenderer), "isa backends don't preserve the op spec when lowering")
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
# without contiguous folding, rand.to("CPU") and rand.contiguous().to("CPU") are different UOps.
# this test asserts they are the identical Buffer
# having different buffers is fine for correctness, because the outputs match.
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
np_a, np_b = a.numpy(), b.numpy()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
linear = c.schedule_linear()
run_linear(linear)
rawbufs = [s.buffer for s in linear.src[-1].src[1:] if s.op is not Ops.BIND]
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
def test_load_removed(self):
a = Tensor.rand(1).realize()
b = Tensor.rand(1).realize()
ta = Tensor.where(Tensor(True), a, b).numpy()
tb = Tensor.where(Tensor(False), a, b).numpy()
np.testing.assert_equal(a.numpy(), ta)
np.testing.assert_equal(b.numpy(), tb)
@unittest.skip("TODO: some backends insert more casts")
def test_cast_there_and_back(self):
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
out = tst.neg().cast(dtypes.char).cast(dtypes.int).cast(dtypes.char) * 2
ast = helper_linearizer_opt(out)
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
@unittest.expectedFailure
def test_cast_back_and_there(self):
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
out = tst.neg().cast(dtypes.char).cast(dtypes.int) * 2
ast = helper_linearizer_opt(out)
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx")
def test_late_bias_load(self):
img = Tensor.empty(1, 3, 16, 16)
w = Tensor.empty(16, 3, 3, 3)
b = Tensor.empty(16)
out = img.conv2d(w, b)
ast = helper_linearizer_opt(out)
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
# slice at the last loop end
uslice = [i for i,u in enumerate(uops) if u.op == Ops.END][-1]
# only valid test if outermost range is the reduce
if uops[uslice].src[-1].arg[-1] == AxisType.REDUCE:
load_idxs = [u.src[0] for u in uops[uslice+1:] if u.op == Ops.LOAD]
# assert that there is a global load after the reduce ends
assert any(u.addrspace == AddrSpace.GLOBAL for u in load_idxs)
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.END and u.src[0] in range_in_acc)]
for i,u in enumerate(ranges):
if skip and i in skip: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
def test_two_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum()
ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
def test_three_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum()
ast = helper_linearizer_opt(out, wanna_output=[24])
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
# RANGE -> ALU -> RANGE -> ALU + LOAD -> STORE
assert any(x.op in GroupOp.ALU for x in uops[ranges[0]:ranges[1]])
# the index of the load doesnt depend on the second range
assert any(x.op is Ops.LOAD for x in uops[ranges[0]:ranges[1]])
assert any(x.op in {*GroupOp.ALU, Ops.LOAD} for x in uops[ranges[1]:])
def test_range_outer_op_before_phi(self):
a = Tensor.randn(4, 1).realize()
b = Tensor.randn(1, 1).realize()
out = (a + b[0]).sum() + b[0]
ast = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
# LOAD -> RANGE -> LOAD -> STORE
assert len([x for x in uops[:ranges[0]] if x.op is Ops.LOAD]) == 1
def test_range_outer_op_before_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
b = Tensor.randn(1, 1).realize()
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
ast = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.
a = Tensor.randn(4).realize()
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
uops = tuple(to_program(replace_opts(r.schedule_linear().src[-1].src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=0)]),
renderer=Device[Device.DEFAULT].renderer).src[2].src)
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
@unittest.skip("this is handled at higher level now")
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
uops = tuple(to_program(replace_opts(r.schedule_linear().src[-1].src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=0)]),
renderer=Device[Device.DEFAULT].renderer).src[2].src)
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
assert num_ops <= 1, "more alu uops than needed"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_reduce_upcast(self):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
uops = tuple(to_program(replace_opts(r.schedule_linear().src[-1].src[0],
[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]), renderer=Device[Device.DEFAULT].renderer).src[2].src)
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
stores = [u for u in uops if u.op is Ops.STORE]
assert len(accs) == 0 # it's removed now
assert len(stores) == 1
# NOTE: can reenable, it does work. it just makes BEAM slow
@unittest.expectedFailure
@unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU")
def test_upcast_with_locals_cpu(self):
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
prg = to_program(replace_opts(out.schedule_linear().src[-1].src[0], [Opt(OptOps.LOCAL, axis=0, arg=4)]),
renderer=Device[Device.DEFAULT].renderer)
self.assertEqual(len(prg.src[3].arg.split("for")), 5)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason")
def test_upcast_with_locals(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
program = to_program(replace_opts(r.schedule_linear().src[-1].src[0], opts_to_apply), renderer=Device[Device.DEFAULT].renderer)
stores = [u for u in tuple(program.src[2].src) if u.op is Ops.STORE and u.src[0].addrspace != AddrSpace.REG]
# the first store is to lds and can be upcasted
assert stores[0].src[1].max_numel() == 4
assert any(x.addrspace is AddrSpace.LOCAL for x in stores[0].toposort())
# the second store is to gds with no upcasts
assert stores[1].src[1].max_numel() == 1
assert stores[1].src[1].dtype == dtypes.float
assert any(x.op is Ops.PARAM for x in stores[1].toposort())
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)
uops = tuple(to_program(replace_opts(r.schedule_linear().src[-1].src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=0)]),
renderer=Device[Device.DEFAULT].renderer).src[2].src)
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
assert num_ops == 0, "more alu uops than needed"
def test_sum_acc_dtype(self):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
if tensor_dtype in (dts:=Device[Device.DEFAULT].renderer.supported_dtypes()) and acc_dtype in dts:
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
realized_ast = a.schedule_linear().src[-1].src[0]
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
assert local[0].dtype.base == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
realized_ast = c.schedule_linear().src[-1].src[0]
program = to_program(replace_opts(realized_ast, []), renderer=Device[Device.DEFAULT].renderer)
local = [uop for uop in tuple(program.src[2].src) if uop.op in (Ops.BUFFER, Ops.DEFINE_REG)]
self.assertEqual(local[0].dtype.base, expected_dtype)
tests = (
(dtypes.float16, None, dtypes.float),
(dtypes.bfloat16, None, dtypes.float),
(dtypes.float, None, dtypes.float),
(dtypes.float16, dtypes.float16, dtypes.float16),
(dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
(dtypes.float, dtypes.float16, dtypes.float16),
)
for tensor_dtype, acc_dtype, expected_dtype in tests:
if tensor_dtype in (dts:=Device[Device.DEFAULT].renderer.supported_dtypes()) and acc_dtype in dts and expected_dtype in dts:
a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
helper_arg_acc_dtype(a.sum(dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(a.matmul(b, dtype=acc_dtype), expected_dtype)
helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, dtype=acc_dtype), expected_dtype)
d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
helper_arg_acc_dtype(d.conv2d(w, dtype=acc_dtype), expected_dtype)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_simple_unroll_no_between_phi_dependencies(self):
x, y = Tensor.empty(64, 64), Tensor.empty(64, 64)
r = (x@y).relu()
opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]
ast = helper_linearizer_opt(r, [opt])
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype)
for u in uops:
if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace is AddrSpace.REG:
if uops.index(u) < begin_range:
assert u.src[1].op is Ops.CONST
else:
assert u.src[1].op in GroupOp.ALU
assert begin_range < uops.index(u) < end_range
# children of END are placed after ENDRANGE
if any(x.op is Ops.END and x.src[1].op in GroupOp.ALU for x in u.src):
assert end_range < uops.index(u)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
def test_default_global_reversed(self):
# shrink so that the dims do not collapse
t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6)))
ast = helper_linearizer_opt(t+1)
uops = tuple(to_program(replace_opts(ast, []), renderer=Device[Device.DEFAULT].renderer).src[2].src)
idxs = dedup([uop for uop in uops if uop.op is Ops.SPECIAL])
idxs = sorted(idxs, key=lambda uop: uop.arg)
assert (idxs[0].arg, idxs[0].src[0].arg) == ('gidx0', 6), idxs[0]
assert (idxs[1].arg, idxs[1].src[0].arg) == ('gidx1', 5), idxs[1].arg
assert (idxs[2].arg, idxs[2].src[0].arg) == ('gidx2', 4), idxs[2].arg
def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
sched = [si for si in t.schedule_linear().src if si.src[0].op is Ops.SINK]
# sum_collapse is a full collapse now
assert len(sched) == 1
assert not any(u.op is Ops.REDUCE and len(u.arg[1]) > 0 for u in sched[0].src[0].toposort()), "found reduce in sum collapse"
#lin = Kernel(sched[0].ast)
#assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(a+m)
a.realize()
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
@unittest.skipIf(MOCKGPU and isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, CUDARenderer)), "PTX indexes differently. might be ok?")
def test_where_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(b.where(2, a))
linear, var_vals = a.linear_with_vars()
assert len(linear.src) == 1
run_linear(linear, var_vals)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
program = to_program(replace_opts(linear.src[-1].src[0], []), renderer=Device[Device.DEFAULT].renderer)
assert not any(u.op == Ops.WHERE for u in tuple(program.src[2].src)), "found where where where should be folded"
def test_phi_simplification(self):
def helper(t, max_ops=0):
ast = helper_linearizer_opt(t)
uops = tuple(to_program(ast, renderer=Device[Device.DEFAULT].renderer).src[2].src)
# ignore kernel optimized IF statements for now
if if_op:=next((u for u in uops if u.op is Ops.IF), None):
uops = uops[:uops.index(if_op)]
assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both"
reg_stores = [u for u in uops if u.op is Ops.STORE and isinstance(dt:=u.src[0].dtype, PtrDType) and dt.addrspace == AddrSpace.REG]
assert len(reg_stores) == 0, "STORE to reg should have been simplified"
assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops"
helper(Tensor.arange(5.5, (3.5*300), 3.5).clone(), max_ops=2)
helper(Tensor.arange(-1, -100, -5).clone(), max_ops=2)
# NOTE: both of these split the reduce (this just wasn't tracked before)
#helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
#helper(Tensor.arange(256), max_ops=2)
helper(Tensor.arange(255).clone(), max_ops=2)
@unittest.skip("test implicitly depends on certain optimizations")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason")
def test_grouped_store_phis(self):
"""
float4 acc0 = float4(0.0,0.0,0.0,0.0);
{
acc0 = // ...
}
*((device float4*)(data0+alu2)) = float4(acc0.x,acc0.y,acc0.z,acc0.w);
simplifies to:
*((device float4*)(data0+alu2)) = acc0;
"""
x, y = Tensor.empty(64,64), Tensor.empty(64,64)
out = x.matmul(y)
with Context(TC=0):
ast = helper_linearizer_opt(out)
uops = tuple(to_program(ast, renderer=Device[Device.DEFAULT].renderer).src[2].src)
# check that the float4 cast collapses
store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) # and val.op is not Ops.VECTORIZE
@unittest.skip("test implicitly depends on certain optimizations")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_values(self):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
ast = helper_linearizer_opt(out)
store_val = [u.src[1] for u in tuple(to_program(ast, renderer=Device[Device.DEFAULT].renderer).src[2].src) if u.op is Ops.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.STACK
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_locals_and_globals(self):
x, y = Tensor.empty(64, 64), Tensor.empty(64, 64)
out = x@y
opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
ast = helper_linearizer_opt(out, opts=[opt])
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
uops = tuple(to_program(replace_opts(ast, opt), renderer=Device[Device.DEFAULT].renderer).src[2].src)
local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.PARAM for x in get_recursive(u.src[0]))]
barrier = [u for u in uops if u.op is Ops.BARRIER]
assert len(barrier) == 1
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[1].max_numel() > 1 # and store.src[2].op is not Ops.VECTORIZE
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
assert len([u for u in uops if u.op is Ops.IF])
@unittest.skip("test implicitly depends on certain optimizations")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx for some reason")
def test_grouped_store_local_only(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
ast = helper_linearizer_opt(r)
uops = tuple(to_program(ast, renderer=Device[Device.DEFAULT].renderer).src[2].src)
stores = [u for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
# the float4 value stores directly in lds and we skip upcast
self.assertEqual(stores[0].src[1].dtype, dtypes.float.vec(4))
#assert stores[0].src[-1].op is not Ops.VECTORIZE
# the global store doesn't change
assert stores[1].src[1].dtype == dtypes.float
# *** helpers ***
def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
if isinstance(r, Tensor): r = [r]
linear, var_vals = Tensor.linear_with_vars(*r)
run_linear(UOp(Ops.LINEAR, src=linear.src[:-1]), var_vals) # run all kernels except the last one
last_call = linear.src[-1]
ast = last_call.src[0]
assert ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {last_call}"
last_bufs = [s.buffer for s in last_call.src[1:] if s.op is not Ops.BIND]
# now all input buffers in last_call should be realized
# create fresh buffers for the outputs
bufs = [Buffer(x.device, x.size, x.dtype).allocate() if i < len(ast.src) else x for i,x in enumerate(last_bufs)]
# ensure buffers are allocated
for b in bufs: b.ensure_allocated()
return ast, bufs
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
assert isinstance(ast, UOp), "ast must be UOp"
inbufs = [x.uop.base.buffer for x in inputs]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.size, out.src[1].dtype).allocate() for out in ast.src]
_helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
def helper_linearizer_opt(r:Tensor|list[Tensor], *args, **kwargs):
realized_ast, real_bufs = helper_realized_ast(r)
_helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs)
return realized_ast
def copyout_outputs(outbufs:list[Buffer]) -> list[np.ndarray]:
return [np.frombuffer(x.as_memoryview(), _to_np_dtype(x.dtype)) for x in outbufs]
def reset_bufs(bufs:list[Buffer]):
for buf in bufs: buf.copyin(np.zeros((buf.size*buf.dtype.itemsize,), dtype=np.uint8).data)
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[],
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]):
outbufs = real_bufs[:len(realized_ast.src)]
wanna_output = [np.array(x).flatten() for x in wanna_output]
buf_uops = [UOp.new_buffer(b.device, b.size, b.dtype) for b in real_bufs]
for u,b in zip(buf_uops, real_bufs): buffers[u] = b
def run_prg(opts):
ast = realized_ast if opts is None else replace_opts(realized_ast, list(opts))
run_linear(UOp(Ops.LINEAR, src=(ast.call(*buf_uops),)))
def check_opt(opts):
reset_bufs(outbufs)
run_prg(opts)
for x,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(x, want, atol=atol, rtol=rtol)
# Get baseline if it is not provided, which is not optimized at all.
run_prg(opts=())
if len(wanna_output) == 0: wanna_output = copyout_outputs(outbufs)
else:
for buf,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol)
# Check correctness of handcoded optimiztions.
reset_bufs(outbufs)
run_prg(opts=None)
for buf,want in zip(copyout_outputs(outbufs), wanna_output): np.testing.assert_allclose(buf, want, atol=atol, rtol=rtol)
for x in opts: # Check custom transformations if any.
check_opt(([Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, 1))] if apply_tc else [])+x)
if __name__ == '__main__':
unittest.main()