More sensible optimize_context

This commit is contained in:
Quantum 2017-08-13 20:21:17 -04:00
parent 97cea9c7e5
commit 821fcf9414
3 changed files with 28 additions and 12 deletions

View file

@ -92,13 +92,20 @@ def function():
register_callback(my_report_function) register_callback(my_report_function)
# Remove global callbacks for this block. # Remove global callbacks for this block.
with optimize_context([]): with optimize_context(renew=True):
pass pass
# or... # or...
@optimize_context([]) @optimize_context(renew=True)
def function():
pass
# Shortcut registration syntax.
with optimize_context(my_report_function):
pass
@optimize_context(my_report_function, renew=True)
def function(): def function():
pass pass
# Of course, you can specify a list of callbacks to enable exclusively as well.
``` ```
A sample short report: A sample short report:

View file

@ -39,8 +39,9 @@ def global_callback(report):
class optimize_context(with_metaclass(NoArgDecoratorMeta)): class optimize_context(with_metaclass(NoArgDecoratorMeta)):
def __init__(self, callbacks=None): def __init__(self, callbacks=None, reset=False):
self.callbacks = callbacks self.callbacks = callbacks or []
self.reset = reset
def __enter__(self): def __enter__(self):
try: try:
@ -48,13 +49,13 @@ class optimize_context(with_metaclass(NoArgDecoratorMeta)):
except AttributeError: except AttributeError:
self.old_context = None self.old_context = None
if self.callbacks is None: if self.reset:
if self.old_context is None: base_context = []
_local.callbacks = _global_callbacks[:] elif self.old_context is None:
base_context = _global_callbacks
else: else:
_local.callbacks = self.old_context[:] base_context = self.old_context
else: _local.callbacks = base_context + self.callbacks
_local.callbacks = self.callbacks[:]
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if self.old_context is None: if self.old_context is None:

View file

@ -16,11 +16,19 @@ class OptimizeContextTest(TestCase):
config.register_callback(2) config.register_callback(2)
self.assertEqual(config.get_callbacks(), [1, 2]) self.assertEqual(config.get_callbacks(), [1, 2])
with optimize_context([]): with optimize_context(reset=True):
self.assertEqual(config.get_callbacks(), []) self.assertEqual(config.get_callbacks(), [])
config.register_callback(3) config.register_callback(3)
self.assertEqual(config.get_callbacks(), [3]) self.assertEqual(config.get_callbacks(), [3])
with optimize_context([3], reset=True):
self.assertEqual(config.get_callbacks(), [3])
with optimize_context([3]):
self.assertEqual(config.get_callbacks(), [1, 2, 3])
self.assertEqual(config.get_callbacks(), [1, 2])
config.register_callback(4) config.register_callback(4)
self.assertEqual(config.get_callbacks(), [1, 2, 4]) self.assertEqual(config.get_callbacks(), [1, 2, 4])