More sensible optimize_context

This commit is contained in:
Quantum 2017-08-13 20:21:17 -04:00
parent 97cea9c7e5
commit 821fcf9414
3 changed files with 28 additions and 12 deletions

View file

@ -92,13 +92,20 @@ def function():
register_callback(my_report_function)
# Remove global callbacks for this block.
with optimize_context([]):
with optimize_context(renew=True):
pass
# or...
@optimize_context([])
@optimize_context(renew=True)
def function():
pass
# Shortcut registration syntax.
with optimize_context(my_report_function):
pass
@optimize_context(my_report_function, renew=True)
def function():
pass
# Of course, you can specify a list of callbacks to enable exclusively as well.
```
A sample short report:

View file

@ -39,8 +39,9 @@ def global_callback(report):
class optimize_context(with_metaclass(NoArgDecoratorMeta)):
def __init__(self, callbacks=None):
self.callbacks = callbacks
def __init__(self, callbacks=None, reset=False):
self.callbacks = callbacks or []
self.reset = reset
def __enter__(self):
try:
@ -48,13 +49,13 @@ class optimize_context(with_metaclass(NoArgDecoratorMeta)):
except AttributeError:
self.old_context = None
if self.callbacks is None:
if self.old_context is None:
_local.callbacks = _global_callbacks[:]
else:
_local.callbacks = self.old_context[:]
if self.reset:
base_context = []
elif self.old_context is None:
base_context = _global_callbacks
else:
_local.callbacks = self.callbacks[:]
base_context = self.old_context
_local.callbacks = base_context + self.callbacks
def __exit__(self, exc_type, exc_val, exc_tb):
if self.old_context is None:

View file

@ -16,11 +16,19 @@ class OptimizeContextTest(TestCase):
config.register_callback(2)
self.assertEqual(config.get_callbacks(), [1, 2])
with optimize_context([]):
with optimize_context(reset=True):
self.assertEqual(config.get_callbacks(), [])
config.register_callback(3)
self.assertEqual(config.get_callbacks(), [3])
with optimize_context([3], reset=True):
self.assertEqual(config.get_callbacks(), [3])
with optimize_context([3]):
self.assertEqual(config.get_callbacks(), [1, 2, 3])
self.assertEqual(config.get_callbacks(), [1, 2])
config.register_callback(4)
self.assertEqual(config.get_callbacks(), [1, 2, 4])