leave it to remove_childless try 1

This commit is contained in:
qazal
2024-03-03 16:17:49 +02:00
parent 5f728a6ab6
commit bf25e935f8
3 changed files with 17 additions and 13 deletions

View File

@@ -780,8 +780,8 @@ class TestLinearizerUOptimize(unittest.TestCase):
k.hand_coded_optimizations()
k.linearize()
store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST
store_els = [u for u in k.uops if u.uop is UOps.STORE][0].vin[2:]
assert len(store_els) == 4 and all(el.dtype is dtypes.float for el in store_els)
def test_grouped_store_locals_and_globals(self):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared:

View File

@@ -217,18 +217,16 @@ class UOpGraph:
self.uops.remove(u)
changed_something = True
# (recursively) remove childless uops
self.remove_childless()
# store float4 upcasts directly if possible
replaced_stores: Dict[UOp,UOp] = {}
replaced_stores: Dict[UOp,Tuple[UOp,...]] = {}
for u in self.uops:
if u.uop is not UOps.STORE or (val:=u.vin[-1]).uop is not UOps.CAST or cast(DType,val.dtype).count == 1: continue
if all(el.uop is UOps.GEP for el in val.vin): replaced_stores[u] = val.vin[0].vin[0]
elif all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = phi_resolve_acc(val)
if all(el.uop is UOps.GEP for el in val.vin) or all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = val.vin
for prev,new in replaced_stores.items():
self.uops.remove(prev.vin[-1]) # remove the old upcast
self.uops[self.uops.index(prev)].vin = (prev.vin[0],prev.vin[1],new) # replace with the float4 value
self.uops[self.uops.index(prev)].vin = (prev.vin[0],prev.vin[1],*new) # replace with the float4 elements
# (recursively) remove childless uops
self.remove_childless()
# add UOps.END*
self.add_ends()

View File

@@ -2,6 +2,7 @@ from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict,
import math, functools
from collections import defaultdict, Counter
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.codegen.uops import phi_resolve_acc
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import prod, strip_parens, getenv
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
@@ -118,9 +119,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
kk("}")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
if len(vin) > 3: kk(f"if ({r[vin[3]]}) {{")
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL))
if len(vin) > 3: kk("}")
if len(vin) > 3 and len(set(u.uop for u in vin[2:])) == 1:
# find the vector's definition
val = r[phi_resolve_acc(vin[2])] if vin[2].uop is UOps.PHI else r[vin[2].vin[0]]
kk(lang.render_store(r[vin[0]], vin[0].dtype, val, vin[2].dtype.vec(len(vin[2:])), strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL))
else:
if len(vin) > 3: kk(f"if ({r[vin[3]]}) {{")
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL))
if len(vin) > 3: kk("}")
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: