mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix: Context (#1430)
* Fixed issue in Context * Cleaned up fix Now that DEBUG.value = 3 always works we can do so in __new__ as well.
This commit is contained in:
@@ -99,5 +99,12 @@ with Context(VARIABLE=1):
|
||||
test()
|
||||
self.assertEqual(VARIABLE.value, 0)
|
||||
|
||||
def test_context_exit_reverts_updated_values(self):
|
||||
D = ContextVar("D", 1)
|
||||
D.value = 2
|
||||
with Context(D=3):
|
||||
...
|
||||
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -30,10 +30,11 @@ class Context(contextlib.ContextDecorator):
|
||||
stack: ClassVar[List[dict[str, int]]] = [{}]
|
||||
def __init__(self, **kwargs): self.kwargs = kwargs
|
||||
def __enter__(self):
|
||||
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
|
||||
Context.stack.append(self.kwargs)
|
||||
Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
|
||||
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
|
||||
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
|
||||
def __exit__(self, *args):
|
||||
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, Context.stack[0][k])
|
||||
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
|
||||
|
||||
class ContextVar:
|
||||
_cache: ClassVar[Dict[str, ContextVar]] = {}
|
||||
@@ -42,7 +43,7 @@ class ContextVar:
|
||||
def __new__(cls, key, default_value):
|
||||
if key in ContextVar._cache: return ContextVar._cache[key]
|
||||
instance = ContextVar._cache[key] = super().__new__(cls)
|
||||
instance.value = Context.stack[0][key] = getenv(key, default_value)
|
||||
instance.value = getenv(key, default_value)
|
||||
return instance
|
||||
def __bool__(self): return bool(self.value)
|
||||
def __ge__(self, x): return self.value >= x
|
||||
|
||||
Reference in New Issue
Block a user