good kernel with changes in lowerer

This commit is contained in:
p4sscode
2024-07-22 13:38:08 -03:00
parent 3b9c628828
commit 975e2b5a4e
4 changed files with 11 additions and 3 deletions

6
.env Normal file
View File

@@ -0,0 +1,6 @@
DEBUG=5
GRAPH=1
CLANG=1
AMX=1
GRAPHUOPS=1
# PYTHON=1

View File

@@ -3,7 +3,7 @@ from tinygrad.helpers import getenv
from tinygrad import dtypes, Tensor
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
N = getenv("N", 4096)
N = getenv("N", 64)
M = getenv("M", N)
K = getenv("K", N)
CNT = getenv("CNT", 10)

View File

@@ -694,7 +694,7 @@ class Kernel:
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
elif self.opts.device == "CLANG":
reduce_axes, fix_st1, fix_st2 = [], None, None
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted, self.shape_len-self.upcasted)
upcast_axis = (self.shape_len-self.upcasted+1, self.shape_len-self.upcasted, self.shape_len-self.upcasted+1)
elif self.opts.device in {"CUDA", "NV"}:
reduce_axes = [self.shape_len-self.upcasted, self.shape_len-self.upcasted+1]
upcast_axis = (self.shape_len-self.upcasted, self.shape_len-self.upcasted+2, self.shape_len-self.upcasted+2)

View File

@@ -175,7 +175,9 @@ class IndependentLowerer:
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
con = tuple(UOp(UOps.CONTRACT, dtype.vec(4), tuple([UOp(UOps.GEP, dtype, (ret,), i+j*4) for i in range(4)])) for j in range(4))
return UOp(UOps.EXPAND, dtype, con, arg=((upcast_axis[2], wmma_sz[2]//4),))
# NOTE: always using ridxs is fine here
return UOp(UOps.REDUCE, dtype, (in_uops[0],) + tuple(self.ridxs[i] for i in x.arg), x.op)
return UOp.alu(x.op, *in_uops)