From 03cf0afa4f2ad477fb41c22b3dbecb6023b2966f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 1 Nov 2023 23:01:32 -0700 Subject: [PATCH] move all to compile api (#2203) * move metal+clang to compile api * all to the new style * remove binary arg * fix triton * fixup tests * fix clang * diskcache is generic * __wrapped__ * compile_gpu * fix thneed * keep the src in the ASTRunner * lib * move compile_gpu * compile_gpu in device * put compiler in astrunner * test reverts * triton compiler * ugh, that too --- docs/abstractions.py | 4 +-- extra/thneed.py | 20 +++--------- test/external/external_test_speed_llama.py | 2 +- test/test_custom_function.py | 2 +- test/test_kernel_cache.py | 36 ++++++++++------------ test/test_uops.py | 2 +- tinygrad/helpers.py | 20 ++++++------ tinygrad/ops.py | 15 ++++----- tinygrad/renderer/cstyle.py | 2 +- tinygrad/renderer/llvmir.py | 2 +- tinygrad/renderer/triton.py | 5 +-- tinygrad/runtime/ops_clang.py | 24 +++++++-------- tinygrad/runtime/ops_cuda.py | 20 ++++++------ tinygrad/runtime/ops_gpu.py | 23 +++++++++----- tinygrad/runtime/ops_hip.py | 19 ++++++------ tinygrad/runtime/ops_llvm.py | 31 ++++++++++--------- tinygrad/runtime/ops_metal.py | 33 ++++++++++---------- tinygrad/runtime/ops_webgpu.py | 4 +-- 18 files changed, 128 insertions(+), 136 deletions(-) diff --git a/docs/abstractions.py b/docs/abstractions.py index fe65b072a0..ca2371eec5 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -217,7 +217,7 @@ from tinygrad.runtime.lib import RawMallocBuffer # ClangProgram is the simplest runtime (in tinygrad/runtime/ops_clang.py, code 7/10) # __init__ calls clang, and __call__ calls the function in the *.so outputted by clang # in CLANG, global_size and local_size are ignored -from tinygrad.runtime.ops_clang import ClangProgram +from tinygrad.runtime.ops_clang import ClangProgram, compile_clang # a concrete example looks like this, this adds two size 1 RawBuffer # first we create two numpy buffers containing 2 and 3 @@ -229,7 +229,7 @@ input_a, input_b = RawMallocBuffer.fromCPU(numpy_a), RawMallocBuffer.fromCPU(num output = RawMallocBuffer(1, dtypes.float32) # compile the program, run it, and 2+3 does indeed equal 5 -program = ClangProgram("add", f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}") +program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")) program(None, None, output, input_a, input_b) # NOTE: the None are for global_size and local_size print(output.toCPU()) assert output.toCPU()[0] == 5, "it's still 5" diff --git a/extra/thneed.py b/extra/thneed.py index 239b089957..b880b58fdc 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -4,7 +4,7 @@ import struct import json import traceback import numpy as np -from tinygrad.runtime.ops_gpu import CLProgram +from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu from tinygrad.helpers import DEBUG, getenv from collections import defaultdict import pyopencl as cl @@ -104,21 +104,11 @@ class Thneed: if 'data' in o: self.buffers_to_save.add(buf) - # load in the programs (this isn't used) - prgs = {} - for k,v in jdat['programs'].items(): - print("building", k) - try: - prgs[k] = CLProgram(k, v, rename=False) - except Exception: - print("FAILED", k) - traceback.print_exc() - exit(0) - # load binaries + prgs = {} for o in jdat['binaries']: nptr = ptr + o['length'] - prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr], binary=True) + prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr]) ptr = nptr # populate the cl_cache @@ -208,7 +198,7 @@ class Thneed: # zero out the buffer cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True) - CLProgram("from_image_strided", """ + CLProgram("from_image_strided", compile_gpu(""" __kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) { const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; int2 l; @@ -216,7 +206,7 @@ class Thneed: l.x = get_global_id(0); out[l.y*row_pitch + l.x] = read_imagef(in, smp, l); } - """, argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4))) + """), argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4))) # multiple of 32 isn't enough jdat['objects'].append({ diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 9bd99f0103..1b1ea3e728 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod from tinygrad.runtime.lib import RawBuffer class FakeProgram: - def __init__(self, name:str, prg:str, binary:bool): pass + def __init__(self, name:str, prg:str): pass def __call__(self, global_size, local_size, *bufs, wait=False): pass class RawFakeBuffer(RawBuffer): diff --git a/test/test_custom_function.py b/test/test_custom_function.py index e263e22d08..7d4cca0086 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer): __kernel void atan2_gpu(global float *c, global float *a, global float *b) { int idx = get_global_id(0); c[idx] = atan2(a[idx], b[idx]); - }""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized]) + }""", global_size=[prod(ret.shape)]).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized]) return ret.realized def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer): diff --git a/test/test_kernel_cache.py b/test/test_kernel_cache.py index 83c87d1dfe..82a38f5319 100644 --- a/test/test_kernel_cache.py +++ b/test/test_kernel_cache.py @@ -2,36 +2,34 @@ import unittest import secrets import string -import tempfile -import pathlib from tinygrad.tensor import Tensor from tinygrad.ops import Device -from tinygrad.helpers import cache_compiled -import tinygrad.runtime.ops_clang +from tinygrad.helpers import diskcache def generate_random_string(length=16): alphabet = string.ascii_letters + string.digits return ''.join(secrets.choice(alphabet) for _ in range(length)) +compile_call_count = 0 + +@diskcache +def helper_test_compile(prg:str) -> bytes: + global compile_call_count + compile_call_count += 1 + return prg.encode() + class TestKernelCache(unittest.TestCase): - compile_call_count = 0 - - @cache_compiled - def __helper_test_compile(self, prg, output_file=pathlib.Path(tempfile.mktemp()), **kwargs): - self.compile_call_count += 1 - return prg.encode() - def test_compile_cache(self): prg1 = generate_random_string(64) + "a" prg2 = generate_random_string(64) + "b" - cold_compile_res = self.__helper_test_compile(prg1) - warm_compile_res = self.__helper_test_compile(prg1) + cold_compile_res = helper_test_compile(prg1) + warm_compile_res = helper_test_compile(prg1) assert cold_compile_res == warm_compile_res == prg1.encode() - assert self.compile_call_count == 1 + assert compile_call_count == 1 - prg2_res = self.__helper_test_compile(prg2) + prg2_res = helper_test_compile(prg2) assert prg2_res == prg2.encode() - assert self.compile_call_count == 2 + assert compile_call_count == 2 def test_kernel_cache_in_action(self): if Device.DEFAULT not in ["CLANG"]: @@ -42,15 +40,15 @@ class TestKernelCache(unittest.TestCase): x = a + b x.realize() - orig_compile_func = tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile - tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = None # making it not callable + orig_compile_func = Device['CLANG'].compiler + Device['CLANG'].compiler = None # making it not callable a1 = Tensor.rand(4,4) b1 = Tensor.rand(4,4) x1 = a1 + b1 x1.realize() # Same kernel should be from cache. - tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = orig_compile_func + Device['CLANG'].compiler = orig_compile_func if __name__ == "__main__": unittest.main() diff --git a/test/test_uops.py b/test/test_uops.py index 23aa787a67..764e2efe9d 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -8,7 +8,7 @@ from tinygrad.codegen.linearizer import UOps, UOp def _uops_to_prg(uops): src, runtime_args = Device[Device.DEFAULT].renderer("test", uops) - return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].runtime) + return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops))) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 418acd1da8..857b5bbafc 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -153,22 +153,12 @@ class GlobalCounters: @staticmethod def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0 -# *** compiled cache decorator *** - -def cache_compiled(func): - if getenv("DISABLE_COMPILER_CACHE"): return func - def wrapper(self, prg:str, *args, **kwargs) -> bytes: - table, key = f"compiler_cache_{type(self).__name__}", hashlib.sha256(prg.encode()).hexdigest() - if (ret:=diskcache_get(table, key)): return ret - return diskcache_put(table, key, func(self, prg, *args, **kwargs)) - return wrapper - # *** universal database cache *** CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache") CACHELEVEL = getenv("CACHELEVEL", 2) -VERSION = 5 +VERSION = 6 _db_connection = None def db_connection(): global _db_connection @@ -207,3 +197,11 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any): conn.commit() cur.close() return val + +def diskcache(func): + def wrapper(*args, **kwargs) -> bytes: + table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest() + if (ret:=diskcache_get(table, key)): return ret + return diskcache_put(table, key, func(*args, **kwargs)) + setattr(wrapper, "__wrapped__", func) + return wrapper diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 447ebed4dd..e464d0a9da 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -194,8 +194,8 @@ class GraphBatchExecutor(BasicBatchExecutor): def exec_instance(self, instid): raise NotImplementedError("must be implemented") class ASTRunner: - def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None): - if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg) + def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None): + if DEBUG >= 4: print(prg) self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} def optimize_local_size(self, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]: @@ -211,8 +211,9 @@ class ASTRunner: return float('inf') return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1] - def build(self, runtime, batch_exec=BasicBatchExecutor): - self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec + def build(self, compiler, runtime, batch_exec=BasicBatchExecutor): + self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) + self.clprg, self.batch_exec = runtime(self.name, self.lib, **self.runtime_args), batch_exec return self def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]: @@ -243,8 +244,8 @@ class ASTRunner: return et class Compiled: - def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor): - self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, runtime, synchronize, batch_exec + def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor): + self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_exec self.method_cache: Dict[LazyOp, ASTRunner] = {} def to_program(self, k): @@ -252,7 +253,7 @@ class Compiled: src, runtime_args = self.renderer(k.function_name, k.uops) return ASTRunner(k.function_name, src, k.global_size, k.local_size, op_estimate=k.info.flops, mem_estimate=k.mem_estimate, - display_name=k.display_name, runtime_args=runtime_args).build(self.runtime, self.batch_exec) + display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime, self.batch_exec) def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs): # check if we can reuse the output buffer diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index d4317ba1db..070a26445e 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -209,4 +209,4 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu else: raise RuntimeError(f"failed to render {uop}") - return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {"binary":False} + return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {} diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 1f0f8243d4..f0fc1e2339 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -144,4 +144,4 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) bb[-1].ret_void() - return str(module), {"binary":False} + return str(module), {} diff --git a/tinygrad/renderer/triton.py b/tinygrad/renderer/triton.py index 2153131bde..af6a67b218 100644 --- a/tinygrad/renderer/triton.py +++ b/tinygrad/renderer/triton.py @@ -118,7 +118,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): for x in local_size: acc_local_size *= next_power_of_2(x) local_size = [acc_local_size] + [1] * (len(local_size) - 1) - if DEBUG >=4: print(prg) + if DEBUG >= 4: print(prg) getlines = linecache.getlines linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "" == filename else getlines(filename, module_globals) exec(compile(prg, "", "exec"), globals()) # pylint: disable=W0122\ @@ -126,4 +126,5 @@ def uops_to_triton(function_name:str, uops:List[UOp]): prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0]) max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")] for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i]) - return prg, {"binary":True, "shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))} + + return prg, {"shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))} diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 843704e1eb..bc83f7c1b9 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,7 +1,7 @@ import time, ctypes, subprocess, platform, functools, pathlib, tempfile from typing import Any from tinygrad.ops import Compiled -from tinygrad.helpers import cache_compiled +from tinygrad.helpers import diskcache from tinygrad.runtime.lib import RawMallocBuffer from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -13,26 +13,24 @@ args = { CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include \n' -class ClangProgram: - def __init__(self, name:str, prg:str, binary=False): - self.prg: bytes = prg if binary else self.compile(CLANG_PROGRAM_HEADER+prg) +@diskcache +def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes: + # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here + with tempfile.NamedTemporaryFile(delete=True) as output_file: + subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8')) + return pathlib.Path(output_file.name).read_bytes() +class ClangProgram: + def __init__(self, name:str, prg:bytes): # write to disk so we can load it with tempfile.NamedTemporaryFile(delete=True) as cached_file_path: - pathlib.Path(cached_file_path.name).write_bytes(self.prg) + pathlib.Path(cached_file_path.name).write_bytes(prg) self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name] - @cache_compiled - def compile(self, prg) -> bytes: - # TODO: sadly clang doesn't like the use of /dev/stdout here - with tempfile.NamedTemporaryFile(delete=True) as output_file: - subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8')) - return pathlib.Path(output_file.name).read_bytes() - def __call__(self, unused_global_size, unused_local_size, *args, wait=False): if wait: st = time.perf_counter() self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args]) if wait: return time.perf_counter()-st renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int")) -ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram) +ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index a987cf29fe..178d899cca 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Optional, List, Any, Tuple import numpy as np from pycuda.compiler import compile as cuda_compile # type: ignore -from tinygrad.helpers import DEBUG, getenv, colored, cache_compiled +from tinygrad.helpers import DEBUG, getenv, colored, diskcache from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.codegen.kernel import LinearizerOptions @@ -88,9 +88,12 @@ class CUDAGraph(GraphBatchExecutor): def exec_instance(self, instid): self.graphs[instid][0].launch() +@diskcache +def compile_cuda(prg) -> bytes: return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']) + class CUDAProgram: - def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_override=None): - if not binary: prg = self.compile(prg).decode('utf-8') + def __init__(self, name:str, _prg:bytes, shared=0, local_size_override=None): + prg = _prg.decode('utf-8') if DEBUG >= 5: print(pretty_ptx(prg)) if DEBUG >= 6: try: @@ -102,10 +105,6 @@ class CUDAProgram: # TODO: name is wrong, so we get it from the ptx using hacks self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override - @cache_compiled - def compile(self, prg) -> bytes: - return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets']) - def __call__(self, global_size, local_size, *args, wait=False): if wait: start, end = cuda.Event(), cuda.Event() @@ -118,7 +117,8 @@ class CUDAProgram: if getenv("TRITON") == 1: from tinygrad.renderer.triton import uops_to_triton - TritonRenderer = uops_to_triton - CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), TritonRenderer, CUDAProgram, cuda.Context.synchronize) + CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), + uops_to_triton, lambda x: x.encode('utf-8'), CUDAProgram, cuda.Context.synchronize) else: - CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), CUDARenderer, CUDAProgram, cuda.Context.synchronize, CUDAGraph) + CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), + CUDARenderer, compile_cuda, CUDAProgram, cuda.Context.synchronize, CUDAGraph) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 9c2715a8b0..6ad6066d27 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -1,9 +1,11 @@ from __future__ import annotations +import os +os.environ['PYOPENCL_NO_CACHE'] = '1' import pathlib import numpy as np import pyopencl as cl # type: ignore from typing import Optional, List -from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport +from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache from tinygrad.ops import Compiled from tinygrad.renderer.opencl import OpenCLRenderer from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer @@ -61,23 +63,28 @@ class CLBuffer(RawBufferCopyInOut, RawBufferTransfer): cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait() else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd") +@diskcache +def compile_gpu(prg:str) -> bytes: + clprg = cl.Program(CL.cl_ctxs[0], prg) + clprg.build() + return clprg.get_info(cl.program_info.BINARIES)[0] + class CLProgram: - def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None): - self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) if binary else cl.Program(ctx, prg) for ctx in CL.cl_ctxs] # type: ignore + def __init__(self, name:str, prg:bytes, argdtypes=None, options=None): + self.name, self.clprograms = name, [cl.Program(ctx, ctx.devices, [prg]*len(ctx.devices)) for ctx in CL.cl_ctxs] # type: ignore self._clprgs = [clprogram.build(options=options) for clprogram in self.clprograms] self.clprgs = [clprg.__getattr__(name) for clprg in self._clprgs] if DEBUG >= 5 and not OSX: if 'Adreno' in CL.cl_ctxs[0].devices[0].name: - fromimport('disassemblers.adreno', 'disasm')(self.binary()) + fromimport('disassemblers.adreno', 'disasm')(prg) elif CL.cl_ctxs[0].devices[0].name.startswith('gfx'): - asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], self.binary())) + asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], prg)) print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])) else: # print the PTX for NVIDIA. TODO: probably broken for everything else - print(self.binary().decode('utf-8')) + print(prg.decode('utf-8')) if argdtypes is not None: self.set_argdtypes(argdtypes) - def binary(self): return self.clprograms[0].get_info(cl.program_info.BINARIES)[0] def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs] @staticmethod @@ -100,4 +107,4 @@ class CLProgram: return None return None -GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, CLProgram, CL.synchronize) +GPUBuffer = Compiled(CLBuffer, LinearizerOptions(), OpenCLRenderer, compile_gpu, CLProgram, CL.synchronize) diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index e1b9491d02..af3df91132 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -2,7 +2,7 @@ import numpy as np import ctypes, functools import extra.hip_wrapper as hip from typing import Tuple, Any, List -from tinygrad.helpers import DEBUG, getenv, cache_compiled +from tinygrad.helpers import DEBUG, getenv, diskcache from tinygrad.ops import Compiled, ASTRunner, GraphBatchExecutor from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer from tinygrad.codegen.kernel import LinearizerOptions @@ -78,10 +78,15 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer): hip.hipSetDevice(x._device) hip.hipMemcpy(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice) +@diskcache +def compile_hip(prg) -> bytes: + prog = hip.hiprtcCreateProgram(prg, "", [], []) + hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}']) + return hip.hiprtcGetCode(prog) + class HIPProgram: - def __init__(self, name:str, prg:str, binary=False): + def __init__(self, name:str, prg:bytes): self.modules, self.prgs = [], [] - prg = prg if binary else self.compile(prg, name) if DEBUG >= 6: asm = early_exec((["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], prg)) @@ -92,12 +97,6 @@ class HIPProgram: self.modules.append(hip.hipModuleLoadData(prg)) self.prgs.append(hip.hipModuleGetFunction(self.modules[-1], name)) - @cache_compiled - def compile(self, prg, name) -> bytes: - prog = hip.hiprtcCreateProgram(prg, name, [], []) - hip.hiprtcCompileProgram(prog, [f'--offload-arch={hip.hipGetDeviceProperties(HIP.default_device).gcnArchName}']) - return hip.hiprtcGetCode(prog) - def __call__(self, global_size, local_size, *args, wait=False): hip.hipSetDevice(args[0]._device) if wait: @@ -138,4 +137,4 @@ __device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset """, gid = [f'blockIdx.{chr(120+i)}' for i in range(3)], lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])) -HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, HIPProgram, hip.hipDeviceSynchronize, HIPGraph) +HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize, HIPGraph) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index a8c935efe3..16c2900611 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -1,7 +1,7 @@ import time, ctypes from typing import ClassVar from tinygrad.ops import Compiled -from tinygrad.helpers import getenv, DEBUG, cache_compiled +from tinygrad.helpers import getenv, DEBUG, diskcache from ctypes import CFUNCTYPE from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.llvmir import uops_to_llvm_ir @@ -9,6 +9,8 @@ from tinygrad.runtime.lib import RawMallocBuffer import llvmlite.binding as llvm # type: ignore +LLVMOPT = bool(getenv("LLVMOPT")) + class LLVM: target_machine: ClassVar[llvm.targets.TargetMachine] = None engine: ClassVar[llvm.executionengine.ExecutionEngine] = None @@ -26,7 +28,7 @@ class LLVM: LLVM.target_machine.add_analysis_passes(LLVM.optimizer) # TODO: this makes compile times so much faster - if getenv("LLVMOPT"): + if LLVMOPT: llvm.set_option(str(), '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed if DEBUG >= 4: llvm.set_option(str(), '--debug-only=loop-vectorize') #llvm.set_option(str(), '--debug') @@ -44,19 +46,18 @@ class LLVM: backing_mod.triple = llvm.get_process_triple() LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine) -class LLVMProgram: - def __init__(self, name:str, prg:str, binary=False): - self.prg = prg if binary else self.compile(prg) - LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(self.prg)) - self.fxn = LLVM.engine.get_function_address(name) +@diskcache +def compile_llvm(prg, llvmopt=LLVMOPT) -> bytes: + mod = llvm.parse_assembly(prg) + mod.verify() + LLVM().optimizer.run(mod) + if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod)) + return LLVM.target_machine.emit_object(mod) - @cache_compiled - def compile(self, prg) -> bytes: - mod = llvm.parse_assembly(prg) - mod.verify() - LLVM().optimizer.run(mod) - if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod)) - return LLVM.target_machine.emit_object(mod) +class LLVMProgram: + def __init__(self, name:str, lib:bytes): + LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib)) + self.fxn = LLVM.engine.get_function_address(name) def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False): cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn) @@ -64,4 +65,4 @@ class LLVMProgram: cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs]) if wait: return time.perf_counter()-st -LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, LLVMProgram) +LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False), uops_to_llvm_ir, compile_llvm, LLVMProgram) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 43320c7e88..179e64c22a 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -3,7 +3,7 @@ import os, subprocess, pathlib, ctypes, tempfile import Metal, Cocoa, libdispatch # type: ignore from typing import List, Any, Tuple from tinygrad.codegen.kernel import LinearizerOptions -from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, cache_compiled +from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache from tinygrad.ops import Compiled, ASTRunner, BasicBatchExecutor from tinygrad.renderer.metal import MetalRenderer from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator @@ -58,9 +58,21 @@ def unwrap(x): assert err is None, str(err) return ret +@diskcache +def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes: + if use_xcode: + # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode + air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8')) + return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air) + options = Metal.MTLCompileOptions.alloc().init() + library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None)) + # TODO: avoid file write here? + with tempfile.NamedTemporaryFile(delete=True) as output_file: + library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None) + return pathlib.Path(output_file.name).read_bytes() + class MetalProgram: - def __init__(self, name:str, prg:str, binary:bool=False): - lib = prg if binary else self.compile(prg) + def __init__(self, name:str, lib:bytes): data = libdispatch.dispatch_data_create(lib, len(lib), None, None) self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None)) self.fxn = self.library.newFunctionWithName_(name) @@ -71,19 +83,6 @@ class MetalProgram: os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}") self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None)) - @cache_compiled - def compile(self, prg) -> bytes: - if getenv("METAL_XCODE"): - # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode - air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8')) - return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air) - options = Metal.MTLCompileOptions.alloc().init() - library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None)) - # TODO: avoid file write here? - with tempfile.NamedTemporaryFile(delete=True) as output_file: - library.serializeToURL_error_(Cocoa.NSURL.URLWithString_(f"file://{output_file.name}"), None) - return pathlib.Path(output_file.name).read_bytes() - def __call__(self, global_size, local_size, *bufs, wait=False): assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" command_buffer = METAL.mtl_queue.commandBuffer() @@ -101,4 +100,4 @@ class MetalProgram: return command_buffer.GPUEndTime() - command_buffer.GPUStartTime() METAL.mtl_buffers_in_flight.append(command_buffer) -MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, MetalProgram, METAL.synchronize, MetalBatchExecutor) +MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, MetalBatchExecutor) diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 4d58b65204..2fa9bbdd72 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -12,7 +12,7 @@ import wgpu # type: ignore wgpu_device = get_default_device() class WebGPUProgram: - def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg) + def __init__(self, name: str, prg: str): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg) def __call__(self, global_size, local_size, *bufs, wait=False): assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))] @@ -42,4 +42,4 @@ class RawWebGPUBuffer(RawBufferCopyIn): def toCPU(self) -> np.ndarray: return np.frombuffer(wgpu_device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore renderer = functools.partial(uops_to_cstyle, WGSLLanguage()) -WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, WebGPUProgram) +WebGpuBuffer = Compiled(RawWebGPUBuffer, LinearizerOptions(supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535]), renderer, lambda x: x, WebGPUProgram)