Merge branch 'master' into move_gates_to_load_store

This commit is contained in:
George Hotz
2026-05-05 13:55:08 -07:00
committed by GitHub
17 changed files with 116 additions and 42 deletions

View File

@@ -333,7 +333,7 @@ jobs:
deps: testing_unit
python-version: '3.14'
- name: Test SPEC=2
run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 test/unit test/backend test/opt --ignore test/backend/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 test/unit test/backend test/opt --ignore test/backend/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" -k "not test_conv2d_ceildiv_edge_case" --splits 2 --group ${{ matrix.group }}
fuzzing:
name: Fuzzing

View File

@@ -192,6 +192,20 @@ class TestVminVmaxDivMod(unittest.TestCase):
self.assertEqual(uop.vmin, 3)
self.assertEqual(uop.vmax, 6)
def test_vmin_vmax_floordiv_floormod(self):
# FLOORDIV/FLOORMOD ranges differ from IDIV/MOD when the dividend can be negative
x = UOp.variable('x', -7, 7)
floordiv = x.alu(Ops.FLOORDIV, x.const_like(3))
self.assertEqual(floordiv.vmin, -3)
self.assertEqual(floordiv.vmax, 2)
floormod = x.alu(Ops.FLOORMOD, x.const_like(3))
self.assertEqual(floormod.vmin, 0)
self.assertEqual(floormod.vmax, 2)
# negative const divisor: floormod range is [c+1, 0]
floormod_neg = x.alu(Ops.FLOORMOD, x.const_like(-3))
self.assertEqual(floormod_neg.vmin, -2)
self.assertEqual(floormod_neg.vmax, 0)
# cross 0
x = UOp.variable('x', -10, 10)
uop = x // -2

View File

@@ -46,6 +46,20 @@ class TestExecALU(unittest.TestCase):
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (7, -3)), -2)
self.assertEqual(exec_alu(Ops.IDIV, dtypes.int8, (-50, 6)), -8)
def test_floordiv(self):
self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (8, 2)), 4)
self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (7, 3)), 2)
self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (7, -3)), -3)
self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (-7, 3)), -3)
self.assertEqual(exec_alu(Ops.FLOORDIV, dtypes.int8, (-50, 6)), -9)
def test_floormod(self):
self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (8, 2)), 0)
self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (7, 3)), 1)
self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (7, -3)), -2)
self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (-7, 3)), 2)
self.assertEqual(exec_alu(Ops.FLOORMOD, dtypes.int8, (-50, 6)), 4)
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
np.testing.assert_allclose(exec_alu(Ops.MUL, dtypes.float32, (7.0, exec_alu(Ops.RECIPROCAL, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))

View File

@@ -11,7 +11,7 @@ from test.mockgpu.usb import MockUSB
@unittest.skipUnless(issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "HCQ device required to run")
class TestHCQUnit(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "CPU", "requires non-CPU HCQ device")
def test_supports_exec_item(self):
def test_supports_uop(self):
d0, cpu_dev = Device[Device.DEFAULT], Device["CPU"]
@TinyJit
@@ -20,23 +20,23 @@ class TestHCQUnit(unittest.TestCase):
inp, inp_cpu = Tensor.randn(10, 10, device=Device.DEFAULT).realize(), Tensor.randn(10, 10, device="CPU").realize()
for _ in range(5): f(inp, inp_cpu)
# construct minimal CALL UOps for supports_exec_item (graphs only see PROGRAMs after compile_linear)
# construct minimal CALL UOps for supports_uop (graphs only see PROGRAMs after compile_linear)
gpu_call = UOp(Ops.PROGRAM).call(UOp.new_buffer(Device.DEFAULT, 1, dtypes.float))
cpu_call = UOp(Ops.PROGRAM).call(UOp.new_buffer("CPU", 1, dtypes.float))
gpu_devs = [d0]
# local MMIO: GPU works alone and with CPU in batch (cpu_support=True)
assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True
assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is True
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is True
assert HCQGraph.supports_uop(gpu_devs, gpu_call) is True
assert HCQGraph.supports_uop(gpu_devs, cpu_call) is True
assert HCQGraph.supports_uop(gpu_devs + [cpu_dev], gpu_call) is True
# USB MMIO: GPU-only still works, but CPU batching must be rejected (cpu_support=False)
orig_view = d0.timeline_signal.base_buf.view
try:
d0.timeline_signal.base_buf.view = USBMMIOInterface(MockUSB(bytearray(256)), 0, 16, fmt='B')
assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True
assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is False
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is False
assert HCQGraph.supports_uop(gpu_devs, gpu_call) is True
assert HCQGraph.supports_uop(gpu_devs, cpu_call) is False
assert HCQGraph.supports_uop(gpu_devs + [cpu_dev], gpu_call) is False
finally:
d0.timeline_signal.base_buf.view = orig_view

View File

@@ -27,18 +27,18 @@ class TestMetalGraph(unittest.TestCase):
c.src = (MagicMock(op=Ops.PROGRAM),) + tuple(bufs)
return c
def test_supports_exec_item_normal_offset(self):
assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True
def test_supports_uop_normal_offset(self):
assert self.MetalGraph.supports_uop([self.dev], self.call(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True
def test_supports_exec_item_overflow_offset(self):
assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False
def test_supports_uop_overflow_offset(self):
assert self.MetalGraph.supports_uop([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False
def test_supports_exec_item_nonmetal_buf(self):
def test_supports_uop_nonmetal_buf(self):
# non-BUFFER_VIEW ops should not be checked for offset
buf = MagicMock()
buf.op = Ops.BUFFER
buf.device = Device.DEFAULT
self.MetalGraph.supports_exec_item([self.dev], self.call(buf))
self.MetalGraph.supports_uop([self.dev], self.call(buf))
if __name__ == "__main__":
unittest.main()

View File

@@ -48,8 +48,8 @@ def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp:
devs = dedup([Device[x] for b in si.src[1:] if b.op is not Ops.BIND for x in (b.device if isinstance(b.device, tuple) else (b.device,))])
graph_t = graph_class(devs[0]) if devs[0].graph is not None else None
can_graph = graph_t is not None and graph_t.supports_exec_item(devs, si)
can_extend = can_graph and graph_t is not None and (not current_batch_devs or graph_t.supports_exec_item(current_batch_devs, si)) \
can_graph = graph_t is not None and graph_t.supports_uop(devs, si)
can_extend = can_graph and graph_t is not None and (not current_batch_devs or graph_t.supports_uop(current_batch_devs, si)) \
and (max_batch_size == 0 or len(current_batch) < max_batch_size)
if not can_extend and current_batch: flush_batch()
@@ -166,13 +166,13 @@ class GraphRunner:
for x in (b.device if isinstance(b.device, tuple) else (b.device,))])
@staticmethod
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
return new_call.src[0].op is Ops.PROGRAM and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
# a marker for your graph supporting multiple devices of the same type
class MultiGraphRunner(GraphRunner):
@staticmethod
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
# Devices must be the same type
return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1

View File

@@ -64,6 +64,9 @@ def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
# cstyle div and mod
def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0
def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y
# python floor div and mod
def floordiv(x:int, y:int) -> int: return x//y if y != 0 else 0
def floormod(x:int, y:int) -> int: return x-floordiv(x,y)*y
def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
def hi32(x:Any) -> Any: return x >> 32 # Any is sint
def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint

View File

@@ -316,7 +316,7 @@ class HCQGraph(MultiGraphRunner):
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
@staticmethod
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
def supports_uop(batch_devs:list[Compiled], new_call:UOp) -> bool:
# Check if all devices are HCQ
all_devs = cast(list[HCQCompiled], GraphRunner._all_devs(batch_devs, new_call))
if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False

View File

@@ -107,7 +107,7 @@ class MetalGraph(GraphRunner):
self.collect_timestamps()
@staticmethod
def supports_exec_item(batch_devs, new_call:UOp) -> bool:
def supports_uop(batch_devs, new_call:UOp) -> bool:
# Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range.
if any(b.op is Ops.BUFFER_VIEW and b.arg[1] * b.dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
return GraphRunner.supports_exec_item(batch_devs, new_call)
return GraphRunner.supports_uop(batch_devs, new_call)

View File

@@ -1307,10 +1307,7 @@ class Tensor(OpMixin):
numerator, denominator = self._broadcasted(x, reverse)
if dtypes.is_int(numerator.dtype):
if rounding_mode == "trunc": return numerator.idiv(denominator)
if rounding_mode == "floor":
truncate_div, truncate_mod = numerator.idiv(denominator), numerator._binop(Ops.MOD, denominator, False)
opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0))
return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div)
if rounding_mode == "floor": return numerator._binop(Ops.FLOORDIV, denominator, False)
d = numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
output_dtype = numerator.dtype if dtypes.is_int(numerator.dtype) else d.dtype
if rounding_mode == "trunc": return d.trunc().cast(output_dtype)
@@ -1328,6 +1325,7 @@ class Tensor(OpMixin):
```
"""
a, b = self._broadcasted(x, reverse)
if dtypes.is_int(a.dtype): return a._binop(Ops.FLOORMOD, b, False)
return a - a.div(b, rounding_mode="floor") * b
def fmod(self, x:Tensor|ConstType) -> Tensor:

View File

@@ -65,6 +65,7 @@ class Ops(FastEnum):
CMPLT = auto(); CMPNE = auto(); CMPEQ = auto()
XOR = auto(); OR = auto(); AND = auto()
THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto()
FLOORDIV = auto(); FLOORMOD = auto()
# TernaryOps
WHERE = auto(); MULACC = auto()
@@ -110,7 +111,7 @@ class Ops(FastEnum):
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIPROCAL, Ops.NEG, Ops.TRUNC}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ,
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW, Ops.FLOORDIV, Ops.FLOORMOD}
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)
@@ -137,6 +138,6 @@ class GroupOp:
Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW, Ops.FLOORDIV}
All = set(Ops)

View File

@@ -438,16 +438,31 @@ def get_transcendental_patterns(ops:tuple[Ops, ...], force_transcendental:bool)
if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
return PatternMatcher(pat)
def floordiv_to_idiv(d:UOp, a:UOp, b:UOp) -> UOp:
if (a.vmin >= 0 and b.vmin > 0) or (a.vmax <= 0 and b.vmax < 0): return a.alu(Ops.IDIV, b)
return a.alu(Ops.IDIV, b) - (a.alu(Ops.MOD, b).ne(0) & (a<0).ne(b<0)).cast(d.dtype)
def floormod_to_mod(d:UOp, a:UOp, b:UOp) -> UOp:
if (a.vmin >= 0 and b.vmin > 0) or (a.vmax <= 0 and b.vmax < 0): return a.alu(Ops.MOD, b)
r = a.alu(Ops.MOD, b)
# use where instead of mul to avoid being fused into MULACC (which int64 long-decomp doesn't handle)
return r + (r.ne(0) & (a<0).ne(b<0)).where(b, b.const_like(0))
powers_of_two: dict[int, int] = {2**i:i for i in range(64)}
@functools.cache
def get_late_rewrite_patterns(ops:tuple[Ops, ...], disable_fast_idiv:bool) -> PatternMatcher:
pat: list[tuple[UPat, Callable]] = []
pat: list[tuple[UPat, Callable]] = [
(UPat(Ops.FLOORDIV, name="d", src=(UPat.var("a"), UPat.var("b"))), floordiv_to_idiv),
(UPat(Ops.FLOORMOD, name="d", src=(UPat.var("a"), UPat.var("b"))), floormod_to_mod),
]
# no real hardware supports THREEFRY, but NullRenderer does
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
# TODO: drop the x.vmin>=0 guard once UOp `%` lowers to FLOORMOD instead of MOD
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"),
lambda x,c: x & (c.arg-1) if c.arg in powers_of_two and x.vmin >= 0 else None)]
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
lambda x,y: (x | y).logical_not())]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)

View File

@@ -8,7 +8,8 @@ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, DTypeLike, to_d
from tinygrad.dtype import ConstFloat, PyConst, storage_fmt_for_dtype, to_storage_scalar, from_storage_scalar
from tinygrad.device import Buffer, MultiBuffer, canonicalize_device
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC, CAPTURE_PROCESS_REPLAY
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, floordiv, floormod, diskcache_put, to_function_name, cpu_profile, TracingKey
from tinygrad.helpers import VIZ, SPEC, CAPTURE_PROCESS_REPLAY
from tinygrad.helpers import colored, ansilen, printable
if TYPE_CHECKING:
from tinygrad.renderer import Estimates
@@ -806,7 +807,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# is f a monotonically increasing function regards its input
if self.op in GroupOp.Irreducible: return True
if self.op is Ops.ADD: return self.src[0].is_increasing() and self.src[1].is_increasing()
if self.op in (Ops.MUL, Ops.IDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
if self.op in (Ops.MUL, Ops.IDIV, Ops.FLOORDIV) and self.src[1].op is Ops.CONST and self.src[1].arg >= 0: return self.src[0].is_increasing()
return False # False if not sure
def const_factor(self) -> int:
"""largest known int that divides self"""
@@ -867,6 +868,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
if s1_vmin*s1_vmax>0:
return min(vals:=(cdiv(s0_vmin, s1_vmin), cdiv(s0_vmin, s1_vmax), cdiv(s0_vmax, s1_vmin), cdiv(s0_vmax, s1_vmax))), max(vals)
if self.op is Ops.FLOORDIV:
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
if s1_vmin*s1_vmax>0: return min(vals:=(s0_vmin//s1_vmin, s0_vmin//s1_vmax, s0_vmax//s1_vmin, s0_vmax//s1_vmax)), max(vals)
if self.op is Ops.FLOORMOD:
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
if (c:=s1_vmin) == s1_vmax > 0: return (s0_vmin%c, s0_vmax%c) if s0_vmin//c == s0_vmax//c else (0, c-1)
if (c:=s1_vmin) == s1_vmax < 0: return (s0_vmin%c, s0_vmax%c) if s0_vmin//c == s0_vmax//c else (c+1, 0)
if s1_vmin > 0: return (0, s1_vmax-1)
if s1_vmax < 0: return (s1_vmin+1, 0)
if self.op is Ops.XOR and s1_vmin == s1_vmax == -1 and isinstance(s0_vmin, int) and isinstance(s0_vmax, int): return ~s0_vmax, ~s0_vmin
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
@@ -1053,7 +1063,8 @@ python_alu: dict[Ops, Callable] = {
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.TRUNC: math.trunc,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
Ops.MOD: cmod, Ops.IDIV: cdiv, Ops.FLOORDIV: floordiv, Ops.FLOORMOD: floormod,
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z, Ops.CMPEQ: operator.eq}
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
if dtype.count > 1:

View File

@@ -45,7 +45,7 @@ shared_spec = PatternMatcher([
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), dtype=dtypes.bool, src=(UPat.var("x"), UPat.var("y"))), lambda x,y: x.dtype.base == y.dtype.base),
# and SHL/SHR, the shift distance can be an int
(UPat((Ops.SHL, Ops.SHR), src=(UPat.var("x"), UPat.var("y")), name="a"), lambda a,x,y: a.dtype == x.dtype and y.dtype in (x.dtype, dtypes.uint)),
(UPat((Ops.IDIV, Ops.MOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat((Ops.IDIV, Ops.MOD, Ops.FLOORDIV, Ops.FLOORMOD), name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(GroupOp.ALU, name="x"), lambda x: all(x.dtype.base == y.dtype.base for y in x.src)),
# CAST
@@ -76,7 +76,7 @@ movement_ops = PatternMatcher([
# inputs to movement ops
(UPat((Ops.STACK, Ops.VCONST), dtype=dtypes.weakint), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.weakint), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV, Ops.FLOORDIV}, dtype=dtypes.weakint), lambda: True),
# AFTER on Movement Op, INDEX, BUFFER, COPY, or BITCAST
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.INDEX, Ops.MULTI, Ops.CONTIGUOUS, Ops.BUFFER, Ops.BITCAST, Ops.COPY})),),
@@ -283,7 +283,7 @@ full_spec = PatternMatcher([
# where on index in rhs position is fine
(UPat(Ops.WHERE, dtype=dtypes.weakint, src=(UPat(dtype=dtypes.bool), UPat(), UPat(dtype=dtypes.weakint))), lambda: True),
# allow index dtype on a restricted set of UOps
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.MAX,
(UPat((Ops.ADD, Ops.MUL, Ops.MOD, Ops.IDIV, Ops.FLOORDIV, Ops.FLOORMOD, Ops.MAX,
Ops.SPECIAL, Ops.CAST, Ops.RANGE, Ops.VCONST, Ops.STACK), dtype=dtypes.weakint), lambda: True),
# while BIND is being casted

View File

@@ -10,13 +10,15 @@ if z3.get_version() < (4, 12, 4, 0):
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
def z3_cdiv(a:z3.ArithRef, b:z3.ArithRef) -> z3.ArithRef:return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
def z3_floordiv(a:z3.ArithRef, b:z3.ArithRef) -> z3.ArithRef: return z3.If(b > 0, a/b, (-a)/(-b))
def z3_xor(a:z3.ExprRef, b:z3.ExprRef) -> z3.ExprRef:
if isinstance(a, z3.BoolRef): return a^b
# x ^ -1 = -(x+1), i.e. bitwise NOT
if isinstance(b, z3.IntNumRef) and b.as_long() == -1: return -(a+1)
if isinstance(a, z3.IntNumRef) and a.as_long() == -1: return -(b+1)
raise RuntimeError(f"z3 int XOR only supports XOR with -1, got {a=} {b=}")
z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv,
z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.FLOORDIV: z3_floordiv,
Ops.FLOORMOD: lambda a,b: a-z3_floordiv(a,b)*b,
Ops.SHR: lambda a,b: a/(2**b.as_long()), Ops.SHL: lambda a,b: a*(2**b.as_long()),
Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, Ops.MAX: lambda a,b: z3.If(a<b, b, a),}
def create_bounded(name:str, vmin:int, vmax:int, z3ctx:z3.Context) -> tuple[z3.ArithRef, z3.BoolRef]:

View File

@@ -122,6 +122,19 @@
font-size: 6px;
fill: #08090e;
}
g.tag.ref * {
cursor: pointer;
}
g.tag.ref circle {
fill: #9FDDE6;
stroke: #4a4b57;
}
g.tag.ref path {
fill: none;
stroke: #08090e;
stroke-width: 0.7;
stroke-linejoin: miter;
}
.label :is(text, p) {
font-weight: 350;
}

View File

@@ -56,9 +56,10 @@ function intersectRect(r1, r2) {
return {x:r1.x+dx*scale, y:r1.y+dy*scale};
}
function addTags(root) {
function addTags(root, path) {
root.selectAll("circle").data(d => [d]).join("circle").attr("r", 5);
root.selectAll("text").data(d => [d]).join("text").text(d => d).attr("dy", "0.35em");
if (path != null) root.selectAll("path").data(d => [d]).join("path").attr("d", path);
else root.selectAll("text").data(d => [d]).join("text").text(d => d).attr("dy", "0.35em");
}
const drawGraph = (data) => {
@@ -116,9 +117,11 @@ const drawGraph = (data) => {
});
addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag")
.attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag));
addTags(nodes.selectAll("g.type").data(d => d.callNode ? [d] : []).join("g")
.attr("class", d => `tag ${d.collapsed ? 'collapsed' : 'expanded'}`)
addTags(nodes.selectAll("g.type").data(d => d.callNode ? [d] : []).join("g").attr("class", d => `tag ${d.collapsed ? 'collapsed' : 'expanded'}`)
.attr("transform", d => `translate(${-d.width/2}, ${0})`).datum(d => d.collapsed ? "+" : ""));
addTags(nodes.selectAll("g.ref").data(d => d.ref != null ? [d] : []).join("g").attr("class", "tag ref")
.attr("transform", d => `translate(${d.width/2-2}, ${-d.height/2+2})`).on("click", (e,d) => { e.stopPropagation(); switchCtx(d.ref); }),
"M-1.7 1.7 L1.7 -1.7 M-0.55 -1.7 H1.7 V0.55");
// draw edges
const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis), edges = g.edges();
d3.select("#edges").selectAll("path.edgePath").data(edges).join("path").attr("class", "edgePath").attr("d", (e) => {
@@ -1076,7 +1079,7 @@ async function main() {
metadata.appendChild(codeBlock(upat[1], "python", { loc:upat[0], wrap:true }));
const diffCode = metadata.appendChild(document.createElement("pre")).appendChild(document.createElement("code"));
for (const line of diff) {
diffCode.appendChild(colored([{st:line, color:line.startsWith("+") ? "#3aa56d" : line.startsWith("") ? "#d14b4b" : "#f0f0f5"}]));
diffCode.appendChild(colored([{st:line, color:line.startsWith("+") ? "#3aa56d" : line.startsWith("-") ? "#d14b4b" : "#f0f0f5"}]));
diffCode.appendChild(document.createElement("br"));
}
diffCode.className = "wrap";