mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
check the input length into argfix (#3610)
* check the input length into argfix it's possible to overlook setting keyword for kwargs and argfix silently truncates input * add test
This commit is contained in:
@@ -214,6 +214,13 @@ class TestTinygrad(unittest.TestCase):
|
||||
self.assertEqual(Tensor.empty(1,10,20).shape, (1,10,20))
|
||||
self.assertEqual(Tensor.empty((10,20,40)).shape, (10,20,40))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
Tensor.zeros((2, 2), 2, 2)
|
||||
with self.assertRaises(ValueError):
|
||||
Tensor.zeros((2, 2), (2, 2))
|
||||
with self.assertRaises(ValueError):
|
||||
Tensor.randn((128, 128), 0.0, 0.01)
|
||||
|
||||
def test_numel(self):
|
||||
assert Tensor.randn(10, 10).numel() == 100
|
||||
assert Tensor.randn(1,2,5).numel() == 10
|
||||
|
||||
@@ -17,7 +17,11 @@ OSX = platform.system() == "Darwin"
|
||||
CI = os.getenv("CI", "") != ""
|
||||
|
||||
def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
|
||||
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
|
||||
def argfix(*x):
|
||||
if x and x[0].__class__ in (tuple, list):
|
||||
if len(x) != 1: raise ValueError(f"bad arg {x}")
|
||||
return tuple(x[0])
|
||||
return x
|
||||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def all_same(items:List[T]): return all(x == items[0] for x in items)
|
||||
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
||||
@@ -207,9 +211,7 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
||||
_pack_, _fields_ = 1, fields
|
||||
return CStruct
|
||||
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
||||
def flat_mv(mv:memoryview):
|
||||
if len(mv) == 0: return mv
|
||||
return mv.cast("B", shape=(mv.nbytes,))
|
||||
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|
||||
|
||||
# *** Helpers for CUDA-like APIs.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user