From 2a2f81dd3d7bd6043e711561bd3ed0d2739e5303 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 6 Jun 2026 09:49:09 -0700 Subject: [PATCH] remove ANON from addrspace, refactor marg (#16523) * remove ANON from addrspace, refactor marg * as_shape * as_shape is cached --- test/backend/test_uops.py | 3 ++- tinygrad/dtype.py | 6 +++--- tinygrad/uop/ops.py | 12 ++++++------ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/test/backend/test_uops.py b/test/backend/test_uops.py index e1be405f1c..38ef195b1f 100644 --- a/test/backend/test_uops.py +++ b/test/backend/test_uops.py @@ -204,7 +204,8 @@ class TestLocalAccess(unittest.TestCase): out = Device[Device.DEFAULT].renderer.render(uops) # half is supported in wgsl, so it doesn't have to be packed corrected_size = size//(4//dtype.itemsize) if dtype != dtypes.half else size - self.assertIn(f"temp0: array<{Device[Device.DEFAULT].renderer.buf_map(dtype)},{corrected_size}>;", out) + # temp0: array<{Device[Device.DEFAULT].renderer.buf_map(dtype)},{corrected_size}>; + self.assertIn(f",{corrected_size}>;", out) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") @unittest.skip("tinygrad doesn't support this behavior") diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index a76c8b6da7..3c85b318ab 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -3,7 +3,7 @@ from typing import Final, ClassVar, Callable, Literal import math, struct, ctypes, functools from dataclasses import dataclass, fields from tinygrad.helpers import ceildiv, getenv, prod, round_up, OSX -from enum import Enum, auto +from enum import IntEnum, auto class ConstFloat(float): """Float subclass that distinguishes -0.0 from 0.0 and where nan == nan.""" @@ -49,9 +49,9 @@ class DTypeMetaClass(type): DTypeMetaClass.dcache[args] = ret = super().__call__(*args) return ret -class AddrSpace(Enum): +class AddrSpace(IntEnum): def __repr__(self): return str(self) - GLOBAL = auto(); LOCAL = auto(); REG = auto(); ANON = auto() # noqa: E702 + GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702 @dataclass(frozen=True, eq=False) class DType(metaclass=DTypeMetaClass): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 2633d49b5e..492b118901 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -686,15 +686,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.STACK: return self.src[i].sintify() case _: raise RuntimeError(f"no sgep on {self.op}") + @functools.cached_property + def as_shape(self) -> tuple[sint, ...]: + return tuple(ssimplify(self.sgep(i)) for i in range(max(self.dtype.count, len(self.src)))) + @functools.cached_property def marg(self): match self.op: - case Ops.RESHAPE | Ops.EXPAND: return tuple(ssimplify(self.src[1].sgep(i)) for i in range(self.src[1].dtype.count)) - case Ops.PAD | Ops.SHRINK: - # this is like broadcasting for shapes - return tuple(((ssimplify(self.src[1]) if self.src[1].shape == () else self.src[1].sgep(i)), - (ssimplify(self.src[2]) if self.src[2].shape == () else self.src[2].sgep(i))) - for i in range(max(self.src[1].dtype.count, self.src[2].dtype.count))) + case Ops.RESHAPE | Ops.EXPAND: return self.src[1].as_shape + case Ops.PAD | Ops.SHRINK: return tuple(zip(self.src[1].as_shape, self.src[2].as_shape)) case Ops.PERMUTE | Ops.FLIP: return self.arg case _: raise RuntimeError(f"{self.op} is not a MovementOp")