From 8a93c48901768d11a2ccd7697e5724ec5e348779 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 1 Oct 2024 13:58:42 +0800 Subject: [PATCH] pickle main pattern matcher [run_process_replay] (#6827) * pickle main pattern matcher [run_process_replay] * del line --- test/test_pickle.py | 4 ++++ test/unit/test_helpers.py | 6 ++++++ tinygrad/helpers.py | 14 ++++++++------ tinygrad/ops.py | 5 +++-- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/test/test_pickle.py b/test/test_pickle.py index d66769e6e7..6366eacf76 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -20,6 +20,10 @@ class TestPickle(unittest.TestCase): pm2 = pickle.loads(pm_str) self.assertEqual(pm2.rewrite(sink).key, tt.key) + def test_pickle_main_pattern_matcher(self): + from tinygrad.codegen.uopgraph import sym + pickle.dumps(sym) + def test_pickle_realized_tensor(self): t = Tensor.rand(10, 10).realize() st = pickle.dumps(t) diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 33c85a50ce..b8722f515b 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -17,11 +17,13 @@ class TestContextVars(unittest.TestCase): _TMP = ContextVar("_TMP", 5) self.assertEqual(_TMP.value, 5) + @unittest.expectedFailure def test_multiple_creation_ignored(self): _TMP2 = ContextVar("_TMP2", 1) _TMP2 = ContextVar("_TMP2", 2) self.assertEqual(_TMP2.value, 1) + @unittest.expectedFailure def test_new_var_inside_context(self): # Creating a _new_ variable inside a context should not have any effect on its scope (?) with Context(VARIABLE=1): @@ -29,6 +31,7 @@ class TestContextVars(unittest.TestCase): _TMP3 = ContextVar("_TMP3", 2) self.assertEqual(_TMP3.value, 1) + @unittest.expectedFailure def test_value_accross_modules(self): # Mocking module import by invoking the code but not in our globals(). exec('from tinygrad.helpers import ContextVar;C = ContextVar("C", 13)', {}) # pylint:disable=exec-used @@ -36,6 +39,7 @@ class TestContextVars(unittest.TestCase): C = ContextVar("C", 0) self.assertEqual(C.value, 13) + @unittest.expectedFailure def test_assignment_across_modules(self): B = ContextVar("B", 1) # local assignment @@ -56,6 +60,7 @@ class TestContextVars(unittest.TestCase): with Context(SOMETHING_ELSE=1): pass + @unittest.expectedFailure def test_inside_context_assignment(self): with Context(VARIABLE=4): # What you can and cannot do inside a context. @@ -70,6 +75,7 @@ class TestContextVars(unittest.TestCase): # Related to 2. above. Note that VARIABLE is back to 0 again as expected. self.assertEqual(VARIABLE.value, 0) + @unittest.expectedFailure def test_new_var_inside_context_other_module(self): with Context(VARIABLE=1): _NEW2 = ContextVar("_NEW2", 0) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 5218a8ac67..dcac86491f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,6 +1,6 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip -import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect +import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect, importlib from dataclasses import dataclass from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 @@ -103,11 +103,10 @@ class ContextVar: _cache: ClassVar[Dict[str, ContextVar]] = {} value: int key: str - def __new__(cls, key, default_value): - if key in ContextVar._cache: return ContextVar._cache[key] - instance = ContextVar._cache[key] = super().__new__(cls) - instance.value, instance.key = getenv(key, default_value), key - return instance + def __init__(self, key, default_value): + assert key not in ContextVar._cache, f"attempt to recreate ContextVar {key}" + ContextVar._cache[key] = self + self.value, self.key = getenv(key, default_value), key def __bool__(self): return bool(self.value) def __ge__(self, x): return self.value >= x def __gt__(self, x): return self.value > x @@ -384,3 +383,6 @@ def _serialize_code(code:types.CodeType): 'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars'] return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args) copyreg.pickle(types.CodeType, _serialize_code) + +def _serialize_module(module:types.ModuleType): return importlib.import_module, (module.__name__,) +copyreg.pickle(types.ModuleType, _serialize_module) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b5bfee512f..6c33e3556f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -490,9 +490,10 @@ def deconstruct_function(fxn:Callable) -> Tuple: new_globals = {k:v for k,v in fxn.__globals__.items() if k in fxn.__code__.co_names} for co in fxn.__code__.co_consts: if isinstance(co, types.CodeType): new_globals.update({k:v for k,v in fxn.__globals__.items() if k in co.co_names}) - new_code_obj = pickle.loads(pickle.dumps(fxn.__code__)) if getenv("TEST_PICKLE") else fxn.__code__ # NOTE: optional round trip through pickle! + # NOTE: optional round trip through pickle! assert fxn.__closure__ is None, "closures are not supported in pattern matchers" - return new_code_obj, new_globals, fxn.__name__, fxn.__defaults__ + ret = fxn.__code__, new_globals, fxn.__name__, fxn.__defaults__ + return pickle.loads(pickle.dumps(ret)) if getenv("TEST_PICKLE") else ret class PatternMatcher: def __init__(self, patterns:List[Tuple[UPat, Callable]]):