Source code for metacells.utilities.parallel

"""
Parallel
--------

Due to the notorious GIL, using multiple Python threads is essentially useless. This leaves us with
two options for using multiple processors, which is mandatory for reasonable performance on the
large data sets we work on:

* Use multiple threads in the internal C++ implementation of some Python functions; this is done by both numpy and the
  C++ extension functions provided by this package, and works even for reasonably small sized work, such as sorting each
  of the rows of a large matrix.

* Use Python multi-processing. This is costly and works only for large sized work, such as computing metacells for
  different piles.

Each of these two approaches works tolerably well on its own, even though both are sub-optimal. The
problem starts when we want to combine them. Consider a server with 50 processors. Invoking
``corrcoef`` on a large matrix will use them all. This is great if one computes metacells for a
single pile. Suppose, however, you want to compute metacells for 50 piles, and do so using
multi-processing. Each and every of the 50 sub-processes will invoke ``corcoeff`` which will spawn
50 internal threads, resulting in the operating system seeing 2500 processes competing for the same
50 hardware processors. "This does not end well."

You would expect that, two decades after multi-core systems became available, this would have been
solved "out of the box" by the parallel frameworks (Python, OpenMP, TBB, etc.) all agreeing to
cooperate with each other. However, somehow this isn't seen as important by the people maintaining
these frameworks; in fact, most of them don't properly handle nested parallelism within their own
framework, never mind playing well with others.

So in practice, while languages built for parallelism (such as Julia and Rust) deal well with nested
parallel construct, using a mixture of older serial languages (such as Python and C++) puts us in a
swamp, and "you can't build a castel in a swamp". In our case, numpy uses some underlying parallel
threads framework, our own extensions uses OpenMP parallel threads, and we are forced to use the
Python-multi-processing framework itself on top of both, and each of these frameworks is blind to
the others.

As a crude band-aid, we force both whatever-numpy-uses and OpenMP to use a specific number of
threads. So, when we use multi-processing, we limit each sub-process to use less internal threads,
such that the total will be at most 50. This very sub-optimal, but at least it doesn't bring the
server to its knees trying to deal with a total load of 2500 processes.

A final twist on all this is that hyper-threading is (worse than) useless for heavy compute threads.
We therefore by default only use one thread per physical cores. We get the number pf physical cores
using the ``psutil`` package.

.. todo::

    Re-implement all the package in a single language more suitable for scientific computing. Julia
    is looking like a good combination of convenience and performance...
"""
import ctypes
import os
import sys
from math import ceil
from multiprocessing import Value
from multiprocessing import get_context
from threading import current_thread
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar

import psutil  # type: ignore
from threadpoolctl import threadpool_limits  # type: ignore

import metacells.utilities.documentation as utd
import metacells.utilities.logging as utl
import metacells.utilities.progress as utp
import metacells.utilities.timing as utm

if "sphinx" not in sys.argv[0]:
    import metacells.extensions as xt  # type: ignore

__all__ = [
    "is_main_process",
    "set_processors_count",
    "get_processors_count",
    "parallel_map",
]


PROCESSORS_COUNT = 0

MAIN_PROCESS_PID = os.getpid()

IS_MAIN_PROCESS: Optional[bool] = True

MAP_INDEX = 0
PROCESS_INDEX = 0

PROCESSES_COUNT = 0
NEXT_PROCESS_INDEX = Value(ctypes.c_int32, lock=True)
PARALLEL_FUNCTION: Optional[Callable[[int], Any]] = None


[docs] def is_main_process() -> bool: """ Return whether this is the main process, as opposed to a sub-process spawned by :py:func:`parallel_map`. """ return bool(IS_MAIN_PROCESS)
[docs] def set_processors_count(processors: int) -> None: """ Set the (maximal) number of processors to use in parallel. The default value of ``0`` means using all the available physical processors. Note that if hyper-threading is enabled, this would be less than (typically half of) the number of logical processors in the system. This is intentional, as there's no value - actually, negative value - in running multiple heavy computations on hyper-threads of the same physical processor. Otherwise, the value is the actual (positive) number of processors to use. Override this by setting the ``METACELLS_PROCESSORS_COUNT`` environment variable or by invoking this function from the main thread. """ assert IS_MAIN_PROCESS if processors == 0: processors = psutil.cpu_count(logical=False) assert processors > 0 global PROCESSORS_COUNT PROCESSORS_COUNT = processors threadpool_limits(limits=PROCESSORS_COUNT) xt.set_threads_count(PROCESSORS_COUNT) os.environ["OMP_NUM_THREADS"] = str(PROCESSORS_COUNT) os.environ["MKL_NUM_THREADS"] = str(PROCESSORS_COUNT)
if "sphinx" not in sys.argv[0]: set_processors_count(int(os.environ.get("METACELLS_PROCESSORS_COUNT", "0")))
[docs] def get_processors_count() -> int: """ Return the number of PROCESSORs we are allowed to use. """ assert PROCESSORS_COUNT > 0 return PROCESSORS_COUNT
T = TypeVar("T")
[docs] @utd.expand_doc() def parallel_map( function: Callable[[int], T], invocations: int, *, max_processors: int = 0, hide_from_progress_bar: bool = False, ) -> List[T]: """ Execute ``function``, in parallel, ``invocations`` times. Each invocation is given the invocation's index as its single argument. For our simple pipelines, only the main process is allowed to execute functions in parallel processes, that is, we do not support nested ``parallel_map`` calls. This uses :py:func:`get_processors_count` processes. If ``max_processors`` (default: {max_processors}) is zero, use all available processors. Otherwise, further reduces the number of processes used to at most the specified value. If this ends up using a single process, runs the function serially. Otherwise, fork new processes to execute the function invocations (using ``multiprocessing.get_context('fork').Pool.map``). The downside is that this is slow, and you need to set up **mutable** shared memory (e.g. for large results) in advance. The upside is that each of these processes starts with a shared memory copy(-on-write) of the full Python state, that is, all the inputs for the function are available "for free". If a progress bar is active at the time of invoking ``parallel_map``, and ``hide_from_progress_bar`` is not set, then it is assumed the parallel map will cover all the current (slice of) the progress bar, and it is reported into it in increments of ``1/invocations``. .. todo:: It is currently only possible to invoke :py:func:`parallel_map` from the main application thread (that is, it does not nest). """ if invocations == 0: return [] assert function.__is_timed__ # type: ignore global IS_MAIN_PROCESS assert IS_MAIN_PROCESS global PROCESSES_COUNT PROCESSES_COUNT = min(PROCESSORS_COUNT, invocations) if max_processors != 0: assert max_processors > 0 PROCESSES_COUNT = min(PROCESSES_COUNT, max_processors) if PROCESSES_COUNT == 1: return [function(index) for index in range(invocations)] NEXT_PROCESS_INDEX.value = 0 # type: ignore global PARALLEL_FUNCTION assert PARALLEL_FUNCTION is None global MAP_INDEX MAP_INDEX += 1 num_threads = str(ceil(PROCESSES_COUNT / invocations)) os.environ["OMP_NUM_THREADS"] = num_threads os.environ["MKL_NUM_THREADS"] = num_threads PARALLEL_FUNCTION = function IS_MAIN_PROCESS = None try: results: List[Optional[T]] = [None] * invocations utm.flush_timing() with utm.timed_step("parallel_map"): utm.timed_parameters(index=MAP_INDEX, processes=PROCESSES_COUNT) with get_context("fork").Pool(PROCESSES_COUNT) as pool: for index, result in pool.imap_unordered(_invocation, range(invocations)): if utp.has_progress_bar() and not hide_from_progress_bar: utp.did_progress(1 / invocations) results[index] = result return results # type: ignore finally: IS_MAIN_PROCESS = True PARALLEL_FUNCTION = None os.environ["OMP_NUM_THREADS"] = str(PROCESSES_COUNT) os.environ["MKL_NUM_THREADS"] = str(PROCESSES_COUNT)
def _invocation(index: int) -> Tuple[int, Any]: global IS_MAIN_PROCESS if IS_MAIN_PROCESS is None: IS_MAIN_PROCESS = os.getpid() == MAIN_PROCESS_PID assert not IS_MAIN_PROCESS global PROCESS_INDEX with NEXT_PROCESS_INDEX: PROCESS_INDEX = NEXT_PROCESS_INDEX.value # type: ignore NEXT_PROCESS_INDEX.value += 1 # type: ignore current_thread().name = f"#{MAP_INDEX}.{PROCESS_INDEX}" utm.in_parallel_map(MAP_INDEX, PROCESS_INDEX) global PROCESSORS_COUNT start_processor_index = int(round(PROCESSORS_COUNT * PROCESS_INDEX / PROCESSES_COUNT)) stop_processor_index = int(round(PROCESSORS_COUNT * (PROCESS_INDEX + 1) / PROCESSES_COUNT)) PROCESSORS_COUNT = stop_processor_index - start_processor_index assert PROCESSORS_COUNT > 0 utl.logger().debug("PROCESSORS: %s", PROCESSORS_COUNT) threadpool_limits(limits=PROCESSORS_COUNT) xt.set_threads_count(PROCESSORS_COUNT) os.environ["OMP_NUM_THREADS"] = str(PROCESSORS_COUNT) os.environ["MKL_NUM_THREADS"] = str(PROCESSORS_COUNT) assert PARALLEL_FUNCTION is not None result = PARALLEL_FUNCTION(index) return index, result