Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to spawn workers inside daemon #1067

Open
wants to merge 21 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/orion/executor/multiprocess_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def Process(*args, **kwds):
if v.major == 3 and v.minor >= 8:
args = args[1:]

if Pool.ALLOW_DAEMON:
return Process(*args, **kwds)
if not Pool.ALLOW_DAEMON:
return PyPool.Process(*args, **kwds)

return _Process(*args, **kwds)

Expand Down Expand Up @@ -167,13 +167,18 @@ def __init__(self, n_workers=-1, backend="multiprocess", **kwargs):
if n_workers <= 0:
n_workers = multiprocessing.cpu_count()

self.pool_config = {"n_workers": n_workers, "backend": backend}
self.pool = PoolExecutor.BACKENDS.get(backend, ThreadPool)(n_workers)

def __setstate__(self, state):
self.pool = state["pool"]
log.warning("Nesting multiprocess executor")
bouthilx marked this conversation as resolved.
Show resolved Hide resolved
self.pool_config = state["pool_config"]
backend = self.pool_config.get("backend", ThreadPool)
bouthilx marked this conversation as resolved.
Show resolved Hide resolved
n_workers = self.pool_config.get("n_workers", -1)
self.pool = PoolExecutor.BACKENDS.get(backend, ThreadPool)(n_workers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit unsure about this part. If the object is serialized and passed to the subprocess, the deserialization step will have the effect or creating another pool of n_workers, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we maybe able to pass a queue instead to avoid creating multiple pools
but nesting the executor in general is a bit of a nono


def __getstate__(self):
return dict(pool=self.pool)
return {"pool_config": self.pool_config}

def __enter__(self):
return self
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/client/test_experiment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def main(*args, **kwargs):


def test_run_experiment_twice():
""""""
"""Makes sure the executor is not freed after workon"""

with create_experiment(config, base_trial) as (cfg, experiment, client):
client.workon(main, max_trials=10)
Expand Down
40 changes: 31 additions & 9 deletions tests/unittests/executor/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import time

import pytest
Expand Down Expand Up @@ -235,7 +236,7 @@ def nested(executor):
return sum(f.get() for f in futures)


@pytest.mark.parametrize("backend", [xfail_dask_if_not_installed(Dask), SingleExecutor])
@pytest.mark.parametrize("backend", backends)
def test_nested_submit(backend):
with backend(5) as executor:
futures = [executor.submit(nested, executor) for i in range(5)]
Expand All @@ -246,17 +247,38 @@ def test_nested_submit(backend):
assert r.value == 35


@pytest.mark.parametrize("backend", [multiprocess, thread])
def test_nested_submit_failure(backend):
def inc(a):
return a + 1


def nested_pool():
data = [1, 2, 3, 4, 5, 6]
with multiprocessing.Pool(5) as p:
result = p.map_async(inc, data)
result.wait()
data = result.get()

return sum(data)


@pytest.mark.parametrize("backend", backends)
def test_nested_submit_pool(backend):
if backend is Dask:
pytest.skip("Dask does not support nesting")

with backend(5) as executor:
futures = [executor.submit(nested_pool) for i in range(5)]

results = executor.async_get(futures, timeout=2)

if backend == multiprocess:
exception = NotImplementedError
elif backend == thread:
exception = TypeError
for r in results:
assert r.value == 27

with pytest.raises(exception):
[executor.submit(nested, executor) for i in range(5)]

@pytest.mark.parametrize("backend", [multiprocess, thread])
bouthilx marked this conversation as resolved.
Show resolved Hide resolved
def test_nested_submit_works(backend):
with backend(5) as executor:
[executor.submit(nested, executor) for i in range(5)]


@pytest.mark.parametrize("executor", executors)
Expand Down