Implement multithreading support.

This commit is contained in:
Quantum 2017-08-13 18:12:00 -04:00
parent e63764c72e
commit 7d0a41a2cd
4 changed files with 128 additions and 61 deletions

View file

@ -1,10 +1,72 @@
callbacks = []
import logging
from functools import wraps
from optimize_later.utils import NoArgDecoratorMeta
try:
import threading
except ImportError:
import dummy_threading as threading
log = logging.getLogger(__name__.rpartition('.')[0] or __name__)
_global_callbacks = []
_local = threading.local()
def get_callbacks():
try:
return _local.callbacks
except AttributeError:
return _global_callbacks
def register_callback(callback):
callbacks.append(callback)
get_callbacks().append(callback)
return callback
def deregister_callback(callback):
callbacks.remove(callback)
get_callbacks().remove(callback)
def global_callback(report):
for callback in get_callbacks():
try:
callback(report)
except Exception:
log.exception('Failed to invoke global callback: %r', callback)
class optimize_context(object):
__metaclass__ = NoArgDecoratorMeta
def __init__(self, callbacks=None):
self.callbacks = callbacks
def __enter__(self):
try:
self.old_context = _local.callbacks
except AttributeError:
self.old_context = None
if self.callbacks is None:
if self.old_context is None:
_local.callbacks = _global_callbacks[:]
else:
_local.callbacks = self.old_context[:]
else:
_local.callbacks = self.callbacks[:]
def __exit__(self, exc_type, exc_val, exc_tb):
if self.old_context is None:
del _local.callbacks
else:
_local.callbacks = self.old_context
def __call__(self, function):
@wraps(function)
def wrapper(*args, **kwargs):
with optimize_context(self.callbacks):
return function(*args, **kwargs)
return wrapper

View file

@ -5,22 +5,14 @@ import time
from copy import copy
from functools import wraps
from numbers import Number
from types import FunctionType
from optimize_later import config
from optimize_later.config import global_callback
from optimize_later.utils import NoArgDecoratorMeta
log = logging.getLogger(__name__.rpartition('.')[0] or __name__)
timer = [time.time, time.clock][os.name == 'nt']
def global_callback(report):
for callback in config.callbacks:
try:
callback(report)
except Exception:
log.exception('Failed to invoke global callback: %r', callback)
def _generate_default_name():
for entry in inspect.stack():
file, line = entry[1:3]
@ -97,13 +89,6 @@ class OptimizeReport(object):
return self.short()
class NoArgDecoratorMeta(type):
def __call__(cls, *args, **kwargs):
if len(args) == 1 and isinstance(args[0], FunctionType):
return cls()(args[0])
return super(NoArgDecoratorMeta, cls).__call__(*args, **kwargs)
class optimize_later(object):
__metaclass__ = NoArgDecoratorMeta
@ -158,27 +143,3 @@ class optimize_later(object):
with copy(self):
return function(*args, **kwargs)
return wrapped
class optimize_context(object):
__metaclass__ = NoArgDecoratorMeta
def __init__(self, callbacks=None):
self.callbacks = callbacks
def __enter__(self):
self.old_context = config.callbacks[:]
if self.callbacks is None:
config.callbacks[:] = self.old_context
else:
config.callbacks[:] = self.callbacks
def __exit__(self, exc_type, exc_val, exc_tb):
config.callbacks[:] = self.old_context
def __call__(self, function):
@wraps(function)
def wrapper(*args, **kwargs):
with optimize_context(self.callbacks):
return function(*args, **kwargs)
return wrapper

View file

@ -2,7 +2,59 @@ import time
from unittest import TestCase
from optimize_later import config
from optimize_later.core import optimize_later, OptimizeReport, OptimizeBlock, optimize_context
from optimize_later.core import optimize_later, OptimizeReport, OptimizeBlock
from optimize_later.config import optimize_context
class OptimizeContextTest(TestCase):
def test_optimize_context(self):
old_global, config._global_callbacks = config._global_callbacks, []
config.register_callback(1)
with optimize_context():
self.assertEqual(config.get_callbacks(), [1])
config.register_callback(2)
self.assertEqual(config.get_callbacks(), [1, 2])
with optimize_context([]):
self.assertEqual(config.get_callbacks(), [])
config.register_callback(3)
self.assertEqual(config.get_callbacks(), [3])
config.register_callback(4)
self.assertEqual(config.get_callbacks(), [1, 2, 4])
config.deregister_callback(2)
self.assertEqual(config.get_callbacks(), [1, 4])
config.deregister_callback(1)
self.assertEqual(config.get_callbacks(), [4])
self.assertEqual(config.get_callbacks(), [1])
config._global_callbacks = old_global
@optimize_context
def test_optimize_context_thread(self):
try:
import threading
except ImportError:
return
test = lambda report: None
seen = [None]
config.register_callback(test)
def thread_proc():
seen[0] = config.get_callbacks()
thread = threading.Thread(target=thread_proc)
thread.start()
thread.join()
self.assertIsInstance(seen[0], list)
self.assertNotIn(test, seen[0])
self.assertIn(test, config.get_callbacks())
class OptimizeLaterTest(TestCase):
@ -130,19 +182,3 @@ class OptimizeLaterTest(TestCase):
self.assertEqual(len(reports), 10)
for report in reports:
self.assertReport(report)
def test_optimize_context(self):
config.register_callback(1)
with optimize_context():
self.assertEqual(config.callbacks, [1])
config.register_callback(2)
self.assertEqual(config.callbacks, [1, 2])
with optimize_context([]):
self.assertEqual(config.callbacks, [])
config.register_callback(3)
self.assertEqual(config.callbacks, [3])
config.register_callback(4)
self.assertEqual(config.callbacks, [1, 2, 4])
self.assertEqual(config.callbacks, [1])

8
optimize_later/utils.py Normal file
View file

@ -0,0 +1,8 @@
from types import FunctionType
class NoArgDecoratorMeta(type):
def __call__(cls, *args, **kwargs):
if len(args) == 1 and isinstance(args[0], FunctionType):
return cls()(args[0])
return super(NoArgDecoratorMeta, cls).__call__(*args, **kwargs)