check the input length into argfix (#3610)

* check the input length into argfix

it's possible to overlook setting keyword for kwargs and argfix silently truncates input

* add test
This commit is contained in:
chenyu
2024-03-04 19:50:17 -05:00
committed by GitHub
parent 7db6dd725d
commit 282bbd5acb
2 changed files with 13 additions and 4 deletions

View File

@@ -214,6 +214,13 @@ class TestTinygrad(unittest.TestCase):
self.assertEqual(Tensor.empty(1,10,20).shape, (1,10,20))
self.assertEqual(Tensor.empty((10,20,40)).shape, (10,20,40))
with self.assertRaises(ValueError):
Tensor.zeros((2, 2), 2, 2)
with self.assertRaises(ValueError):
Tensor.zeros((2, 2), (2, 2))
with self.assertRaises(ValueError):
Tensor.randn((128, 128), 0.0, 0.01)
def test_numel(self):
assert Tensor.randn(10, 10).numel() == 100
assert Tensor.randn(1,2,5).numel() == 10

View File

@@ -17,7 +17,11 @@ OSX = platform.system() == "Darwin"
CI = os.getenv("CI", "") != ""
def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
def argfix(*x):
if x and x[0].__class__ in (tuple, list):
if len(x) != 1: raise ValueError(f"bad arg {x}")
return tuple(x[0])
return x
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items:List[T]): return all(x == items[0] for x in items)
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
@@ -207,9 +211,7 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
_pack_, _fields_ = 1, fields
return CStruct
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
def flat_mv(mv:memoryview):
if len(mv) == 0: return mv
return mv.cast("B", shape=(mv.nbytes,))
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
# *** Helpers for CUDA-like APIs.