mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
refactor: ops_cloud -> ops_remote [pr] (#10166)
This commit is contained in:
24
.github/workflows/test.yml
vendored
24
.github/workflows/test.yml
vendored
@@ -471,11 +471,11 @@ jobs:
|
||||
run: CPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_ops.py
|
||||
- name: Test Quantize ONNX
|
||||
run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py
|
||||
- name: Run CLOUD=1 Test
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
CLOUDDEV=CPU CLOUD=1 python3 -m pytest test/test_tiny.py test/test_jit.py
|
||||
CLOUDDEV=GPU CLOUD=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py
|
||||
CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py
|
||||
REMOTEDEV=CPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_jit.py
|
||||
REMOTEDEV=GPU REMOTE=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py test/test_jit.py
|
||||
REMOTEDEV=GPU IMAGE=2 REMOTE=1 python3 -m pytest test/test_tiny.py test/test_image_dtype.py
|
||||
- name: Test Optimization Helpers
|
||||
run: PYTHONPATH="." DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||
- name: Test Action Space
|
||||
@@ -770,27 +770,27 @@ jobs:
|
||||
- name: Run WEBGPU Efficientnet
|
||||
run: node test/web/test_webgpu.js
|
||||
|
||||
osxcloud:
|
||||
name: MacOS (cloud)
|
||||
osxremote:
|
||||
name: MacOS (remote)
|
||||
runs-on: macos-15
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
CLOUD: 1
|
||||
CLOUDDEV: METAL
|
||||
REMOTE: 1
|
||||
REMOTEDEV: METAL
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: macos-cloud
|
||||
key: macos-remote
|
||||
deps: testing_minimal
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'CLOUD', Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.default.properties['clouddev'] == 'METAL', Device.default.properties['clouddev']"
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == 'REMOTE', Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.default.properties['remotedev'] == 'METAL', Device.default.properties['remotedev']"
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run CLOUD=1 Test
|
||||
- name: Run REMOTE=1 Test
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py
|
||||
|
||||
|
||||
2
test/external/external_test_example.py
vendored
2
test/external/external_test_example.py
vendored
@@ -7,7 +7,7 @@ def multidevice_test(fxn):
|
||||
exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",")
|
||||
def ret(self):
|
||||
for device in Device._devices:
|
||||
if device in ["CLOUD", "DISK", "NPY", "FAKE", "DSP", "NULL"]: continue
|
||||
if device in ["REMOTE", "DISK", "NPY", "FAKE", "DSP", "NULL"]: continue
|
||||
if not CI: print(device)
|
||||
if device in exclude_devices:
|
||||
if not CI: print(f"WARNING: {device} test is excluded")
|
||||
|
||||
@@ -65,5 +65,5 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
|
||||
return out_buf.cast(uop.dtype.fmt).tolist()[0]
|
||||
|
||||
def not_support_multi_device():
|
||||
# CLOUD doesn't support multi device anywhere, GPU, CUDA and METAL don't support multi device if in CI
|
||||
return Device.DEFAULT == "CLOUD" or (CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"))
|
||||
# REMOTE doesn't support multi device anywhere, GPU, CUDA and METAL don't support multi device if in CI
|
||||
return Device.DEFAULT == "REMOTE" or (CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"))
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.helpers import prod, unwrap
|
||||
|
||||
IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU")
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "CLOUD" else Device['CLOUD'].properties['clouddev'])
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties['remotedev'])
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# the CLOUD=1 device is a process boundary between the frontend/runtime
|
||||
# the REMOTE=1 device is a process boundary between the frontend/runtime
|
||||
# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
|
||||
# with CLOUD tinygrad is frontend <-> middleware <-> CloudDevice ///HTTP/// cloud_server <-> runtime <-> hardware
|
||||
# with REMOTE tinygrad is frontend <-> middleware <-> RemoteDevice ///HTTP/// remote_server <-> runtime <-> hardware
|
||||
# this client and server can be on the same machine, same network, or just same internet
|
||||
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
|
||||
|
||||
@@ -17,28 +17,28 @@ from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, Buffe
|
||||
|
||||
# ***** API *****
|
||||
|
||||
class CloudRequest: pass
|
||||
class RemoteRequest: pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
||||
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
|
||||
class BufferFree(RemoteRequest): buffer_num: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702
|
||||
class CopyIn(RemoteRequest): buffer_num: int; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyOut(CloudRequest): buffer_num: int
|
||||
class CopyOut(RemoteRequest): buffer_num: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702
|
||||
class ProgramAlloc(RemoteRequest): name: str; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702
|
||||
class ProgramFree(RemoteRequest): name: str; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramExec(CloudRequest):
|
||||
class ProgramExec(RemoteRequest):
|
||||
name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
|
||||
global_size: Optional[tuple[int, ...]]; local_size: Optional[tuple[int, ...]]; wait: bool # noqa: E702
|
||||
|
||||
@@ -51,13 +51,13 @@ def safe_eval(node): return eval_fxns[node.__class__](node)
|
||||
|
||||
class BatchRequest:
|
||||
def __init__(self):
|
||||
self._q: list[CloudRequest] = []
|
||||
self._q: list[RemoteRequest] = []
|
||||
self._h: dict[str, bytes] = {}
|
||||
def h(self, d:bytes) -> str:
|
||||
binhash = hashlib.sha256(d).digest()
|
||||
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
|
||||
return datahash
|
||||
def q(self, x:CloudRequest): self._q.append(x)
|
||||
def q(self, x:RemoteRequest): self._q.append(x)
|
||||
def serialize(self) -> bytes:
|
||||
self.h(repr(self._q).encode())
|
||||
return b''.join(self._h.values())
|
||||
@@ -73,21 +73,21 @@ class BatchRequest:
|
||||
# ***** backend *****
|
||||
|
||||
@dataclass
|
||||
class CloudSession:
|
||||
class RemoteSession:
|
||||
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
|
||||
buffers: dict[int, Buffer] = field(default_factory=dict)
|
||||
|
||||
class CloudHandler(BaseHTTPRequestHandler):
|
||||
class RemoteHandler(BaseHTTPRequestHandler):
|
||||
protocol_version = 'HTTP/1.1'
|
||||
device: str
|
||||
sessions: defaultdict[str, CloudSession] = defaultdict(CloudSession)
|
||||
sessions: defaultdict[str, RemoteSession] = defaultdict(RemoteSession)
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
|
||||
|
||||
def _do(self, method):
|
||||
session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
|
||||
session = RemoteHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
|
||||
ret, status_code = b"", 200
|
||||
if self.path == "/batch" and method == "POST":
|
||||
# TODO: streaming deserialize?
|
||||
@@ -98,13 +98,13 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
match c:
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = Buffer(CloudHandler.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
||||
session.buffers[c.buffer_num] = Buffer(RemoteHandler.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
||||
case BufferFree(): del session.buffers[c.buffer_num]
|
||||
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
|
||||
case ProgramAlloc():
|
||||
lib = Device[CloudHandler.device].compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = Device[CloudHandler.device].runtime(c.name, lib)
|
||||
lib = Device[RemoteHandler.device].compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = Device[RemoteHandler.device].runtime(c.name, lib)
|
||||
case ProgramFree(): del session.programs[(c.name, c.datahash)]
|
||||
case ProgramExec():
|
||||
bufs = [session.buffers[x]._buf for x in c.bufs]
|
||||
@@ -112,8 +112,8 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif self.path == "/properties" and method == "GET":
|
||||
cls, args = Device[CloudHandler.device].renderer.__reduce__()
|
||||
ret = json.dumps({'clouddev': CloudHandler.device, 'renderer': (cls.__module__, cls.__name__, args)}).encode()
|
||||
cls, args = Device[RemoteHandler.device].renderer.__reduce__()
|
||||
ret = json.dumps({'remotedev': RemoteHandler.device, 'renderer': (cls.__module__, cls.__name__, args)}).encode()
|
||||
else: status_code = 404
|
||||
self.send_response(status_code)
|
||||
self.send_header('Content-Length', str(len(ret)))
|
||||
@@ -123,16 +123,16 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self): return self._do("GET")
|
||||
def do_POST(self): return self._do("POST")
|
||||
|
||||
def cloud_server(port:int):
|
||||
CloudHandler.device = getenv("CLOUDDEV", next(Device.get_available_devices()) if Device.DEFAULT == "CLOUD" else Device.DEFAULT)
|
||||
print(f"start cloud server on {port} with device {CloudHandler.device}")
|
||||
server = HTTPServer(('', port), CloudHandler)
|
||||
def remote_server(port:int):
|
||||
RemoteHandler.device = getenv("REMOTEDEV", next(Device.get_available_devices()) if Device.DEFAULT == "REMOTE" else Device.DEFAULT)
|
||||
print(f"start remote server on {port} with device {RemoteHandler.device}")
|
||||
server = HTTPServer(('', port), RemoteHandler)
|
||||
server.serve_forever()
|
||||
|
||||
# ***** frontend *****
|
||||
|
||||
class CloudAllocator(Allocator):
|
||||
def __init__(self, dev:CloudDevice):
|
||||
class RemoteAllocator(Allocator):
|
||||
def __init__(self, dev:RemoteDevice):
|
||||
self.device = dev
|
||||
super().__init__()
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
@@ -149,8 +149,8 @@ class CloudAllocator(Allocator):
|
||||
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
||||
dest[:] = resp
|
||||
|
||||
class CloudProgram:
|
||||
def __init__(self, dev:CloudDevice, name:str, lib:bytes):
|
||||
class RemoteProgram:
|
||||
def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
|
||||
self.dev, self.name = dev, name
|
||||
self.datahash = self.dev.req.h(lib)
|
||||
self.dev.req.q(ProgramAlloc(self.name, self.datahash))
|
||||
@@ -161,11 +161,11 @@ class CloudProgram:
|
||||
self.dev.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait))
|
||||
if wait: return float(self.dev.batch_submit())
|
||||
|
||||
class CloudDevice(Compiled):
|
||||
class RemoteDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
if (host:=getenv("HOST", "")) != "": self.host = host
|
||||
else:
|
||||
multiprocessing.Process(target=cloud_server, args=(6667,), name="MainProcess", daemon=True).start()
|
||||
multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
|
||||
self.host = "127.0.0.1:6667"
|
||||
|
||||
# state for the connection
|
||||
@@ -173,7 +173,7 @@ class CloudDevice(Compiled):
|
||||
self.buffer_num = 0
|
||||
self.req: BatchRequest = BatchRequest()
|
||||
|
||||
if DEBUG >= 1: print(f"cloud with host {self.host}")
|
||||
if DEBUG >= 1: print(f"remote with host {self.host}")
|
||||
while 1:
|
||||
try:
|
||||
self.conn = http.client.HTTPConnection(self.host, timeout=60.0)
|
||||
@@ -182,13 +182,13 @@ class CloudDevice(Compiled):
|
||||
except Exception as e:
|
||||
print(e)
|
||||
time.sleep(0.1)
|
||||
if DEBUG >= 1: print(f"remote has device {self.properties['clouddev']}")
|
||||
if DEBUG >= 1: print(f"remote has device {self.properties['remotedev']}")
|
||||
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
|
||||
renderer = self.properties['renderer']
|
||||
if not renderer[0].startswith("tinygrad.renderer.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
|
||||
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
|
||||
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
|
||||
super().__init__(device, CloudAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(CloudProgram, self))
|
||||
super().__init__(device, RemoteAllocator(self), renderer_class(*renderer[2]), Compiler(), functools.partial(RemoteProgram, self))
|
||||
|
||||
def __del__(self):
|
||||
# TODO: this is never being called
|
||||
@@ -209,4 +209,4 @@ class CloudDevice(Compiled):
|
||||
assert response.status == 200, f"failed on {method} {path}"
|
||||
return response.read()
|
||||
|
||||
if __name__ == "__main__": cloud_server(getenv("PORT", 6667))
|
||||
if __name__ == "__main__": remote_server(getenv("PORT", 6667))
|
||||
Reference in New Issue
Block a user