Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of callables #15

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 42 additions & 33 deletions proof/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,27 @@
propagated to all dependent analyses.
"""

from __future__ import print_function

import bz2
from copy import deepcopy
from glob import glob
import hashlib
import inspect
import logging
import os

try:
import cPickle as pickle
except ImportError: # pragma: no cover
except ImportError: # pragma: no cover
import pickle

import six


logger = logging.getLogger(__name__)


class Cache(object):
"""
Utility class for managing cached data.
Expand All @@ -43,9 +50,8 @@ def get(self):
Get cached data from memory or disk.
"""
if self._data is None:
f = bz2.BZ2File(self._cache_path)
self._data = pickle.loads(f.read())
f.close()
with bz2.BZ2File(self._cache_path) as f:
self._data = pickle.loads(f.read())

return deepcopy(self._data)

Expand All @@ -55,9 +61,9 @@ def set(self, data):
"""
self._data = data

f = bz2.BZ2File(self._cache_path, 'w')
f.write(pickle.dumps(self._data))
f.close()
with bz2.BZ2File(self._cache_path, 'w') as f:
f.write(pickle.dumps(self._data))


def never_cache(func):
"""
Expand All @@ -67,6 +73,7 @@ def never_cache(func):

return func


class Analysis(object):
"""
An Analysis is a function whose source code fingerprint and output can be
Expand All @@ -77,20 +84,22 @@ class Analysis(object):
If a parent analysis changes then it and all it's children will be
refreshed.

:param func: A callable that implements the analysis. Must accept a `data`
:param _callable: A callable that implements the analysis. Must accept a `data`
argument that is the state inherited from its ancestors analysis.
:param cache_dir: Where to stored the cache files for this analysis.
:param _trace: The ancestors this analysis, if any. For internal use
only.
"""
def __init__(self, func, cache_dir='.proof', _trace=[]):
self._name = func.__name__
self._func = func
def __init__(self, _callable, cache_dir='.proof', _trace=[]):
self._name = _callable.__name__
self._callable = _callable
self._cache_dir = cache_dir
self._trace = _trace + [self]
self._child_analyses = []

self._cache_path = os.path.join(self._cache_dir, '%s.cache' % self._fingerprint())
self._cache_path = os.path.join(
self._cache_dir, '%s.cache' % self._fingerprint()
)
self._cache = Cache(self._cache_path)

self._registered_cache_paths = []
Expand All @@ -102,15 +111,22 @@ def _fingerprint(self):
"""
hasher = hashlib.md5()

history = '\n'.join([analysis._name for analysis in self._trace])
history = '\n'.join(analysis._name for analysis in self._trace)

# In Python 3 function names can be non-ascii identifiers
if six.PY3:
history = history.encode('utf-8')

hasher.update(history)

source = inspect.getsource(self._func)
function_or_method = (
inspect.isfunction(self._callable), inspect.ismethod(self._callable)
)

if any(function_or_method):
source = inspect.getsource(self._callable)
else:
source = inspect.getsource(self._callable.__class__)

# In Python 3 inspect.getsource returns unicode data
if six.PY3:
Expand All @@ -129,20 +145,16 @@ def _cleanup_cache_files(self):
if path not in self._registered_cache_paths:
os.remove(path)

def then(self, child_func):
def then(self, child_callable):
"""
Create a new analysis which will run after this one has completed with
access to the data it generated.

:param func: A callable that implements the analysis. Must accept a
:param child_callable: A callable that implements the analysis. Must accept a
`data` argument that is the state inherited from its ancestors
analysis.
"""
analysis = Analysis(
child_func,
cache_dir=self._cache_dir,
_trace=self._trace
)
analysis = Analysis(child_callable, self._cache_dir, self._trace)

self._child_analyses.append(analysis)

Expand Down Expand Up @@ -174,31 +186,28 @@ def run(self, refresh=False, _parent_cache=None):
if not os.path.exists(self._cache_dir):
os.makedirs(self._cache_dir)

do_not_cache = getattr(self._func, 'never_cache', False)
do_not_cache = getattr(self._callable, 'never_cache', False)

if refresh:
logger.info('Refreshing: {}'.format(self._name))

if refresh is True:
print('Refreshing: %s' % self._name)
elif do_not_cache:
refresh = True

print('Never cached: %s' % self._name)
logger.info('Never cached: {}'.format(self._name))
elif not self._cache.check():
refresh = True

print('Stale cache: %s' % self._name)
logger.info('Stale cache: {}'.format(self._name))

if refresh:
if _parent_cache:
local_data = _parent_cache.get()
else:
local_data = {}
local_data = _parent_cache.get() if _parent_cache else {}

self._func(local_data)
self._callable(local_data)

if not do_not_cache:
self._cache.set(local_data)
else:
print('Deferring to cache: %s' % self._name)
logger.info('Deferring to cache: {}'.format(self._name))

for analysis in self._child_analyses:
analysis.run(refresh=refresh, _parent_cache=self._cache)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_proof.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf8 -*-
# -*- coding: utf-8 -*-

from copy import deepcopy
from glob import glob
Expand Down Expand Up @@ -30,7 +30,7 @@ def setUp(self):
self.executed_stage_never_cache = 0

def tearDown(self):
shutil.rmtree(TEST_CACHE)
shutil.rmtree(TEST_CACHE, ignore_errors=True)

def stage1(self, data):
self.executed_stage1 += 1
Expand Down