mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
good kernel with changes in lowerer
This commit is contained in:
@@ -3,7 +3,7 @@ from tinygrad.helpers import getenv
|
||||
from tinygrad import dtypes, Tensor
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
||||
N = getenv("N", 4096)
|
||||
N = getenv("N", 64)
|
||||
M = getenv("M", N)
|
||||
K = getenv("K", N)
|
||||
CNT = getenv("CNT", 10)
|
||||
|
||||
@@ -694,7 +694,7 @@ class Kernel:
|
||||
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
|
||||
elif self.opts.device == "CLANG":
|
||||
reduce_axes, fix_st1, fix_st2 = [], None, None
|
||||
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted, self.shape_len-self.upcasted)
|
||||
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
|
||||
elif self.opts.device in {"CUDA", "NV"}:
|
||||
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
|
||||
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)
|
||||
|
||||
@@ -175,7 +175,9 @@ class IndependentLowerer:
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
|
||||
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
|
||||
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
|
||||
|
||||
con = tuple(UOp(UOps.CONTRACT, dtype.vec(4), tuple([UOp(UOps.GEP, dtype, (ret,), i+j*4) for i in range(4)])) for j in range(4))
|
||||
return UOp(UOps.EXPAND, dtype, con, arg=((upcast_axis[2], wmma_sz[2]//4),))
|
||||
# NOTE: always using ridxs is fine here
|
||||
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
|
||||
return UOp.alu(x.op, *in_uops)
|
||||
|
||||
Reference in New Issue
Block a user