mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
remove unused device arg from _get_winograd_matcols (#16527)
This commit is contained in:
@@ -60,7 +60,7 @@ def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp:
|
||||
ret.buffer.allocate(memoryview(data if device != "PYTHON" else bytearray(data)))
|
||||
return ret
|
||||
|
||||
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...]|None, dtype:DType) -> list[list[Tensor]]:
|
||||
def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], dtype:DType) -> list[list[Tensor]]:
|
||||
return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), dtype=dtype, buffer=False) for m in mat], dim=dim)
|
||||
for k in range(len(mat[0]))] for dim in range(dims)]
|
||||
|
||||
@@ -70,7 +70,7 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
||||
# due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic
|
||||
t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims
|
||||
# precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...)
|
||||
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype)
|
||||
matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.dtype)
|
||||
# multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t
|
||||
ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims))
|
||||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||
|
||||
Reference in New Issue
Block a user