Implement multithreading support.

This commit is contained in:
Quantum 2017-08-13 18:12:00 -04:00
parent e63764c72e
commit 7d0a41a2cd
4 changed files with 128 additions and 61 deletions

View file

@ -1,10 +1,72 @@
callbacks = [] import logging
from functools import wraps
from optimize_later.utils import NoArgDecoratorMeta
try:
import threading
except ImportError:
import dummy_threading as threading
log = logging.getLogger(__name__.rpartition('.')[0] or __name__)
_global_callbacks = []
_local = threading.local()
def get_callbacks():
try:
return _local.callbacks
except AttributeError:
return _global_callbacks
def register_callback(callback): def register_callback(callback):
callbacks.append(callback) get_callbacks().append(callback)
return callback return callback
def deregister_callback(callback): def deregister_callback(callback):
callbacks.remove(callback) get_callbacks().remove(callback)
def global_callback(report):
for callback in get_callbacks():
try:
callback(report)
except Exception:
log.exception('Failed to invoke global callback: %r', callback)
class optimize_context(object):
__metaclass__ = NoArgDecoratorMeta
def __init__(self, callbacks=None):
self.callbacks = callbacks
def __enter__(self):
try:
self.old_context = _local.callbacks
except AttributeError:
self.old_context = None
if self.callbacks is None:
if self.old_context is None:
_local.callbacks = _global_callbacks[:]
else:
_local.callbacks = self.old_context[:]
else:
_local.callbacks = self.callbacks[:]
def __exit__(self, exc_type, exc_val, exc_tb):
if self.old_context is None:
del _local.callbacks
else:
_local.callbacks = self.old_context
def __call__(self, function):
@wraps(function)
def wrapper(*args, **kwargs):
with optimize_context(self.callbacks):
return function(*args, **kwargs)
return wrapper

View file

@ -5,22 +5,14 @@ import time
from copy import copy from copy import copy
from functools import wraps from functools import wraps
from numbers import Number from numbers import Number
from types import FunctionType
from optimize_later import config from optimize_later.config import global_callback
from optimize_later.utils import NoArgDecoratorMeta
log = logging.getLogger(__name__.rpartition('.')[0] or __name__) log = logging.getLogger(__name__.rpartition('.')[0] or __name__)
timer = [time.time, time.clock][os.name == 'nt'] timer = [time.time, time.clock][os.name == 'nt']
def global_callback(report):
for callback in config.callbacks:
try:
callback(report)
except Exception:
log.exception('Failed to invoke global callback: %r', callback)
def _generate_default_name(): def _generate_default_name():
for entry in inspect.stack(): for entry in inspect.stack():
file, line = entry[1:3] file, line = entry[1:3]
@ -97,13 +89,6 @@ class OptimizeReport(object):
return self.short() return self.short()
class NoArgDecoratorMeta(type):
def __call__(cls, *args, **kwargs):
if len(args) == 1 and isinstance(args[0], FunctionType):
return cls()(args[0])
return super(NoArgDecoratorMeta, cls).__call__(*args, **kwargs)
class optimize_later(object): class optimize_later(object):
__metaclass__ = NoArgDecoratorMeta __metaclass__ = NoArgDecoratorMeta
@ -158,27 +143,3 @@ class optimize_later(object):
with copy(self): with copy(self):
return function(*args, **kwargs) return function(*args, **kwargs)
return wrapped return wrapped
class optimize_context(object):
__metaclass__ = NoArgDecoratorMeta
def __init__(self, callbacks=None):
self.callbacks = callbacks
def __enter__(self):
self.old_context = config.callbacks[:]
if self.callbacks is None:
config.callbacks[:] = self.old_context
else:
config.callbacks[:] = self.callbacks
def __exit__(self, exc_type, exc_val, exc_tb):
config.callbacks[:] = self.old_context
def __call__(self, function):
@wraps(function)
def wrapper(*args, **kwargs):
with optimize_context(self.callbacks):
return function(*args, **kwargs)
return wrapper

View file

@ -2,7 +2,59 @@ import time
from unittest import TestCase from unittest import TestCase
from optimize_later import config from optimize_later import config
from optimize_later.core import optimize_later, OptimizeReport, OptimizeBlock, optimize_context from optimize_later.core import optimize_later, OptimizeReport, OptimizeBlock
from optimize_later.config import optimize_context
class OptimizeContextTest(TestCase):
def test_optimize_context(self):
old_global, config._global_callbacks = config._global_callbacks, []
config.register_callback(1)
with optimize_context():
self.assertEqual(config.get_callbacks(), [1])
config.register_callback(2)
self.assertEqual(config.get_callbacks(), [1, 2])
with optimize_context([]):
self.assertEqual(config.get_callbacks(), [])
config.register_callback(3)
self.assertEqual(config.get_callbacks(), [3])
config.register_callback(4)
self.assertEqual(config.get_callbacks(), [1, 2, 4])
config.deregister_callback(2)
self.assertEqual(config.get_callbacks(), [1, 4])
config.deregister_callback(1)
self.assertEqual(config.get_callbacks(), [4])
self.assertEqual(config.get_callbacks(), [1])
config._global_callbacks = old_global
@optimize_context
def test_optimize_context_thread(self):
try:
import threading
except ImportError:
return
test = lambda report: None
seen = [None]
config.register_callback(test)
def thread_proc():
seen[0] = config.get_callbacks()
thread = threading.Thread(target=thread_proc)
thread.start()
thread.join()
self.assertIsInstance(seen[0], list)
self.assertNotIn(test, seen[0])
self.assertIn(test, config.get_callbacks())
class OptimizeLaterTest(TestCase): class OptimizeLaterTest(TestCase):
@ -130,19 +182,3 @@ class OptimizeLaterTest(TestCase):
self.assertEqual(len(reports), 10) self.assertEqual(len(reports), 10)
for report in reports: for report in reports:
self.assertReport(report) self.assertReport(report)
def test_optimize_context(self):
config.register_callback(1)
with optimize_context():
self.assertEqual(config.callbacks, [1])
config.register_callback(2)
self.assertEqual(config.callbacks, [1, 2])
with optimize_context([]):
self.assertEqual(config.callbacks, [])
config.register_callback(3)
self.assertEqual(config.callbacks, [3])
config.register_callback(4)
self.assertEqual(config.callbacks, [1, 2, 4])
self.assertEqual(config.callbacks, [1])

8
optimize_later/utils.py Normal file
View file

@ -0,0 +1,8 @@
from types import FunctionType
class NoArgDecoratorMeta(type):
def __call__(cls, *args, **kwargs):
if len(args) == 1 and isinstance(args[0], FunctionType):
return cls()(args[0])
return super(NoArgDecoratorMeta, cls).__call__(*args, **kwargs)