111.51x speedup for reduce

This commit is contained in:
George Hotz
2023-01-29 03:06:00 -08:00
parent 45c0aa6e2d
commit bb0cdc2442

View File

@@ -29,7 +29,8 @@ def get_interventions(k):
p4.append((Interventions.SHIFT, (up_axis, amount, True)))
p4.append((Interventions.SHIFT, (up_axis, amount, False)))
max_up = max(st.shape[k.first_reduce] for st in k.sts)
p5 = [(Interventions.REDUCE, (max_up,))]
p5 = []
#p5 = [(Interventions.REDUCE, (max_up,))]
return p1+p2+p3+p4+p5
def apply_intervention(k, typ, dat):
@@ -37,7 +38,7 @@ def apply_intervention(k, typ, dat):
# swap axes
a1, a2 = dat
new_order = list(range(0, k.shape_len))
new_order[a1], new_order[a2] = new_order[a2], new_order[a1]
new_order[a1], new_order[a2] = new_order[a2], new_order[a1]
k.reshape_and_permute(None, new_order)
elif typ == Interventions.UPCAST:
if dat is not None:
@@ -156,6 +157,12 @@ if __name__ == "__main__":
buf2 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 1, 1), views=[View((1, 64, 128, 4, 4, 1, 1, 1, 1), (0, 0, 0, 4, 1, 1, 1, 1, 1), 0)]), hostbuf=GPUBuffer(shape=(16,), force_create=True))
op2 = LazyOp(BinaryOps.ADD, (op1,buf2,), None)
ast = LazyOp(MovementOps.RESHAPE, (op2,), (64, 512, 4))
elif int(os.getenv("REDUCE", "0")):
buf0 = GPUBuffer(shape=ShapeTracker(shape=(32, 8, 112, 112), views=[View((32, 8, 112, 112), (12544, 401408, 112, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 32, 112, 112), force_create=True))
op0 = LazyOp(ReduceOps.SUM, (buf0,), (32, 1, 1, 1))
buf1 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 1, 1), views=[View((32, 1, 1, 1), (0, 0, 0, 0), 0)]), hostbuf=GPUBuffer(shape=(1,), backing=np.array([9.964923e-06], dtype=np.float32)))
op1 = LazyOp(BinaryOps.MUL, (op0,buf1,), None)
ast = LazyOp(MovementOps.RESHAPE, (op1,), (1, 32, 1, 1))
elif int(os.getenv("BC", "0")):
# big conv
buf0 = GPUBuffer(shape=ShapeTracker(shape=(8, 1, 32, 112, 112, 3, 3, 3), views=[View((8, 3, 225, 225), (150528, 50176, 224, 1), 0), ZeroView((8, 3, 224, 224), ((0, 8), (0, 3), (0, 225), (0, 225))), View((8, 1, 32, 112, 112, 3, 3, 3), (151875, 151875, 0, 450, 2, 50625, 225, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 3, 224, 224), force_create=True))