diff --git a/examples/gpt2.py b/examples/gpt2.py index ade7e23393..da1c80bc1c 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -50,7 +50,7 @@ class Attention: keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1) # save the cache - cache_k, cache_v = keys, values + cache_k, cache_v = keys.realize(), values.realize() xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2) output = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1) return self.c_proj(output), cache_k, cache_v diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 503a0ae203..d50e845380 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -9,13 +9,11 @@ from collections import defaultdict from tinygrad.codegen.optimizer import Opt, OptOps actions = flatten([[Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,7]] for axis in range(6)]) actions += flatten([[Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4]] for axis in range(4)]) -actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,16]] for axis in range(5)]) +actions += flatten([[Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29]] for axis in range(5)]) +actions += flatten([[Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,256]] for axis in range(3)]) actions += [ Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8), - Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.GROUPTOP, axis=0, amt=256), - Opt(op=OptOps.GROUPTOP, axis=1, amt=16), Opt(op=OptOps.GROUPTOP, axis=1, amt=256), - Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), ]