diff --git a/test/unit/test_allreduce.py b/test/unit/test_allreduce.py index 96ea150dc6..2404408fe8 100644 --- a/test/unit/test_allreduce.py +++ b/test/unit/test_allreduce.py @@ -1,5 +1,5 @@ import unittest -from tinygrad import Tensor, Device +from tinygrad import Tensor from tinygrad.helpers import Context from tinygrad.ops import Ops @@ -7,7 +7,7 @@ class TestRingAllReduce(unittest.TestCase): def test_schedule_ring(self): with Context(RING=2): N = 6 - ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(N)) + ds = tuple(f"CPU:{i}" for i in range(N)) t = Tensor.empty(N, N*100).shard(ds, axis=0).realize() schedules = t.sum(0).schedule_with_vars()[0] copies = [si for si in schedules if si.ast.op is Ops.COPY]