diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 3870c5581b..124bd77ca1 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -780,8 +780,8 @@ class TestLinearizerUOptimize(unittest.TestCase): k.hand_coded_optimizations() k.linearize() - store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0] - assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST + store_els = [u for u in k.uops if u.uop is UOps.STORE][0].vin[2:] + assert len(store_els) == 4 and all(el.dtype is dtypes.float for el in store_els) def test_grouped_store_locals_and_globals(self): if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared: diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index cf8be22bba..ef23437622 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -217,18 +217,16 @@ class UOpGraph: self.uops.remove(u) changed_something = True - # (recursively) remove childless uops - self.remove_childless() - # store float4 upcasts directly if possible - replaced_stores: Dict[UOp,UOp] = {} + replaced_stores: Dict[UOp,Tuple[UOp,...]] = {} for u in self.uops: if u.uop is not UOps.STORE or (val:=u.vin[-1]).uop is not UOps.CAST or cast(DType,val.dtype).count == 1: continue - if all(el.uop is UOps.GEP for el in val.vin): replaced_stores[u] = val.vin[0].vin[0] - elif all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = phi_resolve_acc(val) + if all(el.uop is UOps.GEP for el in val.vin) or all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = val.vin for prev,new in replaced_stores.items(): - self.uops.remove(prev.vin[-1]) # remove the old upcast - self.uops[self.uops.index(prev)].vin = (prev.vin[0],prev.vin[1],new) # replace with the float4 value + self.uops[self.uops.index(prev)].vin = (prev.vin[0],prev.vin[1],*new) # replace with the float4 elements + + # (recursively) remove childless uops + self.remove_childless() # add UOps.END* self.add_ends() diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 6f286b6882..ff3a4ec119 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, import math, functools from collections import defaultdict, Counter from tinygrad.codegen.linearizer import UOps, UOp +from tinygrad.codegen.uops import phi_resolve_acc from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.helpers import prod, strip_parens, getenv from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType @@ -118,9 +119,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st kk("}") elif uop is UOps.STORE: assert vin[0].dtype is not None and vin[2].dtype is not None - if len(vin) > 3: kk(f"if ({r[vin[3]]}) {{") - kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)) - if len(vin) > 3: kk("}") + if len(vin) > 3 and len(set(u.uop for u in vin[2:])) == 1: + # find the vector's definition + val = r[phi_resolve_acc(vin[2])] if vin[2].uop is UOps.PHI else r[vin[2].vin[0]] + kk(lang.render_store(r[vin[0]], vin[0].dtype, val, vin[2].dtype.vec(len(vin[2:])), strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)) + else: + if len(vin) > 3: kk(f"if ({r[vin[3]]}) {{") + kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)) + if len(vin) > 3: kk("}") else: assert dtype is not None, f"None dtype for uop {uop}" if uop is UOps.LOOP: