Source code for ewoksjob.client.local.pool

import sys
from contextlib import contextmanager
from typing import Mapping, Optional, Tuple
from uuid import uuid4
import multiprocessing
import weakref
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future as NativeFuture

try:
    from pyslurmutils.client.errors import RemoteExit
    from pyslurmutils.concurrent.futures import SlurmRestFuture
    from pyslurmutils.concurrent.futures import SlurmRestExecutor
except ImportError:
    SlurmRestExecutor = None
    SlurmRestFuture = None
    RemoteExit = None

from .futures import LocalFuture


__all__ = ["get_active_pool", "pool_context"]


_EWOKS_WORKER_POOL = None


[docs] def get_active_pool(raise_on_missing: Optional[bool] = True): if raise_on_missing and _EWOKS_WORKER_POOL is None: raise RuntimeError("No worker pool is available") return _EWOKS_WORKER_POOL
[docs] @contextmanager def pool_context(*args, **kwargs): global _EWOKS_WORKER_POOL if _EWOKS_WORKER_POOL is None: with _LocalPool(*args, **kwargs) as pool_obj: _EWOKS_WORKER_POOL = pool_obj try: yield pool_obj finally: _EWOKS_WORKER_POOL = None else: yield _EWOKS_WORKER_POOL
class _LocalPool: def __init__( self, *args, pool_type: Optional[str] = None, context: Optional[str] = None, **kwargs, ) -> None: if pool_type is None: pool_type = "process" if context is None: context = "spawn" if pool_type == "process": if context: if sys.version_info >= (3, 7): kwargs["mp_context"] = multiprocessing.get_context(context) else: multiprocessing.set_start_method(context, force=True) self._executor = ProcessPoolExecutor(*args, **kwargs) elif pool_type == "thread": self._executor = ThreadPoolExecutor(*args, **kwargs) elif pool_type == "slurm": if SlurmRestExecutor is None: raise RuntimeError("requires pyslurmutils") self._executor = SlurmRestExecutor(*args, **kwargs) else: raise ValueError(f"Unknown pool type '{pool_type}'") self._pool_type = pool_type self._tasks = weakref.WeakValueDictionary() def __enter__(self): self._executor.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): return self._executor.__exit__(exc_type, exc_val, exc_tb) def shutdown(self, **kw): return self._executor.shutdown(**kw) @property def pool_type(self): return self._pool_type def submit( self, func, uuid: Optional[str] = None, args: Optional[Tuple] = tuple(), kwargs: Optional[Mapping] = None, ) -> LocalFuture: """Like celery.send_task""" if kwargs is None: kwargs = dict() if uuid is None: uuid = str(uuid4()) native_future = self._executor.submit(func, *args, **kwargs) future = LocalFuture(uuid, native_future) self._tasks[uuid] = future return future def get_future(self, uuid: str) -> LocalFuture: future = self._tasks.get(uuid) if future is not None: return future if self.pool_type == "slurm": future = SlurmRestFuture() else: future = NativeFuture() return LocalFuture(uuid, future) def get_unfinished_uuids(self) -> list: return [uuid for uuid, future in self._tasks.items() if not future.done()]