mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
things that are only used in one place don't belong in helpers [pr] (#6878)
* things that are only used in one place don't belong in helpers [pr] * pretty print moved
This commit is contained in:
@@ -3,8 +3,8 @@ from typing import Dict, Union, Tuple, Any, List, cast
|
||||
import functools, hashlib
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import dedup, pretty_print, prod
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps
|
||||
from tinygrad.helpers import dedup, prod
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps, pretty_print
|
||||
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, struct, contextlib, tempfile, pathlib, json, time, atexit, random
|
||||
from tinygrad import Device, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import CI, getenv, Context, ProfileLogger
|
||||
from tinygrad.device import Buffer, BufferOptions, HCQCompiled
|
||||
from tinygrad.helpers import CI, getenv, Context
|
||||
from tinygrad.device import Buffer, BufferOptions, ProfileLogger, HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import gzip, unittest
|
||||
from PIL import Image
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
|
||||
from tinygrad.tensor import get_shape
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -10,13 +10,13 @@ from tinygrad.ops import resolve
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import ImageDType, PtrDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, get_contraction, to_function_name, diskcache_put
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
|
||||
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||
from tinygrad.codegen.lowerer import ast_to_uop
|
||||
from tinygrad.codegen.lowerer import ast_to_uop, get_contraction
|
||||
|
||||
class OptOps(Enum):
|
||||
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import functools, itertools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
|
||||
@@ -8,7 +8,14 @@ from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
||||
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
||||
except ValueError: return None
|
||||
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
||||
|
||||
# ***** indexing *****
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import multiprocessing, decimal, statistics, random
|
||||
import multiprocessing, decimal, statistics, random, json
|
||||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Iterator
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Iterator, Union
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
|
||||
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
||||
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILEPATH, PROFILE
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -495,6 +495,44 @@ class HCQProgram:
|
||||
if wait: self.device.timeline_signal.wait(self.device.timeline_value - 1)
|
||||
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
||||
|
||||
class ProfileLogger:
|
||||
writers: int = 0
|
||||
mjson: List[Dict] = []
|
||||
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
||||
|
||||
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
||||
|
||||
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
||||
|
||||
def _ensure_actor(self, actor_name, subactor_name):
|
||||
if actor_name not in self.actors:
|
||||
self.actors[actor_name] = (pid:=len(self.actors))
|
||||
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
||||
|
||||
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
||||
self.actors[subactor_key] = (tid:=len(self.actors))
|
||||
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
||||
|
||||
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
||||
|
||||
def __del__(self):
|
||||
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
||||
for name, st, et, actor_name, subactor_name, args in self.events:
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
||||
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
||||
|
||||
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
||||
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
|
||||
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
|
||||
|
||||
ProfileLogger.writers -= 1
|
||||
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
||||
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
||||
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
||||
|
||||
class HCQCompiled(Compiled):
|
||||
"""
|
||||
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
||||
import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect, importlib
|
||||
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
from typing_extensions import TypeGuard
|
||||
from tinygrad.shape.shapetracker import sint
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
@@ -68,21 +67,6 @@ def get_child(obj, key):
|
||||
else: obj = getattr(obj, k)
|
||||
return obj
|
||||
|
||||
def get_shape(x) -> Tuple[int, ...]:
|
||||
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str): return ()
|
||||
if (aapi := (hasattr(x, "shape") and x.shape == ())): return ()
|
||||
subs = [get_shape(xi) for xi in x]
|
||||
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
|
||||
slen = 1 if aapi else len(subs)
|
||||
return (slen,) + (subs[0] if subs else ())
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
||||
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
||||
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
||||
except ValueError: return None
|
||||
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@@ -171,44 +155,6 @@ class Profiling(contextlib.ContextDecorator):
|
||||
colored(_format_fcn(fcn).ljust(50), "yellow"),
|
||||
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if scallers else '')
|
||||
|
||||
class ProfileLogger:
|
||||
writers: int = 0
|
||||
mjson: List[Dict] = []
|
||||
actors: Dict[Union[str, Tuple[str, str]], int] = {}
|
||||
|
||||
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
|
||||
|
||||
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
|
||||
|
||||
def _ensure_actor(self, actor_name, subactor_name):
|
||||
if actor_name not in self.actors:
|
||||
self.actors[actor_name] = (pid:=len(self.actors))
|
||||
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
||||
|
||||
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
|
||||
self.actors[subactor_key] = (tid:=len(self.actors))
|
||||
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
||||
|
||||
return self.actors[actor_name], self.actors.get(subactor_key, -1)
|
||||
|
||||
def __del__(self):
|
||||
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
||||
for name, st, et, actor_name, subactor_name, args in self.events:
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
|
||||
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
|
||||
|
||||
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
|
||||
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
|
||||
pid, tid = self._ensure_actor(actor_name,subactor_name)
|
||||
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
|
||||
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
|
||||
|
||||
ProfileLogger.writers -= 1
|
||||
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
|
||||
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
||||
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
|
||||
|
||||
# *** universal database cache ***
|
||||
|
||||
_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
|
||||
@@ -363,16 +309,6 @@ class tqdm:
|
||||
class trange(tqdm):
|
||||
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|
||||
|
||||
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
||||
def dfs(x:Any, cache:dict):
|
||||
for s in srcfn(x) or []:
|
||||
cache.setdefault(s, [len(cache), 0, False])[1] += 1
|
||||
if cache[s][1] == 1: dfs(s, cache)
|
||||
if cache is None: dfs(x, cache:={})
|
||||
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
||||
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
||||
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
||||
|
||||
# *** universal support for code object pickling
|
||||
|
||||
def _reconstruct_code(*args): return types.CodeType(*args)
|
||||
|
||||
@@ -6,7 +6,7 @@ from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, pretty_print, prod, getenv, all_same
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -159,6 +159,17 @@ def resolve(x, default:bool=True):
|
||||
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax)
|
||||
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
|
||||
# used for UOp and UPat
|
||||
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
|
||||
def dfs(x:Any, cache:dict):
|
||||
for s in srcfn(x) or []:
|
||||
cache.setdefault(s, [len(cache), 0, False])[1] += 1
|
||||
if cache[s][1] == 1: dfs(s, cache)
|
||||
if cache is None: dfs(x, cache:={})
|
||||
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
|
||||
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
|
||||
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
|
||||
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
class UOp(MathTrait):
|
||||
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Seque
|
||||
from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
@@ -57,6 +57,14 @@ def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noq
|
||||
del ret.srcs
|
||||
return ret
|
||||
|
||||
def get_shape(x) -> Tuple[int, ...]:
|
||||
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str): return ()
|
||||
if (aapi := (hasattr(x, "shape") and x.shape == ())): return ()
|
||||
subs = [get_shape(xi) for xi in x]
|
||||
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
|
||||
slen = 1 if aapi else len(subs)
|
||||
return (slen,) + (subs[0] if subs else ())
|
||||
|
||||
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
|
||||
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user