mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
zero fold (#1748)
* add constant fold * err, it's just zero folding * self store fold + caching * prints and more folds * simpler winograd kernels * remove childless uops
This commit is contained in:
@@ -55,5 +55,35 @@ class TestLinearizer(unittest.TestCase):
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
assert num_ops <= 1, "more alu uops than needed"
|
||||
|
||||
def test_zero_fold(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled):
|
||||
self.skipTest("Only Compiled uses linearizer")
|
||||
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack([a, b])
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
assert num_ops == 0, "more alu uops than needed"
|
||||
|
||||
@unittest.skip("constant folding not supported yet")
|
||||
def test_constant_fold(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled):
|
||||
self.skipTest("Only Compiled uses linearizer")
|
||||
|
||||
a, b = Tensor(2), Tensor(3)
|
||||
r = a * b
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
|
||||
assert num_ops <= 0, "more load or alu uops than needed"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Iterator, Union, Sequence, Final
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Iterator, Union, Sequence, Final, Set
|
||||
import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
@@ -93,8 +93,8 @@ class Linearizer(OptimizedKernel):
|
||||
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
return self.uop(UOps.ALU, dtype, (a, render_b), op, cachable=True)
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self),
|
||||
NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b),
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self, cachable=True),
|
||||
NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b, cachable=True),
|
||||
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
||||
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
|
||||
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
||||
@@ -133,11 +133,11 @@ class Linearizer(OptimizedKernel):
|
||||
assert valid.min == 1
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const)
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const)
|
||||
self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const, cachable=True)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
alt = self.uop(UOps.CONST, localtype, [], invalid_value)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE)
|
||||
alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE, cachable=True)
|
||||
else:
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [], MemOp(self.get_buffer_name(i), idx, self.bufs[i].__class__ is LocalBuffer, self.bufs[i].dtype, valid, invalid_value))
|
||||
ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key])
|
||||
@@ -359,10 +359,35 @@ class Linearizer(OptimizedKernel):
|
||||
# end the global (and maybe local) loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (loop_global_idxs+loop_local_idxs, "global+local") if not self.group_for_reduce else (loop_global_idxs, "global"))
|
||||
|
||||
# (recursively) remove childless uops
|
||||
UOPS_WO_SIDE_EFFECTS = {UOps.CONST, UOps.ALU, UOps.LOAD, UOps.CAST, UOps.GEP}
|
||||
while 1:
|
||||
has_child: Set[UOp] = set()
|
||||
for ru in self.uops:
|
||||
for vu in ru.vin:
|
||||
has_child.add(vu)
|
||||
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop not in UOPS_WO_SIDE_EFFECTS]
|
||||
if len(nu) == len(self.uops): break
|
||||
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
|
||||
self.uops = nu
|
||||
|
||||
return self
|
||||
|
||||
def uop(self, uop:UOps, dtype:Optional[DType], vin:Union[Tuple[UOp, ...], List[UOp]], arg:Any=None, cachable=False) -> UOp:
|
||||
key = (uop, dtype, tuple(vin), arg)
|
||||
if uop == UOps.STORE and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self store is noop
|
||||
if uop == UOps.ALU:
|
||||
# rewrites. NOTE: the rewritten NEG op is still around...
|
||||
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, [vin[0], vin[1].vin[0]], BinaryOps.SUB, cachable=cachable)
|
||||
# constant folding
|
||||
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.uop(UOps.CONST, dtype, [], -vin[0].arg, cachable=True)
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
||||
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
||||
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
||||
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
|
||||
self.uops.append(UOp(uop, dtype, tuple(vin), arg, len(self.uops)))
|
||||
if DEBUG >= 4: print(self.uops[-1])
|
||||
|
||||
@@ -105,7 +105,7 @@ class Interpreted:
|
||||
if DEBUG >= 3: st = time.perf_counter()
|
||||
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
||||
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if DEBUG >= 5 or (self.buffer != FlopCounter and DEBUG >= 3): print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if not created_context: context[ast] = ret
|
||||
if output is not None and output.output_buffer is not None:
|
||||
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype
|
||||
|
||||
@@ -173,7 +173,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
|
||||
elif uop == UOps.ALU:
|
||||
assert dtype is not None
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin])
|
||||
if child_count[u] == 1: r[u] = val
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else:
|
||||
r[u] = ssa('alu')
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")
|
||||
|
||||
@@ -114,7 +114,7 @@ def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
new_view = View(new_shape)
|
||||
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
|
||||
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
|
||||
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
||||
if DEBUG >= 5: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
||||
return new_view, True
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
||||
@@ -603,6 +603,7 @@ class Tensor:
|
||||
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x or reverse else self
|
||||
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
|
||||
if x.__class__ is not Tensor and x == -1.0: return -self
|
||||
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x)
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user