Files
StarPilot/tinygrad_repo/test/unit/test_masked_tensor.py
T
firestar5683 d0e1db6766 StarPilot
2026-03-22 03:15:05 -05:00

30 lines
750 B
Python

import unittest
from tinygrad.tensor import Tensor
class TestMaskedTensor(unittest.TestCase):
def test_mul_masked(self):
a = Tensor([1,1,1,1,1])
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
def test_mul_both_masked(self):
a = Tensor([1,1]).pad(((0,3),))
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
def test_add_masked(self):
a = Tensor([1,1]).pad(((0,2),))
b = Tensor([1,1]).pad(((0,2),))
c = a+b
ret = c.data()
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]
if __name__ == '__main__':
unittest.main()