# This module contains modified code from Dask which is
# licensed under BSD 3-Clause License for the following holder:
# Copyright (c) 2014, Anaconda, Inc. and contributors.
from __future__ import annotations
import multiprocessing
import multiprocessing.pool
import os
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from queue import Queue
from warnings import warn
from dask import config
from dask.callbacks import local_callbacks, unpack_callbacks
from dask.core import _execute_task, flatten, get_dependencies
from dask.local import (MultiprocessingPoolExecutor, batch_execute_tasks,
default_get_id, default_pack_exception, finish_task,
identity, nested_get, queue_get, start_state_from_dask)
from dask.multiprocessing import (_dumps, _loads, _process_get_id, get_context,
initialize_worker_process, pack_exception,
reraise)
from dask.optimization import cull, fuse
from dask.order import order
from dask.system import CPU_COUNT
from dask.utils import ensure_dict
from feste import context
[docs]def get_async(submit, num_workers, dsk, result, cache=None, # type: ignore
get_id=default_get_id, rerun_exceptions_locally=None,
pack_exception=default_pack_exception, raise_exception=reraise,
callbacks=None, dumps=identity, loads=identity, chunksize=None,
**kwargs):
"""This is mostly Dask's get_async with changes to introduce optimization
during execution, with batching being an example."""
chunksize = chunksize or context.get("multiprocessing.chunk_size")
queue: Queue = Queue()
if isinstance(result, list):
result_flat = set(flatten(result))
else:
result_flat = {result}
results = set(result_flat)
dsk = dict(dsk)
with local_callbacks(callbacks) as callbacks:
_, _, pretask_cbs, posttask_cbs, _ = unpack_callbacks(callbacks)
started_cbs = []
succeeded = False
# if start_state_from_dask fails, we will have something
# to pass to the final block.
state = {}
try:
for cb in callbacks:
if cb[0]:
cb[0](dsk)
started_cbs.append(cb)
keyorder = order(dsk)
state = start_state_from_dask(dsk, cache=cache, sortkey=keyorder.get)
for _, start_state, _, _, _ in callbacks:
if start_state:
start_state(dsk, state)
if rerun_exceptions_locally is None:
rerun_exceptions_locally = \
context.get("multiprocessing.rerun_exceptions_locally")
if state["waiting"] and not state["ready"]:
raise ValueError("Found no accessible jobs in dask")
def fire_tasks(chunksize: int) -> None:
"""Fire off a task to the thread pool"""
# Determine chunksize and/or number of tasks to submit
nready = len(state["ready"])
if chunksize == -1:
ntasks = nready
chunksize = -(ntasks // -num_workers)
else:
used_workers = -(len(state["running"]) // -chunksize)
avail_workers = max(num_workers - used_workers, 0)
ntasks = min(nready, chunksize * avail_workers)
# Prep all ready tasks for submission
args = []
for _ in range(ntasks):
# Get the next task to compute (most recently added)
key = state["ready"].pop()
# Notify task is running
state["running"].add(key)
for f in pretask_cbs:
f(key, dsk, state)
# Prep args to send
data = {
dep: state["cache"][dep] for dep in get_dependencies(dsk, key)
}
args.append(
(
key,
dumps((dsk[key], data)),
dumps,
loads,
get_id,
pack_exception,
)
)
# Batch submit
for i in range(-(len(args) // -chunksize)):
each_args = args[i * chunksize:(i + 1) * chunksize]
if not each_args:
break
fut = submit(batch_execute_tasks, each_args)
fut.add_done_callback(queue.put)
# Main loop, wait on tasks to finish, insert new ones
while state["waiting"] or state["ready"] or state["running"]:
fire_tasks(chunksize)
for key, res_info, failed in queue_get(queue).result():
if failed:
exc, tb = loads(res_info)
if rerun_exceptions_locally:
data = {
dep: state["cache"][dep]
for dep in get_dependencies(dsk, key)
}
task = dsk[key]
_execute_task(task, data) # Re-execute locally
else:
raise_exception(exc, tb)
res, worker_id = loads(res_info)
state["cache"][key] = res
finish_task(dsk, key, state, results, keyorder.get)
for f in posttask_cbs:
f(key, res, dsk, state, worker_id)
succeeded = True
finally:
for _, _, _, _, finish in started_cbs:
if finish:
finish(dsk, state, not succeeded)
return nested_get(result, state["cache"])
[docs]def get_multiprocessing(dsk: Mapping, keys: Sequence[Hashable] | Hashable, # type: ignore
num_workers=None, func_loads=None, func_dumps=None,
optimize_graph=True, pool=None, initializer=None,
chunksize=None, **kwargs):
chunksize = chunksize or context.get("multiprocessing.chunk_size")
pool = pool or config.get("pool", None)
initializer = initializer or config.get("multiprocessing.initializer", None)
num_workers = num_workers or context.get("multiprocessing.num_workers") or CPU_COUNT
if pool is None:
# In order to get consistent hashing in subprocesses, we need to set a
# consistent seed for the Python hash algorithm. Unfortunately, there
# is no way to specify environment variables only for the Pool
# processes, so we have to rely on environment variables being
# inherited.
if os.environ.get("PYTHONHASHSEED") in (None, "0"):
os.environ["PYTHONHASHSEED"] = "42"
mp_context = get_context()
initializer = partial(initialize_worker_process, user_initializer=initializer)
pool = ProcessPoolExecutor(
num_workers, mp_context=mp_context, initializer=initializer
)
cleanup = True
else:
if initializer is not None:
warn(
"The ``initializer`` argument is ignored when ``pool`` is provided. "
"The user should configure ``pool`` with the needed ``initializer`` "
"on creation."
)
if isinstance(pool, multiprocessing.pool.Pool):
pool = MultiprocessingPoolExecutor(pool)
cleanup = False
# Optimize Dask
dsk = ensure_dict(dsk)
dsk2, dependencies = cull(dsk, keys)
if optimize_graph:
dsk3, dependencies = fuse(dsk2, keys, dependencies)
else:
dsk3 = dsk2
# We specify marshalling functions in order to catch serialization
# errors and report them to the user.
loads = func_loads or context.get("multiprocessing.func_loads") or _loads
dumps = func_dumps or context.get("multiprocessing.func_dumps") or _dumps
# Note former versions used a multiprocessing Manager to share
# a Queue between parent and workers, but this is fragile on Windows
# (issue #1652).
try:
# Run
result = get_async(
pool.submit,
pool._max_workers,
dsk3,
keys,
get_id=_process_get_id,
dumps=dumps,
loads=loads,
pack_exception=pack_exception,
raise_exception=reraise,
chunksize=chunksize,
# rerun_exceptions_locally=False,
**kwargs,
)
finally:
if cleanup:
pool.shutdown()
return result