diff --git a/proof/analysis.py b/proof/analysis.py index f4ce294..e1db426 100644 --- a/proof/analysis.py +++ b/proof/analysis.py @@ -10,20 +10,27 @@ propagated to all dependent analyses. """ +from __future__ import print_function + import bz2 from copy import deepcopy from glob import glob import hashlib import inspect +import logging import os try: import cPickle as pickle -except ImportError: # pragma: no cover +except ImportError: # pragma: no cover import pickle import six + +logger = logging.getLogger(__name__) + + class Cache(object): """ Utility class for managing cached data. @@ -43,9 +50,8 @@ def get(self): Get cached data from memory or disk. """ if self._data is None: - f = bz2.BZ2File(self._cache_path) - self._data = pickle.loads(f.read()) - f.close() + with bz2.BZ2File(self._cache_path) as f: + self._data = pickle.loads(f.read()) return deepcopy(self._data) @@ -55,9 +61,9 @@ def set(self, data): """ self._data = data - f = bz2.BZ2File(self._cache_path, 'w') - f.write(pickle.dumps(self._data)) - f.close() + with bz2.BZ2File(self._cache_path, 'w') as f: + f.write(pickle.dumps(self._data)) + def never_cache(func): """ @@ -67,6 +73,7 @@ def never_cache(func): return func + class Analysis(object): """ An Analysis is a function whose source code fingerprint and output can be @@ -77,20 +84,22 @@ class Analysis(object): If a parent analysis changes then it and all it's children will be refreshed. - :param func: A callable that implements the analysis. Must accept a `data` + :param _callable: A callable that implements the analysis. Must accept a `data` argument that is the state inherited from its ancestors analysis. :param cache_dir: Where to stored the cache files for this analysis. :param _trace: The ancestors this analysis, if any. For internal use only. """ - def __init__(self, func, cache_dir='.proof', _trace=[]): - self._name = func.__name__ - self._func = func + def __init__(self, _callable, cache_dir='.proof', _trace=[]): + self._name = _callable.__name__ + self._callable = _callable self._cache_dir = cache_dir self._trace = _trace + [self] self._child_analyses = [] - self._cache_path = os.path.join(self._cache_dir, '%s.cache' % self._fingerprint()) + self._cache_path = os.path.join( + self._cache_dir, '%s.cache' % self._fingerprint() + ) self._cache = Cache(self._cache_path) self._registered_cache_paths = [] @@ -102,7 +111,7 @@ def _fingerprint(self): """ hasher = hashlib.md5() - history = '\n'.join([analysis._name for analysis in self._trace]) + history = '\n'.join(analysis._name for analysis in self._trace) # In Python 3 function names can be non-ascii identifiers if six.PY3: @@ -110,7 +119,14 @@ def _fingerprint(self): hasher.update(history) - source = inspect.getsource(self._func) + function_or_method = ( + inspect.isfunction(self._callable), inspect.ismethod(self._callable) + ) + + if any(function_or_method): + source = inspect.getsource(self._callable) + else: + source = inspect.getsource(self._callable.__class__) # In Python 3 inspect.getsource returns unicode data if six.PY3: @@ -129,20 +145,16 @@ def _cleanup_cache_files(self): if path not in self._registered_cache_paths: os.remove(path) - def then(self, child_func): + def then(self, child_callable): """ Create a new analysis which will run after this one has completed with access to the data it generated. - :param func: A callable that implements the analysis. Must accept a + :param child_callable: A callable that implements the analysis. Must accept a `data` argument that is the state inherited from its ancestors analysis. """ - analysis = Analysis( - child_func, - cache_dir=self._cache_dir, - _trace=self._trace - ) + analysis = Analysis(child_callable, self._cache_dir, self._trace) self._child_analyses.append(analysis) @@ -174,31 +186,28 @@ def run(self, refresh=False, _parent_cache=None): if not os.path.exists(self._cache_dir): os.makedirs(self._cache_dir) - do_not_cache = getattr(self._func, 'never_cache', False) + do_not_cache = getattr(self._callable, 'never_cache', False) + + if refresh: + logger.info('Refreshing: {}'.format(self._name)) - if refresh is True: - print('Refreshing: %s' % self._name) elif do_not_cache: refresh = True - - print('Never cached: %s' % self._name) + logger.info('Never cached: {}'.format(self._name)) elif not self._cache.check(): refresh = True - print('Stale cache: %s' % self._name) + logger.info('Stale cache: {}'.format(self._name)) if refresh: - if _parent_cache: - local_data = _parent_cache.get() - else: - local_data = {} + local_data = _parent_cache.get() if _parent_cache else {} - self._func(local_data) + self._callable(local_data) if not do_not_cache: self._cache.set(local_data) else: - print('Deferring to cache: %s' % self._name) + logger.info('Deferring to cache: {}'.format(self._name)) for analysis in self._child_analyses: analysis.run(refresh=refresh, _parent_cache=self._cache) diff --git a/tests/test_proof.py b/tests/test_proof.py index acf6619..5dff43d 100644 --- a/tests/test_proof.py +++ b/tests/test_proof.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# -*- coding: utf8 -*- +# -*- coding: utf-8 -*- from copy import deepcopy from glob import glob @@ -30,7 +30,7 @@ def setUp(self): self.executed_stage_never_cache = 0 def tearDown(self): - shutil.rmtree(TEST_CACHE) + shutil.rmtree(TEST_CACHE, ignore_errors=True) def stage1(self, data): self.executed_stage1 += 1