Files
dragonpilot/tinygrad_repo/test/test_sample.py
T
Vehicle Researcher db5cbadcf2 openpilot v0.9.9 release
date: 2025-06-05T19:54:08
master commit: 8aadf02b2fd91f4e1285e18c2c7feb32d93b66f5
2025-06-12 14:30:06 -07:00

21 lines
774 B
Python

import unittest
import numpy as np
from tinygrad import Tensor, Variable, Device
from tinygrad.helpers import OSX
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
class TestSample(unittest.TestCase):
def test_sample(self):
X = Tensor.rand(10000, 50).realize()
BS = 16
idxs = np.random.randint(0, X.shape[0], size=(BS))
# this uncovered a bug with arg sort order
batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)])
print(idxs)
ret = x.numpy()
base = X.numpy()[idxs]
np.testing.assert_equal(ret, base)
if __name__ == '__main__':
unittest.main()