diff --git a/test/backend/test_edgecases.py b/test/backend/test_edgecases.py index b3039a71fb..89bd245ae3 100644 --- a/test/backend/test_edgecases.py +++ b/test/backend/test_edgecases.py @@ -219,7 +219,6 @@ class TestUOpValidationIssue(unittest.TestCase): class TestEdgeCases(unittest.TestCase): # add tests exposing new and diverse kinds of bugs that might impact real users here - @unittest.expectedFailure def test_circular_pad_negative(self): # negative pads with circular mode should wrap like PyTorch arr = np.arange(9).reshape(1, 1, 3, 3).astype(np.float32) diff --git a/test/backend/test_ops.py b/test/backend/test_ops.py index 67d3833700..7c04506b7c 100644 --- a/test/backend/test_ops.py +++ b/test/backend/test_ops.py @@ -1989,9 +1989,7 @@ class TestOps(unittest.TestCase): self.helper_test_exception([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,6,0,0), mode="circular"), lambda x: x.pad((3,6,0,0), mode="circular"), expected=(RuntimeError, ValueError)) - with self.assertRaises(NotImplementedError): - # negative pads with circular pads is not supported - Tensor.randn(1,1,5,5).pad((3,-5,1,-5), mode="circular") + helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (1,-2,2,-1), mode="circular"), lambda x: x.pad((1,-2,2,-1), mode="circular")) def test_pad_reshape(self): helper_test_op([(1, 2)], diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index a30659b389..0e5b79e86e 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -257,9 +257,11 @@ class OpMixin(ElementwiseMixin, ReduceMixin): return MovementMixin.pad(X.const_like(1).cast(dtypes.bool), pads).where(base, base.const_like(value)) def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Self: - if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.') - if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported") - orig_shape, X = self.shape, self.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX)) + # shrink first for negative pads, then wrap the non-negative remainder + X = self.shrink(tuple((-smin(pB,0), smin(pA+sh,sh)) for (pB,pA),sh in zip(pX, self.shape))) + pX = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) + if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.') + orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX)) return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pX, orig_shape, X.shape))) def _pad_reflect_replicate(self, pX:tuple[tuple[sint, sint], ...], mode:str) -> Self: