fix: Context (#1430)

* Fixed issue in Context

* Cleaned up fix

Now that DEBUG.value = 3 always works we can do so in __new__ as well.
This commit is contained in:
Alex Telon
2023-08-04 16:53:48 +02:00
committed by GitHub
parent c08ed1949f
commit 7325bc914f
2 changed files with 12 additions and 4 deletions

View File

@@ -99,5 +99,12 @@ with Context(VARIABLE=1):
test()
self.assertEqual(VARIABLE.value, 0)
def test_context_exit_reverts_updated_values(self):
D = ContextVar("D", 1)
D.value = 2
with Context(D=3):
...
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."
if __name__ == '__main__':
unittest.main()

View File

@@ -30,10 +30,11 @@ class Context(contextlib.ContextDecorator):
stack: ClassVar[List[dict[str, int]]] = [{}]
def __init__(self, **kwargs): self.kwargs = kwargs
def __enter__(self):
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
Context.stack.append(self.kwargs)
Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
def __exit__(self, *args):
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, Context.stack[0][k])
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
class ContextVar:
_cache: ClassVar[Dict[str, ContextVar]] = {}
@@ -42,7 +43,7 @@ class ContextVar:
def __new__(cls, key, default_value):
if key in ContextVar._cache: return ContextVar._cache[key]
instance = ContextVar._cache[key] = super().__new__(cls)
instance.value = Context.stack[0][key] = getenv(key, default_value)
instance.value = getenv(key, default_value)
return instance
def __bool__(self): return bool(self.value)
def __ge__(self, x): return self.value >= x