wmma: add reduce axis choice to TC action space (#4328)

* wmma: add reduce axis choice to TC action space

* add test for TC multi-reduce axis choice
This commit is contained in:
Francis Lam
2024-04-29 16:15:39 -07:00
committed by GitHub
parent 93abcd3113
commit a9a1fa6bbf
3 changed files with 50 additions and 22 deletions

View File

@@ -352,11 +352,11 @@ class Kernel:
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
axis_choices = list(itertools.product(axis_buf0, axis_buf1))
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
if not(axis < len(axis_choices)): continue
s0, s1 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0] # s0 is n, s1 is m
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, self.first_reduce]) if self.full_shape[x]%tc.dims[i] != 0]
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
if axis_pads and (opt_level < 2): continue
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
@@ -366,7 +366,7 @@ class Kernel:
try:
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
except KernelOptError: continue
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]), append_opt=False)
self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
for (tc_dim, tc_amt) in tc.threads:
@@ -379,7 +379,7 @@ class Kernel:
return False
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, tc_opt:Optional[int]=getenv("TC_OPT")) -> bool:
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=getenv("TC_OPT")) -> bool:
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
@@ -396,7 +396,7 @@ class Kernel:
"""
if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False
try: # check TC first and apply hand-coded opts if successful
self.apply_opt(Opt(OptOps.TC, 0, tc_opt))
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
if (tc_opts:=self.tensor_core_opts) is not None:
if extra_opts is not None: