mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix disktensor offset issue (#3532)
This commit is contained in:
@@ -205,7 +205,6 @@ class TestDiskTensor(unittest.TestCase):
|
||||
helper_test_disk_tensor("dt5", [1,2,3,4,5], lambda x: x.reshape((1,5)))
|
||||
helper_test_disk_tensor("dt6", [1,2,3,4], lambda x: x.reshape((2,2)))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_to_different_dtype(self):
|
||||
# NOTE: this is similar to Y_train in fetch_cifar
|
||||
t = Tensor.empty(10, device=f'disk:{temp("dt7")}', dtype=dtypes.int64)
|
||||
|
||||
@@ -64,12 +64,11 @@ class DiskRunner(JITRunner):
|
||||
assert view.mask is None, "view cannot have a mask"
|
||||
assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides"
|
||||
self.new_size = prod(view.shape)
|
||||
self.new_offset = view.offset
|
||||
self.new_offset = view.offset * top_src.arg.dtype.itemsize
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False):
|
||||
assert len(rawbufs) == 2
|
||||
src = rawbufs[1]._buf
|
||||
# TODO: src.dtype.itemsize or self.new_dtype.itemsize?
|
||||
rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset*src.dtype.itemsize)
|
||||
rawbufs[0]._buf = DiskBuffer(src.ud, self.new_size, self.new_dtype, offset=src.offset+self.new_offset)
|
||||
|
||||
class DiskDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)
|
||||
|
||||
Reference in New Issue
Block a user