diff --git a/extra/kernel_search.py b/extra/kernel_search.py index d6c859e770..1a5c4f6b21 100644 --- a/extra/kernel_search.py +++ b/extra/kernel_search.py @@ -44,9 +44,9 @@ def apply_intervention(k, typ, *dat): def search(ast): # get baseline k = CLASTKernel(ast) - CL.time_sum = 0 - #k.hand_coded_optimizations() - k.codegen()(*k.bufs) + for i in range(3): + CL.time_sum = 0 + k.codegen()(*k.bufs) winning_interventions = [] best_time = baseline = CL.time_sum @@ -68,7 +68,7 @@ def search(ast): best_time = CL.time_sum winning_interventions.append(inter) - for i in range(100): + for i in range(200): try: test() except Exception as e: diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index a3553d05fc..a62911be46 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -305,7 +305,7 @@ class CLASTKernel(ASTKernel): # compile kernel self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops) - mem_estimate = sum(prod(x.shape) for x in self.sts) + mem_estimate = sum(prod(x._base_shape) for x in self.bufs) if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}") def runner(*bufs):