From 7325bc914ffe72f042dd6b7a5fdcae87800a5851 Mon Sep 17 00:00:00 2001 From: Alex Telon Date: Fri, 4 Aug 2023 16:53:48 +0200 Subject: [PATCH] fix: Context (#1430) * Fixed issue in Context * Cleaned up fix Now that DEBUG.value = 3 always works we can do so in __new__ as well. --- test/test_helpers.py | 7 +++++++ tinygrad/helpers.py | 9 +++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/test/test_helpers.py b/test/test_helpers.py index 62822f1555..deacd97a6a 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -99,5 +99,12 @@ with Context(VARIABLE=1): test() self.assertEqual(VARIABLE.value, 0) + def test_context_exit_reverts_updated_values(self): + D = ContextVar("D", 1) + D.value = 2 + with Context(D=3): + ... + assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value." + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index ffca0eee16..3cf9d885be 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -30,10 +30,11 @@ class Context(contextlib.ContextDecorator): stack: ClassVar[List[dict[str, int]]] = [{}] def __init__(self, **kwargs): self.kwargs = kwargs def __enter__(self): - for k,v in self.kwargs.items(): ContextVar._cache[k].value = v - Context.stack.append(self.kwargs) + Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state. + for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state. + Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later. def __exit__(self, *args): - for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, Context.stack[0][k]) + for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value) class ContextVar: _cache: ClassVar[Dict[str, ContextVar]] = {} @@ -42,7 +43,7 @@ class ContextVar: def __new__(cls, key, default_value): if key in ContextVar._cache: return ContextVar._cache[key] instance = ContextVar._cache[key] = super().__new__(cls) - instance.value = Context.stack[0][key] = getenv(key, default_value) + instance.value = getenv(key, default_value) return instance def __bool__(self): return bool(self.value) def __ge__(self, x): return self.value >= x