From 6e0c16dfb030d6d49e697dbe6f1b074a3516625a Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 3 Jun 2024 19:39:02 +0800 Subject: [PATCH] cleanup render_reduceop (#4807) * update acc key * refactor return type * remove return type * run all reduces * set acc key [run_process_replay] * local_idxs are copied in render_reduceop [run_process_replay] --- tinygrad/codegen/linearizer.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index b75e2bf121..0698f618bc 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -96,8 +96,7 @@ class Linearizer(Kernel): acc_count = 0 for idx, valid, rep_idx in zip(e_idxs, e_valids, iter_idxs(expand_vars)): this_const, idx, valid = (invalid_value, NumNode(0), NumNode(1)) if valid.max == 0 else (const, idx, valid) - # todo: when multiple reduceops are supported, clearly disambiguate and test acc load keys are unique for each reduceop - key = f"{acc is not None}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501 + key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501 if key not in self.load_cache: if acc is not None: self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count)) @@ -309,9 +308,9 @@ class Linearizer(Kernel): # end the late reduce loop self.load_cache.clear() - # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have - # been rewritten with fake end_local_idxs. - return (accs, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs) + # all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have + # been rewritten with fake end_local_idxs. + return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) def linearize(self): @@ -401,9 +400,9 @@ class Linearizer(Kernel): fake_reduce_idxs = [x*0 for x in reduce_idxs] alias_buf_idxs = self.index_local_aliases(global_idxs,local_idxs,reduce_idxs,upcast_idxs,full_upcast_idxs) # render reduce op - for reduceop in [self.reduceop] if self.reduceop is not None else []: - accs,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = \ - self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs,full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop]) + for reduceop in self.reduceops: + local_idxs, upcast_idxs = self.render_reduceop(reduceop,accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs, + full_upcast_idxs,reduce_idxs,fake_reduce_idxs,alias_buf_idxs[reduceop]) # load latebufs loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \