Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extensible serializers support #326

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ The following configuration values exist for Flask-Caching:
``CACHE_DEFAULT_TIMEOUT`` The default timeout that is used if no
timeout is specified. Unit of time is
seconds.
``CACHE_SERIALIZER`` Pickle-like serialization implementation.
It should support load(-s) and dump(-s)
methods and binary strings/files. May be
object, import string or predefined
implementation name (``"json"`` or
``"pickle"``). Defaults to "pickle", but
pickle module is not secure (CVE-2021-33026).
Consider using another serializer (eg. JSON).
``CACHE_SERIALIZER_ERROR`` Deserialization error. May be object,
import string or predefined error name
(``"JSONError"`` or ``"PickleError"``).
Defaults to ``"PickleError"``.
``CACHE_IGNORE_ERRORS`` If set to any errors that occurred during the
deletion process will be ignored. However, if
it is set to ``False`` it will stop on the
Expand Down
54 changes: 52 additions & 2 deletions src/flask_caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,33 @@ def make_template_fragment_key(fragment_name: str, vary_on: List[str] = None) ->
return TEMPLATE_FRAGMENT_KEY_TEMPLATE % (fragment_name, "_".join(vary_on))


def load_module(
module: Union[str, Any],
lookup_obj: Optional[Any] = None,
return_back: bool = False
) -> Any:
"""Dynamic module loading.

:param module: Module name, import string or object
:param lookup_obj: Try to import `module` from `lookup_obj`
:param return_back: Return `module` value if `module` is not string
:returns: Loaded module
:raises ImportError: When module load is not possible
"""
if isinstance(module, str):
if "." in module:
return import_string(module)
elif lookup_obj is not None:
try:
return getattr(lookup_obj, module)
except AttributeError:
pass
elif return_back:
return module

raise ImportError("Could not load %s" % module)


class Cache:
"""This class is used to control the cache objects."""

Expand Down Expand Up @@ -201,6 +228,8 @@ def init_app(self, app: Flask, config=None) -> None:
config.setdefault("CACHE_TYPE", "null")
config.setdefault("CACHE_NO_NULL_WARNING", False)
config.setdefault("CACHE_SOURCE_CHECK", False)
config.setdefault("CACHE_SERIALIZER", "pickle")
config.setdefault("CACHE_SERIALIZER_ERROR", "PickleError")

if config["CACHE_TYPE"] == "null" and not config["CACHE_NO_NULL_WARNING"]:
warnings.warn(
Expand Down Expand Up @@ -236,8 +265,23 @@ def _set_cache(self, app: Flask, config) -> None:
plain_name_used = False

cache_factory = import_string(import_me)

from . import serialization

cache_args = config["CACHE_ARGS"][:]
cache_options = {"default_timeout": config["CACHE_DEFAULT_TIMEOUT"]}
cache_options = {
"default_timeout": config["CACHE_DEFAULT_TIMEOUT"],
"serializer_impl": load_module(
config["CACHE_SERIALIZER"],
lookup_obj=serialization,
return_back=True
),
"serializer_error": load_module(
config["CACHE_SERIALIZER_ERROR"],
lookup_obj=serialization,
return_back=True
)
}

if isinstance(cache_factory, type) and issubclass(cache_factory, BaseCache):
cache_factory = cache_factory.factory
Expand Down Expand Up @@ -313,7 +357,7 @@ def unlink(self, *args, **kwargs) -> bool:

def cached(
self,
timeout: Optional[int] = None,
timeout: Optional[int]=None,
key_prefix: str = "view/%s",
unless: Optional[Callable] = None,
forced_update: Optional[Callable] = None,
Expand All @@ -323,6 +367,7 @@ def cached(
cache_none: bool = False,
make_cache_key: Optional[Callable] = None,
source_check: Optional[bool] = None,
force_tuple: bool = True,
) -> Callable:
"""Decorator. Use this to cache a function. By default the cache key
is `view/request.path`. You are able to use this decorator with any
Expand Down Expand Up @@ -423,6 +468,8 @@ def get_list():
formed with the function's source code hash in
addition to other parameters that may be included
in the formation of the key.
:param force_tuple: Default True. Cast output from list to tuple.
JSON doesn't support tuple, but Flask expects it.
"""

def decorator(f):
Expand Down Expand Up @@ -471,6 +518,9 @@ def decorated_function(*args, **kwargs):
found = False
else:
found = self.cache.has(cache_key)
elif force_tuple and isinstance(rv, list) and len(rv) == 2:
# JSON compatibility for flask
rv = tuple(rv)
except Exception:
if self.app.debug:
raise
Expand Down
56 changes: 54 additions & 2 deletions src/flask_caching/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,34 @@
"""
from cachelib import BaseCache as CachelibBaseCache

from flask_caching.serialization import pickle, PickleError
import warnings

def iteritems_wrapper(mappingorseq):
"""Wrapper for efficient iteration over mappings represented by dicts
or sequences::

>>> for k, v in iteritems_wrapper((i, i*i) for i in xrange(5)):
... assert k*k == v

>>> for k, v in iteritems_wrapper(dict((i, i*i) for i in xrange(5))):
... assert k*k == v

"""
if hasattr(mappingorseq, "items"):
return mappingorseq.items()
return mappingorseq


def extract_serializer_args(data):
result = dict()
serializer_prefix = "serializer_"
for key in tuple(data.keys()):
if key.startswith(serializer_prefix):
result[key] = data.pop(key)
return result



class BaseCache(CachelibBaseCache):
"""Baseclass for the cache systems. All the cache systems implement this
Expand All @@ -19,12 +47,36 @@ class BaseCache(CachelibBaseCache):
:param default_timeout: The default timeout (in seconds) that is used if
no timeout is specified on :meth:`set`. A timeout
of 0 indicates that the cache never expires.
:param serializer_impl: Pickle-like serialization implementation. It should
support load(-s) and dump(-s) methods and binary
strings/files.
:param serializer_error: Deserialization exception - for specified
implementation.
"""

def __init__(self, default_timeout=300):
CachelibBaseCache.__init__(self, default_timeout=default_timeout)
def __init__(
self,
default_timeout=300,
serializer_impl=pickle,
serializer_error=PickleError
):
CachelibBaseCache.__init__(
self,
default_timeout=default_timeout
)

self.default_timeout = default_timeout
self.ignore_errors = False

if serializer_impl is pickle:
warnings.warn(
"Pickle serializer is not secure and may "
"lead to remote code execution. "
"Consider using another serializer (eg. JSON)."
)
self._serializer = serializer_impl
self._serialization_error = serializer_error

@classmethod
def factory(cls, app, config, args, kwargs):
return cls()
Expand Down
107 changes: 94 additions & 13 deletions src/flask_caching/backends/filesystemcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import hashlib
import logging
import os
import pickle
import tempfile
from time import time

from cachelib import FileSystemCache as CachelibFileSystemCache

from flask_caching.backends.base import BaseCache
from flask_caching.backends.base import BaseCache, extract_serializer_args

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,9 +53,14 @@ def __init__(
mode=0o600,
hash_method=hashlib.md5,
ignore_errors=False,
**kwargs
):

BaseCache.__init__(self, default_timeout=default_timeout)
BaseCache.__init__(
self,
default_timeout=default_timeout,
**extract_serializer_args(kwargs)
)
CachelibFileSystemCache.__init__(
self,
cache_dir=cache_dir,
Expand Down Expand Up @@ -100,7 +105,7 @@ def _prune(self):
try:
remove = False
with open(fname, "rb") as f:
expires = pickle.load(f)
expires, _ = self._serializer.load(f)
remove = (expires != 0 and expires <= now) or idx % 3 == 0
if remove:
os.remove(fname)
Expand All @@ -117,37 +122,113 @@ def get(self, key):
filename = self._get_filename(key)
try:
with open(filename, "rb") as f:
pickle_time = pickle.load(f)
data = self._serializer.load(f)
if isinstance(data, int):
# backward compatibility
# should be removed in the next major release
pickle_time = data
result = self._serializer.load(f)
else:
pickle_time, result = data
expired = pickle_time != 0 and pickle_time < time()
if not expired:
hit_or_miss = "hit"
result = pickle.load(f)
if expired:
result = None
self.delete(key)
else:
hit_or_miss = "hit"
except FileNotFoundError:
pass
except (OSError, pickle.PickleError) as exc:
logger.error("get key %r -> %s", key, exc)
except Exception as exc:
if exc is OSError or exc is self._serialization_error:
logger.error("get key %r -> %s", key, exc)
else:
raise exc
expiredstr = "(expired)" if expired else ""
logger.debug("get key %r -> %s %s", key, hit_or_miss, expiredstr)
return result

def add(self, key, value, timeout=None):
filename = self._get_filename(key)
added = False
should_add = not os.path.exists(filename)
if should_add:
added = self.set(key, value, timeout)
addedstr = "added" if added else "not added"
logger.debug("add key %r -> %s", key, addedstr)
return should_add

def set(self, key, value, timeout=None, mgmt_element=False):
result = False

# Management elements have no timeout
if mgmt_element:
timeout = 0

# Don't prune on management element update, to avoid loop
else:
self._prune()

timeout = self._normalize_timeout(timeout)
filename = self._get_filename(key)
try:
fd, tmp = tempfile.mkstemp(
suffix=self._fs_transaction_suffix, dir=self._path
)
with os.fdopen(fd, "wb") as f:
self._serializer.dump((timeout, value), f)

# https://github.com/sh4nks/flask-caching/issues/238#issuecomment-801897606
is_new_file = not os.path.exists(filename)
if not is_new_file:
os.remove(filename)
os.replace(tmp, filename)

os.chmod(filename, self._mode)
except OSError as exc:
logger.error("set key %r -> %s", key, exc)
else:
result = True
logger.debug("set key %r", key)
# Management elements should not count towards threshold
if not mgmt_element and is_new_file:
self._update_count(delta=1)
return result

def delete(self, key, mgmt_element=False):
deleted = False
try:
os.remove(self._get_filename(key))
except FileNotFoundError:
logger.debug("delete key %r -> no such key")
except (OSError) as exc:
logger.error("delete key %r -> %s", key, exc)
else:
deleted = True
logger.debug("deleted key %r", key)
# Management elements should not count towards threshold
if not mgmt_element:
self._update_count(delta=-1)
return deleted

def has(self, key):
result = False
expired = False
filename = self._get_filename(key)
try:
with open(filename, "rb") as f:
pickle_time = pickle.load(f)
pickle_time, _ = self._serializer.load(f)
expired = pickle_time != 0 and pickle_time < time()
if expired:
self.delete(key)
else:
result = True
except FileNotFoundError:
pass
except (OSError, pickle.PickleError) as exc:
logger.error("get key %r -> %s", key, exc)
except Exception as exc:
if exc is OSError or exc is self._serialization_error:
logger.error("get key %r -> %s", key, exc)
else:
raise exc
expiredstr = "(expired)" if expired else ""
logger.debug("has key %r -> %s %s", key, result, expiredstr)
return result
Loading