support image dtype in cloud [pr] (#7482)

* support image dtype in cloud [pr]

* remove outdated osx hack

* unused imports
This commit is contained in:
George Hotz
2024-11-02 23:54:27 +08:00
committed by GitHub
parent 24d7fde63d
commit 72a9ac27e9
5 changed files with 25 additions and 18 deletions

View File

@@ -314,7 +314,10 @@ jobs:
run: CLANG=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Run CLOUD=1 Test
run: CLOUDDEV=CLANG CLOUD=1 python3 test/test_ops.py TestOps.test_tiny_add
run: |
CLOUDDEV=CLANG CLOUD=1 python3 test/test_tiny.py
CLOUDDEV=GPU CLOUD=1 python3 test/test_tiny.py
CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 test/test_tiny.py
- if: ${{ matrix.task == 'onnx' }}
name: Test Action Space
run: PYTHONPATH="." GPU=1 python3 extra/optimization/get_action_space.py

View File

@@ -5,8 +5,8 @@ if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"
from tinygrad import fetch, Tensor, TinyJit, Device, Context, GlobalCounters
from tinygrad.helpers import OSX, DEBUG, getenv
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters
from tinygrad.helpers import DEBUG, getenv
from tinygrad.tensor import _from_np_dtype
import onnx
@@ -17,12 +17,6 @@ OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/comm
OUTPUT = "/tmp/openpilot.pkl"
def compile():
# hack to fix GPU on OSX: max doesn't work on half, see test/external/external_gpu_fail_osx.py
if OSX:
from tinygrad.ops import BinaryOps
from tinygrad.renderer.cstyle import ClangRenderer, CStyleLanguage
CStyleLanguage.code_for_op[BinaryOps.MAX] = ClangRenderer.code_for_op[BinaryOps.MAX]
Tensor.no_grad = True
Tensor.training = False

View File

@@ -1,6 +1,7 @@
# basic self-contained tests of the external functionality of tinygrad
import unittest
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device
from tinygrad.helpers import IMAGE
class TestTiny(unittest.TestCase):
@@ -22,7 +23,7 @@ class TestTiny(unittest.TestCase):
a = Tensor.ones(N,N).contiguous()
b = Tensor.eye(N).contiguous()
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
self.assertEqual(out.dtype, out_dtype)
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
# *** JIT (for Python speed) ***

View File

@@ -9,8 +9,9 @@ from typing import Tuple, Optional, Dict, Any, DefaultDict
from collections import defaultdict
import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii
from dataclasses import dataclass, field
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap
from tinygrad.device import Compiled, Allocator, Compiler, Device
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, prod
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
from http.server import HTTPServer, BaseHTTPRequestHandler
# ***** backend *****
@@ -18,7 +19,8 @@ from http.server import HTTPServer, BaseHTTPRequestHandler
@dataclass
class CloudSession:
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
buffers: Dict[int, Tuple[Any, int]] = field(default_factory=dict)
# TODO: the buffer should track this internally
buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
buffer_num = 0
class CloudHandler(BaseHTTPRequestHandler):
@@ -49,16 +51,20 @@ class CloudHandler(BaseHTTPRequestHandler):
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
elif self.path.startswith("/alloc") and method == "POST":
size = int(self.path.split("=")[-1])
buffer_options: Optional[BufferOptions] = None
if 'image' in self.path:
image_shape = tuple([int(x) for x in self.path.split("=")[-2].split("&")[0].split(",")])
buffer_options = BufferOptions(image=dtypes.imageh(image_shape) if prod(image_shape)*2 == size else dtypes.imagef(image_shape))
session.buffer_num += 1
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size), size)
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size, buffer_options), size, buffer_options)
ret = str(session.buffer_num).encode()
elif self.path.startswith("/buffer"):
key = int(self.path.split("/")[-1])
buf,sz = session.buffers[key]
buf,sz,buffer_options = session.buffers[key]
if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data())))
elif method == "DELETE":
Device[CloudHandler.dname].allocator.free(buf,sz)
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
del session.buffers[key]
else: return self._fail()
elif self.path.startswith("/program"):
@@ -100,7 +106,10 @@ class CloudAllocator(Allocator):
def __init__(self, device:CloudDevice):
self.device = device
super().__init__()
def _alloc(self, size:int, options) -> int: return int(self.device.send("POST", f"alloc?size={size}"))
def _alloc(self, size:int, options) -> int:
# TODO: ideally we shouldn't have to deal with images here
extra = ("image="+','.join([str(x) for x in options.image.shape])+"&") if options.image is not None else ""
return int(self.device.send("POST", f"alloc?{extra}size={size}"))
def _free(self, opaque, options):
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected):
self.device.send("DELETE", f"buffer/{opaque}", data=b"")

View File

@@ -261,7 +261,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
def _data(self) -> memoryview:
if 0 in self.shape: return memoryview(bytearray(0))
# NOTE: this realizes on the object from as_buffer being a Python object
cpu = self.cast(self.dtype.scalar()).contiguous().to("CLANG").realize()
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
buf = cast(Buffer, cast(LazyBuffer, cpu.lazydata).base.realized)
if self.device != "CLANG": buf.options = BufferOptions(nolru=True)
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)