DISABLE_FAST_IDIV actually works as a ContextVar (#14378)

This commit is contained in:
Christopher Milan
2026-01-27 13:12:42 -08:00
committed by GitHub
parent 8c899e4aaf
commit f34efc1ad1
3 changed files with 42 additions and 26 deletions

View File

@@ -375,21 +375,9 @@ class TestLocalAccess(unittest.TestCase):
sres = uop(uops, Ops.LOAD, dtypes.int32, (smem.index(ofs),))
self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42)
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_bitshift_left(self):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp.const(dtypes.int, 2)
c2 = UOp.const(dtypes.int, 3)
l1 = g1.index(c1)
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
ops = [x.op for x in uops]
self.assertIn(Ops.SHL, ops)
self.assertIn(Ops.MUL, ops)
@unittest.skipIf(Device.DEFAULT == "METAL", "compiler bug")
@unittest.skipUnless(Ops.SHR in Device[Device.DEFAULT].renderer.code_for_op, "fast_idiv requires SHR")
class TestFastIdiv(unittest.TestCase):
def test_division_power_of_two(self):
for dt in (dtypes.int32, dtypes.uint32):
g = UOp(Ops.DEFINE_GLOBAL, dt.ptr(), (), 0)
@@ -402,6 +390,7 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.SHR, ops, f"For dtype={dt} divison by power of two did not simplify to shift")
self.assertNotIn(Ops.IDIV, ops, f"For dtype={dt} divison by power of two did not simplify to shift")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long")
def test_fast_idiv_and_mod(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp.const(dtypes.uint, 3)
@@ -420,6 +409,14 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.SHR, ops)
self.assertNotIn(Ops.MOD, ops)
def test_fast_idiv_remove_powers_of_two(self):
ridx = UOp.range(2**20, 0)
uops = to_uops_list([ridx//(7*64)], ren=Device[Device.DEFAULT].renderer)
ops = [x.op for x in uops]
# this requires shifting out the powers of two before doing fast_idiv
# (((ridx0>>6)*18725)>>17) instead of (int)((((long)(ridx0)*1198373)>>29))
self.assertNotIn(Ops.CAST, ops)
@unittest.expectedFailure
def test_fast_idiv_overflow(self):
# This will be possible with a slightly different method for fast_idiv
@@ -433,13 +430,32 @@ class TestAssembly(unittest.TestCase):
self.assertIn(Ops.SHR, ops)
self.assertNotIn(Ops.IDIV, ops)
def test_fast_idiv_remove_powers_of_two(self):
ridx = UOp.range(2**20, 0)
uops = to_uops_list([ridx//(7*64)], ren=Device[Device.DEFAULT].renderer)
def test_disable_fast_idiv(self):
g = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0)
c = UOp.const(dtypes.uint, 3)
l = g.index(c)
a = UOp(Ops.IDIV, dtypes.uint, (l, c))
with Context(DISABLE_FAST_IDIV=1):
uops = to_uops_list([a], ren=Device[Device.DEFAULT].renderer)
ops = [x.op for x in uops]
# this requires shifting out the powers of two before doing fast_idiv
# (((ridx0>>6)*18725)>>17) instead of (int)((((long)(ridx0)*1198373)>>29))
self.assertNotIn(Ops.CAST, ops)
self.assertNotIn(Ops.SHR, ops)
self.assertIn(Ops.IDIV, ops)
@unittest.skipUnless(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_bitshift_left(self):
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp.const(dtypes.int, 2)
c2 = UOp.const(dtypes.int, 3)
l1 = g1.index(c1)
a1 = UOp(Ops.MUL, dtypes.int, (l1, c1))
a2 = UOp(Ops.MUL, dtypes.int, (l1, c2))
uops = to_uops_list([a1,a2], ren=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render(uops)
ops = [x.op for x in uops]
self.assertIn(Ops.SHL, ops)
self.assertIn(Ops.MUL, ops)
def test_mulacc_unrolled(self):
# test that acc = acc + a0*b0 + a1*b1 + a2*b2 + a3*b3

View File

@@ -1,6 +1,6 @@
from typing import cast
import itertools
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context
from tinygrad.helpers import DISABLE_FAST_IDIV, DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer, ProgramSpec
@@ -95,7 +95,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# decompositions
supported_ops = tuple(ren.code_for_op.keys())
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2, bool(DISABLE_FAST_IDIV))
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
# final rules for the renderer (without sym)

View File

@@ -2,7 +2,7 @@ from typing import Callable
import math, functools
from tinygrad.dtype import dtypes, DType, promo_lattice
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import polyN, DISABLE_FAST_IDIV
from tinygrad.helpers import polyN
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher
TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
@@ -318,7 +318,7 @@ def threefry2x32(x: UOp, key: UOp):
powers_of_two = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental):
def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental, disable_fast_idiv):
pat: list[tuple[UPat, Callable]] = []
for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)):
if op not in ops or force_transcendental:
@@ -342,7 +342,7 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental):
pat += [(UPat.var("x", dtypes.uints)//UPat.cvar("c"), lambda x,c: x >> v if (v:=powers_of_two.get(c.arg, 0)) else None)]
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: (x+(l.const_like(l.vmin) if (l:=(x<0)).vmin==l.vmax else l).where(
c-1, 0)) >> v if (v:=powers_of_two.get(c.arg, 0)) else None)] # (x+(x<0).where(c-1, 0)) >> v
if not DISABLE_FAST_IDIV:
if not disable_fast_idiv:
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
if Ops.NEG in ops: