diff --git a/test/test_ops.py b/test/test_ops.py index 8878af0eb6..48e2d3fc65 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1221,7 +1221,7 @@ class TestOps(unittest.TestCase): def test_einsum_shape_check(self): a = Tensor.zeros(3,8,10,5) b = Tensor.zeros(11,5,13,16,8) - with self.assertRaises(AssertionError): + with self.assertRaises(RuntimeError): Tensor.einsum('pqrs,tuqvr->pstuv',a,b) def test_einsum_arity_check1(self): diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 2896f60e77..3000ef89ed 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -93,7 +93,7 @@ class TestMergeDicts(unittest.TestCase): assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3} assert merge_dicts([a, c]) == a assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3} - with self.assertRaises(AssertionError): + with self.assertRaises(RuntimeError): merge_dicts([a, d]) class TestStripParens(unittest.TestCase): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 94378141de..53ba564de4 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -56,7 +56,7 @@ def i2u(bits: int, value: int): return value if value >= 0 else (1< bool: return str(type(x)) == "" def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]: kvs = set([(k,v) for d in ds for k,v in d.items()]) - assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" + if len(kvs) != len(set(kv[0] for kv in kvs)): raise RuntimeError(f"{kvs} contains different values for the same key") return {k:v for d in ds for k,v in d.items()} def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]: ret:tuple[list[T], list[T]] = ([], [])