fix device canonicalize for :0 in middle [pr] (#8193)

replace is wrong because it does not check if `:0` is at the end. use re.sub instead
This commit is contained in:
chenyu
2024-12-12 16:32:36 -05:00
committed by GitHub
parent 40a4c603b9
commit d47530c0d4
2 changed files with 17 additions and 18 deletions

View File

@@ -8,18 +8,17 @@ from tinygrad.helpers import diskcache_get, diskcache_put, getenv
class TestDevice(unittest.TestCase):
def test_canonicalize(self):
assert Device.canonicalize(None) == Device.DEFAULT
assert Device.canonicalize("CPU") == "CPU"
assert Device.canonicalize("cpu") == "CPU"
assert Device.canonicalize("GPU") == "GPU"
assert Device.canonicalize("GPU:0") == "GPU"
assert Device.canonicalize("gpu:0") == "GPU"
assert Device.canonicalize("GPU:1") == "GPU:1"
assert Device.canonicalize("gpu:1") == "GPU:1"
assert Device.canonicalize("GPU:2") == "GPU:2"
assert Device.canonicalize("disk:/dev/shm/test") == "DISK:/dev/shm/test"
# TODO: fix this
# assert Device.canonicalize("disk:000.txt") == "DISK:000.txt"
self.assertEqual(Device.canonicalize(None), Device.DEFAULT)
self.assertEqual(Device.canonicalize("CPU"), "CPU")
self.assertEqual(Device.canonicalize("cpu"), "CPU")
self.assertEqual(Device.canonicalize("GPU"), "GPU")
self.assertEqual(Device.canonicalize("GPU:0"), "GPU")
self.assertEqual(Device.canonicalize("gpu:0"), "GPU")
self.assertEqual(Device.canonicalize("GPU:1"), "GPU:1")
self.assertEqual(Device.canonicalize("gpu:1"), "GPU:1")
self.assertEqual(Device.canonicalize("GPU:2"), "GPU:2")
self.assertEqual(Device.canonicalize("disk:/dev/shm/test"), "DISK:/dev/shm/test")
self.assertEqual(Device.canonicalize("disk:000.txt"), "DISK:000.txt")
def test_getitem_not_exist(self):
with self.assertRaises(ModuleNotFoundError):
@@ -34,15 +33,15 @@ class TestCompiler(unittest.TestCase):
diskcache_put("key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True):
assert MockCompiler("key").compile_cached("123") == str.encode("123")
assert diskcache_get("key", "123") == str.encode("123")
self.assertEqual(MockCompiler("key").compile_cached("123"), str.encode("123"))
self.assertEqual(diskcache_get("key", "123"), str.encode("123"))
def test_compile_cached_disabled(self):
diskcache_put("disabled_key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True):
assert MockCompiler("disabled_key").compile_cached("123") == str.encode("123")
assert diskcache_get("disabled_key", "123") is None
self.assertEqual(MockCompiler("disabled_key").compile_cached("123"), str.encode("123"))
self.assertIsNone(diskcache_get("disabled_key", "123"))
def test_device_compile(self):
getenv.cache_clear()

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Optional, Dict, Tuple, Any, Iterator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys, re
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
from tinygrad.renderer import Renderer
@@ -14,7 +14,7 @@ class _Device:
def __init__(self) -> None:
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def _canonicalize(self, device:str) -> str: return ((d:=device.split(":", 1)[0].upper()) + device[len(d):]).replace(":0", "")
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))