mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
oops, broke mem estimates
This commit is contained in:
@@ -44,9 +44,9 @@ def apply_intervention(k, typ, *dat):
|
||||
def search(ast):
|
||||
# get baseline
|
||||
k = CLASTKernel(ast)
|
||||
CL.time_sum = 0
|
||||
#k.hand_coded_optimizations()
|
||||
k.codegen()(*k.bufs)
|
||||
for i in range(3):
|
||||
CL.time_sum = 0
|
||||
k.codegen()(*k.bufs)
|
||||
|
||||
winning_interventions = []
|
||||
best_time = baseline = CL.time_sum
|
||||
@@ -68,7 +68,7 @@ def search(ast):
|
||||
best_time = CL.time_sum
|
||||
winning_interventions.append(inter)
|
||||
|
||||
for i in range(100):
|
||||
for i in range(200):
|
||||
try:
|
||||
test()
|
||||
except Exception as e:
|
||||
|
||||
@@ -305,7 +305,7 @@ class CLASTKernel(ASTKernel):
|
||||
|
||||
# compile kernel
|
||||
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops)
|
||||
mem_estimate = sum(prod(x.shape) for x in self.sts)
|
||||
mem_estimate = sum(prod(x._base_shape) for x in self.bufs)
|
||||
|
||||
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
|
||||
def runner(*bufs):
|
||||
|
||||
Reference in New Issue
Block a user