use Ops.REDUCE (#9721)

* decrease bert python time [pr]

* order copies

* Revert "order copies"

This reverts commit 3f62c8693b.

* rewrite count

* Ops.REDUCE

* acc first in the add chain

* Fix tensor core acc

* arange patterns look good

* fix multireduce gate

* reduce rewrite rule

* bump that to 15 minutes

* multiwmma isn't fusing

* gep through wmma is gep pushing

* bump that timeout too, it's all env setup

* add failing test
This commit is contained in:
George Hotz
2025-04-04 10:14:34 +08:00
committed by GitHub
parent 949459fdd6
commit cac8bcf8b5
11 changed files with 115 additions and 43 deletions

View File

@@ -149,7 +149,7 @@ jobs:
torchbackend:
name: Torch Backend Tests
runs-on: ubuntu-latest
timeout-minutes: 10
timeout-minutes: 15
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -186,7 +186,7 @@ jobs:
torchbackendmore:
name: Torch Backend Tests More
runs-on: ubuntu-latest
timeout-minutes: 10
timeout-minutes: 15
steps:
- name: Checkout Code
uses: actions/checkout@v4

View File

@@ -46,6 +46,8 @@ class TestArange(unittest.TestCase):
def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496)
def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0)
@unittest.skip("doesn't work yet. TODO: this absolutely should work")
def test_complexity_w_local_unroll4(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UNROLL, 0, 4)], limit=0)
@unittest.skip("doesn't work yet")
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)])

View File

@@ -1825,8 +1825,9 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> list[Kernel]:
lins: list[Kernel] = []
outbufs = [real_bufs[x.src[0].arg] for x in realized_ast.src]
device = real_bufs[0].device
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=Device.DEFAULT))
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=device))
def check_opt(opts, create_k, expected_color_size):
k = create_k()

View File

@@ -120,17 +120,19 @@ class TestUOpsStats(unittest.TestCase):
# NOTE; ops also include indexing ops
assert expected_ops <= ops and ops <= expected_ops * 2
def test_simple_matmul(self):
a = Tensor.empty(1024,1024)
b = Tensor.empty(1024,1024)
def test_simple_matmul(self, M=1024, N=1024, K=1024):
a = Tensor.empty(M,N)
b = Tensor.empty(N,K)
c = a@b
ops, mem = get_stats(c)
expected_ops = c.numel() * 1024 * 2
expected_ops = c.numel() * N * 2
required_mem = a.nbytes() + b.nbytes() + c.nbytes()
assert expected_ops <= ops and ops <= expected_ops * 1.2
# NOTE: it's hard to assert on the memory here, all depends on caching
assert required_mem <= mem
def test_simple_matmul_8192(self): self.test_simple_matmul(8192, 8192, 8192)
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
@@ -154,7 +156,7 @@ class TestUOpsStats(unittest.TestCase):
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
N = 100
N = 64
@unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe?
class TestStatsOptimized(unittest.TestCase):
@classmethod
@@ -174,6 +176,14 @@ class TestStatsOptimized(unittest.TestCase):
self.check_gemm(p)
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N)
def test_gemm_tc_unroll(self):
k = Kernel(self.ast_gemm)
if not k.apply_tensor_cores(): self.skipTest("no tensor cores")
k.apply_opt(Opt(OptOps.UNROLL, 0, 2))
p = k.to_program()
print(p.src)
self.check_gemm(p)
# this is a good lesson about why UPCASTing is a good idea
def test_gemm_one_upcasted(self):

View File

@@ -0,0 +1,28 @@
import unittest
from tinygrad import Tensor, Context, Device
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
class TestLinearizerRewrite(unittest.TestCase):
def test_reduction(self):
t = Tensor.ones((64,64), device="NULL").contiguous().realize()
out = (t*2).sum(axis=1)
with Context(SPLIT_REDUCEOP=0, DEVECTORIZE=0):
si = out.schedule()[-1]
k = Kernel(si.ast, Device["CPU"].renderer)
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
prg = k.to_program()
print(prg.src)
def test_arange(self):
out = Tensor.arange(32, device="NULL")
with Context(SPLIT_REDUCEOP=0, DEVECTORIZE=0):
si = out.schedule()[-1]
k = Kernel(si.ast, Device["CPU"].renderer)
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
prg = k.to_program()
print(prg.src)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,10 +1,10 @@
from typing import Optional, Any, Callable, cast
import functools, operator, itertools
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
from tinygrad.ops import graph_rewrite, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
@@ -280,6 +280,38 @@ pm_render = PatternMatcher([
lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
])
# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.ASSIGN ***
@dataclass
class ReduceContext:
acc_num: int = 0
def reduce_to_acc(ctx:ReduceContext, red:UOp):
inp, reduce_range = red.src[0], red.src[1:]
# if this has a horizontal reduction component, do that first
if inp.dtype != red.dtype:
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
horizontal_amount = inp.dtype.count//red.dtype.count
lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
else:
lst = [inp]
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
# if we have a range
if len(reduce_range) != 0:
acc = UOp(Ops.DEFINE_ACC, red.dtype, (red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
lst = [acc] + lst # put acc as the first element
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.assign(ret) if len(reduce_range) != 0 else ret
pm_reduce = PatternMatcher([
# REDUCE -> DEFINE_ACC+ASSIGN
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
# tensor core built in accumulate
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
])
# *** uop graph ***
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
@@ -287,6 +319,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
# remove reduce
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
# devectorize is optional
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)

View File

@@ -1,9 +1,9 @@
# the job of the lowerer is to do indexing
import functools, itertools, operator, math
import itertools, operator, math
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
from tinygrad.codegen.expander import expand_rewrite
@@ -116,17 +116,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
alu_op: Ops = x.arg[0]
ret = x.src[0]
# create acc
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
ctx.acc_num += 1
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+[ret.gep(i) for i in range(ret.dtype.count)])
else:
ret = acc.alu(alu_op, ret)
if not len(reduce_range): return ret
# create ACC and assign
return acc.assign(ret)
# REDUCE supports both "horizonal" reduction and range reduction. the horizonal elements are taken in the nearest group
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op)
def lower_load_store(ctx: IndexContext, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
@@ -135,8 +128,8 @@ def lower_load_store(ctx: IndexContext, x: UOp):
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.REDUCE:
reduce_input = x.src[2].src[0]
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
else: store_back = False
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes

View File

@@ -172,6 +172,19 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
return rem//(c//gcd)+quo
def gep_through_wmma(gep:UOp, wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
wmma_idxs = gep.arg[::out_sz]
for i in range(out_sz):
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
src_args = []
ssz = prod(x[1] for x in sz)
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
tsrcs.append(s.gep(tuple(src_args)))
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
gep_pushing = PatternMatcher([
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
@@ -193,6 +206,8 @@ gep_pushing = PatternMatcher([
if not isinstance(x.dtype, PtrDType) else None),
# VECTORIZE on same GEP
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
# push some GEPs through WMMAs
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
])
commutative = PatternMatcher([
@@ -395,19 +410,6 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
def gep_through_wmma(gep:UOp, wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
wmma_idxs = gep.arg[::out_sz]
for i in range(out_sz):
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
src_args = []
ssz = prod(x[1] for x in sz)
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
tsrcs.append(s.gep(tuple(src_args)))
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
@@ -431,14 +433,9 @@ sym = symbolic_flat+PatternMatcher([
# VECTORIZE void is SINK
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
# push some GEPs through WMMAs
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
# tensor core with a 0 input is acc
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
# tensor core cleanups
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry + remove longs
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
@@ -489,4 +486,7 @@ sym = symbolic_flat+PatternMatcher([
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
# move const multiply after REDUCE. TODO: enable later
#(UPat(Ops.REDUCE, src=(UPat.var("x")*UPat.cvar("c", vec=False),), arg=Ops.ADD, name="r", allow_any_len=True),
# lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
])

View File

@@ -42,7 +42,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
if isinstance(ji.prg, ViewOp): continue
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD", "NULL"}:
ji_graph_dev = Device[ji.bufs[0].device]
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None

View File

@@ -715,6 +715,7 @@ class UPat(MathTrait):
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
self.src: Any = None
assert self.name != "ctx", "UPat can't be named ctx"
assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}"
# try all permutations if it's a list
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]

View File

@@ -10,9 +10,11 @@ class NullProgram:
return 1e-4
class NullAllocator(Allocator):
dev = None
def _alloc(self, size, options): pass
def _copyin(self, dest, src:memoryview): pass
def _copyout(self, dest:memoryview, src): pass
def _transfer(self, dest, src, sz:int, src_dev, dest_dev): pass
class NullGraph(MultiGraphRunner):
def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-3