more crap to remove without convs

This commit is contained in:
George Hotz
2022-07-17 13:02:27 -07:00
parent 5e96ed523a
commit cfabbbd6bb
2 changed files with 11 additions and 11 deletions

View File

@@ -4,7 +4,7 @@ import cProfile
import pstats
import unittest
import torch
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, Device
def start_profile():
import time
@@ -19,6 +19,7 @@ def stop_profile(pr, sort='cumtime'):
ps.sort_stats(sort)
ps.print_stats(0.2)
@unittest.skipUnless(getattr(Device, "OPENCL", None) is None or Device.DEFAULT != Device.OPENCL, "OOM on OpenCL")
class TestConvSpeed(unittest.TestCase):
def test_mnist(self):

View File

@@ -115,13 +115,12 @@ class GPUBuffer:
return type(x)(new_shape)._processing_op([("A", x)], GPUBuffer.code_for_op[op], None, GPUBuffer.start_for_op[op])
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0") -> GPUBuffer:
ints, params, ewbufs, conv_src = '', [], bufs, ''
global_size = [prod(ret.shape), 1, 1]
loop : List[Tuple[str, str]] = []
assert C is None
# this takes a ret index to an inp index, indexing 0 on the reduced strides
# if it's not a reduce, this should be a NOOP
view = View(ret.shape, strides_for_shape(bufs[0][1].shape))
assert C is None
loop : List[Tuple[str, str]] = []
if ret.shape != bufs[0][1].shape: # this is a reduce
# reverse operation of expand, this validates inputs
# generate loops with combined adjacent reduce axis
@@ -131,16 +130,16 @@ class GPUBuffer:
acc *= shp
kernel_name = "reduce" if len(loop) > 0 else "elementwise"
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs}
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints}
float acc = {start}; int gid = get_global_id(0); {conv_src} int idx = gid; {view.expr.replace('//', '/')};
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{
float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')};
{' '.join([ls for ls, _ in loop[::-1]])}
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])}
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs])}
acc = {code};
{' '.join([le for _, le in loop])}
output[gid] = acc;
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params))))
conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params])
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types))))
conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]])
return ret