From bb0cdc2442591db41e63f27de66b7201e66efb89 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 29 Jan 2023 03:06:00 -0800 Subject: [PATCH] 111.51x speedup for reduce --- extra/kernel_search.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/extra/kernel_search.py b/extra/kernel_search.py index a5c48d8c05..b7c38f46c7 100644 --- a/extra/kernel_search.py +++ b/extra/kernel_search.py @@ -29,7 +29,8 @@ def get_interventions(k): p4.append((Interventions.SHIFT, (up_axis, amount, True))) p4.append((Interventions.SHIFT, (up_axis, amount, False))) max_up = max(st.shape[k.first_reduce] for st in k.sts) - p5 = [(Interventions.REDUCE, (max_up,))] + p5 = [] + #p5 = [(Interventions.REDUCE, (max_up,))] return p1+p2+p3+p4+p5 def apply_intervention(k, typ, dat): @@ -37,7 +38,7 @@ def apply_intervention(k, typ, dat): # swap axes a1, a2 = dat new_order = list(range(0, k.shape_len)) - new_order[a1], new_order[a2] = new_order[a2], new_order[a1] + new_order[a1], new_order[a2] = new_order[a2], new_order[a1] k.reshape_and_permute(None, new_order) elif typ == Interventions.UPCAST: if dat is not None: @@ -156,6 +157,12 @@ if __name__ == "__main__": buf2 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 1, 1), views=[View((1, 64, 128, 4, 4, 1, 1, 1, 1), (0, 0, 0, 4, 1, 1, 1, 1, 1), 0)]), hostbuf=GPUBuffer(shape=(16,), force_create=True)) op2 = LazyOp(BinaryOps.ADD, (op1,buf2,), None) ast = LazyOp(MovementOps.RESHAPE, (op2,), (64, 512, 4)) + elif int(os.getenv("REDUCE", "0")): + buf0 = GPUBuffer(shape=ShapeTracker(shape=(32, 8, 112, 112), views=[View((32, 8, 112, 112), (12544, 401408, 112, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 32, 112, 112), force_create=True)) + op0 = LazyOp(ReduceOps.SUM, (buf0,), (32, 1, 1, 1)) + buf1 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 1, 1), views=[View((32, 1, 1, 1), (0, 0, 0, 0), 0)]), hostbuf=GPUBuffer(shape=(1,), backing=np.array([9.964923e-06], dtype=np.float32))) + op1 = LazyOp(BinaryOps.MUL, (op0,buf1,), None) + ast = LazyOp(MovementOps.RESHAPE, (op1,), (1, 32, 1, 1)) elif int(os.getenv("BC", "0")): # big conv buf0 = GPUBuffer(shape=ShapeTracker(shape=(8, 1, 32, 112, 112, 3, 3, 3), views=[View((8, 3, 225, 225), (150528, 50176, 224, 1), 0), ZeroView((8, 3, 224, 224), ((0, 8), (0, 3), (0, 225), (0, 225))), View((8, 1, 32, 112, 112, 3, 3, 3), (151875, 151875, 0, 450, 2, 50625, 225, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 3, 224, 224), force_create=True))