mirror of
https://github.com/quantum5/optimize-later.git
synced 2025-04-24 12:32:04 -04:00
Implement multithreading support.
This commit is contained in:
parent
e63764c72e
commit
7d0a41a2cd
|
@ -1,10 +1,72 @@
|
|||
callbacks = []
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
from optimize_later.utils import NoArgDecoratorMeta
|
||||
|
||||
try:
|
||||
import threading
|
||||
except ImportError:
|
||||
import dummy_threading as threading
|
||||
|
||||
log = logging.getLogger(__name__.rpartition('.')[0] or __name__)
|
||||
|
||||
_global_callbacks = []
|
||||
_local = threading.local()
|
||||
|
||||
|
||||
def get_callbacks():
|
||||
try:
|
||||
return _local.callbacks
|
||||
except AttributeError:
|
||||
return _global_callbacks
|
||||
|
||||
|
||||
def register_callback(callback):
|
||||
callbacks.append(callback)
|
||||
get_callbacks().append(callback)
|
||||
return callback
|
||||
|
||||
|
||||
def deregister_callback(callback):
|
||||
callbacks.remove(callback)
|
||||
get_callbacks().remove(callback)
|
||||
|
||||
|
||||
def global_callback(report):
|
||||
for callback in get_callbacks():
|
||||
try:
|
||||
callback(report)
|
||||
except Exception:
|
||||
log.exception('Failed to invoke global callback: %r', callback)
|
||||
|
||||
|
||||
class optimize_context(object):
|
||||
__metaclass__ = NoArgDecoratorMeta
|
||||
|
||||
def __init__(self, callbacks=None):
|
||||
self.callbacks = callbacks
|
||||
|
||||
def __enter__(self):
|
||||
try:
|
||||
self.old_context = _local.callbacks
|
||||
except AttributeError:
|
||||
self.old_context = None
|
||||
|
||||
if self.callbacks is None:
|
||||
if self.old_context is None:
|
||||
_local.callbacks = _global_callbacks[:]
|
||||
else:
|
||||
_local.callbacks = self.old_context[:]
|
||||
else:
|
||||
_local.callbacks = self.callbacks[:]
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.old_context is None:
|
||||
del _local.callbacks
|
||||
else:
|
||||
_local.callbacks = self.old_context
|
||||
|
||||
def __call__(self, function):
|
||||
@wraps(function)
|
||||
def wrapper(*args, **kwargs):
|
||||
with optimize_context(self.callbacks):
|
||||
return function(*args, **kwargs)
|
||||
return wrapper
|
||||
|
|
|
@ -5,22 +5,14 @@ import time
|
|||
from copy import copy
|
||||
from functools import wraps
|
||||
from numbers import Number
|
||||
from types import FunctionType
|
||||
|
||||
from optimize_later import config
|
||||
from optimize_later.config import global_callback
|
||||
from optimize_later.utils import NoArgDecoratorMeta
|
||||
|
||||
log = logging.getLogger(__name__.rpartition('.')[0] or __name__)
|
||||
timer = [time.time, time.clock][os.name == 'nt']
|
||||
|
||||
|
||||
def global_callback(report):
|
||||
for callback in config.callbacks:
|
||||
try:
|
||||
callback(report)
|
||||
except Exception:
|
||||
log.exception('Failed to invoke global callback: %r', callback)
|
||||
|
||||
|
||||
def _generate_default_name():
|
||||
for entry in inspect.stack():
|
||||
file, line = entry[1:3]
|
||||
|
@ -97,13 +89,6 @@ class OptimizeReport(object):
|
|||
return self.short()
|
||||
|
||||
|
||||
class NoArgDecoratorMeta(type):
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if len(args) == 1 and isinstance(args[0], FunctionType):
|
||||
return cls()(args[0])
|
||||
return super(NoArgDecoratorMeta, cls).__call__(*args, **kwargs)
|
||||
|
||||
|
||||
class optimize_later(object):
|
||||
__metaclass__ = NoArgDecoratorMeta
|
||||
|
||||
|
@ -158,27 +143,3 @@ class optimize_later(object):
|
|||
with copy(self):
|
||||
return function(*args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
|
||||
class optimize_context(object):
|
||||
__metaclass__ = NoArgDecoratorMeta
|
||||
|
||||
def __init__(self, callbacks=None):
|
||||
self.callbacks = callbacks
|
||||
|
||||
def __enter__(self):
|
||||
self.old_context = config.callbacks[:]
|
||||
if self.callbacks is None:
|
||||
config.callbacks[:] = self.old_context
|
||||
else:
|
||||
config.callbacks[:] = self.callbacks
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
config.callbacks[:] = self.old_context
|
||||
|
||||
def __call__(self, function):
|
||||
@wraps(function)
|
||||
def wrapper(*args, **kwargs):
|
||||
with optimize_context(self.callbacks):
|
||||
return function(*args, **kwargs)
|
||||
return wrapper
|
||||
|
|
|
@ -2,7 +2,59 @@ import time
|
|||
from unittest import TestCase
|
||||
|
||||
from optimize_later import config
|
||||
from optimize_later.core import optimize_later, OptimizeReport, OptimizeBlock, optimize_context
|
||||
from optimize_later.core import optimize_later, OptimizeReport, OptimizeBlock
|
||||
from optimize_later.config import optimize_context
|
||||
|
||||
|
||||
class OptimizeContextTest(TestCase):
|
||||
def test_optimize_context(self):
|
||||
old_global, config._global_callbacks = config._global_callbacks, []
|
||||
|
||||
config.register_callback(1)
|
||||
with optimize_context():
|
||||
self.assertEqual(config.get_callbacks(), [1])
|
||||
config.register_callback(2)
|
||||
self.assertEqual(config.get_callbacks(), [1, 2])
|
||||
|
||||
with optimize_context([]):
|
||||
self.assertEqual(config.get_callbacks(), [])
|
||||
config.register_callback(3)
|
||||
self.assertEqual(config.get_callbacks(), [3])
|
||||
|
||||
config.register_callback(4)
|
||||
self.assertEqual(config.get_callbacks(), [1, 2, 4])
|
||||
|
||||
config.deregister_callback(2)
|
||||
self.assertEqual(config.get_callbacks(), [1, 4])
|
||||
|
||||
config.deregister_callback(1)
|
||||
self.assertEqual(config.get_callbacks(), [4])
|
||||
|
||||
self.assertEqual(config.get_callbacks(), [1])
|
||||
|
||||
config._global_callbacks = old_global
|
||||
|
||||
@optimize_context
|
||||
def test_optimize_context_thread(self):
|
||||
try:
|
||||
import threading
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
test = lambda report: None
|
||||
seen = [None]
|
||||
config.register_callback(test)
|
||||
|
||||
def thread_proc():
|
||||
seen[0] = config.get_callbacks()
|
||||
|
||||
thread = threading.Thread(target=thread_proc)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
self.assertIsInstance(seen[0], list)
|
||||
self.assertNotIn(test, seen[0])
|
||||
self.assertIn(test, config.get_callbacks())
|
||||
|
||||
|
||||
class OptimizeLaterTest(TestCase):
|
||||
|
@ -130,19 +182,3 @@ class OptimizeLaterTest(TestCase):
|
|||
self.assertEqual(len(reports), 10)
|
||||
for report in reports:
|
||||
self.assertReport(report)
|
||||
|
||||
def test_optimize_context(self):
|
||||
config.register_callback(1)
|
||||
with optimize_context():
|
||||
self.assertEqual(config.callbacks, [1])
|
||||
config.register_callback(2)
|
||||
self.assertEqual(config.callbacks, [1, 2])
|
||||
|
||||
with optimize_context([]):
|
||||
self.assertEqual(config.callbacks, [])
|
||||
config.register_callback(3)
|
||||
self.assertEqual(config.callbacks, [3])
|
||||
|
||||
config.register_callback(4)
|
||||
self.assertEqual(config.callbacks, [1, 2, 4])
|
||||
self.assertEqual(config.callbacks, [1])
|
||||
|
|
8
optimize_later/utils.py
Normal file
8
optimize_later/utils.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from types import FunctionType
|
||||
|
||||
|
||||
class NoArgDecoratorMeta(type):
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if len(args) == 1 and isinstance(args[0], FunctionType):
|
||||
return cls()(args[0])
|
||||
return super(NoArgDecoratorMeta, cls).__call__(*args, **kwargs)
|
Loading…
Reference in a new issue