mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
remove ANON from addrspace, refactor marg (#16523)
* remove ANON from addrspace, refactor marg * as_shape * as_shape is cached
This commit is contained in:
@@ -204,7 +204,8 @@ class TestLocalAccess(unittest.TestCase):
|
||||
out = Device[Device.DEFAULT].renderer.render(uops)
|
||||
# half is supported in wgsl, so it doesn't have to be packed
|
||||
corrected_size = size//(4//dtype.itemsize) if dtype != dtypes.half else size
|
||||
self.assertIn(f"temp0: array<{Device[Device.DEFAULT].renderer.buf_map(dtype)},{corrected_size}>;", out)
|
||||
# temp0: array<{Device[Device.DEFAULT].renderer.buf_map(dtype)},{corrected_size}>;
|
||||
self.assertIn(f",{corrected_size}>;", out)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
@unittest.skip("tinygrad doesn't support this behavior")
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Final, ClassVar, Callable, Literal
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass, fields
|
||||
from tinygrad.helpers import ceildiv, getenv, prod, round_up, OSX
|
||||
from enum import Enum, auto
|
||||
from enum import IntEnum, auto
|
||||
|
||||
class ConstFloat(float):
|
||||
"""Float subclass that distinguishes -0.0 from 0.0 and where nan == nan."""
|
||||
@@ -49,9 +49,9 @@ class DTypeMetaClass(type):
|
||||
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
|
||||
return ret
|
||||
|
||||
class AddrSpace(Enum):
|
||||
class AddrSpace(IntEnum):
|
||||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); LOCAL = auto(); REG = auto(); ANON = auto() # noqa: E702
|
||||
GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class DType(metaclass=DTypeMetaClass):
|
||||
|
||||
@@ -686,15 +686,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.STACK: return self.src[i].sintify()
|
||||
case _: raise RuntimeError(f"no sgep on {self.op}")
|
||||
|
||||
@functools.cached_property
|
||||
def as_shape(self) -> tuple[sint, ...]:
|
||||
return tuple(ssimplify(self.sgep(i)) for i in range(max(self.dtype.count, len(self.src))))
|
||||
|
||||
@functools.cached_property
|
||||
def marg(self):
|
||||
match self.op:
|
||||
case Ops.RESHAPE | Ops.EXPAND: return tuple(ssimplify(self.src[1].sgep(i)) for i in range(self.src[1].dtype.count))
|
||||
case Ops.PAD | Ops.SHRINK:
|
||||
# this is like broadcasting for shapes
|
||||
return tuple(((ssimplify(self.src[1]) if self.src[1].shape == () else self.src[1].sgep(i)),
|
||||
(ssimplify(self.src[2]) if self.src[2].shape == () else self.src[2].sgep(i)))
|
||||
for i in range(max(self.src[1].dtype.count, self.src[2].dtype.count)))
|
||||
case Ops.RESHAPE | Ops.EXPAND: return self.src[1].as_shape
|
||||
case Ops.PAD | Ops.SHRINK: return tuple(zip(self.src[1].as_shape, self.src[2].as_shape))
|
||||
case Ops.PERMUTE | Ops.FLIP: return self.arg
|
||||
case _: raise RuntimeError(f"{self.op} is not a MovementOp")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user