diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index d1a853a1dd..0f2984592e 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -387,7 +387,7 @@ class Group: idxs = tuple(idx * rv.length if i == 3 else idx for i, idx in enumerate(idxs)) src_i = ((idxs[0] * src.shape[-3] + idxs[1]) * src.shape[-2] + idxs[2]) * src.shape[-1] + idxs[3] - for outer in self.ker.range(dst.shape[-2]): + for outer in self.ker.range(dst.shape[-2], track=False): src_i += outer * reductions + (laneid % reductions) src_load = srcf[src_i] diff --git a/extra/thunder/tiny/tk/kernel.py b/extra/thunder/tiny/tk/kernel.py index 76658dceb4..6541f326a9 100644 --- a/extra/thunder/tiny/tk/kernel.py +++ b/extra/thunder/tiny/tk/kernel.py @@ -89,7 +89,7 @@ class Kernel(AbstractContextManager): # end stores stores store_uops = [] - for _i in range(stores): + for _ in range(stores): store = self.store_stack.pop()[0] if hasattr(store, '_uop'): store_uops.append(store._uop) else: store_uops.append(store) @@ -97,7 +97,12 @@ class Kernel(AbstractContextManager): return uop.end(*rngs).sink(arg=KernelInfo(name=self.name, opts_to_apply=())).simplify() - def endrange(self): + def endrange(self, ranges:int=1): last_store = self.store_stack.pop() - last_range = self.range_stack.pop() - return last_store[1].after(last_store[0].end(last_range._rng)).reshape(last_store[1].shape) + + rngs = [] + for _ in range(ranges): + last_range = self.range_stack.pop() + rngs.append(last_range._rng) + + return last_store[1].after(last_store[0].end(*rngs)).reshape(last_store[1].shape)