Files
tinygrad/test/null/test_simplify_valid_idx.py
George Hotz 1e7f1dcf49 add ParamArgs [pr] (#16421)
* add ParamArgs

* fix export

* cleanups

* fixes

* simpler
2026-05-28 19:17:17 -07:00

601 lines
24 KiB
Python

import unittest, itertools
from tinygrad.codegen.late.devectorizer import load_store_indexing
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops, graph_rewrite
from tinygrad.uop.symbolic import simplify_valid, sym, pm_move_where_on_load
from tinygrad.helpers import Context
from test.helpers import full_rewrite
from test.null.test_uop_symbolic import check_uop_against_string
# symbolic-only idx + valid simplification (no late lowering of FLOORDIV/FLOORMOD)
def simplify_valid_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move_where_on_load, name="simplify_valid_idx")
# image-aware idx + valid simplification: adds the codegen-layer matcher that drops provably in-bounds gates
def simplify_image_idx(sink: UOp) -> UOp: return graph_rewrite(sink, sym+pm_move_where_on_load+load_store_indexing, name="simplify_image_idx")
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(Ops.LOAD, dtypes.float, (
UOp.param(0, dtypes.float.ptr()).index(idx.valid(valid), ptr=True),
UOp.const(dtypes.float, 0.0)
))
def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UOp]):
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp.param(0, dtypes.imagef(image_shape)).index(idx[1].valid(valid), idx[0].valid(valid), ptr=True),
UOp(Ops.STACK, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0.0),) * 4)
))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.weakint, (UOp.const(dtypes.weakint, nmax),), expr)
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp.range(nmax, n)
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# (ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1)))
f0 = ((idx1*24)+(ridx2*3)+ridx0+765)%768
f1 = ridx0+(idx1*48)+(ridx2*6)+(-6)
f2 = (idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)
f3 = (idx2*2)+ridx1+(-1)
self.assertFalse(f0.is_increasing())
self.assertTrue(f1.is_increasing())
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp.range(5, 2)
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid, extra=()):
load = simplify_valid_idx(UOp.sink(load, *extra)).src[0]
off = load.src[0].src[1]
check_uop_against_string(self, off.get_idx(), sidx)
check_uop_against_string(self, off.get_valid(), svalid)
def test_cumsum(self):
gidx0 = Special("gidx0", 5)
lidx0 = Special("lidx0", 4)
gate = (gidx0*4+lidx0<19).ne(True)
idx = gidx0*4+lidx0-19
load = get_gated_load_uop(gate, idx)
self.check(load,
"0",
"(((lidx0+(gidx0*4))<19)!=True)")
def test_simplify_within_valid1(self):
ridx0 = Range(0, 4)
ridx1 = Range(1, 4)
ridx2 = Range(2, 4)
ridx3 = Range(3, 4)
valid = ((ridx0*3+ridx1)<8) & ((((ridx0*3+ridx1)//8+ridx2*3+ridx3)%4)<2)
idx = ridx0+ridx1+ridx2+ridx3
load = get_gated_load_uop(valid, idx)
self.check(load,
"(((r0+r1)+r2)+r3)",
"((((r0*3)+r1)<8)&((((r2*3)+r3)%4)<2))")
def test_simplify_within_valid2(self):
gidx0 = Special("gidx0", 56)
ridx0 = Range(0, 3)
alu0 = gidx0+ridx0
valid = (alu0 < 57) & (alu0 >= 1)
self.assertIsNone(simplify_valid(valid))
def test_valid_order_matters1(self):
ridx0 = Range(0, 2)
v0 = ridx0<1
v1 = ((ridx0*5+1)%6)<5
self.assertEqual(simplify_valid(v0&v1).render(), "(r0<1)")
self.assertEqual(simplify_valid(v1&v0).render(), "(r0<1)")
def test_valid_order_matters2(self):
gidx0 = Special("gidx0", 13)
gidx1 = Special("gidx1", 13)
ridx0 = Range(0, 4)
alu0 = (gidx1+(ridx0*13))
v0 = (gidx0+11)%14<11
v1 = (alu0+((gidx0+39)//42))%14<11
v2 = gidx0<3
v3 = alu0<42
for v in itertools.permutations([v0,v1,v2,v3]):
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
def test_simplify_valid_from_div(self):
x = Variable("x", -100, 100)
valid = ((x<0)&((100%x).cast(dtypes.bool)))
# NOTE: this simplifies the (100%x) part somehow, still has two clauses
self.assertIsNotNone(simplify_valid(valid))
self.assertEqual(len(list(valid.split_uop(Ops.AND))), 2)
@unittest.expectedFailure # TODO: fix
def test_from_merge_views(self):
# taken from test_merges_from_fuzzer1
# generated by
# v0 = View(shape=(2, 4), strides=(2, 1), offset=-2, mask=((0, 2), (2, 4)), contiguous=False)
# v1 = View(shape=(2, 4, 2, 2), strides=(4, 0, -2, -1), offset=3, mask=None, contiguous=False)
# s = ShapeTracker((v0, v1))
# idx, valid = s.to_indexed_uops()
# print(f"{idx.render()=}")
# print(f"{valid.render()=}")
# s = ShapeTracker((View(shape=(2, 4, 2, 2), strides=(2, 0, 0, -1), offset=1, mask=((0, 2), (0, 4), (0, 1), (0, 2)), contiguous=False),))
# idx, valid = s.to_indexed_uops()
# print(f"{idx.render()=}")
# print(f"{valid.render()=}")
ridx0 = Range(0, 2)
ridx2 = Range(2, 2)
ridx3 = Range(3, 2)
idx = (((ridx0*2)+((((ridx2*2)+(ridx3*3))+3)%4))+-2)
valid = ((((((ridx2*2)+(ridx3*3))+3)%4)<2)!=True) # noqa: E712
load = get_gated_load_uop(valid, idx)
self.check(load,
"(((r0*2)+(r3*-1))+1)",
"(r2<1)")
def test_load_in_valid(self):
# from FUSE_ARANGE=1 python test/test_ops.py TestOps.test_scatter_add
# can lead to OOB
ridx2 = Range(2, 4)
lidx0 = Special("lidx0", 3)
gidx0 = Special("gidx0", 2)
idx=(((lidx0+(gidx0*3))+(ridx2*5))+40)
valid = (lidx0+(gidx0*3)) < 5
val7 = get_gated_load_uop(valid, idx)
valid2 = valid & val7.cast(dtypes.bool).logical_not()
self.assertIsNone(simplify_valid(valid2))
def test_valid_becomes_const1(self):
# from DSP mobilenetv2
ridx0 = Range(0, 30)
ridx1 = Range(1, 7)
ridx2 = Range(2, 2)
alu11 = (ridx1+ridx2)
alu15 = ((alu11+1)//7)
idx = (alu15*-31)+(((((alu11+218)//224)+ridx0)%30)*1568)
valid = (ridx2<1)&(ridx1<6)
load = get_gated_load_uop(valid, idx)
# prevent ridx1 and ridx2 from being shrunk
red = load.reduce(ridx1, ridx2, arg=Ops.ADD)
self.check(load,
"(r0*1568)",
"((r2<1)&(r1<6))",
extra=(red,))
def test_valid_becomes_const1_z3(self):
from z3 import Ints, Solver, And, If, Not, unsat
ridx0, ridx1, ridx2, alu11, alu15 = Ints('ridx0 ridx1 ridx2 alu11 alu15')
alu11 = (ridx1+ridx2)
alu15 = ((alu11+1)/7)
idx = (alu15*-31)+(((((alu11+218)/224)+ridx0)%30)*1568)
valid = (ridx2<1)&(ridx1<6)
load = If(valid, idx, 0)
# correct simplification
s = Solver()
s.add(And(0<=ridx0, ridx0<30, 0<=ridx1, ridx1<7, 0<=ridx2, ridx2<2))
simplifed_idx = (ridx0*1568)
simplifed_load = If(valid, simplifed_idx, 0)
s.add(Not(load == simplifed_load)) # Check if they are NOT equivalent
assert s.check() == unsat, f"The expressions are not equivalent. {s.model()=}"
# new solver for a wrong simplified expression
s = Solver()
s.add(And(0<=ridx0, ridx0<30, 0<=ridx1, ridx1<7, 0<=ridx2, ridx2<2))
wrong_simplifed_idx = (ridx0*1567)+ridx1
wrong_simplifed_load = If(valid, wrong_simplifed_idx, 0)
s.add(Not(load == wrong_simplifed_load)) # Check if they are NOT equivalent
assert s.check() != unsat, "The expressions are equivalent??"
print("The expressions are not equivalent.")
print(s.model())
def test_valid_becomes_const2(self):
ridx0 = Range(0, 4)
ridx1 = Range(1, 4)
ridx2 = Range(2, 4)
ridx3 = Range(3, 4)
# TODO: this should also work without the extra nesting
idx = (((ridx0+ridx1)+(ridx2+ridx3)+28)//30)
valid = ((ridx0+ridx1)<1).ne(True) & ((ridx2+ridx3)<1).ne(True)
load = get_gated_load_uop(valid, idx)
self.check(load,
"1",
"((((r0+r1)<1)!=True)&(((r2+r3)<1)!=True))")
def test_valid_with_non_const_rhs(self):
ridx0 = Range(0, 1024)
ridx1 = Range(1, 4)
ridx2 = Range(2, 4)
valid = (ridx0<(ridx1*4 + ridx2))&(ridx0<-1).ne(True)
idx = ridx0
load = get_gated_load_uop(valid, idx)
self.check(load,
"r0",
"(r0<((r1*4)+r2))")
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = simplify_image_idx(load.sink()).src[0]
off = load.src[0]
self.assertEqual(len(off.src), 3)
idx0, idx1 = off.src[2].get_idx(), off.src[1].get_idx()
check_uop_against_string(self, idx0, sidx0)
check_uop_against_string(self, idx1, sidx1)
self.assertEqual(off.src[1].get_valid(), off.src[2].get_valid())
if svalid is not None:
check_uop_against_string(self, off.src[1].get_valid(), svalid)
else:
self.assertEqual(off.src[1].get_valid(), UOp.const(dtypes.bool, True), "svalid is None but valid is not True")
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
# (idx1 < c+1).ne(True) -> idx > c
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (10, 10, 4)
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-1))
self.check(load, None, "gidx0", "(gidx1+-1)")
load = get_load_image_uop(shape, (gidx1<1).ne(True), (gidx0, gidx1-2))
self.check(load, None, "gidx0", "(gidx1+-2)")
# should match any one of the AND clause and drop the matched statement from valid
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
load = get_load_image_uop(shape, valid, (gidx0+1, gidx1-1))
self.check(load, "((gidx0<1)!=True)", "(gidx0+1)", "(gidx1+-1)")
valid = (gidx0<1).ne(True) & (gidx1<1).ne(True)
load = get_load_image_uop(shape, valid, (gidx0, gidx1-1))
self.check(load, "((gidx0<1)!=True)", "gidx0", "(gidx1+-1)")
def test_idx_lt_bound(self):
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
load = get_load_image_uop((10, 10, 4), gidx1<10, (gidx0, gidx1))
self.check(load, None, "gidx0", "gidx1")
# same thing, valid has a div
load = get_load_image_uop((10, 10, 4), gidx1//2<5, (gidx0, gidx1))
self.check(load, None, "gidx0", "gidx1")
# 10x20 image, not out of bound
load = get_load_image_uop((20, 10, 4), gidx1<10, (gidx0, gidx1))
self.check(load, "(gidx1<10)", "gidx0", "gidx1")
def test_generic_idx_lt_bound(self):
# (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (10, 10, 4)
load = get_load_image_uop(shape, (gidx1<8), (gidx0, gidx1+2))
self.check(load, None, "gidx0", "(gidx1+2)")
load = get_load_image_uop(shape, (gidx1<5), (gidx0, gidx1+5))
self.check(load, None, "gidx0", "(gidx1+5)")
def test_valid_empty_set(self):
gidx0 = Special("gidx0", 32)
gidx1 = Special("gidx1", 32)
shape = (32, 32, 4)
idx = (gidx0%2, gidx1+2)
# not empty
load = get_load_image_uop(shape, gidx0<8, idx)
self.check(load, "(gidx0<8)", "(gidx0%2)", "(gidx1+2)")
# empty -> invalid
load = get_load_image_uop(shape, (gidx0<8) & (gidx0<8).ne(True), idx)
with Context(NOOPT=1, SPEC=0):
load = full_rewrite(load.sink()).src[0]
self.assertEqual(load.op, Ops.STACK)
self.assertEqual(load.dtype.count, 4)
def test_openpilot_conv1(self):
# first conv in openpilot
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
# ridx0 = Variable("ridx0", 0, 5)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range(0, 6)
ridx1 = Range(1, 3)
ridx2 = Range(2, 3)
alu1 = ((idx2*2)+ridx1)
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
valid = ((((idx2*2)+(ridx1))<1).ne(True))&((((idx1*8)+(ridx2))<1).ne(True))
shape = (128, 1536, 4)
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*48)+(r2*6))+r0)+-6)", "(((idx2*2)+r1)+-1)")
def test_openpilot_conv2(self):
# conv in test/external/external_test_valid_remove.py
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
# ridx0 = Variable("ridx0", 0, 2)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range(0, 3)
ridx1 = Range(1, 3)
ridx2 = Range(2, 3)
alu1 = ((idx2*2)+ridx1)
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
valid = ((((idx2*2)+ridx1)<1).ne(True))&((((idx1*8)+ridx2)<1).ne(True))
shape = (128, 768, 4)
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((((idx1*24)+(r2*3))+r0)+-3)", "(((idx2*2)+r1)+-1)")
def test_openpilot_conv3(self):
# in openpilot 0.9.7
idx0 = Special("idx0", 64)
idx1 = Special("idx1", 2)
idx2 = Special("idx2", 4)
ridx0 = Range(0, 7)
ridx1 = Range(1, 7)
alu2 = ((idx2*2)+ridx0)
alu4 = ((idx1*8)+ridx1)
alu6 = ((idx1*512)+(ridx1*64)+idx0)
valid = (alu2<11)&(alu4<3).ne(True)
shape = (8, 1024, 4)
idx = (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)//8)+1)//2)+(-4)))
load = get_load_image_uop(shape, valid, idx)
self.check(load,
"((((idx2*2)+r0)<11)&((((idx1*8)+r1)<3)!=True))",
"(idx0+(idx1*512+r1*64)+-192)",
"((((idx2*2)+r0)+(((idx1+((r1+5)//8))+1)//2))+-4)")
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
gidx = Special("gidx", 512)
valid = (gidx<488) & (gidx<480).ne(True)
idx = ((gidx*3+18)%26, (gidx*3+18)//26-56)
load = get_load_image_uop((1, 26, 4), valid, idx)
self.check(load, None, "((gidx*3)+-1438)", "0")
def test_simplify2(self):
# from DEV=CL DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
lidx = Special("lidx", 4)
valid = (lidx<3) & (lidx<1).ne(True)
idx = ((lidx+1)%2, (lidx+1)//2-1)
load = get_load_image_uop((1, 2, 4), valid, idx)
self.check(load, None, "(lidx+-1)", "0")
def test_simplify3(self):
# from openpilot
idx0 = Special("idx0", 265)
valid = (idx0<201).ne(True)
idx = ((idx0+55)%64, (idx0+55)//64-4)
load = get_load_image_uop((1, 64, 4), valid, idx)
self.check(load, None, "(idx0+-201)", "0")
def test_simplify4(self):
idx0 = Special("idx0", 512)
shape = (4, 64, 4)
alu2 = ((idx0*4+1)%32)
alu3 = ((idx0*4+2)%32)
alu4 = ((idx0*4+3)%32)
alu5 = (idx0*4%32)
alu8 = (idx0//8%32//4)
alu9 = idx0<256
load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8)))
self.check(load, "(idx0<256)", "(idx0%2*32+idx0//32+8)", "(idx0//2%4)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu3*8))%64),(alu3//8)))
self.check(load, "(idx0<256)", "(idx0%2*32+idx0//32+16)", "(idx0//2%4)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu4*8))%64),(alu4//8)))
self.check(load, "(idx0<256)", "(idx0%2*32+idx0//32+24)", "(idx0//2%4)")
load = get_load_image_uop(shape, alu9, (((alu8+(alu5*8))%64),(alu5//8)))
self.check(load, "(idx0<256)", "(idx0%2*32+idx0//32)", "(idx0//2%4)")
def test_simplify5(self):
# openpilot 0.9.7, chunk replacement to simplify
shape = (10, 384, 4)
idx0 = Special("idx0", 16)
idx1 = Special("idx1", 24)
alu0 = idx0*4
alu1 = (idx1*256)+alu0
alu2 = idx1//3
alu3 = ((alu1+1)%768)
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
valid = alu3<640
load = get_load_image_uop(shape, valid, idx)
self.check(load, None, "((idx0+((idx1//3)*16))+128)", "((idx1%3)*4)")
def test_simplify6(self):
# from openpilot
# the valid implies the numerator of the div/mod is positive and can be simplified with floordiv rules
idx1 = Special("idx1", 16)
idx2 = Special("idx2", 64)
ridx3 = Range(3, 3)
ridx4 = Range(4, 3)
ridx5 = Range(5, 3)
alu0 = ((idx2*1536)+(ridx4*768)+ridx3+(idx1*24)+(ridx5*3)+-771)%768
alu1 = ((idx2*1536)+(ridx4*768)+ridx3+(idx1*24)+(ridx5*3)+-771)//768
valid = (((idx2+ridx4)<1)!=1)&(((idx1+ridx5)<1)!=1)
load = get_load_image_uop((128, 768, 4), valid, (alu0, alu1))
self.check(load, None, "((((idx1*24)+r3)+(r5*3))+-3)", "(((idx2*2)+r4)+-1)")
def test_simplify7(self):
# DEBUG=2 ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1397 ALLOWED_GATED_READ_IMAGE=94 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 # noqa: E501
# kernel 143
gidx0 = Special("gidx0", 32)
lidx0 = Special("lidx0", 16)
lidx1 = Special("lidx1", 8)
r0 = Range(0, 7)
# buf.render()='UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((32, 1024, 4)), arg=1, src=())'
alu0 = ((gidx0*2+(lidx0*128+r0*64+lidx1*8+-183)%64*64+(lidx0*128+r0*64+lidx1*8+-183)//64%32*4096+1)//4%1024)
alu1 = ((gidx0*2+(lidx0*128+r0*64+lidx1*8+-183)%64*64+(lidx0*128+r0*64+lidx1*8+-183)//64%32*4096+1)//4096)
valid = ((lidx1<7)&((((lidx0*2+r0)<3)!=1)&((lidx0*2+r0)<35)))
load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1))
self.check(load, None, "(lidx1*128+gidx0//2+144)", "(lidx0*2+r0+-3)")
# same idx, written without the inline simplification of the inner div/mod
alu0 = ((gidx0*2+lidx1*512+(lidx0*8192+r0*4096)+-11711)//4%1024)
alu1 = (lidx0*2+r0+-3)
valid = ((lidx1<7)&((((lidx0*2+r0)<3)!=1)&((lidx0*2+r0)<35)))
load = get_load_image_uop((32, 1024, 4), valid, (alu0, alu1))
self.check(load, None, "(lidx1*128+gidx0//2+144)", "(lidx0*2+r0+-3)")
def test_simplify8(self):
# from openpilot compile3, kernel r_4_16_8_16_4_4_3_3n1
# valid guarantees A >= 0, so divmod simplifies and gate is removed
gidx0 = Special("gidx0", 16)
gidx1 = Special("gidx1", 4)
lidx0 = Special("lidx0", 8)
lidx1 = Special("lidx1", 16)
A = gidx0 + gidx1*8192 + lidx0*1024 + lidx1*64 - 1040
valid = ((lidx1 < 1).ne(True)) & (((gidx1 + lidx0) < 1).ne(True))
load = get_load_image_uop((32, 1024, 4), valid, (A % 1024, A // 1024))
self.check(load, None, "(gidx0+lidx1*64+-16)", "(lidx0+gidx1*8+-1)")
def test_simplify9(self):
# from openpilot compile3, kernel r_32_16_8_4_4_7_7 (image 1x16384)
# valid guarantees A1 >= 0 and A1 < 512, gate should be removable
gidx0 = Special("gidx0", 32)
lidx0 = Special("lidx0", 16)
lidx1 = Special("lidx1", 8)
r0 = Range(0, 7)
A1 = lidx0*32 + r0*32 + lidx1*4 - 99
valid = ((lidx1 < 1).ne(True)) & ((lidx0 + r0) < 3).ne(True) & ((lidx0 + r0) < 19)
alu0 = gidx0 + (A1 % 32)*32 + (A1 // 32 % 16)*1024
load = get_load_image_uop((1, 16384, 4), valid, (alu0, UOp.const(dtypes.weakint, 0)))
try:
self.check(load, None, "(gidx0+lidx0*1024+r0*1024+lidx1*128+-3168)", "0")
except AssertionError:
# TODO: fold valid
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<19))",
"(gidx0+lidx1*128+(lidx0*1024+r0*1024)+-3168)", "0")
def test_simplify10(self):
# from openpilot compile3, kernel r_16_8_4_4_4_4_7_7 (image 1x8192)
# valid guarantees A1 >= 0 and A1 < 128, gate should be removable
gidx0 = Special("gidx0", 16)
lidx0 = Special("lidx0", 8)
lidx1 = Special("lidx1", 4)
lidx2 = Special("lidx2", 4)
r0 = Range(0, 7)
A1 = lidx0*16 + r0*16 + lidx1*4 - 51
valid = ((lidx1 < 1).ne(True)) & ((lidx0 + r0) < 3).ne(True) & ((lidx0 + r0) < 11)
alu0 = lidx2 + gidx0*4 + (A1 % 16)*64 + (A1 // 16 % 8)*1024
load = get_load_image_uop((1, 8192, 4), valid, (alu0, UOp.const(dtypes.weakint, 0)))
try:
self.check(load, None, "(lidx2+gidx0*4+lidx0*1024+r0*1024+lidx1*256+-3264)", "0")
except AssertionError:
# TODO: fold valid
self.check(load, "(((lidx1<1)!=True)&(((lidx0+r0)<3)!=True)&((lidx0+r0)<11))",
"(lidx2+gidx0*4+lidx1*256+(lidx0*1024+r0*1024)+-3264)", "0")
class TestDropTrueGate(unittest.TestCase):
def test_drop_true_gate_on_index(self):
# test that INDEX with a constant True valid gets simplified to drop the valid
from tinygrad.codegen.late.devectorizer import load_store_indexing
from tinygrad.uop.ops import graph_rewrite
from tinygrad.uop.symbolic import sym
buf = UOp.param(0, dtypes.int.ptr())
idx = UOp.const(dtypes.weakint, 0)
true_gate = UOp.const(dtypes.bool, True)
index_with_gate = UOp(Ops.INDEX, dtypes.int.ptr(), (buf, idx.valid(true_gate)))
# apply the optimization
result = graph_rewrite(index_with_gate, sym+load_store_indexing)
# the True valid should be dropped (INDEX should only have 2 sources)
self.assertEqual(len(result.src), 2, "True valid should be dropped from INDEX")
class TestRangeShrink(unittest.TestCase):
def get_ranges(self, sink):
with Context(NOOPT=1, SPEC=0):
result = full_rewrite(sink)
return [u for u in result.toposort() if u.op is Ops.RANGE]
def test_range_shrink_single_guard(self):
# range 0..203 guarded by r < 4 everywhere -> shrink to 0..3
r = Range(0, 204)
load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
ranges = self.get_ranges(load.sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 4)
def test_range_shrink_picks_max_guard(self):
# two loads guard the same range with r < 4 and r < 8 -> shrink to max(4, 8) = 8
r = Range(0, 204)
load1 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
load2 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 8), r)
ranges = self.get_ranges(UOp.sink(load1, load2))
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 8)
def test_range_no_shrink_guard_ge_max(self):
# guard r < 300 with range max 204 -> no shrink (guard doesn't constrain)
r = Range(0, 204)
load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 300), r)
ranges = self.get_ranges(load.sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 204)
def test_range_no_shrink_when_unguarded_elsewhere(self):
# one load guards r < 4, but another load uses r without a gate -> no shrink
r = Range(0, 204)
load1 = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
load2 = UOp(Ops.LOAD, dtypes.float, (UOp.param(1, dtypes.float.ptr()).index(r, ptr=True),))
ranges = self.get_ranges(UOp.sink(load1, load2))
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 204)
def test_range_no_shrink_when_used_in_reduce(self):
# range used in both a gated load AND directly in the reduce expression -> no shrink
r = Range(0, 204)
gated_load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 4), r)
red = (r.cast(dtypes.float) + gated_load).reduce(r, arg=Ops.ADD)
ranges = self.get_ranges(red.sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 204)
def test_range_shrink_to_single_iteration(self):
# guard r < 1 shrinks range to 1 -> single iteration, range eliminated entirely
r = Range(0, 204)
load = get_gated_load_uop(r < UOp.const(dtypes.weakint, 1), r)
ranges = self.get_ranges(load.sink())
self.assertEqual(len(ranges), 0)
def test_range_shrink_store_where_invalid(self):
# emulates mask.where(x.pad_to(mask.shape), Invalid): range should shrink accordingly
from tinygrad.dtype import Invalid
r = Range(0, 204)
x = (r < 4).where(UOp.const(dtypes.float, 1), Invalid)
ranges = self.get_ranges(UOp.param(0, dtypes.float.ptr()).index(r).store((r < 4).where(x, 0)).sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 4)
def test_range_shrink_store_where_invalid_flipped(self):
# above, but flipped
from tinygrad.dtype import Invalid
r = Range(0, 204)
x = (r < 4).where(UOp.const(dtypes.float, 1), Invalid)
ranges = self.get_ranges(UOp.param(0, dtypes.float.ptr()).index(r).store((r < 4).where(0, x)).sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 4)
if __name__ == '__main__':
unittest.main()