Source code for pybnb.misc

"""
Miscellaneous utilities used for development.

Copyright by Gabriel A. Hackebeil (gabe.hackebeil@gmail.com).
"""
from types import FrameType
from typing import Union, Callable, Dict, List, Tuple, Any, Optional, IO, ContextManager
import logging
import signal
import numbers
import math


Handler = Callable[[int, FrameType], None]


def _cast_to_float_or_int(x):
    # type: (Union[int, float]) -> Union[int, float]
    """Casts a number to a float or int built-in type. Makes
    a reasonable attempt to preserve integrality."""
    if type(x) in (float, int):
        return x
    elif isinstance(x, numbers.Integral):
        x_ = int(x)
        assert x_ == x
        return x_
    else:
        x_ = float(x)
        assert x_ == x
        return x_


[docs]class MPI_InterruptHandler(object): """A context manager for temporarily assigning a handler to SIGINT and SIGUSR1, depending on the availability of these signals in the current OS.""" _sigs = [signal.SIGINT] if hasattr(signal, "SIGUSR1"): # not available on windows _sigs.append(signal.SIGUSR1) __slots__ = ("_released", "_original_handlers", "_handler", "_disable") def __init__(self, handler, disable=False): # type: (Handler, bool) -> None self._released = True self._original_handlers = [] # type: List[Tuple[int, Any]] self._handler = handler self._disable = disable def __enter__(self): # type: () -> MPI_InterruptHandler if self._disable: return self self._released = False self._original_handlers = [ (signum, signal.getsignal(signum)) for signum in self._sigs ] def handler(signum, frame): # type: (int, FrameType) -> None self._handler(signum, frame) self.release() for signum in self._sigs: signal.signal(signum, handler) return self def __exit__(self, type, value, tb): # type: (Any, Any, Any) -> None self.release() def release(self): # type: () -> None if not self._released: for signum, handler in self._original_handlers: if handler is not None: signal.signal(signum, handler) self._released = True
[docs]def metric_format(num, unit="s", digits=1, align_unit=False): # type: (float, str, int, bool) -> str """Format and scale output with metric prefixes. Example ------- >>> metric_format(0) '0.0 s' >>> metric_format(0, align_unit=True) '0.0 s ' >>> metric_format(0.002, unit='B') '2.0 mB' >>> metric_format(2001, unit='B') '2.0 KB' >>> metric_format(2001, unit='B', digits=3) '2.001 KB' """ if num is None: return "<unknown>" prefix = "" if (num >= 1.0) or (num == 0.0): if num >= 1000.0: num /= 1000.0 for p in ["K", "M", "G", "T", "P", "E", "Z", "Yi"]: prefix = p if abs(num) < 1000.0: break num /= 1000.0 else: num *= 1000.0 for p in ["m", "u", "n", "p", "f"]: prefix = p if abs(round(num, digits + 3)) >= 1: break num *= 1000.0 if (prefix == "") and align_unit: return ("%." + str(digits) + "f %s ") % (num, unit) else: return ("%." + str(digits) + "f %s%s") % (num, prefix, unit)
[docs]def time_format(num, digits=1, align_unit=False): # type: (float, int, bool) -> str """Format and scale output according to standard time units. Example ------- >>> time_format(0) '0.0 s' >>> time_format(0, align_unit=True) '0.0 s ' >>> time_format(0.002) '2.0 ms' >>> time_format(2001) '33.4 m' >>> time_format(2001, digits=3) '33.350 m' """ if num is None: return "<unknown>" unit = "s" if (num >= 1.0) or (num == 0.0): if num >= 60.0: num /= 60.0 unit = "m" if num >= 60.0: num /= 60.0 unit = "h" if num >= 24.0: num /= 24.0 unit = "d" else: num *= 1000.0 for p in ["ms", "us", "ns", "ps", "fs"]: unit = p if abs(round(num, digits + 3)) >= 1: break num *= 1000.0 if (len(unit) == 1) and align_unit: return ("%." + str(digits) + "f %s ") % (num, unit) else: return ("%." + str(digits) + "f %s") % (num, unit)
[docs]def get_gap_labels(gap, key="gap", format="f"): # type: (float, str, str) -> Tuple[int, str, str] """Get format strings with enough size and precision to print a given gap tolerance.""" gap_length = 10 gap_digits = 2 # extract the decimal part, so we can determine the # precision we need gap = gap - math.floor(gap) if gap != 0: while (gap < (10 ** (-gap_digits + 2))) and (gap_length < 20): gap_digits += 1 if gap_length - gap_digits < 5: gap_length += 1 gap_label_str = "{" + key + ":>" + str(gap_length) + "}" gap_number_str = ( "{" + key + ":>" + str(gap_length) + "." + str(gap_digits) + format + "}" ) return gap_length, gap_label_str, gap_number_str
class _NullCM(object): """A context manager that does nothing""" def __init__(self, obj): # type: (Any) -> None self.obj = obj def __enter__(self): # type: () -> Any return self.obj def __exit__(self, *args): # type: (Any) -> None pass
[docs]def as_stream(stream, mode="w", **kwds): # type: (Union[IO, str], str, Any) -> ContextManager[IO] """A utility for handling function arguments that can be a filename or a file object. This function is meant to be used in the context of a with statement. Parameters ---------- stream : file-like object or string An existing file-like object or the name of a file to open. mode : string Assigned to the mode keyword of the built-in function ``open`` when the `stream` argument is a filename. (default: "w") **kwds Additional keywords passed to the built-in function ``open`` when the `stream` argument is a filename. Returns ------- file-like object A file-like object that can be written to. If the input argument was originally an open file, a dummy context will wrap the file object so that it will not be closed upon exit of the with block. Example ------- >>> import tempfile >>> with tempfile.NamedTemporaryFile() as f: ... # pass a file ... with as_stream(f) as g: ... assert g is f ... assert not f.closed ... f.close() ... # pass a filename ... with as_stream(f.name) as g: ... assert not g.closed ... assert g.closed """ import six if isinstance(stream, six.string_types): return open(stream, mode=mode, **kwds) else: return _NullCM(stream)
[docs]def get_default_args(func): # type: (Callable[..., Any]) -> Dict[str, Any] """Get the default arguments for a function as a dictionary mapping argument name to default value. Example ------- >>> def f(a, b=None): ... pass >>> get_default_args(f) {'b': None} """ import inspect import six if six.PY3: signature = inspect.signature(func) return { k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty } else: a = inspect.getargspec(func) if a.defaults is None: return {} return dict(zip(a.args[-len(a.defaults) :], a.defaults))
[docs]def get_keyword_docs(doc): # type: (str) -> Dict[str, Dict[str, Any]] """Parses a numpy-style docstring to summarize information in the 'Parameters' section into a dictionary.""" import re lines = doc.splitlines() for i_start, line in enumerate(lines): if (line.strip() == "Parameters") and ( lines[i_start + 1].strip() == "----------" ): i_start = i_start + 2 break else: # pragma:nocover assert False for i_stop, line in enumerate(lines[i_start:], i_start): if i_stop <= i_start + 1: continue if line.strip() == "": break else: # pragma:nocover assert False args = {} choices = {} i = i_start last = None while i != i_stop: if i == i_start: assert re.match(r".+ : .+(, optional)?", lines[i]) if re.match(r".+ : .+(, optional)?", lines[i]): if i != i_start: assert last is not None args[last] = args[last].strip() last = lines[i].split(" : ")[0].strip() if re.match(r".+ : \{.+\}(, optional)?", lines[i]): assert lines[i].count("{") == 1 assert lines[i].count("}") == 1 opts = lines[i].split("{")[1].split("}")[0] choices[last] = [eval(c.strip()) for c in opts.split(",")] args[last] = "" else: assert last is not None args[last] += lines[i].strip() + " " i += 1 assert last is not None args[last] = args[last].strip() data = {} # type: Dict[str, Dict[str, Any]] for key, val in args.items(): data[key] = {"doc": val} default_ = re.search(r"\(default: .*\)", val) if default_ is not None: split_ = default_.group(0)[1:-1].split(": ") assert len(split_) == 2 data[key]["default"] = eval(split_[1]) key_doc = re.split(r"\(default: .*\)", val) assert len(key_doc) == 2 assert key_doc[1].strip() == "" data[key]["doc"] = key_doc[0].strip() if key in choices: data[key]["choices"] = choices[key] return data
class _simple_stdout_filter(logging.Filter): def filter(self, record): # type: (logging.LogRecord) -> bool # only show WARNING or below return bool(record.levelno <= logging.WARNING) class _simple_stderr_filter(logging.Filter): def filter(self, record): # type: (logging.LogRecord) -> bool # only show ERROR or above return bool(record.levelno >= logging.ERROR)
[docs]def get_simple_logger( name="<local>", filename=None, stream=None, console=True, level=logging.INFO, formatter=None, ): # type: (str, Optional[str], Any, bool, int, Optional[logging.Formatter]) -> logging.Logger """Creates a logging object configured to write to any combination of a file, a stream, and the console, or hide all output. Parameters ---------- name : string The name assigned to the ``logging.Logger`` instance. (default: "<local>") filename : string, optional The name of a file to write to. (default: None) stream : file-like object, optional A file-like object to write to. (default: None) console : bool, optional If True, the logger will be configured to print output to the console through stdout and stderr. (default: True) level : int, optional The logging level to use. (default: ``logging.INFO``) formatter: ``logging.Formatter``, optional The logging formatter to use. (default: None) Returns ------- ``logging.Logger`` A logging object """ log = logging.Logger(name, level=level) if filename is not None: # create file handler which logs even debug messages fh = logging.FileHandler(filename) fh.setLevel(level) log.addHandler(fh) if stream is not None: ch = logging.StreamHandler(stream) ch.setLevel(level) log.addHandler(ch) if console: import sys cout = logging.StreamHandler(sys.stdout) cout.setLevel(level) cout.addFilter(_simple_stdout_filter()) log.addHandler(cout) cerr = logging.StreamHandler(sys.stderr) cerr.setLevel(level) cerr.addFilter(_simple_stderr_filter()) log.addHandler(cerr) if formatter is not None: for h in log.handlers: h.setFormatter(formatter) if (filename is None) and (stream is None) and (not console): log.disabled = True return log
def _run_command_line_solver(problem, args): import pybnb if args.nested_solver: problem = pybnb.futures.NestedSolver( problem, node_limit=args.nested_node_limit, time_limit=args.nested_time_limit, queue_limit=args.nested_queue_limit, track_bound=args.nested_track_bound, queue_strategy=args.nested_queue_strategy, ) else: nested_solver_defaults = get_default_args(pybnb.futures.NestedSolver.__init__) if ( args.nested_node_limit != nested_solver_defaults["node_limit"] ): # pragma:nocover logging.getLogger("pybnb").warning( "The user-specified --nested-node-limit " "setting will be ignored. Did you forget the " "--nested-solver flag?" ) if ( args.nested_time_limit != nested_solver_defaults["time_limit"] ): # pragma:nocover logging.getLogger("pybnb").warning( "The user-specified --nested-time-limit " "setting will be ignored. Did you forget the " "--nested-solver flag?" ) if ( args.nested_queue_limit != nested_solver_defaults["queue_limit"] ): # pragma:nocover logging.getLogger("pybnb").warning( "The user-specified --nested-queue-limit " "setting will be ignored. Did you forget the " "--nested-solver flag?" ) if ( args.nested_track_bound != nested_solver_defaults["track_bound"] ): # pragma:nocover logging.getLogger("pybnb").warning( "The user-specified --nested-disable-track-bound " "setting will be ignored. Did you forget the " "--nested-solver flag?" ) if ( args.nested_queue_strategy != nested_solver_defaults["queue_strategy"] ): # pragma:nocover logging.getLogger("pybnb").warning( "The user-specified --nested-queue-strategy " "setting will be ignored. Did you forget the " "--nested-solver flag?" ) solve_kwds = dict(vars(args)) del solve_kwds["disable_mpi"] del solve_kwds["profile"] del solve_kwds["nested_solver"] del solve_kwds["nested_node_limit"] del solve_kwds["nested_time_limit"] del solve_kwds["nested_queue_limit"] del solve_kwds["nested_track_bound"] del solve_kwds["nested_queue_strategy"] if args.disable_mpi: results = pybnb.solve(problem, comm=None, **solve_kwds) else: results = pybnb.solve(problem, **solve_kwds) return results
[docs]def create_command_line_solver(problem, parser=None): """Convert a given problem implementation to a command-line example by exposing the :func:`pybnb.solver.solve` function arguments using argparse.""" import os import tempfile # for profiling try: import cProfile as profile except ImportError: # pragma:nocover import profile try: import pstats pstats_available = True except ImportError: # pragma:nocover pstats_available = False import pybnb try: import yaml # noqa: F401 except ImportError: # pragma:nocover raise ImportError( "The PyYAML module is required to run the command-line solver." ) from pybnb.convergence_checker import _auto_queue_tolerance if parser is None: import argparse parser = argparse.ArgumentParser( description="Run parallel branch and bound", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) solver_init_defaults = get_default_args(pybnb.Solver.__init__) solver_init_docs = get_keyword_docs(pybnb.Solver.__doc__) assert set(solver_init_defaults.keys()) == set(solver_init_docs.keys()) solver_init_defaults.pop("comm") solver_init_docs.pop("comm") assert len(solver_init_defaults) == len(solver_init_docs) for key in solver_init_defaults: assert solver_init_defaults[key] == solver_init_docs[key]["default"] assert "choices" not in solver_init_docs[key] parser.add_argument( "--dispatcher-rank", type=int, default=solver_init_defaults.pop("dispatcher_rank"), help=solver_init_docs["dispatcher_rank"]["doc"], ) assert len(solver_init_defaults) == 0, str(solver_init_defaults) solve_defaults = get_default_args(pybnb.Solver.solve) solve_docs = get_keyword_docs(pybnb.Solver.solve.__doc__) solve_docs.pop("problem") assert set(solve_defaults.keys()) == set(solve_docs.keys()) solve_defaults.pop("best_node") solve_docs.pop("best_node") solve_defaults.pop("initialize_queue") solve_docs.pop("initialize_queue") solve_defaults.pop("scale_function") solve_docs.pop("scale_function") solve_defaults.pop("log") solve_docs.pop("log") assert len(solve_defaults) == len(solve_docs) for key in solve_defaults: if key == "queue_tolerance": assert "default" not in solve_docs[key] assert solve_defaults[key] is _auto_queue_tolerance else: assert solve_defaults[key] == solve_docs[key]["default"] assert "choices" not in solve_docs[key] if key == "queue_strategy": solve_docs[key]["choices"] = [v_.value for v_ in pybnb.QueueStrategy] assert "**(D)**" in solve_docs[key]["doc"] solve_docs[key]["doc"] = ( "**(D)** Sets the strategy for prioritizing " "nodes in the central dispatcher queue. Can " "also be set to a comma-separated list of " "choices to define a lexicographic sorting " "strategy." ) class _QueueStrategyJoin(argparse.Action): # pragma:nocover def __init__(self, option_strings, dest, nargs=None, **kwargs): if nargs is not None: raise ValueError("nargs not allowed") super(_QueueStrategyJoin, self).__init__(option_strings, dest, **kwargs) def __call__(self, parser, namespace, values, option_string=None): if values in solve_docs["queue_strategy"]["choices"]: namespace.queue_strategy = values else: assert "," in values vals = tuple(v.strip() for v in values.split(",") if v.strip()) assert len(vals) > 0 assert all(v in solve_docs["queue_strategy"]["choices"] for v in vals) namespace.queue_strategy = vals class _QueueStrategyChoices(object): # pragma:nocover def __contains__(self, val): if val in solve_docs["queue_strategy"]["choices"]: return True if "," not in val: return False vals = [v.strip() for v in val.split(",") if v.strip()] if len(vals) == 0: return False return all(v in solve_docs["queue_strategy"]["choices"] for v in vals) def __iter__(self): return solve_docs["queue_strategy"]["choices"].__iter__() parser.add_argument( "--best-objective", type=float, default=solve_defaults.pop("best_objective"), help=solve_docs["best_objective"]["doc"], ) tmp_ = solve_defaults.pop("disable_objective_call") assert not tmp_ del tmp_ parser.add_argument( "--disable-objective-call", default=False, action="store_true", help=solve_docs["disable_objective_call"]["doc"], ) parser.add_argument( "--queue-strategy", type=str, choices=_QueueStrategyChoices(), action=_QueueStrategyJoin, default=solve_defaults.pop("queue_strategy"), help=solve_docs["queue_strategy"]["doc"], ) parser.add_argument( "--absolute-gap", type=float, default=solve_defaults.pop("absolute_gap"), help=solve_docs["absolute_gap"]["doc"], ) parser.add_argument( "--relative-gap", type=float, default=solve_defaults.pop("relative_gap"), help=solve_docs["relative_gap"]["doc"], ) parser.add_argument( "--objective-stop", type=float, default=solve_defaults.pop("objective_stop"), help=solve_docs["objective_stop"]["doc"], ) parser.add_argument( "--bound-stop", type=float, default=solve_defaults.pop("bound_stop"), help=solve_docs["bound_stop"]["doc"], ) parser.add_argument( "--node-limit", type=int, default=solve_defaults.pop("node_limit"), help=solve_docs["node_limit"]["doc"], ) parser.add_argument( "--time-limit", type=float, default=solve_defaults.pop("time_limit"), help=solve_docs["time_limit"]["doc"], ) parser.add_argument( "--queue-limit", type=int, default=solve_defaults.pop("queue_limit"), help=solve_docs["queue_limit"]["doc"], ) val = solve_defaults.pop("track_bound") assert val parser.add_argument( "--disable-track-bound", action="store_false", dest="track_bound", default=True, help=solve_docs["track_bound"]["doc"], ) def _float_or_None(val): # pragma:nocover if val == "None": return None return float(val) parser.add_argument( "--queue-tolerance", type=_float_or_None, default=solve_defaults.pop("queue_tolerance"), help=solve_docs["queue_tolerance"]["doc"], ) parser.add_argument( "--branch-tolerance", type=_float_or_None, default=solve_defaults.pop("branch_tolerance"), help=solve_docs["branch_tolerance"]["doc"], ) parser.add_argument( "--comparison-tolerance", type=float, default=solve_defaults.pop("comparison_tolerance"), help=solve_docs["comparison_tolerance"]["doc"], ) parser.add_argument( "--log-interval-seconds", type=float, default=solve_defaults.pop("log_interval_seconds"), help=solve_docs["log_interval_seconds"]["doc"], ) val = solve_defaults.pop("log_new_incumbent") assert val parser.add_argument( "--disable-log-new-incumbent", action="store_false", dest="log_new_incumbent", default=True, help=solve_docs["log_new_incumbent"]["doc"], ) val = solve_defaults.pop("disable_signal_handlers") assert not val parser.add_argument( "--disable-signal-handlers", action="store_true", default=False, help=solve_docs["disable_signal_handlers"]["doc"], ) assert len(solve_defaults) == 0, str(solve_defaults) parser.add_argument( "--log-filename", type=str, default=None, help=("A filename to store solver output into."), ) parser.add_argument( "--results-filename", type=str, default=None, help=( "When set, saves the solver results into a " "YAML-formatted file with the given name." ), ) parser.add_argument( "--disable-mpi", default=False, action="store_true", help=( "Do not attempt to import mpi4py.MPI. " "Enabling this option is equivalent to " "creating a Solver with `comm=None`." ), ) if pstats_available: parser.add_argument( "--profile", dest="profile", type=int, default=0, help=( "Enable profiling by setting this " "option to a positive integer (the " "maximum number of functions to " "profile)." ), ) parser.add_argument( "--version", action="version", version="pybnb " + str(pybnb.__version__) ) nested_solver_defaults = get_default_args(pybnb.futures.NestedSolver.__init__) nested_solver_docs = get_keyword_docs(pybnb.futures.NestedSolver.__doc__) nested_solver_docs.pop("problem") assert len(nested_solver_defaults) == len(nested_solver_docs) parser.add_argument( "--nested-solver", action="store_true", default=False, help=( "**(W)** Wraps the problem in a " ":class:`pybnb.futures.NestedSolver` object. " "See additional --nested-solver-* options." ), ) parser.add_argument( "--nested-node-limit", type=int, default=nested_solver_defaults.pop("node_limit"), help=nested_solver_docs["node_limit"]["doc"], ) parser.add_argument( "--nested-time-limit", type=float, default=nested_solver_defaults.pop("time_limit"), help=nested_solver_docs["time_limit"]["doc"], ) parser.add_argument( "--nested-queue-limit", type=int, default=nested_solver_defaults.pop("queue_limit"), help=nested_solver_docs["queue_limit"]["doc"], ) val = nested_solver_defaults.pop("track_bound") assert val parser.add_argument( "--nested-disable-track-bound", action="store_false", dest="nested_track_bound", default=True, help=nested_solver_docs["track_bound"]["doc"], ) parser.add_argument( "--nested-queue-strategy", type=str, choices=_QueueStrategyChoices(), action=_QueueStrategyJoin, default=nested_solver_defaults.pop("queue_strategy"), help=nested_solver_docs["queue_strategy"]["doc"], ) assert len(nested_solver_defaults) == 0, str(nested_solver_defaults) args = parser.parse_args() try: import mpi4py # noqa: F401 except ImportError: # pragma:nocover if not args.disable_mpi: raise ImportError( "The mpi4py module is not " "available. To run this script " "without it, use the '--disable-mpi' " "option" ) if pstats_available and (args.profile): # pragma:nocover # # Call the main routine with profiling. # handle, tfile = tempfile.mkstemp() os.close(handle) try: profile.runctx( "_run_command_line_solver(problem, args)", globals(), locals(), tfile ) p = pstats.Stats(tfile).strip_dirs() p.sort_stats("time", "cumulative") p = p.print_stats(args.profile) p.print_callers(args.profile) p.print_callees(args.profile) p = p.sort_stats("cumulative", "calls") p.print_stats(args.profile) p.print_callers(args.profile) p.print_callees(args.profile) p = p.sort_stats("calls") p.print_stats(args.profile) p.print_callers(args.profile) p.print_callees(args.profile) finally: os.remove(tfile) else: _run_command_line_solver(problem, args)