Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pickling support for arrays, buffers #786

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
167 changes: 166 additions & 1 deletion pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
MemoryObject,
MemoryMap,
Buffer,
PooledBuffer,

_Program,
Kernel,
Expand Down Expand Up @@ -197,7 +198,7 @@
enqueue_migrate_mem_objects, unload_platform_compiler)

if get_cl_header_version() >= (2, 0):
from pyopencl._cl import SVM, SVMAllocation, SVMPointer
from pyopencl._cl import SVM, SVMAllocation, SVMPointer, PooledSVM

if _cl.have_gl():
from pyopencl._cl import ( # noqa: F401
Expand Down Expand Up @@ -2407,4 +2408,168 @@ def fsvm_empty_like(ctx, ary, alignment=None):
_KERNEL_ARG_CLASSES = (*_KERNEL_ARG_CLASSES, SVM)


# {{{ pickling support

import threading
from contextlib import contextmanager


_QUEUE_FOR_PICKLING_TLS = threading.local()


@contextmanager
def queue_for_pickling(queue, alloc=None):
r"""A context manager that, for the current thread, sets the command queue
to be used for pickling and unpickling :class:`Array`\ s and :class:`Buffer`\ s
to *queue*."""
try:
existing_pickle_queue = _QUEUE_FOR_PICKLING_TLS.queue
except AttributeError:
existing_pickle_queue = None

if existing_pickle_queue is not None:
raise RuntimeError("queue_for_pickling should not be called "
"inside the context of its own invocation.")

_QUEUE_FOR_PICKLING_TLS.queue = queue
_QUEUE_FOR_PICKLING_TLS.alloc = alloc
try:
yield None
finally:
_QUEUE_FOR_PICKLING_TLS.queue = None
_QUEUE_FOR_PICKLING_TLS.alloc = None


def _get_queue_for_pickling(obj):
try:
queue = _QUEUE_FOR_PICKLING_TLS.queue
alloc = _QUEUE_FOR_PICKLING_TLS.alloc
except AttributeError:
queue = None

if queue is None:
raise RuntimeError(f"{type(obj).__name__} instances can only be pickled while "
"queue_for_pickling is active.")

return queue, alloc


def _getstate_buffer(self):
import pyopencl as cl
queue, _alloc = _get_queue_for_pickling(self)

state = {}
state["size"] = self.size
state["flags"] = self.flags

a = bytearray(self.size)
cl.enqueue_copy(queue, a, self)

state["_pickle_data"] = a

return state


def _setstate_buffer(self, state):
import pyopencl as cl
queue, _alloc = _get_queue_for_pickling(self)

size = state["size"]
flags = state["flags"]

a = state["_pickle_data"]
Buffer.__init__(self, queue.context, flags | cl.mem_flags.COPY_HOST_PTR, size, a)


Buffer.__getstate__ = _getstate_buffer
Buffer.__setstate__ = _setstate_buffer


def _getstate_pooledbuffer(self):
import pyopencl as cl
queue, _alloc = _get_queue_for_pickling(self)

state = {}
state["size"] = self.size
state["flags"] = self.flags

a = bytearray(self.size)
cl.enqueue_copy(queue, a, self)
state["_pickle_data"] = a

return state


def _setstate_pooledbuffer(self, state):
_queue, _alloc = _get_queue_for_pickling(self)

_size = state["size"]
_flags = state["flags"]

_a = state["_pickle_data"]
# FIXME: Unclear what to do here - PooledBuffer does not have __init__


PooledBuffer.__getstate__ = _getstate_pooledbuffer
PooledBuffer.__setstate__ = _setstate_pooledbuffer


if get_cl_header_version() >= (2, 0):
def _getstate_svmallocation(self):
import pyopencl as cl

state = {}
state["size"] = self.size

queue, _alloc = _get_queue_for_pickling(self)

a = bytearray(self.size)
cl.enqueue_copy(queue, a, self)

state["_pickle_data"] = a

return state

def _setstate_svmallocation(self, state):
import pyopencl as cl

queue, _alloc = _get_queue_for_pickling(self)

size = state["size"]

a = state["_pickle_data"]
SVMAllocation.__init__(self, queue.context, size, alignment=0, flags=0,
queue=queue)
cl.enqueue_copy(queue, self, a)

SVMAllocation.__getstate__ = _getstate_svmallocation
SVMAllocation.__setstate__ = _setstate_svmallocation

def _getstate_pooled_svm(self):
import pyopencl as cl

state = {}
state["size"] = self.size

queue, _alloc = _get_queue_for_pickling(self)

a = bytearray(self.size)
cl.enqueue_copy(queue, a, self)

state["_pickle_data"] = a

return state

def _setstate_pooled_svm(self, state):
_queue, _alloc = _get_queue_for_pickling(self)
_size = state["size"]
_data = state["_pickle_data"]

# FIXME: Unclear what to do here - PooledSVM does not have __init__

PooledSVM.__getstate__ = _getstate_pooled_svm
PooledSVM.__setstate__ = _setstate_pooled_svm

# }}}

# vim: foldmethod=marker
43 changes: 42 additions & 1 deletion pyopencl/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class DoubleDowncastWarning(UserWarning):


_DOUBLE_DOWNCAST_WARNING = (
"The operation you requested would result in a double-precisision "
"The operation you requested would result in a double-precision "
"quantity according to numpy semantics. Since your device does not "
"support double precision, a single-precision quantity is being returned.")

Expand Down Expand Up @@ -705,6 +705,47 @@ def __init__(
"than expected, potentially leading to crashes.",
InconsistentOpenCLQueueWarning, stacklevel=2)

# {{{ Pickling

def __getstate__(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be useful if this worked for subclasses (liked TaggedCLArray), too. Should be tested, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of fdb3525 ?

try:
queue = cl._QUEUE_FOR_PICKLING_TLS.queue
except AttributeError:
queue = None

if queue is None:
raise RuntimeError("CL Array instances can only be pickled while "
"cl.queue_for_pickling is active.")

state = self.__dict__.copy()

del state["allocator"]
del state["context"]
del state["events"]
del state["queue"]
return state

def __setstate__(self, state):
try:
queue = cl._QUEUE_FOR_PICKLING_TLS.queue
alloc = cl._QUEUE_FOR_PICKLING_TLS.alloc
except AttributeError:
queue = None
alloc = None

if queue is None:
raise RuntimeError("CL Array instances can only be unpickled while "
"cl.queue_for_pickling is active.")

self.__dict__.update(state)

self.allocator = alloc
self.context = queue.context
self.events = []
self.queue = queue

# }}}

@property
def ndim(self):
return len(self.shape)
Expand Down
90 changes: 90 additions & 0 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,96 @@ def test_xdg_cache_home(ctx_factory):
# }}}


# {{{ test pickling

from pytools.tag import Taggable


class TaggableCLArray(cl_array.Array, Taggable):
def __init__(self, cq, shape, dtype, tags):
super().__init__(cq=cq, shape=shape, dtype=dtype)
self.tags = tags


@pytest.mark.parametrize("use_mempool", [False, True])
def test_array_pickling(ctx_factory, use_mempool):
context = ctx_factory()
queue = cl.CommandQueue(context)

if use_mempool:
alloc = cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue))
else:
alloc = None

a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
a_gpu = cl_array.to_device(queue, a, allocator=alloc)

import pickle
with pytest.raises(RuntimeError):
pickle.dumps(a_gpu)

with cl.queue_for_pickling(queue):
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))
assert np.all(a_gpu_pickled.get() == a)

# {{{ subclass test

a_gpu_tagged = TaggableCLArray(queue, a.shape, a.dtype, tags={"foo", "bar"})
a_gpu_tagged.set(a)

with cl.queue_for_pickling(queue):
a_gpu_tagged_pickled = pickle.loads(pickle.dumps(a_gpu_tagged))

assert np.all(a_gpu_tagged_pickled.get() == a)
assert a_gpu_tagged_pickled.tags == a_gpu_tagged.tags

# }}}

# {{{ SVM test

from pyopencl.characterize import has_coarse_grain_buffer_svm

if has_coarse_grain_buffer_svm(queue.device):
from pyopencl.tools import SVMAllocator, SVMPool

alloc = SVMAllocator(context, alignment=0, queue=queue)
if use_mempool:
alloc = SVMPool(alloc)

a_dev = cl_array.to_device(queue, a, allocator=alloc)

with cl.queue_for_pickling(queue, alloc):
a_dev_pickled = pickle.loads(pickle.dumps(a_dev))

assert np.all(a_dev_pickled.get() == a)
assert a_dev_pickled.allocator is alloc

# }}}


def test_buffer_pickling(ctx_factory):
context = ctx_factory()
queue = cl.CommandQueue(context)

a = np.array([1, 2, 3, 4, 5]).astype(np.float32)
a_gpu = cl.Buffer(context, cl.mem_flags.READ_WRITE, a.nbytes)
cl.enqueue_copy(queue, a_gpu, a)

import pickle

with pytest.raises(cl.RuntimeError):
pickle.dumps(a_gpu)

with cl.queue_for_pickling(queue):
a_gpu_pickled = pickle.loads(pickle.dumps(a_gpu))

a_new = np.empty_like(a)
cl.enqueue_copy(queue, a_new, a_gpu_pickled)
assert np.all(a_new == a)

# }}}


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
Loading