Merge branch 'master' into shrink_in_render

This commit is contained in:
George Hotz
2026-05-31 09:29:52 -07:00
committed by GitHub
3 changed files with 6 additions and 7 deletions

View File

@@ -219,7 +219,6 @@ class TestUOpValidationIssue(unittest.TestCase):
class TestEdgeCases(unittest.TestCase):
# add tests exposing new and diverse kinds of bugs that might impact real users here
@unittest.expectedFailure
def test_circular_pad_negative(self):
# negative pads with circular mode should wrap like PyTorch
arr = np.arange(9).reshape(1, 1, 3, 3).astype(np.float32)

View File

@@ -1989,9 +1989,7 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([(1,1,5,5)],
lambda x: torch.nn.functional.pad(x, (3,6,0,0), mode="circular"), lambda x: x.pad((3,6,0,0), mode="circular"),
expected=(RuntimeError, ValueError))
with self.assertRaises(NotImplementedError):
# negative pads with circular pads is not supported
Tensor.randn(1,1,5,5).pad((3,-5,1,-5), mode="circular")
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (1,-2,2,-1), mode="circular"), lambda x: x.pad((1,-2,2,-1), mode="circular"))
def test_pad_reshape(self):
helper_test_op([(1, 2)],

View File

@@ -257,9 +257,11 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
return MovementMixin.pad(X.const_like(1).cast(dtypes.bool), pads).where(base, base.const_like(value))
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Self:
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
orig_shape, X = self.shape, self.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX))
# shrink first for negative pads, then wrap the non-negative remainder
X = self.shrink(tuple((-smin(pB,0), smin(pA+sh,sh)) for (pB,pA),sh in zip(pX, self.shape)))
pX = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX))
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pX, orig_shape, X.shape)))
def _pad_reflect_replicate(self, pX:tuple[tuple[sint, sint], ...], mode:str) -> Self: