remove first contiguous in multi from_sharded (#5121)

second contiguous guarantees lbs are contiguous going into MultiLazyBuffer, don't need the first contiguous
This commit is contained in:
chenyu
2024-07-03 19:42:56 -04:00
committed by GitHub
parent f1ff65e763
commit e5ba385f03

View File

@@ -71,7 +71,7 @@ class MultiLazyBuffer:
@staticmethod
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
lbs = [lb] * len(devices)
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)