From abb10c83cd0fcdc5817243d01ea976928eb73101 Mon Sep 17 00:00:00 2001 From: qazal Date: Fri, 19 Apr 2024 07:18:21 +0300 Subject: [PATCH] tunable multi output fusion --- tinygrad/engine/schedule.py | 4 ++-- tinygrad/helpers.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2ba0803f7c..e433191dc0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Tuple, List, Dict, Optional, Set, DefaultDict from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer -from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv +from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv from tinygrad.shape.symbolic import Variable from tinygrad.dtype import ImageDType, dtypes from tinygrad.lazy import LazyBuffer @@ -213,7 +213,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul output_groups: DefaultDict[Tuple, List[LazyBuffer]] = defaultdict(list) for r in realizes: if r.realized is not None or r.op is LoadOps.CONST or r in seen: continue - output_groups[(reduce_for_op[r], ) if r in reduce_for_op else (r, )].append(r) + output_groups[(reduce_for_op[r], ) if r in reduce_for_op and MULTIOUTPUT else (r, )].append(r) # preschedule all buffers in realizes prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()} diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 3cacf7dfa4..6de7afe507 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -97,6 +97,7 @@ class ContextVar: DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1) GRAPH, GRAPHPATH, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("RING", 1) +MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1) # **************** global state Counters ****************