make slow tests faster (#16244)

This commit is contained in:
chenyu
2026-05-18 11:42:02 -04:00
committed by GitHub
parent 981c12182f
commit 5ae4dbd599
2 changed files with 14 additions and 20 deletions

View File

@@ -368,8 +368,8 @@ class TestRandomness(unittest.TestCase):
@TinyJit
def sample_one(): return Tensor(w).multinomial(1, replacement=False).realize()
tiny_samples = [sample_one().item() for _ in range(1000)]
torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(1000)]
tiny_samples = [sample_one().item() for _ in range(400)]
torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(400)]
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
w = list(range(32))
@@ -384,8 +384,8 @@ class TestRandomness(unittest.TestCase):
@TinyJit
def sample_three(): return Tensor(w).multinomial(3, replacement=False).realize()
tiny_draws = np.array([sample_three().numpy() for _ in range(1000)])
torch_draws = np.array([torch.tensor(w).multinomial(3, replacement=False).numpy() for _ in range(1000)])
tiny_draws = np.array([sample_three().numpy() for _ in range(400)])
torch_draws = np.array([torch.tensor(w).multinomial(3, replacement=False).numpy() for _ in range(400)])
for pos in range(3):
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_draws[:, pos]), lambda _: torch.tensor(torch_draws[:, pos])))
@@ -415,7 +415,7 @@ class TestRandomness(unittest.TestCase):
def test_rand_chain(self):
# NOTE: this fails if property propagates deeper than stack limit
for _ in range(833): Tensor.rand(1)
Tensor.rand(1).realize()
Tensor.rand(1).schedule_linear()
def test_random_counter_overflow(self):
device = Device.DEFAULT

View File

@@ -5,7 +5,7 @@
import gc, unittest, functools
import numpy as np
from typing import cast
from hypothesis import assume, given, settings, strategies as strat
from hypothesis import assume, given, strategies as strat
from tinygrad import nn, dtypes, Device, Tensor, Variable
from tinygrad.device import is_dtype_supported
@@ -113,12 +113,6 @@ class TestSchedule(unittest.TestCase):
run_linear(*check_schedule(b, 1))
np.testing.assert_allclose(b.numpy(), np.broadcast_to(a.numpy().astype(np.float16), (2, 4, 4))+2, rtol=1e-3)
def test_indexing_scalars_simple(self):
X = Tensor.randn(2, 2).realize()
xt = X[Tensor(1)][Tensor(0)]
run_linear(*check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), X.numpy()[1][0])
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
def test_add_chain_buffers(self):
N = 31
@@ -130,14 +124,14 @@ class TestSchedule(unittest.TestCase):
root = root + functools.reduce(lambda a,b:a+b, bufs[i:i+X])
self.assertEqual(root.item(), sum(range(N)))
@given(strat.sampled_from(range(2,4)), strat.sampled_from(range(2,4)), strat.sampled_from(range(0,4)), strat.sampled_from(range(0,4)))
@settings(deadline=None)
def test_indexing_scalars(self, x, y, a, b):
assume(a<x and b<y)
X = Tensor.randn(x, y).realize()
xt = X[Tensor(a)][Tensor(b)]
run_linear(*check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), X.numpy()[a][b])
def test_indexing_scalars(self):
# cover each shape at all index corners
for x, y in [(2,2), (2,3), (3,2), (3,3)]:
for a, b in [(0,0), (0,y-1), (x-1,0), (x-1,y-1)]:
X = Tensor.randn(x, y).realize()
xt = X[Tensor(a)][Tensor(b)]
run_linear(*check_schedule(xt, 1))
np.testing.assert_equal(xt.numpy(), X.numpy()[a][b])
def test_push_pads_elementwise(self):
x = Tensor.full((4,4), 2.).contiguous().realize()