mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
use Ops.REDUCE (#9721)
* decrease bert python time [pr]
* order copies
* Revert "order copies"
This reverts commit 3f62c8693b.
* rewrite count
* Ops.REDUCE
* acc first in the add chain
* Fix tensor core acc
* arange patterns look good
* fix multireduce gate
* reduce rewrite rule
* bump that to 15 minutes
* multiwmma isn't fusing
* gep through wmma is gep pushing
* bump that timeout too, it's all env setup
* add failing test
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -149,7 +149,7 @@ jobs:
|
||||
torchbackend:
|
||||
name: Torch Backend Tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -186,7 +186,7 @@ jobs:
|
||||
torchbackendmore:
|
||||
name: Torch Backend Tests More
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -46,6 +46,8 @@ class TestArange(unittest.TestCase):
|
||||
def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496)
|
||||
|
||||
def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0)
|
||||
@unittest.skip("doesn't work yet. TODO: this absolutely should work")
|
||||
def test_complexity_w_local_unroll4(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.UNROLL, 0, 4)], limit=0)
|
||||
@unittest.skip("doesn't work yet")
|
||||
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)])
|
||||
|
||||
|
||||
@@ -1825,8 +1825,9 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
|
||||
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> list[Kernel]:
|
||||
lins: list[Kernel] = []
|
||||
outbufs = [real_bufs[x.src[0].arg] for x in realized_ast.src]
|
||||
device = real_bufs[0].device
|
||||
|
||||
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=Device.DEFAULT))
|
||||
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=device))
|
||||
|
||||
def check_opt(opts, create_k, expected_color_size):
|
||||
k = create_k()
|
||||
|
||||
@@ -120,17 +120,19 @@ class TestUOpsStats(unittest.TestCase):
|
||||
# NOTE; ops also include indexing ops
|
||||
assert expected_ops <= ops and ops <= expected_ops * 2
|
||||
|
||||
def test_simple_matmul(self):
|
||||
a = Tensor.empty(1024,1024)
|
||||
b = Tensor.empty(1024,1024)
|
||||
def test_simple_matmul(self, M=1024, N=1024, K=1024):
|
||||
a = Tensor.empty(M,N)
|
||||
b = Tensor.empty(N,K)
|
||||
c = a@b
|
||||
ops, mem = get_stats(c)
|
||||
expected_ops = c.numel() * 1024 * 2
|
||||
expected_ops = c.numel() * N * 2
|
||||
required_mem = a.nbytes() + b.nbytes() + c.nbytes()
|
||||
assert expected_ops <= ops and ops <= expected_ops * 1.2
|
||||
# NOTE: it's hard to assert on the memory here, all depends on caching
|
||||
assert required_mem <= mem
|
||||
|
||||
def test_simple_matmul_8192(self): self.test_simple_matmul(8192, 8192, 8192)
|
||||
|
||||
#MULACC should have the same stats as MUL + ADD
|
||||
def test_mulacc(self):
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
@@ -154,7 +156,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||
|
||||
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
|
||||
|
||||
N = 100
|
||||
N = 64
|
||||
@unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe?
|
||||
class TestStatsOptimized(unittest.TestCase):
|
||||
@classmethod
|
||||
@@ -174,6 +176,14 @@ class TestStatsOptimized(unittest.TestCase):
|
||||
self.check_gemm(p)
|
||||
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N)
|
||||
|
||||
def test_gemm_tc_unroll(self):
|
||||
k = Kernel(self.ast_gemm)
|
||||
if not k.apply_tensor_cores(): self.skipTest("no tensor cores")
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 2))
|
||||
p = k.to_program()
|
||||
print(p.src)
|
||||
self.check_gemm(p)
|
||||
|
||||
# this is a good lesson about why UPCASTing is a good idea
|
||||
|
||||
def test_gemm_one_upcasted(self):
|
||||
|
||||
28
test/unit/test_linearizer_rewrite.py
Normal file
28
test/unit/test_linearizer_rewrite.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Context, Device
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
|
||||
class TestLinearizerRewrite(unittest.TestCase):
|
||||
def test_reduction(self):
|
||||
t = Tensor.ones((64,64), device="NULL").contiguous().realize()
|
||||
out = (t*2).sum(axis=1)
|
||||
with Context(SPLIT_REDUCEOP=0, DEVECTORIZE=0):
|
||||
si = out.schedule()[-1]
|
||||
k = Kernel(si.ast, Device["CPU"].renderer)
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
|
||||
def test_arange(self):
|
||||
out = Tensor.arange(32, device="NULL")
|
||||
with Context(SPLIT_REDUCEOP=0, DEVECTORIZE=0):
|
||||
si = out.schedule()[-1]
|
||||
k = Kernel(si.ast, Device["CPU"].renderer)
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Optional, Any, Callable, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve
|
||||
from tinygrad.ops import graph_rewrite, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.codegen.symbolic import symbolic_simple, split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat, gep_pushing
|
||||
from tinygrad.helpers import getenv, flatten, TRANSCENDENTAL, AMX, prod, DEVECTORIZE
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -280,6 +280,38 @@ pm_render = PatternMatcher([
|
||||
lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC+Ops.ASSIGN ***
|
||||
|
||||
@dataclass
|
||||
class ReduceContext:
|
||||
acc_num: int = 0
|
||||
|
||||
def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
inp, reduce_range = red.src[0], red.src[1:]
|
||||
# if this has a horizontal reduction component, do that first
|
||||
if inp.dtype != red.dtype:
|
||||
# NOTE: [0 1 2 3 4 5 6 7] -> [0+4, 1+5, 2+6, 3+7]
|
||||
horizontal_amount = inp.dtype.count//red.dtype.count
|
||||
lst = [inp.gep(tuple(range(i, inp.dtype.count, horizontal_amount))) for i in range(0, horizontal_amount)]
|
||||
else:
|
||||
lst = [inp]
|
||||
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
||||
# if we have a range
|
||||
if len(reduce_range) != 0:
|
||||
acc = UOp(Ops.DEFINE_ACC, red.dtype, (red.const_like(identity_element(red.arg, red.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
||||
lst = [acc] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.assign(ret) if len(reduce_range) != 0 else ret
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
# tensor core built in accumulate
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
@@ -287,6 +319,9 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# remove reduce
|
||||
sink = graph_rewrite(sink, pm_reduce+gep_pushing, ctx=ReduceContext(), name="remove_reduce")
|
||||
|
||||
# devectorize is optional
|
||||
if DEVECTORIZE >= 2: sink = graph_rewrite(sink, sym+load_store_folding+load_store_indexing, ctx=opts)
|
||||
elif DEVECTORIZE: sink = graph_rewrite(sink, sym+devectorize+load_store_folding+correct_load_store+load_store_indexing, ctx=opts)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
import functools, itertools, operator, math
|
||||
import itertools, operator, math
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
|
||||
from tinygrad.codegen.expander import expand_rewrite
|
||||
@@ -116,17 +116,10 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
assert all(x.op is Ops.UNROLL for x in reduce_expand), f"not all UNROLLS in {reduce_expand} for {x.axis_arg}"
|
||||
alu_op: Ops = x.arg[0]
|
||||
ret = x.src[0]
|
||||
# create acc
|
||||
acc = UOp(Ops.DEFINE_ACC, x.dtype, (x.const_like(identity_element(alu_op, x.dtype.scalar())),) + tuple(reduce_range), (ctx.acc_num,))
|
||||
ctx.acc_num += 1
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [acc]+[ret.gep(i) for i in range(ret.dtype.count)])
|
||||
else:
|
||||
ret = acc.alu(alu_op, ret)
|
||||
if not len(reduce_range): return ret
|
||||
# create ACC and assign
|
||||
return acc.assign(ret)
|
||||
# REDUCE supports both "horizonal" reduction and range reduction. the horizonal elements are taken in the nearest group
|
||||
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple(reduce_range), alu_op)
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
@@ -135,8 +128,8 @@ def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.ASSIGN:
|
||||
reduce_input = x.src[2].src[1].src[1] if x.src[2].src[1].src[1] is not x.src[2].src[0] else x.src[2].src[1].src[0]
|
||||
if cast(PtrDType, x.src[0].dtype).local and x.src[2].op is Ops.REDUCE:
|
||||
reduce_input = x.src[2].src[0]
|
||||
store_back = reduce_input.op is Ops.LOAD and cast(PtrDType, reduce_input.src[0].dtype).local
|
||||
else: store_back = False
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
|
||||
@@ -172,6 +172,19 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
|
||||
if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd
|
||||
return rem//(c//gcd)+quo
|
||||
|
||||
def gep_through_wmma(gep:UOp, wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
wmma_idxs = gep.arg[::out_sz]
|
||||
for i in range(out_sz):
|
||||
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
|
||||
tsrcs = []
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
src_args = []
|
||||
ssz = prod(x[1] for x in sz)
|
||||
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
|
||||
tsrcs.append(s.gep(tuple(src_args)))
|
||||
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
||||
|
||||
gep_pushing = PatternMatcher([
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
@@ -193,6 +206,8 @@ gep_pushing = PatternMatcher([
|
||||
if not isinstance(x.dtype, PtrDType) else None),
|
||||
# VECTORIZE on same GEP
|
||||
(UPat(Ops.VECTORIZE, name="v", src=UPat(Ops.GEP, src=(UPat.var("x"),))), lambda v,x: x.gep(tuple(get_single_element(i.arg) for i in v.src))),
|
||||
# push some GEPs through WMMAs
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
])
|
||||
|
||||
commutative = PatternMatcher([
|
||||
@@ -395,19 +410,6 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
def gep_through_wmma(gep:UOp, wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
wmma_idxs = gep.arg[::out_sz]
|
||||
for i in range(out_sz):
|
||||
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
|
||||
tsrcs = []
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
src_args = []
|
||||
ssz = prod(x[1] for x in sz)
|
||||
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
|
||||
tsrcs.append(s.gep(tuple(src_args)))
|
||||
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
||||
|
||||
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
|
||||
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
|
||||
|
||||
@@ -431,14 +433,9 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# VECTORIZE void is SINK
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
||||
# push some GEPs through WMMAs
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
# tensor core with a 0 input is acc
|
||||
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
||||
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
||||
# tensor core cleanups
|
||||
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry + remove longs
|
||||
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
|
||||
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
|
||||
@@ -489,4 +486,7 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
|
||||
# move const multiply after REDUCE. TODO: enable later
|
||||
#(UPat(Ops.REDUCE, src=(UPat.var("x")*UPat.cvar("c", vec=False),), arg=Ops.ADD, name="r", allow_any_len=True),
|
||||
# lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
|
||||
])
|
||||
|
||||
@@ -42,7 +42,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
|
||||
if isinstance(ji.prg, ViewOp): continue
|
||||
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
||||
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD", "NULL"}:
|
||||
ji_graph_dev = Device[ji.bufs[0].device]
|
||||
|
||||
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None
|
||||
|
||||
@@ -715,6 +715,7 @@ class UPat(MathTrait):
|
||||
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
||||
self.src: Any = None
|
||||
assert self.name != "ctx", "UPat can't be named ctx"
|
||||
assert dtype is None or isinstance(dtype, DType) or all(isinstance(x, DType) for x in dtype), f"invalid dtype {dtype}"
|
||||
|
||||
# try all permutations if it's a list
|
||||
if isinstance(src, list): self.src = list(itertools.permutations(src)) if not all_same(src) else [src]
|
||||
|
||||
@@ -10,9 +10,11 @@ class NullProgram:
|
||||
return 1e-4
|
||||
|
||||
class NullAllocator(Allocator):
|
||||
dev = None
|
||||
def _alloc(self, size, options): pass
|
||||
def _copyin(self, dest, src:memoryview): pass
|
||||
def _copyout(self, dest:memoryview, src): pass
|
||||
def _transfer(self, dest, src, sz:int, src_dev, dest_dev): pass
|
||||
|
||||
class NullGraph(MultiGraphRunner):
|
||||
def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-3
|
||||
|
||||
Reference in New Issue
Block a user