* add constant fold

* err, it's just zero folding

* self store fold + caching

* prints and more folds

* simpler winograd kernels

* remove childless uops
This commit is contained in:
George Hotz
2023-09-03 13:48:11 -07:00
committed by GitHub
parent e17b1af160
commit ed194a1d3b
6 changed files with 66 additions and 9 deletions

View File

@@ -55,5 +55,35 @@ class TestLinearizer(unittest.TestCase):
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
assert num_ops <= 1, "more alu uops than needed"
def test_zero_fold(self):
if not isinstance(Device[Device.DEFAULT], Compiled):
self.skipTest("Only Compiled uses linearizer")
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack([a, b])
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
assert num_ops == 0, "more alu uops than needed"
@unittest.skip("constant folding not supported yet")
def test_constant_fold(self):
if not isinstance(Device[Device.DEFAULT], Compiled):
self.skipTest("Only Compiled uses linearizer")
a, b = Tensor(2), Tensor(3)
r = a * b
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k.process()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
assert num_ops <= 0, "more load or alu uops than needed"
if __name__ == '__main__':
unittest.main()

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Iterator, Union, Sequence, Final
from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Iterator, Union, Sequence, Final, Set
import itertools, math, functools
from collections import defaultdict
from enum import Enum, auto
@@ -93,8 +93,8 @@ class Linearizer(OptimizedKernel):
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
return self.uop(UOps.ALU, dtype, (a, render_b), op, cachable=True)
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self),
NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b),
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self, cachable=True),
NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b, cachable=True),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV),
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
@@ -133,11 +133,11 @@ class Linearizer(OptimizedKernel):
assert valid.min == 1
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const)
elif this_const is not None:
self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const)
self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const, cachable=True)
if valid.min == 0 and valid.max == 1:
valid_rendered = valid.render(self.render_ops, self)
alt = self.uop(UOps.CONST, localtype, [], invalid_value)
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE)
alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True)
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE, cachable=True)
else:
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [], MemOp(self.get_buffer_name(i), idx, self.bufs[i].__class__ is LocalBuffer, self.bufs[i].dtype, valid, invalid_value))
ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key])
@@ -359,10 +359,35 @@ class Linearizer(OptimizedKernel):
# end the global (and maybe local) loop
self.uop(UOps.ENDLOOP, None, [], (loop_global_idxs+loop_local_idxs, "global+local") if not self.group_for_reduce else (loop_global_idxs, "global"))
# (recursively) remove childless uops
UOPS_WO_SIDE_EFFECTS = {UOps.CONST, UOps.ALU, UOps.LOAD, UOps.CAST, UOps.GEP}
while 1:
has_child: Set[UOp] = set()
for ru in self.uops:
for vu in ru.vin:
has_child.add(vu)
nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop not in UOPS_WO_SIDE_EFFECTS]
if len(nu) == len(self.uops): break
if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
return self
def uop(self, uop:UOps, dtype:Optional[DType], vin:Union[Tuple[UOp, ...], List[UOp]], arg:Any=None, cachable=False) -> UOp:
key = (uop, dtype, tuple(vin), arg)
if uop == UOps.STORE and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self store is noop
if uop == UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, [vin[0], vin[1].vin[0]], BinaryOps.SUB, cachable=cachable)
# constant folding
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.uop(UOps.CONST, dtype, [], -vin[0].arg, cachable=True)
# zero folding
for x in [0,1]:
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
self.uops.append(UOp(uop, dtype, tuple(vin), arg, len(self.uops)))
if DEBUG >= 4: print(self.uops[-1])

View File

@@ -105,7 +105,7 @@ class Interpreted:
if DEBUG >= 3: st = time.perf_counter()
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
if DEBUG >= 5 or (self.buffer != FlopCounter and DEBUG >= 3): print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
if not created_context: context[ast] = ret
if output is not None and output.output_buffer is not None:
assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype

View File

@@ -173,7 +173,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T
elif uop == UOps.ALU:
assert dtype is not None
val = lang.code_for_op[args](*[r[x] for x in vin])
if child_count[u] == 1: r[u] = val
assert child_count[u] != 0, f"childless ALU op found {u}"
if child_count[u] <= 1: r[u] = val
else:
r[u] = ssa('alu')
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};")

View File

@@ -114,7 +114,7 @@ def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
new_view = View(new_shape)
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
if DEBUG >= 5: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
return new_view, True
@functools.lru_cache(maxsize=None)

View File

@@ -603,6 +603,7 @@ class Tensor:
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor:
if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
if x.__class__ is not Tensor and x == -1.0: return -self
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x)
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: