Simple CUDA Runtime (#480)

* factor out opencl runtime

* don't use CL outside the runtime

* cuda runtime adds

* final_dimension

* tests pass with CUDA backend

* more cuda

* cuda simpler

* retain old functionality

* linter and typing

* move globalcounters out of runtimes

* oops, GlobalCounters in cuda

* MAX_OUTPUT_SHAPE=3 is fine for CUDA
This commit is contained in:
George Hotz
2023-01-27 16:26:24 -08:00
committed by GitHub
parent 6d5e1a8029
commit bd8a5c2ced
5 changed files with 137 additions and 100 deletions

View File

@@ -24,6 +24,7 @@ setup(name='tinygrad',
extras_require={
'gpu': ["pyopencl", "six"],
'llvm': ["llvmlite"],
'cuda': ["pycuda"],
'testing': [
"pytest",
"torch~=1.11.0",

View File

@@ -54,6 +54,8 @@ class TestOps(unittest.TestCase):
def test_add(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
def test_add_simple(self):
helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True)
def test_broadcasted_add(self):
helper_test_op([(45,65), (45,1)], lambda x,y: x+y, lambda x,y: x+y)
def test_broadcasted_add_2(self):

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import os, functools, time, platform
import os
import numpy as np
import pyopencl as cl # type: ignore
from collections import defaultdict
from typing import List, Tuple, Optional, Dict, Union, Set
from tinygrad.helpers import prod
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, LazyOp, Op, ExplicitExecAST, GlobalCounters
@@ -11,97 +9,16 @@ from tinygrad.lazy import IMAGE
from tinygrad.shape import ShapeTracker, View, ZeroView
from tinygrad.shape.symbolic import Variable, ModNode
OSX = platform.system() == "Darwin"
CUDA = int(os.getenv("CUDA", "0"))
if not CUDA: from tinygrad.runtime.opencl import CLBuffer, CLImage, CLProgram, CL # NOTE: using CL will not work for the CUDA runtime # noqa: F401
else: from tinygrad.runtime.cuda import CLBuffer, CLImage, CLProgram # type: ignore
VALIDHACKS = int(os.getenv("VALIDHACKS", "0")) # TODO: remove the need for this
NATIVE_EXPLOG = int(os.getenv("NATIVE_EXPLOG", "0")) # this is needed as a switch for the tests to pass
CLCACHE = int(os.getenv("CLCACHE", "1"))
FLOAT16 = int(os.getenv("FLOAT16", "0"))
PRINT_AST = os.getenv("PRINT_AST", "0")
TEST_AST = int(os.getenv("TEST_AST", "0"))
class CLBuffer:
def __init__(self, size):
if len(CL.BUFFER_CACHE[size]) > 0:
self.cl = CL.BUFFER_CACHE[size].pop()
else:
# TODO: on GPU OOM, clear the cache
self.cl = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size)
CL.mem_used += self.cl.size
def __del__(self):
if CLCACHE:
CL.BUFFER_CACHE[self.cl.size].append(self.cl)
else:
CL.mem_used -= self.cl.size
class CLImage:
fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
def __init__(self, shape):
self.cl = cl.Image(CL().cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
CL.mem_used += self.cl.row_pitch * self.cl.height
def __del__(self):
CL.mem_used -= self.cl.row_pitch * self.cl.height
class CL:
CACHE, kernel_count, mem_used, time_sum, ops_sum = None, -1, 0, 0.0, 0.0
BUFFER_CACHE : Dict[int, List[cl.Buffer]] = defaultdict(list)
cl_ctx : Optional[cl.Context] = None
cl_queue : Optional[cl.CommandQueue] = None
def __init__(self):
if CL.cl_queue is not None: return # already initted
devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
if len(devices) == 0: # settle for CPU
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
CL.cl_ctx = cl.Context(devices=[devices[int(os.getenv("CL_DEVICE", "0"))]])
if len(devices) > 1 or DEBUG >= 1: print(f"using {CL.cl_ctx.devices}")
CL.cl_queue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
@staticmethod
def enqueue_copy(a, b, is_blocking=False):
if CL.CACHE is not None: assert False, f"can't copy {a} -> {b} while caching"
if DEBUG >= 1: print(f"**CL** copy in {b.shape}" if isinstance(b, np.ndarray) else f"**CL** copy OUT {a.shape}")
cl.enqueue_copy(CL().cl_queue, a, b, is_blocking=is_blocking)
@functools.lru_cache(maxsize=None)
class CLProgram:
kernel_cnt : Dict[str, int] = defaultdict(int)
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0):
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else ''}" if rename else name
self.prg, self.options, self.argdtypes, self.op_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate
self.clprogram = cl.Program(CL().cl_ctx, CL().cl_ctx.devices, [self.prg]) if binary else cl.Program(CL().cl_ctx, self.prg) # type: ignore
try:
self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name)
except cl.RuntimeError as e:
if DEBUG >= 3: print("FAILED TO BUILD", self.prg)
raise e
if self.argdtypes is not None:
self.clprg.set_scalar_arg_dtypes(self.argdtypes)
CLProgram.kernel_cnt[name] += 1
def __call__(self, *args):
CL.kernel_count += 1
if DEBUG >= 4: print(args[0], args[1], self.prg)
if OSX and DEBUG >= 2: st = time.monotonic_ns()
if CL.CACHE is not None: CL.CACHE.append((self, args))
else: e = self.clprg(CL().cl_queue, *args)
if DEBUG >= 2:
CL.cl_queue.finish()
# NOTE: Profiling is (sadly) broken in OS X, so we take the real kernel time
# BOUNTY: will paypal $50 to anyone who fixes this
et = (time.monotonic_ns() - st) if OSX else (e.profile.end - e.profile.start)
if DEBUG >= 1:
CL.time_sum += 0 if DEBUG <= 1 or CL.CACHE is not None else et
CL.ops_sum += self.op_estimate
print(f"**CL** {CL.kernel_count:6d} {self.name:28s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {self.op_estimate/1e6:7.1f}M/{CL.ops_sum/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " +
(str() if DEBUG <= 1 or CL.CACHE is not None else f"tm {et/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({self.op_estimate/et:8.2f} GFLOPS)"))
GlobalCounters.global_ops += self.op_estimate
GlobalCounters.global_mem += sum([x.size//4 for x in args[2:] if isinstance(x, cl.Buffer)])
return e if CL.CACHE is None else None
# **** end CL wrappers ****
def group_float4(x):
assert all(y.typ == Types.FLOAT for y in x) and len(x)%4 == 0
return [Token(f"(float4)({','.join([x[i+j].tok for j in range(4)])})", Types.FLOAT4) for i in range(0, len(x), 4)]
@@ -192,6 +109,7 @@ class CLASTKernel(ASTKernel):
if isinstance(x.op, ReduceOps) and not do_reduce: return acc
values = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
code = CLASTKernel.code_for_op[x.op] # TODO: replace this with a function
if CUDA and x.op == UnaryOps.SIGN: self.prekernel.add("inline __device__ float sign(float x) { float val = (signbit(x) == 0.0f) ? 1.0f : -1.0f; return (x == 0.0f) ? 0.0f : val; }")
if len(values) == 2:
# TODO: sometimes this is split, sometimes it's multiply
if isinstance(x.op, ReduceOps) and values[0][0].typ == Types.FLOAT4 and len(values[0])*4 == len(values[1]): values[0] = split_float4(values[0])
@@ -317,7 +235,7 @@ class CLASTKernel(ASTKernel):
print("old:", self.shapes)
print("old:", self.strides)
self.hand_coded_optimizations()
if not CUDA: self.hand_coded_optimizations()
self.output_shape = list(self.shapes[0][:self.first_reduce]) + self.group_for_reduce
if DEBUG >= 3:
@@ -330,18 +248,18 @@ class CLASTKernel(ASTKernel):
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
self.prekernel : Set[str] = set()
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"]
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(isinstance(buf._buf, CLImage) for buf in self.bufs) else []
# output_shape[-1] is get_global_id(0)
self.kernel += [f"int idx{len(self.output_shape)-1-i} = get_global_id({i}); /* {self.output_shape[-1-i]} */\n" for i in range(min(3, len(self.output_shape)))]
if len(self.output_shape) > 3:
MAX_OUTPUT_SHAPE = 3
self.kernel += [f"int idx{len(self.output_shape)-1-i} = {f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' if CUDA else f'get_global_id({i})'}; /* {self.output_shape[-1-i]} */\n" for i in range(min(MAX_OUTPUT_SHAPE, len(self.output_shape))) if self.output_shape[-1-i] != 1]
if len(self.output_shape) > MAX_OUTPUT_SHAPE:
# sometimes, there's more dimensions. compact all the dimensions into the first one
# TODO: these compactions should be searchable
final_dimension = len(self.output_shape)-3
for i in range(len(self.output_shape)-4, -1, -1):
final_dimension = len(self.output_shape)-MAX_OUTPUT_SHAPE
for i in range(final_dimension-1, -1, -1):
self.kernel += [f"int idx{i} = idx{final_dimension} % {self.output_shape[i]};", f"idx{final_dimension} = idx{final_dimension} / {self.output_shape[i]};\n"]
self.output_shape = [prod(self.output_shape[0:-2])] + list(self.output_shape[-2:])
self.output_shape = [prod(self.output_shape[0:final_dimension+1])] + list(self.output_shape[final_dimension+1:])
if DEBUG >= 3: print(f"replaced output shape with {self.output_shape}")
# early ast
@@ -383,16 +301,19 @@ class CLASTKernel(ASTKernel):
# kernel function definition
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.bufs[0].shape if x != 1])
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else ("__global "+self.buftokens[i].decltype()) for i,x in enumerate(self.bufs)]
self.kernel = list(self.prekernel) + [f"__kernel void {function_name}(",] + \
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if isinstance(x._buf, CLImage) else ("__global "+self.buftokens[i].decltype()) for i,x in enumerate(self.bufs)] if not CUDA else [self.buftokens[i].decltype() for i,x in enumerate(self.bufs)]
self.kernel = list(self.prekernel) + [f"{'__global__' if CUDA else '__kernel'} void {function_name}(",] + \
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete])] + \
[") {\n"] + self.kernel
# compile kernel
self.fxn = CLProgram(function_name, ' '.join(self.kernel), op_estimate=self.info.flops)
mem_estimate = sum(prod(x) for x in self.shapes)
if DEBUG >= 3 and len(self.bufs_to_delete): print(f"deleting buffers {self.bufs_to_delete}")
def runner(*bufs):
GlobalCounters.global_ops += self.info.flops
GlobalCounters.global_mem += mem_estimate
clbufs = [x.cl for i,x in enumerate(bufs) if i not in self.bufs_to_delete]
return self.fxn(self.output_shape[::-1] if len(self.output_shape) > 0 else [1], (self.group_for_reduce[::-1] + [1]*(len(self.output_shape)-len(self.group_for_reduce))) if self.group_for_reduce else None, *clbufs)
return runner
@@ -418,7 +339,7 @@ class GPUBuffer(ExplicitExecAST):
if self._buf is None:
self._buf = CLImage(self._base_shape) if (len(self._base_shape) == 3 and self._base_shape[2] == 4 and IMAGE >= 2) else CLBuffer(4*prod(self._base_shape))
if self._backing is not None:
CL().enqueue_copy(self._buf.cl, self._backing, is_blocking=False)
self._buf.copyin(self._backing)
self._backing = None
return self._buf.cl
@@ -431,7 +352,7 @@ class GPUBuffer(ExplicitExecAST):
data = np.empty(self.shape, dtype=np.float32)
cl_buf = self.contiguous()
cl_buf = cl_buf if isinstance(cl_buf._buf, CLBuffer) else self.movement_op(MovementOps.RESHAPE, list(self.shape)+[1]).unary_op(UnaryOps.NOOP)
CL().enqueue_copy(data, cl_buf.cl, is_blocking=True)
cl_buf._buf.copyout(data)
return data
@classmethod

24
tinygrad/runtime/cuda.py Normal file
View File

@@ -0,0 +1,24 @@
import pycuda.autoinit # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
from pycuda.compiler import SourceModule # type: ignore
import numpy as np
from tinygrad.ops import DEBUG
class CLImage:
def __init__(self, shape): raise NotImplementedError("CUDA runtime doesn't support images")
class CLBuffer:
def __init__(self, size): self.cl = cuda.mem_alloc(size)
def copyin(self, b:np.ndarray): cuda.memcpy_htod_async(self.cl, b)
def copyout(self, a:np.ndarray): cuda.memcpy_dtoh(a, self.cl)
class CLProgram:
def __init__(self, name:str, prg:str, op_estimate:int=0):
self.name, self.op_estimate = name, op_estimate
if DEBUG >= 4: print("CUDA compile", prg)
self.prg = SourceModule(prg).get_function(name)
def __call__(self, global_size, local_size, *args):
global_size = global_size + [1] * (2 - len(global_size))
if DEBUG >= 2: print("CUDA launch", global_size, local_size)
self.prg(*args, block=(1,1,1), grid=tuple(global_size))

View File

@@ -0,0 +1,89 @@
import os, functools, time, platform
import numpy as np
import pyopencl as cl # type: ignore
from typing import Dict, Optional, Tuple, List
from collections import defaultdict
from tinygrad.ops import DEBUG
OSX = platform.system() == "Darwin"
CLCACHE = int(os.getenv("CLCACHE", "1"))
FLOAT16 = int(os.getenv("FLOAT16", "0"))
class CL:
CACHE, kernel_count, mem_used, time_sum, ops_sum = None, -1, 0, 0.0, 0.0
BUFFER_CACHE : Dict[int, List[cl.Buffer]] = defaultdict(list)
cl_ctx : Optional[cl.Context] = None
cl_queue : Optional[cl.CommandQueue] = None
def __init__(self):
if CL.cl_queue is not None: return # already initted
devices = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
if len(devices) == 0: # settle for CPU
devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], [])
CL.cl_ctx = cl.Context(devices=[devices[int(os.getenv("CL_DEVICE", "0"))]])
if len(devices) > 1 or DEBUG >= 1: print(f"using {CL.cl_ctx.devices}")
CL.cl_queue = cl.CommandQueue(self.cl_ctx, properties=cl.command_queue_properties.PROFILING_ENABLE) # this is an in-order command queue
@staticmethod
def enqueue_copy(a, b, is_blocking=False):
if CL.CACHE is not None: assert False, f"can't copy {a} -> {b} while caching"
if DEBUG >= 1: print(f"**CL** copy in {b.shape}" if isinstance(b, np.ndarray) else f"**CL** copy OUT {a.shape}")
cl.enqueue_copy(CL().cl_queue, a, b, is_blocking=is_blocking)
class CLBuffer:
def __init__(self, size):
if len(CL.BUFFER_CACHE[size]) > 0:
self.cl = CL.BUFFER_CACHE[size].pop()
else:
# TODO: on GPU OOM, clear the cache
self.cl = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size)
CL.mem_used += self.cl.size
def __del__(self):
if CLCACHE: CL.BUFFER_CACHE[self.cl.size].append(self.cl)
else: CL.mem_used -= self.cl.size
def copyin(self, b:np.ndarray): CL.enqueue_copy(self.cl, b, False)
def copyout(self, a:np.ndarray): CL.enqueue_copy(a, self.cl, True)
class CLImage:
fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
def __init__(self, shape):
self.cl = cl.Image(CL().cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
CL.mem_used += self.cl.row_pitch * self.cl.height
def __del__(self):
CL.mem_used -= self.cl.row_pitch * self.cl.height
@functools.lru_cache(maxsize=None)
class CLProgram:
kernel_cnt : Dict[str, int] = defaultdict(int)
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0):
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else ''}" if rename else name
self.prg, self.options, self.argdtypes, self.op_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate
self.clprogram = cl.Program(CL().cl_ctx, CL().cl_ctx.devices, [self.prg]) if binary else cl.Program(CL().cl_ctx, self.prg) # type: ignore
try:
self.clprg = self.clprogram.build(options=list(self.options)).__getattr__(self.name)
except cl.RuntimeError as e:
if DEBUG >= 3: print("FAILED TO BUILD", self.prg)
raise e
if self.argdtypes is not None:
self.clprg.set_scalar_arg_dtypes(self.argdtypes)
CLProgram.kernel_cnt[name] += 1
def __call__(self, *args):
CL.kernel_count += 1
if DEBUG >= 4: print(args[0], args[1], self.prg)
if OSX and DEBUG >= 2: st = time.monotonic_ns()
if CL.CACHE is not None: CL.CACHE.append((self, args))
else: e = self.clprg(CL().cl_queue, *args)
if DEBUG >= 2:
CL.cl_queue.finish()
# NOTE: Profiling is (sadly) broken in OS X, so we take the real kernel time
# BOUNTY: will paypal $50 to anyone who fixes this
et = (time.monotonic_ns() - st) if OSX else (e.profile.end - e.profile.start)
if DEBUG >= 1:
CL.time_sum += 0 if DEBUG <= 1 or CL.CACHE is not None else et
CL.ops_sum += self.op_estimate
print(f"**CL** {CL.kernel_count:6d} {self.name:28s} args {len(args[2:]):5d} kernels {str(args[0]):18s} {str(args[1]):12s} OPs {self.op_estimate/1e6:7.1f}M/{CL.ops_sum/1e9:7.2f}G mem {CL.mem_used/1e9:5.2f} GB " +
(str() if DEBUG <= 1 or CL.CACHE is not None else f"tm {et/1e3:9.2f}us/{CL.time_sum/1e6:9.2f}ms ({self.op_estimate/et:8.2f} GFLOPS)"))
return e if CL.CACHE is None else None