From bb5ded85ccb167151fe6ceb91ee76bd35ecc1c99 Mon Sep 17 00:00:00 2001 From: eliotgolding <177857289+eliotgolding@users.noreply.github.com> Date: Tue, 4 Feb 2025 23:47:33 +0000 Subject: [PATCH] Don't rewrite idiv to rshift when numerator is negative (#8885) * more conditions for shift rewrite mul/idiv * make ptx test uint so the new condition is true * delete idiv test * rewrite to 0 is wrong for idiv, as denominator is cast to 0 before division * mul/div by 2**(large count) is unsupported anyway --- test/test_ops.py | 6 ++++++ test/test_uops.py | 12 ++++++------ tinygrad/codegen/rewriter.py | 4 ++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 85f3e31f48..091201244f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -725,6 +725,12 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: tor.__rshift__(2), lambda: ten.__rshift__(2).cast(dtypes.int32), forward_only=True) helper_test_op([], lambda: tor.bitwise_right_shift(2), lambda: ten.rshift(2).cast(dtypes.int32), forward_only=True) + def test_idiv_shift_rewrite_negative(self): + a = Tensor(-5).idiv(2).item() + b = Tensor(-5).contiguous().idiv(2).item() + self.assertEqual(a, b) + self.assertEqual(Tensor(-1).contiguous().idiv(4).item(), 0) # NOTE this is trunc-div behaviour + def test_sin(self): helper_test_op([(45,65)], lambda x: x.sin()) helper_test_op([()], lambda x: x.sin()) diff --git a/test/test_uops.py b/test/test_uops.py index ddf305cf03..23af1589c7 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -355,12 +355,12 @@ class TestAssembly(unittest.TestCase): self.assertIn(Ops.MUL, ops) def test_bitshift_right(self): - g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) - c1 = UOp(Ops.CONST, dtypes.int, (), 2) - c2 = UOp(Ops.CONST, dtypes.int, (), 3) - l1 = UOp(Ops.LOAD, dtypes.int, (g1.index(c1),)) - a1 = UOp(Ops.IDIV, dtypes.int, (l1, c1)) - a2 = UOp(Ops.IDIV, dtypes.int, (l1, c2)) + g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(), (), 0) + c1 = UOp(Ops.CONST, dtypes.uint, (), 2) + c2 = UOp(Ops.CONST, dtypes.uint, (), 3) + l1 = UOp(Ops.LOAD, dtypes.uint, (g1.index(c1),)) + a1 = UOp(Ops.IDIV, dtypes.uint, (l1, c1)) + a2 = UOp(Ops.IDIV, dtypes.uint, (l1, c2)) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) ops = [x.op for x in uops] diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index adeb7d4b06..22da6d9871 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -3,7 +3,7 @@ from typing import Optional, Any, Callable import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType -from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple +from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -131,7 +131,7 @@ def get_late_rewrite_patterns(ops, force_transcendental=False): if Ops.SHL in ops and Ops.SHR in ops: pat += [ (UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << powers_of_two[c.arg] if c.arg in powers_of_two else None), - (UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> powers_of_two[c.arg] if c.arg in powers_of_two else None) + (UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> powers_of_two[c.arg] if c.arg in powers_of_two and resolve(x>=0,False) else None) ] if Ops.NEG in ops: pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]