Files
dragonpilot/tinygrad_repo/test/test_sample.py
T
Adeeb Shihadeh 100f89a161 openpilot v0.9.9 release (#35334)
* openpilot v0.9.9 release

date: 2025-06-05T19:54:08
master commit: 8aadf02b2fd91f4e1285e18c2c7feb32d93b66f5

* AGNOS 12.4 (#35558)

agnos12.4

---------

Co-authored-by: Vehicle Researcher <user@comma.ai>
Co-authored-by: Maxime Desroches <desroches.maxime@gmail.com>
2025-06-17 16:32:08 -07:00

21 lines
774 B
Python

import unittest
import numpy as np
from tinygrad import Tensor, Variable, Device
from tinygrad.helpers import OSX
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and not OSX, "WEBGPU Vulkan can only run kernels with up to 10 buffers")
class TestSample(unittest.TestCase):
def test_sample(self):
X = Tensor.rand(10000, 50).realize()
BS = 16
idxs = np.random.randint(0, X.shape[0], size=(BS))
# this uncovered a bug with arg sort order
batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)])
print(idxs)
ret = x.numpy()
base = X.numpy()[idxs]
np.testing.assert_equal(ret, base)
if __name__ == '__main__':
unittest.main()