Source code for pybnb.mpi_utils

"""
Various utility function for MPI.

Copyright by Gabriel A. Hackebeil (gabe.hackebeil@gmail.com).
"""
from typing import Optional, List, Any
import array

from six.moves import xrange as range

# used in various places where we are receiving an empty message,
# initialization is delayed to avoid early mpi4py import
_nothing = None  # type: Optional[List[Any]]


# avoids generating a deprecation warning in python 3.7
def _array_to_string(out):
    """converts an array of bytes to a string"""
    if hasattr(out, "tobytes"):
        # array.tobytes was added in python 3.2
        return out.tobytes().decode("utf8")
    else:
        return out.tostring().decode("utf8")


[docs]class Message(object): """A helper class for probing for and receiving messages. A single instance of this class is meant to be reused. Parameters ---------- comm : :class:`mpi4py.MPI.Comm` The MPI communicator to use. """ __slots__ = ("status", "data", "comm") def __init__(self, comm): import mpi4py.MPI self.comm = comm self.status = mpi4py.MPI.Status() self.data = None
[docs] def probe(self, **kwds): """Perform a blocking test for a message""" self.comm.Probe(status=self.status) self.data = None
[docs] def recv(self, datatype=None, data=None): """Complete the receive for the most recent message probe and return the data as a numeric array or a string, depending on the datatype keyword. Parameters ---------- datatype : {``mpi4py.MPI.DOUBLE``, ``mpi4py.MPI.CHAR``}, optional An MPI datatype used to interpret the received data. If None, ``mpi4py.MPI.DOUBLE`` will be used. (default: None) data : array.array or None, optional An existing data array to store data into. If None, one will be created. (default: None) """ assert not self.status.Get_error() if datatype is None: count = self.status.Get_count() else: count = self.status.Get_count(datatype=datatype) if count == 0: recv_nothing(self.comm, self.status) else: self.data = recv_data(self.comm, self.status, datatype=datatype, out=data)
@property def tag(self): return self.status.Get_tag() @property def source(self): return self.status.Get_source()
[docs]def recv_nothing(comm, status): """A helper function for receiving an empty message. This function is not thread safe. Parameters ---------- comm : :class:`mpi4py.MPI.Comm` An MPI communicator. status : :class:`mpi4py.MPI.Status` An MPI status object that has been populated with information about the message to be received via a probe. If None, a new status object will be created and an empty message will be expected from any source with any tag. (default: None) Returns ------- status : :class:`mpi4py.MPI.Status` If the original status argument was not None, it will be returned after being updated by the receive. Otherwise, the status object that was created will be returned. """ global _nothing import mpi4py.MPI if _nothing is None: _nothing = [array.array("B", []), mpi4py.MPI.CHAR] assert not status.Get_error() assert status.Get_count(mpi4py.MPI.CHAR) == 0 comm.Recv(_nothing, source=status.Get_source(), tag=status.Get_tag(), status=status) assert not status.Get_error() assert status.Get_count(mpi4py.MPI.CHAR) == 0 return status
[docs]def send_nothing(comm, dest, tag=0): """A helper function for sending an empty message with a given tag. This function is not thread safe. Parameters ---------- comm : :class:`mpi4py.MPI.Comm` An MPI communicator. dest : int The process rank to send the message to. tag : int, optional A valid MPI tag to use for the message. (default: 0) """ global _nothing import mpi4py.MPI if _nothing is None: _nothing = [array.array("B", []), mpi4py.MPI.CHAR] comm.Send(_nothing, dest, tag=tag)
[docs]def recv_data(comm, status, datatype, out=None): """A helper function for receiving numeric or string data sent using the lower-level buffer-based mpi4py routines. Parameters ---------- comm : :class:`mpi4py.MPI.Comm` An MPI communicator. status : :class:`mpi4py.MPI.Status` An MPI status object that has been populated with information about the message to be received via a probe. datatype : :class:`mpi4py.MPI.Datatype` An MPI datatype used to interpret the received data. If the datatype is :obj:`mpi4py.MPI.CHAR`, the received data will be converted to a string. out : buffer-like object, optional A buffer-like object that is compatible with the datatype argument and can be passed to comm.Recv. Can only be left as None when the datatype is :obj:`mpi4py.MPI.CHAR`. Returns ------- string or user-provided data buffer """ import mpi4py.MPI assert not status.Get_error() size = status.Get_count(datatype) if datatype == mpi4py.MPI.CHAR: assert out is None out = array.array("B", b"\0") * size assert (out is not None) and (len(out) >= size) comm.Recv( [out, datatype], source=status.Get_source(), tag=status.Get_tag(), status=status ) assert not status.Get_error() if datatype == mpi4py.MPI.CHAR: out = _array_to_string(out) return out
[docs]def dispatched_partition(comm, items, root=0): """A generator that partitions the list of items across processes in the communicator. If the communicator size is greater than 1, the root process will be yielded no items and instead will serve them dynamically by sending list indices to workers as work requests are received. Parameters ---------- comm : :class:`mpi4py.MPI.Comm` or None An MPI communicator or None in the serial processing case. items : list The list of items to partition. This This function assumes each process has an identical copy of the items list. Therefore, items in the list are not transferred (only indices). root : integer, optional An integer indicating which process rank should be designated as the dispatcher. (default: 0) Returns ------- string or user-provided data buffer """ assert root >= 0 N = len(items) if N > 0: if (comm is None) or (comm.size == 1): assert root == 0 for x in items: yield x else: import mpi4py.MPI # it would be pretty easy to refactor this # code to avoid this limitation assert N <= mpi4py.MPI.COMM_WORLD.Get_attr(mpi4py.MPI.TAG_UB) _null = [array.array("b", []), mpi4py.MPI.CHAR] last_tag = {} if comm.rank == root: i = 0 requests = [] for dest in range(comm.size): if dest == root: continue last_tag[dest] = i requests.append(comm.Isend(_null, dest, tag=i)) i += 1 status = mpi4py.MPI.Status() while i < N: comm.Recv(_null, status=status) last_tag[status.Get_source()] = i requests.append(comm.Isend(_null, status.Get_source(), tag=i)) i += 1 for dest in last_tag: if last_tag[dest] < N: requests.append(comm.Isend(_null, dest, tag=N)) requests.append(comm.Irecv(_null, dest)) mpi4py.MPI.Request.Waitall(requests) else: status = mpi4py.MPI.Status() comm.Recv(_null, source=root, status=status) if status.Get_tag() >= N: comm.Send(_null, root) else: while status.Get_tag() < N: yield items[status.Get_tag()] comm.Sendrecv( _null, root, recvbuf=_null, source=root, status=status )