mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
69 lines
2.7 KiB
Python
69 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
import subprocess, json, sys, os
|
|
|
|
REMOTE_HOST = os.getenv("REMOTE_HOST", "192.168.52.154")
|
|
LOCAL_PCI = os.getenv("MLX_PCI", "0000:41:00.0")
|
|
REMOTE_PCI = os.getenv("REMOTE_PCI", "0000:41:00.0")
|
|
LOCAL_IP = os.getenv("LOCAL_IP", "10.0.0.1")
|
|
REMOTE_IP = os.getenv("REMOTE_IP", "10.0.0.2")
|
|
SSH = ["ssh", "-o", "StrictHostKeyChecking=no", REMOTE_HOST]
|
|
TINYGRAD = os.path.dirname(os.path.abspath(__file__)) + "/../.."
|
|
|
|
print("syncing code to remote")
|
|
subprocess.run(["rsync", "-az", "--exclude=.git", "--exclude=__pycache__", "--exclude=*.pyc",
|
|
TINYGRAD + "/", f"{REMOTE_HOST}:~/tinygrad/"], check=True)
|
|
|
|
print("booting remote")
|
|
remote = subprocess.Popen(
|
|
SSH + [f"cd ~/tinygrad && sudo PYTHONPATH=. MLX_DEBUG=1 MLX_PCI={REMOTE_PCI} MLX_IP={REMOTE_IP} python3 extra/mlx_driver/mlxdev.py --server"],
|
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True)
|
|
|
|
remote_info = None
|
|
for line in iter(remote.stdout.readline, ''):
|
|
print(f" [remote] {line}", end='')
|
|
try: remote_info = json.loads(line.strip()); break
|
|
except json.JSONDecodeError: pass
|
|
assert remote_info, "failed to get remote connection info"
|
|
|
|
print("booting local")
|
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
|
|
from extra.mlx_driver.mlxdev import MLXDev, MLXQP
|
|
from tinygrad.runtime.support.system import PCIDevice
|
|
|
|
local_dev = MLXDev(PCIDevice("mlx5", LOCAL_PCI), ip=LOCAL_IP)
|
|
local_qp = MLXQP(local_dev)
|
|
local_info = {"qpn": local_qp.qpn, "mac": local_dev.mac.to_bytes(6,'big').hex(), "gid": local_dev.local_gid.hex()}
|
|
|
|
remote.stdin.write(json.dumps(local_info) + "\n")
|
|
remote.stdin.flush()
|
|
for line in iter(remote.stdout.readline, ''):
|
|
print(f" [remote] {line}", end='')
|
|
if "connected" in line: break
|
|
|
|
local_qp.connect(remote_info["qpn"], int(remote_info["mac"], 16), int(remote_info["gid"], 16))
|
|
print("both QPs in RTS")
|
|
|
|
remote_target = None
|
|
for line in iter(remote.stdout.readline, ''):
|
|
print(f" [remote] {line}", end='')
|
|
try: remote_target = json.loads(line.strip()); break
|
|
except json.JSONDecodeError: pass
|
|
assert remote_target
|
|
|
|
test_msg = b"Test message, rdma works!"
|
|
src_mem, src_paddrs = local_dev.pci_dev.alloc_sysmem(0x1000)
|
|
for i, b in enumerate(test_msg): src_mem[i] = b
|
|
|
|
print(f"RDMA WRITE {len(test_msg)}B to remote phys 0x{remote_target['target_addr']:x}")
|
|
local_qp.rdma_write(remote_target["target_addr"], remote_target["rkey"], src_paddrs[0], local_dev.mkey, len(test_msg))
|
|
|
|
remote.stdin.write("done\n")
|
|
remote.stdin.flush()
|
|
for line in iter(remote.stdout.readline, ''):
|
|
print(f" [remote] {line}", end='')
|
|
if "AS TEXT" in line: break
|
|
|
|
remote.stdin.close()
|
|
remote.wait()
|
|
print("RDMA WRITE test complete")
|