From d79bf356c27ee659f8bca52aa026baa70e65e84d Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 5 May 2026 17:34:44 +0300 Subject: [PATCH 1/4] viz: add CALL -> codegen link (#16044) * work * cleaner * details * rm --- tinygrad/viz/index.html | 13 +++++++++++++ tinygrad/viz/js/index.js | 11 +++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 2d8503b121..35804472bc 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -122,6 +122,19 @@ font-size: 6px; fill: #08090e; } + g.tag.ref * { + cursor: pointer; + } + g.tag.ref circle { + fill: #9FDDE6; + stroke: #4a4b57; + } + g.tag.ref path { + fill: none; + stroke: #08090e; + stroke-width: 0.7; + stroke-linejoin: miter; + } .label :is(text, p) { font-weight: 350; } diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 9465341d37..5df0f87d52 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -56,9 +56,10 @@ function intersectRect(r1, r2) { return {x:r1.x+dx*scale, y:r1.y+dy*scale}; } -function addTags(root) { +function addTags(root, path) { root.selectAll("circle").data(d => [d]).join("circle").attr("r", 5); - root.selectAll("text").data(d => [d]).join("text").text(d => d).attr("dy", "0.35em"); + if (path != null) root.selectAll("path").data(d => [d]).join("path").attr("d", path); + else root.selectAll("text").data(d => [d]).join("text").text(d => d).attr("dy", "0.35em"); } const drawGraph = (data) => { @@ -116,9 +117,11 @@ const drawGraph = (data) => { }); addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag") .attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag)); - addTags(nodes.selectAll("g.type").data(d => d.callNode ? [d] : []).join("g") - .attr("class", d => `tag ${d.collapsed ? 'collapsed' : 'expanded'}`) + addTags(nodes.selectAll("g.type").data(d => d.callNode ? [d] : []).join("g").attr("class", d => `tag ${d.collapsed ? 'collapsed' : 'expanded'}`) .attr("transform", d => `translate(${-d.width/2}, ${0})`).datum(d => d.collapsed ? "+" : "−")); + addTags(nodes.selectAll("g.ref").data(d => d.ref != null ? [d] : []).join("g").attr("class", "tag ref") + .attr("transform", d => `translate(${d.width/2-2}, ${-d.height/2+2})`).on("click", (e,d) => { e.stopPropagation(); switchCtx(d.ref); }), + "M-1.7 1.7 L1.7 -1.7 M-0.55 -1.7 H1.7 V0.55"); // draw edges const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis), edges = g.edges(); d3.select("#edges").selectAll("path.edgePath").data(edges).join("path").attr("class", "edgePath").attr("d", (e) => { From 9c37a0c75da7cf1ac958285b7b23312d9713e354 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 5 May 2026 11:42:14 -0400 Subject: [PATCH 2/4] Ops.FLOORDIV and Ops.FLOORMOD (#16038) * Ops.FLOORDIV and Ops.FLOORMOD lowered into IDIV and MOD in get_late_rewrite_patterns * still need this * exclude * like that? --- .github/workflows/test.yml | 2 +- test/null/test_uop_vmin_vmax.py | 14 ++++++++++++++ test/null/test_uops.py | 14 ++++++++++++++ tinygrad/helpers.py | 3 +++ tinygrad/tensor.py | 6 ++---- tinygrad/uop/__init__.py | 5 +++-- tinygrad/uop/decompositions.py | 19 +++++++++++++++++-- tinygrad/uop/ops.py | 17 ++++++++++++++--- tinygrad/uop/spec.py | 6 +++--- tinygrad/uop/validate.py | 4 +++- 10 files changed, 74 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d31083ac3f..8437e7e64f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -333,7 +333,7 @@ jobs: deps: testing_unit python-version: '3.14' - name: Test SPEC=2 - run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 test/unit test/backend test/opt --ignore test/backend/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }} + run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 test/unit test/backend test/opt --ignore test/backend/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" -k "not test_conv2d_ceildiv_edge_case" --splits 2 --group ${{ matrix.group }} fuzzing: name: Fuzzing diff --git a/test/null/test_uop_vmin_vmax.py b/test/null/test_uop_vmin_vmax.py index a341f78f1e..583326c2b2 100644 --- a/test/null/test_uop_vmin_vmax.py +++ b/test/null/test_uop_vmin_vmax.py @@ -192,6 +192,20 @@ class TestVminVmaxDivMod(unittest.TestCase): self.assertEqual(uop.vmin, 3) self.assertEqual(uop.vmax, 6) + def test_vmin_vmax_floordiv_floormod(self): + # FLOORDIV/FLOORMOD ranges differ from IDIV/MOD when the dividend can be negative + x = UOp.variable('x', -7, 7) + floordiv = x.alu(Ops.FLOORDIV, x.const_like(3)) + self.assertEqual(floordiv.vmin, -3) + self.assertEqual(floordiv.vmax, 2) + floormod = x.alu(Ops.FLOORMOD, x.const_like(3)) + self.assertEqual(floormod.vmin, 0) + self.assertEqual(floormod.vmax, 2) + # negative const divisor: floormod range is [c+1, 0] + floormod_neg = x.alu(Ops.FLOORMOD, x.const_like(-3)) + self.assertEqual(floormod_neg.vmin, -2) + self.assertEqual(floormod_neg.vmax, 0) + # cross 0 x = UOp.variable('x', -10, 10) uop = x // -2 diff --git a/test/null/test_uops.py b/test/null/test_uops.py index 9b20ead8c5..28e0784ff5 100644 --- a/test/null/test_uops.py +++ b/test/null/test_uops.py @@ -46,6 +46,20 @@ class TestExecALU(unittest.TestCase): self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, -3)), -2) self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (-50, 6)), -8) + def test_floordiv(self): + self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (8, 2)), 4) + self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (7, 3)), 2) + self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (7, -3)), -3) + self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (-7, 3)), -3) + self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (-50, 6)), -9) + + def test_floormod(self): + self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (8, 2)), 0) + self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (7, 3)), 1) + self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (7, -3)), -2) + self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (-7, 3)), 2) + self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (-50, 6)), 4) + np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (3.0,)))), 2+(1.0/3.0)) np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (-3.0,)))), -2-(1.0/3.0)) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 12e1a13e38..bbe84e5602 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -64,6 +64,9 @@ def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length() # cstyle div and mod def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0 def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y +# python floor div and mod +def floordiv(x:int, y:int) -> int: return x//y if y != 0 else 0 +def floormod(x:int, y:int) -> int: return x-floordiv(x,y)*y def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint def hi32(x:Any) -> Any: return x >> 32 # Any is sint def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index dd77c1c133..60b81fdb21 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1307,10 +1307,7 @@ class Tensor(OpMixin): numerator, denominator = self._broadcasted(x, reverse) if dtypes.is_int(numerator.dtype): if rounding_mode == "trunc": return numerator.idiv(denominator) - if rounding_mode == "floor": - truncate_div, truncate_mod = numerator.idiv(denominator), numerator._binop(Ops.MOD, denominator, False) - opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0)) - return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div) + if rounding_mode == "floor": return numerator._binop(Ops.FLOORDIV, denominator, False) d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal() output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype if rounding_mode == "trunc": return d.trunc().cast(output_dtype) @@ -1328,6 +1325,7 @@ class Tensor(OpMixin): ``` """ a, b = self._broadcasted(x, reverse) + if dtypes.is_int(a.dtype): return a._binop(Ops.FLOORMOD, b, False) return a - a.div(b, rounding_mode="floor") * b def fmod(self, x:Tensor|ConstType) -> Tensor: diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 4c6d77f05c..d3937d4da8 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -65,6 +65,7 @@ class Ops(FastEnum): CMPLT = auto(); CMPNE = auto(); CMPEQ = auto() XOR = auto(); OR = auto(); AND = auto() THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() + FLOORDIV = auto(); FLOORMOD = auto() # TernaryOps WHERE = auto(); MULACC = auto() @@ -110,7 +111,7 @@ class Ops(FastEnum): class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC} Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ, - Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW} + Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW, Ops.FLOORDIV, Ops.FLOORMOD} Ternary = {Ops.WHERE, Ops.MULACC} ALU = set.union(Unary, Binary, Ternary) @@ -137,6 +138,6 @@ class GroupOp: Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ} # do not preserve f(0) = 0 - UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW} + UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW, Ops.FLOORDIV} All = set(Ops) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 086bd270d1..435f1cfcba 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -438,16 +438,31 @@ def get_transcendental_patterns(ops:tuple[Ops, ...], force_transcendental:bool) if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5)))) return PatternMatcher(pat) +def floordiv_to_idiv(d:UOp, a:UOp, b:UOp) -> UOp: + if (a.vmin >= 0 and b.vmin > 0) or (a.vmax <= 0 and b.vmax < 0): return a.alu(Ops.IDIV, b) + return a.alu(Ops.IDIV, b) - (a.alu(Ops.MOD, b).ne(0) & (a<0).ne(b<0)).cast(d.dtype) + +def floormod_to_mod(d:UOp, a:UOp, b:UOp) -> UOp: + if (a.vmin >= 0 and b.vmin > 0) or (a.vmax <= 0 and b.vmax < 0): return a.alu(Ops.MOD, b) + r = a.alu(Ops.MOD, b) + # use where instead of mul to avoid being fused into MULACC (which int64 long-decomp doesn't handle) + return r + (r.ne(0) & (a<0).ne(b<0)).where(b, b.const_like(0)) + powers_of_two: dict[int, int] = {2**i:i for i in range(64)} @functools.cache def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher: - pat: list[tuple[UPat, Callable]] = [] + pat: list[tuple[UPat, Callable]] = [ + (UPat(Ops.FLOORDIV, name="d", src=(UPat.var("a"), UPat.var("b"))), floordiv_to_idiv), + (UPat(Ops.FLOORMOD, name="d", src=(UPat.var("a"), UPat.var("b"))), floormod_to_mod), + ] # no real hardware supports THREEFRY, but NullRenderer does if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32)) # MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends) if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]))) # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) - if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] + # TODO: drop the x.vmin>=0 guard once UOp `%` lowers to FLOORMOD instead of MOD + if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), + lambda x,c: x & (c.arg-1) if c.arg in powers_of_two and x.vmin >= 0 else None)] if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(), lambda x,y: (x | y).logical_not())] # rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index cf96590681..042879772c 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -8,7 +8,8 @@ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, DTypeLike, to_d from tinygrad.dtype import ConstFloat, PyConst, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar from tinygrad.device import Buffer, MultiBuffer, canonicalize_device from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA -from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY +from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, floordiv, floormod, diskcache_put, to_function_name, cpu_profile, TracingKey +from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY from tinygrad.helpers import colored, ansilen, printable if TYPE_CHECKING: from tinygrad.renderer import Estimates @@ -805,7 +806,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # is f a monotonically increasing function regards its input if self.op in GroupOp.Irreducible: return True if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing() - if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing() + if self.op in (Ops.MUL, Ops.IDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing() return False # False if not sure def const_factor(self) -> int: """largest known int that divides self""" @@ -866,6 +867,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass): assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int) if s1_vmin*s1_vmax>0: return min(vals:=(cdiv(s0_vmin, s1_vmin), cdiv(s0_vmin, s1_vmax), cdiv(s0_vmax, s1_vmin), cdiv(s0_vmax, s1_vmax))), max(vals) + if self.op is Ops.FLOORDIV: + assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int) + if s1_vmin*s1_vmax>0: return min(vals:=(s0_vmin//s1_vmin, s0_vmin//s1_vmax, s0_vmax//s1_vmin, s0_vmax//s1_vmax)), max(vals) + if self.op is Ops.FLOORMOD: + assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int) + if (c:=s1_vmin) == s1_vmax > 0: return (s0_vmin%c, s0_vmax%c) if s0_vmin//c == s0_vmax//c else (0, c-1) + if (c:=s1_vmin) == s1_vmax < 0: return (s0_vmin%c, s0_vmax%c) if s0_vmin//c == s0_vmax//c else (c+1, 0) + if s1_vmin > 0: return (0, s1_vmax-1) + if s1_vmax < 0: return (s1_vmin+1, 0) if self.op is Ops.XOR and s1_vmin == s1_vmax == -1 and isinstance(s0_vmin, int) and isinstance(s0_vmax, int): return ~s0_vmax, ~s0_vmin if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax) if self.op is Ops.CMPLT: return (s0_vmax 1: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 64db03a1fd..4ba226ec54 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -45,7 +45,7 @@ shared_spec = PatternMatcher([ (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base), # and SHL/SHR, the shift distance can be an int (UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)), - (UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), + (UPat((Ops.IDIV, Ops.MOD, Ops.FLOORDIV, Ops.FLOORMOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)), # CAST @@ -76,7 +76,7 @@ movement_ops = PatternMatcher([ # inputs to movement ops (UPat((Ops.STACK, Ops.VCONST), dtype=dtypes.weakint), lambda: True), - (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.weakint), lambda: True), + (UPat({Ops.ADD, Ops.MUL, Ops.IDIV, Ops.FLOORDIV}, dtype=dtypes.weakint), lambda: True), # AFTER on Movement Op, INDEX, BUFFER, COPY, or BITCAST (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.INDEX, Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),), @@ -282,7 +282,7 @@ full_spec = PatternMatcher([ # where on index in rhs position is fine (UPat(Ops.WHERE, dtype=dtypes.weakint, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.weakint))), lambda: True), # allow index dtype on a restricted set of UOps - (UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX, + (UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.FLOORDIV, Ops.FLOORMOD, Ops.MAX, Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.STACK), dtype=dtypes.weakint), lambda: True), # while BIND is being casted diff --git a/tinygrad/uop/validate.py b/tinygrad/uop/validate.py index 776460c93d..4463e943e5 100644 --- a/tinygrad/uop/validate.py +++ b/tinygrad/uop/validate.py @@ -10,13 +10,15 @@ if z3.get_version() < (4, 12, 4, 0): # IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND def z3_cdiv(a:z3.ArithRef, b:z3.ArithRef) -> z3.ArithRef:return z3.If((a<0), z3.If(0 z3.ArithRef: return z3.If(b > 0, a/b, (-a)/(-b)) def z3_xor(a:z3.ExprRef, b:z3.ExprRef) -> z3.ExprRef: if isinstance(a, z3.BoolRef): return a^b # x ^ -1 = -(x+1), i.e. bitwise NOT if isinstance(b, z3.IntNumRef) and b.as_long() == -1: return -(a+1) if isinstance(a, z3.IntNumRef) and a.as_long() == -1: return -(b+1) raise RuntimeError(f"z3 int XOR only supports XOR with -1, got {a=} {b=}") -z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, +z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.FLOORDIV: z3_floordiv, + Ops.FLOORMOD: lambda a,b: a-z3_floordiv(a,b)*b, Ops.SHR: lambda a,b: a/(2**b.as_long()), Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, Ops.MAX: lambda a,b: z3.If(a tuple[z3.ArithRef, z3.BoolRef]: From cee17e0d2f6a7e2534e4181bdd12705cf6431216 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 5 May 2026 21:40:53 +0300 Subject: [PATCH 3/4] viz: fix diff color (#16045) --- tinygrad/viz/js/index.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 5df0f87d52..f86606bc09 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -1079,7 +1079,7 @@ async function main() { metadata.appendChild(codeBlock(upat[1], "python", { loc:upat[0], wrap:true })); const diffCode = metadata.appendChild(document.createElement("pre")).appendChild(document.createElement("code")); for (const line of diff) { - diffCode.appendChild(colored([{st:line, color:line.startsWith("+") ? "#3aa56d" : line.startsWith("−") ? "#d14b4b" : "#f0f0f5"}])); + diffCode.appendChild(colored([{st:line, color:line.startsWith("+") ? "#3aa56d" : line.startsWith("-") ? "#d14b4b" : "#f0f0f5"}])); diffCode.appendChild(document.createElement("br")); } diffCode.className = "wrap"; From 5fa0016ffcea0457aaf4619c91e7a660eb4bacd4 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Tue, 5 May 2026 22:41:13 +0300 Subject: [PATCH 4/4] supports_exec_item -> supports_uop (#16033) --- test/unit/test_hcq_graph.py | 16 ++++++++-------- test/unit/test_metal_graph.py | 12 ++++++------ tinygrad/engine/jit.py | 8 ++++---- tinygrad/runtime/graph/hcq.py | 2 +- tinygrad/runtime/graph/metal.py | 4 ++-- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/test/unit/test_hcq_graph.py b/test/unit/test_hcq_graph.py index d8f0322a4b..270aa988d4 100644 --- a/test/unit/test_hcq_graph.py +++ b/test/unit/test_hcq_graph.py @@ -11,7 +11,7 @@ from test.mockgpu.usb import MockUSB @unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "HCQ device required to run") class TestHCQUnit(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "CPU", "requires non-CPU HCQ device") - def test_supports_exec_item(self): + def test_supports_uop(self): d0, cpu_dev = Device[Device.DEFAULT], Device["CPU"] @TinyJit @@ -20,23 +20,23 @@ class TestHCQUnit(unittest.TestCase): inp, inp_cpu = Tensor.randn(10, 10, device=Device.DEFAULT).realize(), Tensor.randn(10, 10, device="CPU").realize() for _ in range(5): f(inp, inp_cpu) - # construct minimal CALL UOps for supports_exec_item (graphs only see PROGRAMs after compile_linear) + # construct minimal CALL UOps for supports_uop (graphs only see PROGRAMs after compile_linear) gpu_call = UOp(Ops.PROGRAM).call(UOp.new_buffer(Device.DEFAULT, 1, dtypes.float)) cpu_call = UOp(Ops.PROGRAM).call(UOp.new_buffer("CPU", 1, dtypes.float)) gpu_devs = [d0] # local MMIO: GPU works alone and with CPU in batch (cpu_support=True) - assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True - assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is True - assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is True + assert HCQGraph.supports_uop(gpu_devs, gpu_call) is True + assert HCQGraph.supports_uop(gpu_devs, cpu_call) is True + assert HCQGraph.supports_uop(gpu_devs + [cpu_dev], gpu_call) is True # USB MMIO: GPU-only still works, but CPU batching must be rejected (cpu_support=False) orig_view = d0.timeline_signal.base_buf.view try: d0.timeline_signal.base_buf.view = USBMMIOInterface(MockUSB(bytearray(256)), 0, 16, fmt='B') - assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True - assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is False - assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is False + assert HCQGraph.supports_uop(gpu_devs, gpu_call) is True + assert HCQGraph.supports_uop(gpu_devs, cpu_call) is False + assert HCQGraph.supports_uop(gpu_devs + [cpu_dev], gpu_call) is False finally: d0.timeline_signal.base_buf.view = orig_view diff --git a/test/unit/test_metal_graph.py b/test/unit/test_metal_graph.py index 74c733ba90..3038dcfb34 100644 --- a/test/unit/test_metal_graph.py +++ b/test/unit/test_metal_graph.py @@ -27,18 +27,18 @@ class TestMetalGraph(unittest.TestCase): c.src = (MagicMock(op=Ops.PROGRAM),) + tuple(bufs) return c - def test_supports_exec_item_normal_offset(self): - assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True + def test_supports_uop_normal_offset(self): + assert self.MetalGraph.supports_uop([self.dev], self.call(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True - def test_supports_exec_item_overflow_offset(self): - assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False + def test_supports_uop_overflow_offset(self): + assert self.MetalGraph.supports_uop([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False - def test_supports_exec_item_nonmetal_buf(self): + def test_supports_uop_nonmetal_buf(self): # non-BUFFER_VIEW ops should not be checked for offset buf = MagicMock() buf.op = Ops.BUFFER buf.device = Device.DEFAULT - self.MetalGraph.supports_exec_item([self.dev], self.call(buf)) + self.MetalGraph.supports_uop([self.dev], self.call(buf)) if __name__ == "__main__": unittest.main() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 6f7fe24763..6c8d3fff3f 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -48,8 +48,8 @@ def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp: devs = dedup([Device[x] for b in si.src[1:] if b.op is not Ops.BIND for x in (b.device if isinstance(b.device, tuple) else (b.device,))]) graph_t = graph_class(devs[0]) if devs[0].graph is not None else None - can_graph = graph_t is not None and graph_t.supports_exec_item(devs, si) - can_extend = can_graph and graph_t is not None and (not current_batch_devs or graph_t.supports_exec_item(current_batch_devs, si)) \ + can_graph = graph_t is not None and graph_t.supports_uop(devs, si) + can_extend = can_graph and graph_t is not None and (not current_batch_devs or graph_t.supports_uop(current_batch_devs, si)) \ and (max_batch_size == 0 or len(current_batch) < max_batch_size) if not can_extend and current_batch: flush_batch() @@ -166,13 +166,13 @@ class GraphRunner: for x in (b.device if isinstance(b.device, tuple) else (b.device,))]) @staticmethod - def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool: + def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool: return new_call.src[0].op is Ops.PROGRAM and len(GraphRunner._all_devs(batch_devs, new_call)) == 1 # a marker for your graph supporting multiple devices of the same type class MultiGraphRunner(GraphRunner): @staticmethod - def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool: + def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool: # Devices must be the same type return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1 diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 2cc769f887..e4da62d609 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -316,7 +316,7 @@ class HCQGraph(MultiGraphRunner): for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True)) @staticmethod - def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool: + def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool: # Check if all devices are HCQ all_devs = cast(list[HCQCompiled], GraphRunner._all_devs(batch_devs, new_call)) if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index b33305c607..25c89257e5 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -107,7 +107,7 @@ class MetalGraph(GraphRunner): self.collect_timestamps() @staticmethod - def supports_exec_item(batch_devs, new_call:UOp) -> bool: + def supports_uop(batch_devs, new_call:UOp) -> bool: # Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range. if any(b.op is Ops.BUFFER_VIEW and b.arg[1] * b.dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False - return GraphRunner.supports_exec_item(batch_devs, new_call) + return GraphRunner.supports_uop(batch_devs, new_call)