Source code for utils

"""
Utility functions for basic functionality of the py:module:`cluster_generator` package.
"""
import functools
import logging
import multiprocessing
import os
import pathlib as pt
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import yaml
from more_itertools import always_iterable
from numpy.random import RandomState
from scipy.integrate import quad
from unyt import kpc
from unyt import physical_constants as pc
from unyt import unyt_array, unyt_quantity

# -- configuration directory -- #
_config_directory = os.path.join(pt.Path(__file__).parents[0], "bin", "config.yaml")
_bin_directory = os.path.join(pt.Path(__file__).parents[0], "bin")


# defining the custom yaml loader for unit-ed objects
def _yaml_unit_constructor(loader: yaml.FullLoader, node: yaml.nodes.MappingNode):
    kw = loader.construct_mapping(node)
    i_s = kw["input_scalar"]
    del kw["input_scalar"]
    return unyt_array(i_s, **kw)


def _yaml_lambda_loader(loader: yaml.FullLoader, node: yaml.nodes.ScalarNode):
    return eval(loader.construct_scalar(node))


def _get_loader():
    loader = yaml.FullLoader
    loader.add_constructor("!unyt", _yaml_unit_constructor)
    loader.add_constructor("!lambda", _yaml_lambda_loader)
    return loader


try:
    with open(_config_directory, "r+") as config_file:
        cgparams = yaml.load(config_file, _get_loader())

except FileNotFoundError as er:
    raise FileNotFoundError(
        f"Couldn't find the configuration file! Is it at {_config_directory}? Error = {er.__repr__()}"
    )
except yaml.YAMLError as er:
    raise yaml.YAMLError(
        f"The configuration file is corrupted! Error = {er.__repr__()}"
    )


stream = (
    sys.stdout
    if cgparams["system"]["logging"]["main"]["stream"] in ["STDOUT", "stdout"]
    else sys.stderr
)
cgLogger = logging.getLogger("cluster_generator")

cg_sh = logging.StreamHandler(stream=stream)

# create formatter and add it to the handlers
formatter = logging.Formatter(cgparams["system"]["logging"]["main"]["format"])
cg_sh.setFormatter(formatter)
# add the handler to the logger
cgLogger.addHandler(cg_sh)
cgLogger.setLevel(cgparams["system"]["logging"]["main"]["level"])
cgLogger.propagate = False

mylog = cgLogger

# -- Setting up the developer debugger -- #
devLogger = logging.getLogger("development_logger")

if cgparams["system"]["logging"]["developer"][
    "enabled"
]:  # --> We do want to use the development logger.
    # -- checking if the user has specified a directory -- #
    if cgparams["system"]["logging"]["developer"]["output_directory"] is not None:
        from datetime import datetime

        dv_fh = logging.FileHandler(
            os.path.join(
                cgparams["system"]["logging"]["developer"]["output_directory"],
                f"{datetime.now().strftime('%m-%d-%y_%H-%M-%S')}.log",
            )
        )

        # adding the formatter
        dv_formatter = logging.Formatter(
            cgparams["system"]["logging"]["main"]["format"]
        )

        dv_fh.setFormatter(dv_formatter)
        devLogger.addHandler(dv_fh)
        devLogger.setLevel("DEBUG")
        devLogger.propagate = False

    else:
        mylog.warning(
            "User enabled development logger but did not specify output directory. Dev logger will not be used."
        )
else:
    devLogger.propagate = False
    devLogger.disabled = True


[docs] class LogMute: """Context manager for muting logging output."""
[docs] def __init__(self, logger): self.logger = logger
def __enter__(self): self.logger.disabled = True def __exit__(self, exc_type, exc_val, exc_tb): self.logger.disabled = False
def _enforce_style(func): """Enforces the mpl style.""" @functools.wraps(func) def wrapper(*args, **kwargs): _rcp_copy = plt.rcParams.copy() for _k, _v in cgparams["plotting"]["defaults"].items(): plt.rcParams[_k] = _v out = func(*args, **kwargs) plt.rcParams = _rcp_copy del _rcp_copy return out return wrapper mp = (pc.mp).to("Msun") G = (pc.G).to("kpc**3/Msun/Myr**2") kboltz = (pc.kboltz).to("Msun*kpc**2/Myr**2/K") kpc_to_cm = (1.0 * kpc).to_value("cm") X_H = cgparams["physics"]["hydrogen_abundance"] mu = 1.0 / (2.0 * X_H + 0.75 * (1.0 - X_H)) mue = 1.0 / (X_H + 0.5 * (1.0 - X_H)) # -- Utility functions -- # _truncator_function = lambda a, r, x: 1 / (1 + (x / r) ** a) class TimeoutException(Exception): """Exception raised when function runs out of runtime allocaiton.""" def __init__(self, msg="", func=None, max_time=None): self.msg = f"{msg} -- {str(func)} -- max_time={max_time} s" def _daemon_process_runner(*args, **kwargs): # Runs the function specified in the kwargs in a daemon process # send_end = kwargs.pop("__send_end") function = kwargs.pop("__function") try: result = function(*args, **kwargs) except Exception as e: send_end.send(e) return send_end.send(result)
[docs] def time_limit(function, max_execution_time, *args, **kwargs): """ Assert a maximal time limit on functions with potentially problematic / unbounded execution times. .. warning:: This function launches a daemon process. Parameters ---------- function: callable The function to run under the time limit. max_execution_time: float The maximum runtime in seconds. args: arguments to pass to the function. kwargs: optional keyword arguments to pass to the function. """ import time from tqdm import tqdm recv_end, send_end = multiprocessing.Pipe(False) kwargs["__send_end"] = send_end kwargs["__function"] = function tqdm_kwargs = {} for key in ["desc"]: if key in kwargs: tqdm_kwargs[key] = kwargs.pop(key) N = 1000 p = multiprocessing.Process(target=_daemon_process_runner, args=args, kwargs=kwargs) p.start() for _ in tqdm( range(N), **tqdm_kwargs, bar_format="{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining} - {postfix}]", colour="green", leave=False, ): time.sleep(max_execution_time / 1000) if not p.is_alive(): p.join() result = recv_end.recv() break if p.is_alive(): p.terminate() p.join() raise TimeoutException( "Failed to complete process within time limit.", func=function, max_time=max_execution_time, ) else: p.join() result = recv_end.recv() if isinstance(result, Exception): raise result else: return result
[docs] def integrate_mass(profile, rr): """Integrates over a profile with spherical volume element""" mass_int = lambda r: profile(r) * r * r mass = np.zeros(rr.shape) for i, r in enumerate(rr): mass[i] = 4.0 * np.pi * quad(mass_int, 0, r)[0] return mass
[docs] def integrate(profile, rr): """Integrate over the radii""" ret = np.zeros(rr.shape) rmax = rr[-1] with warnings.catch_warnings(record=True) as w: for i, r in enumerate(rr): ret[i] = quad(profile, r, rmax)[0] if len(w) > 0: mylog.warning( f"Detected {len(w)} warnings during integration. Non-Physical regions may be present in your profiles." ) return ret
[docs] def integrate_toinf(profile, rr): """Integrate to infinity""" ret = np.zeros(rr.shape) rmax = rr[-1] for i, r in enumerate(rr): ret[i] = quad(profile, r, rmax)[0] ret[:] += quad(profile, rmax, np.inf, limit=100)[0] return ret
[docs] def generate_particle_radii(r, m, num_particles, r_max=None, prng=None): """Inverse sampling method to generate particle radii.""" prng = parse_prng(prng) if r_max is None: ridx = r.size else: ridx = np.searchsorted(r, r_max) mtot = m[ridx - 1] u = prng.uniform(size=num_particles) P_r = np.insert(m[:ridx], 0, 0.0) P_r /= P_r[-1] r = np.insert(r[:ridx], 0, 0.0) radius = np.interp(u, P_r, r, left=0.0, right=1.0) return radius, mtot
[docs] def ensure_ytquantity(x, default_units): """Ensures the quantity has units""" if isinstance(x, unyt_quantity): return unyt_quantity(x.v, x.units).in_units(default_units) elif isinstance(x, tuple): return unyt_quantity(x[0], x[1]).in_units(default_units) else: return unyt_quantity(x, default_units)
[docs] def ensure_ytarray(arr, units): """Ensures the array is a united array""" if not isinstance(arr, unyt_array): arr = unyt_array(arr, units) return arr.to(units)
[docs] def parse_prng(prng): """Grabs random state""" if isinstance(prng, RandomState): return prng else: return RandomState(prng)
[docs] def ensure_list(x): """Force x to be a list""" return list(always_iterable(x))
def _closest_factors(val): assert isinstance(val, int), "Value must be integer." a, b, i = 1, val, 0 while a < b: i += 1 if val % i == 0: a = i b = val // a return (a, b)