mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Merge branch 'master' into shrink_in_render
This commit is contained in:
@@ -219,7 +219,6 @@ class TestUOpValidationIssue(unittest.TestCase):
|
||||
class TestEdgeCases(unittest.TestCase):
|
||||
# add tests exposing new and diverse kinds of bugs that might impact real users here
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_circular_pad_negative(self):
|
||||
# negative pads with circular mode should wrap like PyTorch
|
||||
arr = np.arange(9).reshape(1, 1, 3, 3).astype(np.float32)
|
||||
|
||||
@@ -1989,9 +1989,7 @@ class TestOps(unittest.TestCase):
|
||||
self.helper_test_exception([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.pad(x, (3,6,0,0), mode="circular"), lambda x: x.pad((3,6,0,0), mode="circular"),
|
||||
expected=(RuntimeError, ValueError))
|
||||
with self.assertRaises(NotImplementedError):
|
||||
# negative pads with circular pads is not supported
|
||||
Tensor.randn(1,1,5,5).pad((3,-5,1,-5), mode="circular")
|
||||
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (1,-2,2,-1), mode="circular"), lambda x: x.pad((1,-2,2,-1), mode="circular"))
|
||||
|
||||
def test_pad_reshape(self):
|
||||
helper_test_op([(1, 2)],
|
||||
|
||||
@@ -257,9 +257,11 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
return MovementMixin.pad(X.const_like(1).cast(dtypes.bool), pads).where(base, base.const_like(value))
|
||||
|
||||
def _pad_circular(self, pX:tuple[tuple[sint, sint], ...]) -> Self:
|
||||
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, self.shape)): raise ValueError('Padding value causes wrapping around more than once.')
|
||||
if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
|
||||
orig_shape, X = self.shape, self.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX))
|
||||
# shrink first for negative pads, then wrap the non-negative remainder
|
||||
X = self.shrink(tuple((-smin(pB,0), smin(pA+sh,sh)) for (pB,pA),sh in zip(pX, self.shape)))
|
||||
pX = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
||||
if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
|
||||
orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pX))
|
||||
return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pX, orig_shape, X.shape)))
|
||||
|
||||
def _pad_reflect_replicate(self, pX:tuple[tuple[sint, sint], ...], mode:str) -> Self:
|
||||
|
||||
Reference in New Issue
Block a user