mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
leave it to remove_childless try 1
This commit is contained in:
@@ -780,8 +780,8 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST
|
||||
store_els = [u for u in k.uops if u.uop is UOps.STORE][0].vin[2:]
|
||||
assert len(store_els) == 4 and all(el.dtype is dtypes.float for el in store_els)
|
||||
|
||||
def test_grouped_store_locals_and_globals(self):
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared:
|
||||
|
||||
@@ -217,18 +217,16 @@ class UOpGraph:
|
||||
self.uops.remove(u)
|
||||
changed_something = True
|
||||
|
||||
# (recursively) remove childless uops
|
||||
self.remove_childless()
|
||||
|
||||
# store float4 upcasts directly if possible
|
||||
replaced_stores: Dict[UOp,UOp] = {}
|
||||
replaced_stores: Dict[UOp,Tuple[UOp,...]] = {}
|
||||
for u in self.uops:
|
||||
if u.uop is not UOps.STORE or (val:=u.vin[-1]).uop is not UOps.CAST or cast(DType,val.dtype).count == 1: continue
|
||||
if all(el.uop is UOps.GEP for el in val.vin): replaced_stores[u] = val.vin[0].vin[0]
|
||||
elif all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = phi_resolve_acc(val)
|
||||
if all(el.uop is UOps.GEP for el in val.vin) or all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = val.vin
|
||||
for prev,new in replaced_stores.items():
|
||||
self.uops.remove(prev.vin[-1]) # remove the old upcast
|
||||
self.uops[self.uops.index(prev)].vin = (prev.vin[0],prev.vin[1],new) # replace with the float4 value
|
||||
self.uops[self.uops.index(prev)].vin = (prev.vin[0],prev.vin[1],*new) # replace with the float4 elements
|
||||
|
||||
# (recursively) remove childless uops
|
||||
self.remove_childless()
|
||||
|
||||
# add UOps.END*
|
||||
self.add_ends()
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict,
|
||||
import math, functools
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.codegen.uops import phi_resolve_acc
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import prod, strip_parens, getenv
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
@@ -118,9 +119,14 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
if len(vin) > 3: kk(f"if ({r[vin[3]]}) {{")
|
||||
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL))
|
||||
if len(vin) > 3: kk("}")
|
||||
if len(vin) > 3 and len(set(u.uop for u in vin[2:])) == 1:
|
||||
# find the vector's definition
|
||||
val = r[phi_resolve_acc(vin[2])] if vin[2].uop is UOps.PHI else r[vin[2].vin[0]]
|
||||
kk(lang.render_store(r[vin[0]], vin[0].dtype, val, vin[2].dtype.vec(len(vin[2:])), strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL))
|
||||
else:
|
||||
if len(vin) > 3: kk(f"if ({r[vin[3]]}) {{")
|
||||
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL))
|
||||
if len(vin) > 3: kk("}")
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
|
||||
Reference in New Issue
Block a user