diff --git a/test/test_schedule.py b/test/test_schedule.py index abd9a71031..cb0ad9e3e8 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1,7 +1,6 @@ # this will be the new test_ops for the next level # schedule confirms the right things are capable of fusing # NOTE: this has overlap with external_test_opt.py -# ruff: noqa: E501 import unittest import numpy as np @@ -12,7 +11,6 @@ from tinygrad import nn, dtypes, Device, Tensor from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.spec import type_verify, shape_spec @@ -1958,47 +1956,14 @@ class TestSwizzle(unittest.TestCase): t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy() np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3) - @unittest.skip("this swizzle can't be decided after the ADD") + @unittest.skip("TODO: this swizzle isn't resolvable when there's a mask") def test_swizzle_failure_permute(self): - sink = UOp(Ops.SINK, dtypes.void, arg=None, src=( - UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(20, 65), src=(UOp(Ops.DEVICE, arg="METAL"),)), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(8, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), - x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)), - UOp(Ops.WHERE, dtypes.float, arg=None, src=( - x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(Ops.CONST, dtypes.float, arg=1.0, src=()), - x15:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)), - UOp(Ops.WHERE, dtypes.float, arg=None, src=( - x12, - UOp(Ops.CONST, dtypes.float, arg=0.0003418803389649838, src=()), - x15,)),)), - x6,)),)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.WHERE, dtypes.float, arg=None, src=( - x12, - UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()), - x15,)), - UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(2, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), - x10,)), - UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( - UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(4, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), - UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)) - ret = swizzle_rewrite(sink) - self.assertEqual(swizzle_cnt(ret), 0) + a = Tensor.empty(45,65).T.reshape(65,1,45).pad((None,None,(0,45))).expand(65,45,90) + b = Tensor.empty(45,65) + a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,)) + b_reduce = b.sum(axis=(0,)) + t = a_reduce+b_reduce + with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1)) def store_val(si:ScheduleItem): return si.ast.src[0].src[2] zero_pm = UPat(Ops.CONST, arg=0)