diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 848cbe4bd8..7363c14cbc 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -71,7 +71,7 @@ class MultiLazyBuffer: @staticmethod def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None): - lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices) + lbs = [lb] * len(devices) sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)] return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)