From f34efc1ad1e27d50bbf9f45ca32404385e84f02e Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Tue, 27 Jan 2026 13:12:42 -0800 Subject: [PATCH] DISABLE_FAST_IDIV actually works as a ContextVar (#14378) --- test/test_uops.py | 58 ++++++++++++++++++++++------------ tinygrad/codegen/__init__.py | 4 +-- tinygrad/uop/decompositions.py | 6 ++-- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index cd2c3638cb..c55bc8ef82 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -375,21 +375,9 @@ class TestLocalAccess(unittest.TestCase): sres = uop(uops, Ops.LOAD, dtypes.int32, (smem.index(ofs),)) self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42) -@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends") -class TestAssembly(unittest.TestCase): - def test_bitshift_left(self): - g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) - c1 = UOp.const(dtypes.int, 2) - c2 = UOp.const(dtypes.int, 3) - l1 = g1.index(c1) - a1 = UOp(Ops.MUL, dtypes.int, (l1, c1)) - a2 = UOp(Ops.MUL, dtypes.int, (l1, c2)) - uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer) - Device[Device.DEFAULT].renderer.render(uops) - ops = [x.op for x in uops] - self.assertIn(Ops.SHL, ops) - self.assertIn(Ops.MUL, ops) - +@unittest.skipIf(Device.DEFAULT == "METAL", "compiler bug") +@unittest.skipUnless(Ops.SHR in Device[Device.DEFAULT].renderer.code_for_op, "fast_idiv requires SHR") +class TestFastIdiv(unittest.TestCase): def test_division_power_of_two(self): for dt in (dtypes.int32, dtypes.uint32): g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0) @@ -402,6 +390,7 @@ class TestAssembly(unittest.TestCase): self.assertIn(Ops.SHR, ops, f"For dtype={dt} divison by power of two did not simplify to shift") self.assertNotIn(Ops.IDIV, ops, f"For dtype={dt} divison by power of two did not simplify to shift") + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long") def test_fast_idiv_and_mod(self): g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0) c = UOp.const(dtypes.uint, 3) @@ -420,6 +409,14 @@ class TestAssembly(unittest.TestCase): self.assertIn(Ops.SHR, ops) self.assertNotIn(Ops.MOD, ops) + def test_fast_idiv_remove_powers_of_two(self): + ridx = UOp.range(2**20, 0) + uops = to_uops_list([ridx//(7*64)], ren=Device[Device.DEFAULT].renderer) + ops = [x.op for x in uops] + # this requires shifting out the powers of two before doing fast_idiv + # (((ridx0>>6)*18725)>>17) instead of (int)((((long)(ridx0)*1198373)>>29)) + self.assertNotIn(Ops.CAST, ops) + @unittest.expectedFailure def test_fast_idiv_overflow(self): # This will be possible with a slightly different method for fast_idiv @@ -433,13 +430,32 @@ class TestAssembly(unittest.TestCase): self.assertIn(Ops.SHR, ops) self.assertNotIn(Ops.IDIV, ops) - def test_fast_idiv_remove_powers_of_two(self): - ridx = UOp.range(2**20, 0) - uops = to_uops_list([ridx//(7*64)], ren=Device[Device.DEFAULT].renderer) + def test_disable_fast_idiv(self): + g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0) + c = UOp.const(dtypes.uint, 3) + l = g.index(c) + a = UOp(Ops.IDIV, dtypes.uint, (l, c)) + with Context(DISABLE_FAST_IDIV=1): + uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer) ops = [x.op for x in uops] - # this requires shifting out the powers of two before doing fast_idiv - # (((ridx0>>6)*18725)>>17) instead of (int)((((long)(ridx0)*1198373)>>29)) - self.assertNotIn(Ops.CAST, ops) + self.assertNotIn(Ops.SHR, ops) + self.assertIn(Ops.IDIV, ops) + + +@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends") +class TestAssembly(unittest.TestCase): + def test_bitshift_left(self): + g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) + c1 = UOp.const(dtypes.int, 2) + c2 = UOp.const(dtypes.int, 3) + l1 = g1.index(c1) + a1 = UOp(Ops.MUL, dtypes.int, (l1, c1)) + a2 = UOp(Ops.MUL, dtypes.int, (l1, c2)) + uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer) + Device[Device.DEFAULT].renderer.render(uops) + ops = [x.op for x in uops] + self.assertIn(Ops.SHL, ops) + self.assertIn(Ops.MUL, ops) def test_mulacc_unrolled(self): # test that acc = acc + a0*b0 + a1*b1 + a2*b2 + a3*b3 diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 24ae316d4a..bc14a1b0e7 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,6 +1,6 @@ from typing import cast import itertools -from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context +from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer, ProgramSpec @@ -95,7 +95,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - # decompositions supported_ops = tuple(ren.code_for_op.keys()) - pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2) + pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2, bool(DISABLE_FAST_IDIV)) sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions") # final rules for the renderer (without sym) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 4f1f914e0f..2d0cf5fb2b 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -2,7 +2,7 @@ from typing import Callable import math, functools from tinygrad.dtype import dtypes, DType, promo_lattice from tinygrad.device import is_dtype_supported -from tinygrad.helpers import polyN, DISABLE_FAST_IDIV +from tinygrad.helpers import polyN from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64) @@ -318,7 +318,7 @@ def threefry2x32(x: UOp, key: UOp): powers_of_two = {2**i:i for i in range(64)} @functools.cache -def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental): +def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental, disable_fast_idiv): pat: list[tuple[UPat, Callable]] = [] for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)): if op not in ops or force_transcendental: @@ -342,7 +342,7 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental): pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where( c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v - if not DISABLE_FAST_IDIV: + if not disable_fast_idiv: pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))] pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))] if Ops.NEG in ops: