Source code for ewoksjob.worker.slurm

"""Pool that redirects tasks to a Slurm cluster."""

import atexit
import datetime
import logging
import weakref
from functools import wraps
from typing import Any
from typing import Callable
from typing import Optional

try:
    import gevent
    from gevent import GreenletExit
except ImportError:
    # Avoid error one import. Do cause error when actually trying to use this pool.
    gevent = NotImplemented
    GreenletExit = NotImplemented

from celery import signals
from celery.concurrency.gevent import TaskPool as _TaskPool

try:
    from pyslurmutils.concurrent.futures import SlurmRestExecutor
    from pyslurmutils.concurrent.futures import SlurmRestFuture
except ImportError:
    SlurmRestExecutor = None
    SlurmRestFuture = Any

from .executor import ExecuteType
from .executor import set_execute_getter

__all__ = ("TaskPool",)

logger = logging.getLogger(__name__)


[docs] class TaskPool(_TaskPool): """SLURM Task Pool.""" EXECUTOR_OPTIONS = dict() SLURM_SHUTDOWN_TIMEOUT = 60.0 # seconds def __init__(self, *args, **kwargs): if SlurmRestExecutor is None: raise RuntimeError("requires pyslurmutils") super().__init__(*args, **kwargs) self._slurm_executor = None self._slurm_cleanup_task = None signals.worker_shutdown.connect( self._blocking_wait_for_slurm_cleanup, weak=False ) atexit.register(self._blocking_wait_for_slurm_cleanup) self._create_slurm_executor()
[docs] def restart(self): self._safe_remove_slurm_executor() self._create_slurm_executor()
[docs] def on_stop(self): self._safe_remove_slurm_executor() super().on_stop()
def _safe_remove_slurm_executor(self): """ Initiate cleanup. If we're NOT in a gevent hub callback, block until the executor is fully cleaned up. If we ARE in a hub callback, only kick off cleanup and return immediately; final waiting happens in worker_shutdown/atexit. """ self._start_slurm_cleanup() self._wait_for_slurm_cleanup(timeout=self.SLURM_SHUTDOWN_TIMEOUT) def _wait_for_slurm_cleanup(self, timeout=None): """ Wait until the cleanup thread signals completion. Safe to call outside gevent hub callbacks. If called inside a hub callback, it just returns. """ if self._is_in_gevent_callback(): return task = self._slurm_cleanup_task if task: task.join(timeout=timeout) if task: logger.warning( "Timed out waiting for SLURM executor cleanup (%.1fs).", timeout or -1, ) def _blocking_wait_for_slurm_cleanup(self, **_): """ Runs on Celery's worker_shutdown signal and at process exit. Not a gevent hub callback → can block safely. """ self._wait_for_slurm_cleanup(timeout=self.SLURM_SHUTDOWN_TIMEOUT) def _is_in_gevent_callback(self): if gevent is None: return False hub = gevent.get_hub() return gevent.getcurrent() is hub def _create_slurm_executor(self): maxtasksperchild = self.options["maxtasksperchild"] if maxtasksperchild is None: logger.warning( "The 'slurm' pool does not support Slurm jobs which execute an unlimited number of celery jobs. " "Use '--max-tasks-per-child=1' to remove this warning." ) maxtasksperchild = 1 kwargs = { "max_workers": self.limit, "max_tasks_per_worker": maxtasksperchild, **self.EXECUTOR_OPTIONS, } self._slurm_executor = SlurmRestExecutor(**kwargs) self._slurm_executor._celery_options = dict(self.options) _set_slurm_executor(self._slurm_executor) def _start_slurm_cleanup(self): """ Start cleanup if not already running. Never blocks. """ # If nothing to clean or already cleaning, just ensure the event reflects state if self._slurm_executor is None: self._slurm_cleanup_task = None return # If a previous cleanup greenlet is still around, don't start another if self._slurm_cleanup_task: return # Request non-blocking shutdown self._slurm_executor.shutdown(wait=False) # Do the blocking part in a greenlet self._slurm_cleanup_task = gevent.spawn(self._blocking_cleanup_main) def _blocking_cleanup_main(self): """ Runs in a greenlet; allowed to block (e.g., Thread.join inside executor). """ if self._slurm_executor is not None: try: # __exit__ may perform blocking joins internally; that's fine here. self._slurm_executor.__exit__(None, None, None) except Exception: logger.exception("Error while cleaning up SLURM executor") finally: self._slurm_executor = None logger.debug("SLURM executor cleanup complete")
_SLURM_EXECUTOR = None def _set_slurm_executor(slurm_executor): global _SLURM_EXECUTOR _SLURM_EXECUTOR = weakref.proxy(slurm_executor) set_execute_getter(_get_execute_method) def _get_execute_method() -> ExecuteType: try: submit = _SLURM_EXECUTOR.submit except (AttributeError, ReferenceError): # TaskPool is not instantiated return timeout = _SLURM_EXECUTOR._celery_options["timeout"] soft_timeout = _SLURM_EXECUTOR._celery_options["soft_timeout"] return _slurm_execute_method(submit, timeout, soft_timeout) _SubmitType = Callable[[Callable, Any, Any], SlurmRestFuture] def _slurm_execute_method( submit: _SubmitType, timeout: Optional[float], soft_timeout: Optional[float] ) -> Callable[[_SubmitType], ExecuteType]: """Instead of executing the celery task, forward the ewoks task to Slurm.""" if timeout is None and soft_timeout is None: time_limit_sec = None elif soft_timeout is None: time_limit_sec = timeout elif timeout is None: time_limit_sec = soft_timeout + 10 else: time_limit_sec = timeout @wraps(submit) def execute(ewoks_task: Callable, *args, **kwargs): if time_limit_sec is not None: slurm_arguments = kwargs.setdefault("slurm_arguments", {}) parameters = slurm_arguments.setdefault("parameters", {}) time_limit = str(datetime.timedelta(seconds=round(time_limit_sec))) _ = parameters.setdefault("time_limit", time_limit) future = submit(ewoks_task, *args, **kwargs) try: return future.result() except GreenletExit: _ensure_cancel_job(future) raise return execute def _ensure_cancel_job(future: SlurmRestFuture) -> None: not_cancelled = True while not_cancelled: try: logger.info("Cancel Slurm job %s", future.job_id) future.abort() except GreenletExit: continue not_cancelled = False