mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
support image dtype in cloud [pr] (#7482)
* support image dtype in cloud [pr] * remove outdated osx hack * unused imports
This commit is contained in:
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@@ -314,7 +314,10 @@ jobs:
|
||||
run: CLANG=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Run CLOUD=1 Test
|
||||
run: CLOUDDEV=CLANG CLOUD=1 python3 test/test_ops.py TestOps.test_tiny_add
|
||||
run: |
|
||||
CLOUDDEV=CLANG CLOUD=1 python3 test/test_tiny.py
|
||||
CLOUDDEV=GPU CLOUD=1 python3 test/test_tiny.py
|
||||
CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 test/test_tiny.py
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test Action Space
|
||||
run: PYTHONPATH="." GPU=1 python3 extra/optimization/get_action_space.py
|
||||
|
||||
@@ -5,8 +5,8 @@ if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"
|
||||
|
||||
from tinygrad import fetch, Tensor, TinyJit, Device, Context, GlobalCounters
|
||||
from tinygrad.helpers import OSX, DEBUG, getenv
|
||||
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
|
||||
import onnx
|
||||
@@ -17,12 +17,6 @@ OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/comm
|
||||
OUTPUT = "/tmp/openpilot.pkl"
|
||||
|
||||
def compile():
|
||||
# hack to fix GPU on OSX: max doesn't work on half, see test/external/external_gpu_fail_osx.py
|
||||
if OSX:
|
||||
from tinygrad.ops import BinaryOps
|
||||
from tinygrad.renderer.cstyle import ClangRenderer, CStyleLanguage
|
||||
CStyleLanguage.code_for_op[BinaryOps.MAX] = ClangRenderer.code_for_op[BinaryOps.MAX]
|
||||
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# basic self-contained tests of the external functionality of tinygrad
|
||||
import unittest
|
||||
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device
|
||||
from tinygrad.helpers import IMAGE
|
||||
|
||||
class TestTiny(unittest.TestCase):
|
||||
|
||||
@@ -22,7 +23,7 @@ class TestTiny(unittest.TestCase):
|
||||
a = Tensor.ones(N,N).contiguous()
|
||||
b = Tensor.eye(N).contiguous()
|
||||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||
self.assertEqual(out.dtype, out_dtype)
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
# *** JIT (for Python speed) ***
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ from typing import Tuple, Optional, Dict, Any, DefaultDict
|
||||
from collections import defaultdict
|
||||
import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, Device
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, prod
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
|
||||
# ***** backend *****
|
||||
@@ -18,7 +19,8 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
@dataclass
|
||||
class CloudSession:
|
||||
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
|
||||
buffers: Dict[int, Tuple[Any, int]] = field(default_factory=dict)
|
||||
# TODO: the buffer should track this internally
|
||||
buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
|
||||
buffer_num = 0
|
||||
|
||||
class CloudHandler(BaseHTTPRequestHandler):
|
||||
@@ -49,16 +51,20 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
|
||||
elif self.path.startswith("/alloc") and method == "POST":
|
||||
size = int(self.path.split("=")[-1])
|
||||
buffer_options: Optional[BufferOptions] = None
|
||||
if 'image' in self.path:
|
||||
image_shape = tuple([int(x) for x in self.path.split("=")[-2].split("&")[0].split(",")])
|
||||
buffer_options = BufferOptions(image=dtypes.imageh(image_shape) if prod(image_shape)*2 == size else dtypes.imagef(image_shape))
|
||||
session.buffer_num += 1
|
||||
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size), size)
|
||||
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size, buffer_options), size, buffer_options)
|
||||
ret = str(session.buffer_num).encode()
|
||||
elif self.path.startswith("/buffer"):
|
||||
key = int(self.path.split("/")[-1])
|
||||
buf,sz = session.buffers[key]
|
||||
buf,sz,buffer_options = session.buffers[key]
|
||||
if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data())))
|
||||
elif method == "DELETE":
|
||||
Device[CloudHandler.dname].allocator.free(buf,sz)
|
||||
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
|
||||
del session.buffers[key]
|
||||
else: return self._fail()
|
||||
elif self.path.startswith("/program"):
|
||||
@@ -100,7 +106,10 @@ class CloudAllocator(Allocator):
|
||||
def __init__(self, device:CloudDevice):
|
||||
self.device = device
|
||||
super().__init__()
|
||||
def _alloc(self, size:int, options) -> int: return int(self.device.send("POST", f"alloc?size={size}"))
|
||||
def _alloc(self, size:int, options) -> int:
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
extra = ("image="+','.join([str(x) for x in options.image.shape])+"&") if options.image is not None else ""
|
||||
return int(self.device.send("POST", f"alloc?{extra}size={size}"))
|
||||
def _free(self, opaque, options):
|
||||
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected):
|
||||
self.device.send("DELETE", f"buffer/{opaque}", data=b"")
|
||||
|
||||
@@ -261,7 +261,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
def _data(self) -> memoryview:
|
||||
if 0 in self.shape: return memoryview(bytearray(0))
|
||||
# NOTE: this realizes on the object from as_buffer being a Python object
|
||||
cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
|
||||
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
|
||||
buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
|
||||
if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
|
||||
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
||||
|
||||
Reference in New Issue
Block a user