_pad_circular and _pad_reflect_replicate to mixin (#15944)

This commit is contained in:
chenyu
2026-04-27 16:07:05 -04:00
committed by GitHub
parent 8c174bdad4
commit fe38d6de94
3 changed files with 30 additions and 21 deletions

View File

@@ -243,6 +243,14 @@ class TestTensorUOpCat(unittest.TestCase):
def test_cat_3tensors(self): _check(self, _t(2, 3), lambda x: x.cat(x, x, dim=0))
def test_cat_neg_dim(self): _check(self, _t(2, 3, 4), lambda x: x.cat(x, dim=-1))
class TestTensorUOpPad(unittest.TestCase):
def test_pad_circular(self): _check(self, _t(4, 5), lambda x: x._pad_circular(((1, 2), (0, 3))))
def test_pad_circular_zero_after(self):_check(self, _t(4, 5), lambda x: x._pad_circular(((1, 0), (2, 0))))
def test_pad_reflect(self): _check(self, _t(4, 5), lambda x: x._pad_reflect_replicate(((1, 2), (0, 3)), "reflect"))
def test_pad_reflect_negative(self): _check(self, _t(4, 5), lambda x: x._pad_reflect_replicate(((1, -1), (0, 2)), "reflect"))
def test_pad_replicate(self): _check(self, _t(4, 5), lambda x: x._pad_reflect_replicate(((1, 2), (0, 3)), "replicate"))
def test_pad_replicate_negative(self): _check(self, _t(4, 5), lambda x: x._pad_reflect_replicate(((1, -1), (0, 2)), "replicate"))
class TestTensorUOpStack(unittest.TestCase):
def test_stack_dim0(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=0))
def test_stack_dim1(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=1))

View File

@@ -202,6 +202,27 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
base = base.cast(least_upper_dtype(base.dtype, dtypes.from_py(value)))
return base + MovementMixin.pad(X.ones_like(), pads).cast(dtypes.bool).where(base.zeros_like(), base.full_like(value))
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Self:
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
orig_shape, X = self.shape, self.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX))
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pX, orig_shape, X.shape)))
def _pad_reflect_replicate(self, pX:tuple[tuple[sint, sint], ...], mode:str) -> Self:
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
for d,(pB,pA) in enumerate(pads):
if mode == "reflect":
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
slcB, slcA = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
else:
shrB, shrA = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
pieces = [X_ for X_ in (xB, X, xA) if X_ is not None]
X = pieces[0].cat(*pieces[1:], dim=d)
# shrink after for negative pads (reflection/replication must see full data first)
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
def _ufix_keep_dtype(self, x) -> bool:
# matches Tensor scalar-wrapping behavior: keep self.dtype for float self, or for int self with int/Invalid scalar
return dtypes.is_float(self.dtype) or (dtypes.is_int(self.dtype) and isinstance(x, (int, InvalidType)))

View File

@@ -11,7 +11,7 @@ from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata,
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
from tinygrad.uop.ops import smax, UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
from tinygrad.uop.ops import UOp, Ops, sint, all_metadata, _index_to_concrete_int, Variable, _broadcast_shape
from tinygrad.schedule import create_linear_with_vars
from tinygrad.device import Buffer, canonicalize_device
from tinygrad.engine.realize import run_linear
@@ -910,26 +910,6 @@ class Tensor(OpMixin):
def _mop(self, op:Ops, arg) -> Tensor: return self._apply_uop(UOp._mop, extra_args=(op,), arg=arg)
def _rop(self, op:Ops, axis:tuple[int, ...]) -> Tensor: return self._apply_uop(UOp._rop, op=op, axis=axis)
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Tensor:
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
orig_shape, X = self.shape, self.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX))
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pX, orig_shape, X.shape)))
def _pad_reflect_replicate(self, pX:tuple[tuple[sint, sint], ...], mode:str) -> Tensor:
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
for d,(pB,pA) in enumerate(pads):
if mode == "reflect":
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
slcB, slcA = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
else:
shrB, shrA = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
# shrink after for negative pads (reflection/replication must see full data first)
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
"""
Returns a tensor with padding applied based on the input `padding`.