mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
111.51x speedup for reduce
This commit is contained in:
@@ -29,7 +29,8 @@ def get_interventions(k):
|
||||
p4.append((Interventions.SHIFT, (up_axis, amount, True)))
|
||||
p4.append((Interventions.SHIFT, (up_axis, amount, False)))
|
||||
max_up = max(st.shape[k.first_reduce] for st in k.sts)
|
||||
p5 = [(Interventions.REDUCE, (max_up,))]
|
||||
p5 = []
|
||||
#p5 = [(Interventions.REDUCE, (max_up,))]
|
||||
return p1+p2+p3+p4+p5
|
||||
|
||||
def apply_intervention(k, typ, dat):
|
||||
@@ -37,7 +38,7 @@ def apply_intervention(k, typ, dat):
|
||||
# swap axes
|
||||
a1, a2 = dat
|
||||
new_order = list(range(0, k.shape_len))
|
||||
new_order[a1], new_order[a2] = new_order[a2], new_order[a1]
|
||||
new_order[a1], new_order[a2] = new_order[a2], new_order[a1]
|
||||
k.reshape_and_permute(None, new_order)
|
||||
elif typ == Interventions.UPCAST:
|
||||
if dat is not None:
|
||||
@@ -156,6 +157,12 @@ if __name__ == "__main__":
|
||||
buf2 = GPUBuffer(shape=ShapeTracker(shape=(1, 64, 128, 4, 4, 1, 1, 1, 1), views=[View((1, 64, 128, 4, 4, 1, 1, 1, 1), (0, 0, 0, 4, 1, 1, 1, 1, 1), 0)]), hostbuf=GPUBuffer(shape=(16,), force_create=True))
|
||||
op2 = LazyOp(BinaryOps.ADD, (op1,buf2,), None)
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op2,), (64, 512, 4))
|
||||
elif int(os.getenv("REDUCE", "0")):
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(32, 8, 112, 112), views=[View((32, 8, 112, 112), (12544, 401408, 112, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 32, 112, 112), force_create=True))
|
||||
op0 = LazyOp(ReduceOps.SUM, (buf0,), (32, 1, 1, 1))
|
||||
buf1 = GPUBuffer(shape=ShapeTracker(shape=(32, 1, 1, 1), views=[View((32, 1, 1, 1), (0, 0, 0, 0), 0)]), hostbuf=GPUBuffer(shape=(1,), backing=np.array([9.964923e-06], dtype=np.float32)))
|
||||
op1 = LazyOp(BinaryOps.MUL, (op0,buf1,), None)
|
||||
ast = LazyOp(MovementOps.RESHAPE, (op1,), (1, 32, 1, 1))
|
||||
elif int(os.getenv("BC", "0")):
|
||||
# big conv
|
||||
buf0 = GPUBuffer(shape=ShapeTracker(shape=(8, 1, 32, 112, 112, 3, 3, 3), views=[View((8, 3, 225, 225), (150528, 50176, 224, 1), 0), ZeroView((8, 3, 224, 224), ((0, 8), (0, 3), (0, 225), (0, 225))), View((8, 1, 32, 112, 112, 3, 3, 3), (151875, 151875, 0, 450, 2, 50625, 225, 1), 0)]), hostbuf=GPUBuffer(shape=(8, 3, 224, 224), force_create=True))
|
||||
|
||||
Reference in New Issue
Block a user