mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
wmma: add reduce axis choice to TC action space (#4328)
* wmma: add reduce axis choice to TC action space * add test for TC multi-reduce axis choice
This commit is contained in:
@@ -352,11 +352,11 @@ class Kernel:
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
||||
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1))
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1, range(self.first_reduce, self.shape_len)))
|
||||
if not(axis < len(axis_choices)): continue
|
||||
|
||||
s0, s1 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0] # s0 is n, s1 is m
|
||||
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, self.first_reduce]) if self.full_shape[x]%tc.dims[i] != 0]
|
||||
s0, s1, s2 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0], axis_choices[-(axis+1)][2] # s0 is n, s1 is m, s2 is k
|
||||
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, s2]) if self.full_shape[x]%tc.dims[i] != 0]
|
||||
if axis_pads and (opt_level < 2): continue
|
||||
|
||||
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
||||
@@ -366,7 +366,7 @@ class Kernel:
|
||||
try:
|
||||
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
||||
except KernelOptError: continue
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.UNROLL, s2-self.first_reduce, tc.dims[2]), append_opt=False)
|
||||
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
||||
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
@@ -379,7 +379,7 @@ class Kernel:
|
||||
return False
|
||||
|
||||
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, tc_opt:Optional[int]=getenv("TC_OPT")) -> bool:
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, axis:int=0, tc_opt:int=getenv("TC_OPT")) -> bool:
|
||||
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
||||
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
||||
|
||||
@@ -396,7 +396,7 @@ class Kernel:
|
||||
"""
|
||||
if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False
|
||||
try: # check TC first and apply hand-coded opts if successful
|
||||
self.apply_opt(Opt(OptOps.TC, 0, tc_opt))
|
||||
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
||||
|
||||
if (tc_opts:=self.tensor_core_opts) is not None:
|
||||
if extra_opts is not None:
|
||||
|
||||
Reference in New Issue
Block a user