mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
more crap to remove without convs
This commit is contained in:
@@ -4,7 +4,7 @@ import cProfile
|
||||
import pstats
|
||||
import unittest
|
||||
import torch
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
|
||||
def start_profile():
|
||||
import time
|
||||
@@ -19,6 +19,7 @@ def stop_profile(pr, sort='cumtime'):
|
||||
ps.sort_stats(sort)
|
||||
ps.print_stats(0.2)
|
||||
|
||||
@unittest.skipUnless(getattr(Device, "OPENCL", None) is None or Device.DEFAULT != Device.OPENCL, "OOM on OpenCL")
|
||||
class TestConvSpeed(unittest.TestCase):
|
||||
|
||||
def test_mnist(self):
|
||||
|
||||
@@ -115,13 +115,12 @@ class GPUBuffer:
|
||||
return type(x)(new_shape)._processing_op([("A", x)], GPUBuffer.code_for_op[op], None, GPUBuffer.start_for_op[op])
|
||||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0") -> GPUBuffer:
|
||||
ints, params, ewbufs, conv_src = '', [], bufs, ''
|
||||
global_size = [prod(ret.shape), 1, 1]
|
||||
loop : List[Tuple[str, str]] = []
|
||||
assert C is None
|
||||
|
||||
# this takes a ret index to an inp index, indexing 0 on the reduced strides
|
||||
# if it's not a reduce, this should be a NOOP
|
||||
view = View(ret.shape, strides_for_shape(bufs[0][1].shape))
|
||||
assert C is None
|
||||
loop : List[Tuple[str, str]] = []
|
||||
if ret.shape != bufs[0][1].shape: # this is a reduce
|
||||
# reverse operation of expand, this validates inputs
|
||||
# generate loops with combined adjacent reduce axis
|
||||
@@ -131,16 +130,16 @@ class GPUBuffer:
|
||||
acc *= shp
|
||||
|
||||
kernel_name = "reduce" if len(loop) > 0 else "elementwise"
|
||||
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs}
|
||||
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
|
||||
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints}
|
||||
float acc = {start}; int gid = get_global_id(0); {conv_src} int idx = gid; {view.expr.replace('//', '/')};
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{
|
||||
float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')};
|
||||
{' '.join([ls for ls, _ in loop[::-1]])}
|
||||
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])}
|
||||
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs])}
|
||||
acc = {code};
|
||||
{' '.join([le for _, le in loop])}
|
||||
output[gid] = acc;
|
||||
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params))))
|
||||
conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params])
|
||||
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types))))
|
||||
conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]])
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user