Source code for gwpopulation_pipe.data_analysis

"""
Functions for running stochastic sampling with Bilby for pre-collected posteriors.

The module provides the `gwpopulation_pipe_analysis` executable.

In order to use many of the other functions you will need a class that provides
various attributes specified in the `gwpopulation_pipe` parser.
"""

#!/usr/bin/env python3

import inspect
import json
import os
import sys
from importlib import import_module

import matplotlib

matplotlib.use("agg")  # noqa

import dill
import numpy as np
import pandas as pd
from bilby.core.sampler import run_sampler
from bilby.core.prior import Constraint, LogUniform, ConditionalPriorDict
from bilby.core.utils import (
    infer_args_from_function_except_n_args,
    logger,
    decode_bilby_json,
)
from bilby_pipe.utils import convert_string_to_dict
from gwpopulation.backend import set_backend
from gwpopulation.conversions import convert_to_beta_parameters
from gwpopulation.hyperpe import HyperparameterLikelihood, RateLikelihood
from gwpopulation.models.mass import (
    BrokenPowerLawPeakSmoothedMassDistribution,
    BrokenPowerLawSmoothedMassDistribution,
    MultiPeakSmoothedMassDistribution,
    SinglePeakSmoothedMassDistribution,
    two_component_primary_mass_ratio,
)
from gwpopulation.models.spin import (
    iid_spin,
    iid_spin_magnitude_beta,
    iid_spin_orientation_gaussian_isotropic,
    independent_spin_magnitude_beta,
    independent_spin_orientation_gaussian_isotropic,
)
from gwpopulation.utils import to_numpy
from scipy.stats import gamma
from tqdm.auto import trange

from . import vt_helper
from .parser import create_parser as create_main_parser
from .utils import (
    get_path_or_local,
    prior_conversion,
    KNOWN_ARGUMENTS,
    MinimumEffectiveSamplesLikelihood,
)


[docs] def create_parser(): parser = create_main_parser() parser.add_argument("--prior", help="Prior file readable by bilby.") parser.add_argument( "--models", type=str, action="append", help="Model functions to evaluate, default is " "two component mass and iid spins.", ) parser.add_argument( "--vt-models", type=str, action="append", help="Model functions to evaluate for selection, default is no model", ) parser.add_argument( "--max-samples", default=1e10, type=int, help="Maximum number of posterior samples per event", ) parser.add_argument( "--rate", default=False, type=bool, help="Whether to sample in the merger rate." ) return parser
[docs] def load_prior(args): filename = get_path_or_local(args.prior_file) if filename.endswith(".json"): hyper_prior = ConditionalPriorDict.from_json(filename=filename) else: hyper_prior = ConditionalPriorDict(filename=filename) hyper_prior.conversion_function = prior_conversion if args.rate: hyper_prior["rate"] = LogUniform( minimum=1e-1, maximum=1e3, name="rate", latex_label="$R$", boundary="reflective", ) return hyper_prior
MODEL_MAP = { "two_component_primary_mass_ratio": two_component_primary_mass_ratio, "iid_spin": iid_spin, "iid_spin_magnitude": iid_spin_magnitude_beta, "ind_spin_magnitude": independent_spin_magnitude_beta, "iid_spin_orientation": iid_spin_orientation_gaussian_isotropic, "two_comp_iid_spin_orientation": iid_spin_orientation_gaussian_isotropic, "ind_spin_orientation": independent_spin_orientation_gaussian_isotropic, "SmoothedMassDistribution": SinglePeakSmoothedMassDistribution, "SinglePeakSmoothedMassDistribution": SinglePeakSmoothedMassDistribution, "BrokenPowerLawSmoothedMassDistribution": BrokenPowerLawSmoothedMassDistribution, "MultiPeakSmoothedMassDistribution": MultiPeakSmoothedMassDistribution, "BrokenPowerLawPeakSmoothedMassDistribution": BrokenPowerLawPeakSmoothedMassDistribution, } def _model_class(args): if args.cosmo: from functools import partial from gwpopulation.experimental.cosmo_models import CosmoModel cls = partial(CosmoModel, cosmo_model=args.cosmology) elif args.backend == "jax": from gwpopulation.experimental.jax import NonCachingModel cls = NonCachingModel else: from bilby.hyper.model import Model cls = Model return cls
[docs] def load_model(args): if args.models is None: args.models = dict( mass="two_component_primary_mass_ratio", mag="iid_spin_magnitude", tilt="iid_spin_orientation", redshift="gwpopulation.models.redshift.PowerLawRedshift", ) cls = _model_class(args) model = cls([_load_model(model, args) for model in args.models.values()]) return model
[docs] def load_vt(args): if args.vt_function == "" or args.vt_file == "None": return vt_helper.dummy_selection cls = _model_class(args) vt_model = cls([_load_model(model, args) for model in args.vt_models.values()]) try: vt_func = getattr(vt_helper, args.vt_function) return vt_func( args.vt_file, model=vt_model, ifar_threshold=args.vt_ifar_threshold, snr_threshold=args.vt_snr_threshold, ) except AttributeError: return vt_helper.injection_resampling_vt( vt_file=args.vt_file, model=vt_model, ifar_threshold=args.vt_ifar_threshold, snr_threshold=args.vt_snr_threshold, )
def _load_model(model, args): if model[-5:] == ".json": model = get_path_or_local(model) with open(model, "r") as ff: json_model = json.load(ff, object_hook=decode_bilby_json) try: cls = getattr(import_module(json_model["module"]), json_model["class"]) _model = cls(**json_model.get("kwargs", dict())) logger.info(f"Using {cls} from {json_model['module']}.") except KeyError: logger.error(f"Failed to load {model} from json file.") raise elif "." in model: split_model = model.split(".") module = ".".join(split_model[:-1]) function = split_model[-1] _model = getattr(import_module(module), function) logger.info(f"Using {function} from {module}.") elif model in MODEL_MAP: _model = MODEL_MAP[model] logger.info(f"Using {model}.") else: raise ValueError(f"Model {model} not found.") if inspect.isclass(_model): if "redshift" in model.lower(): kwargs = dict(z_max=args.max_redshift) elif "mass" in model.lower(): kwargs = dict(mmin=args.minimum_mass, mmax=args.maximum_mass) else: kwargs = dict() try: _model = _model(**kwargs) logger.info(f"Created {model} with arguments {kwargs}") except TypeError: logger.warning(f"Failed to instantiate {model} with arguments {kwargs}") _model = _model() return _model
[docs] def create_likelihood(args, posteriors, model, selection): if args.rate: if args.enforce_minimum_neffective_per_event: raise ValueError( "No likelihood available to enforce convergence of Monte Carlo integrals " "while sampling over rate." ) likelihood_class = RateLikelihood elif args.enforce_minimum_neffective_per_event: likelihood_class = MinimumEffectiveSamplesLikelihood else: likelihood_class = HyperparameterLikelihood selection.enforce_convergence = False likelihood = likelihood_class( posteriors, model, conversion_function=convert_to_beta_parameters, selection_function=selection, max_samples=args.max_samples, cupy=args.backend == "cupy", ) return likelihood
[docs] def get_sampler_kwargs(args): sampler_kwargs = dict(nlive=500, nact=2, walks=5) if args.sampler_kwargs == "Default": sampler_kwargs = dict() elif not isinstance(args.sampler_kwargs, dict): sampler_kwargs.update(convert_string_to_dict(args.sampler_kwargs)) else: sampler_kwargs = args.sampler_kwargs if args.sampler == "cpnest" and "seed" not in sampler_kwargs: sampler_kwargs["seed"] = np.random.randint(0, 1e6) return sampler_kwargs
[docs] def compute_rate_posterior(posterior, selection): r""" Compute the rate posterior as a post-processing step. This method is the same as described in https://dcc.ligo.org/T2000100. To get the rate at :math:`z=0` we stop after step four. The total surveyed four-volume is given as .. math:: V_{\rm tot}(\Lambda) = T_{\rm obs} \int dz \frac{1}{1+z} \frac{dVc}{dz} \psi(z|\Lambda) Note that :math:`\psi(z=0|\Lambda) = 1` The sensitive four-volume is then :math:`\mu V_{\rm tot}` where :math:`\mu` is the fraction of injections which are found. We draw samples from the gamma distribution with mean N_EVENTS + 1 These samples of this are then divided by the sensitive four-volume to give the average rate over the surveyed volume with units :math:`Gpc^{-3}yr^{-1}`. Parameters ---------- posterior: pd.DataFrame DataFrame containing the posterior samples selection: vt_helper.InjectionResamplingVT Object that computes: - the mean and variance of the survey completeness - the total surveyed 4-volume weighted by the redshift distribution """ from numpy import log10 from .utils import maybe_jit if selection == vt_helper.dummy_selection: posterior["log_10_rate"] = log10( gamma(a=vt_helper.N_EVENTS).rvs(len(posterior)) ) return else: efficiencies = list() n_effective = list() surveyed_hypervolume = list() func = maybe_jit(selection.detection_efficiency) for ii in trange(len(posterior), file=sys.stdout): parameters = dict(posterior.iloc[ii]) efficiency, variance = func(parameters) efficiencies.append(float(efficiency)) n_effective.append(float(efficiency**2 / variance)) surveyed_hypervolume.append( float(selection.surveyed_hypervolume(parameters)) ) posterior["selection"] = efficiencies posterior["pdet_n_effective"] = n_effective posterior["surveyed_hypervolume"] = surveyed_hypervolume posterior["log_10_rate"] = log10( gamma(a=int(vt_helper.N_EVENTS)).rvs(len(posterior)) / posterior["surveyed_hypervolume"] / posterior["selection"] ) posterior["rate"] = 10**posterior.log_10_rate
[docs] def fill_minimum_n_effective(posterior, likelihood): """ Compute the minimum per event n effective for each posterior sample. This is added to the posterior in place. Parameters ---------- posterior: pd.DataFrame DataFrame containing posterior distribution likelihood: gwpopulation.hyperpe.HyperparameterLikelihood The likelihood used in the analysis. Returns ------- """ if not hasattr(likelihood, "per_event_bayes_factors_and_n_effective"): logger.info( "Likelihood has no method 'per_event_bayes_factors_and_n_effective'" " skipping n_effective calculation." ) return all_n_effectives = list() for ii in trange(len(posterior), file=sys.stdout): parameters = dict(posterior.iloc[ii]) parameters, _ = likelihood.conversion_function(parameters) likelihood.parameters.update(parameters) likelihood.hyper_prior.parameters.update(parameters) _, n_effectives = likelihood.per_event_bayes_factors_and_n_effective() all_n_effectives.append(float(min(n_effectives))) posterior["min_event_n_effective"] = all_n_effectives return
[docs] def resample_single_event_posteriors(likelihood, result, save=True): """ Resample the single event posteriors to use the population-informed prior. Parameters ---------- likelihood: gwpopulation.hyperpe.HyperparameterLikelihood The likelihood object to use. result: bilby.core.result.Result The result whose posterior should be used for the reweighting. save: bool Whether to save the samples to file. If `False` the samples will be returned. Returns ------- original_samples: dict The input samples with the new prior weights in a new `weights` entry. reweighted_samples: dict The input samples resampled in place according to the new prior weights. Note that this will cause samples to be repeated. """ original_samples = likelihood.data original_samples["prior"] = likelihood.sampling_prior reweighted_samples, weights = likelihood.posterior_predictive_resample( result.posterior, return_weights=True ) original_samples["weights"] = weights original_samples = { key: to_numpy(original_samples[key]) for key in original_samples } reweighted_samples = { key: to_numpy(reweighted_samples[key]) for key in reweighted_samples } if save: with open(f"{result.outdir}/{result.label}_samples.pkl", "wb") as ff: dill.dump( dict( original=original_samples, reweighted=reweighted_samples, names=result.meta_data["event_ids"], label=result.label, ), file=ff, ) else: return original_samples, reweighted_samples
[docs] def generate_extra_statistics(posterior, likelihood): from .utils import maybe_jit all_samples = list() func = maybe_jit(likelihood.generate_extra_statistics) for ii in trange(len(posterior), file=sys.stdout): parameters = dict(posterior.iloc[ii]) updated = func(parameters) all_samples.append({key: float(updated[key]) for key in updated}) return pd.DataFrame(all_samples)
[docs] def main(): parser = create_parser() args, _ = parser.parse_known_args(sys.argv[1:]) set_backend(args.backend) os.makedirs(args.run_dir, exist_ok=True) posterior_file = os.path.join(args.run_dir, "data", f"{args.data_label}.pkl") posteriors = pd.read_pickle(posterior_file) if not args.cosmo: for ii, post in enumerate(posteriors): posteriors[ii] = post[post["redshift"] < args.max_redshift] vt_helper.N_EVENTS = len(posteriors) event_ids = list() with open( os.path.join(args.run_dir, "data", f"{args.data_label}_posterior_files.txt"), "r", ) as ff: for line in ff.readlines(): event_ids.append(line.split(":")[0]) logger.info(f"Loaded {len(posteriors)} posteriors") args.models = convert_string_to_dict( str(args.models).replace("[", "{").replace("]", "}") ) args.vt_models = convert_string_to_dict( str(args.vt_models).replace("[", "{").replace("]", "}") ) hyper_prior = load_prior(args) model = load_model(args) selection = load_vt(args) search_keys = list() ignore = ["dataset", "self", "cls"] for func in model.models: if hasattr(func, "variable_names"): param_keys = func.variable_names else: param_keys = infer_args_from_function_except_n_args(func, n=0) param_keys = set(param_keys) param_keys.update(KNOWN_ARGUMENTS.get(func, set())) for key in param_keys: if key in search_keys or key in ignore: continue search_keys.append(key) search_keys.extend(getattr(model, "cosmology_names", list())) logger.info(f"Identified keys: {', '.join(search_keys)}") for key in list(hyper_prior.keys()): if ( key not in search_keys and key != "rate" and not isinstance(hyper_prior[key], Constraint) ): del hyper_prior[key] elif (isinstance(hyper_prior[key], Constraint)) and (args.sampler == "numpyro"): del hyper_prior[key] likelihood = create_likelihood(args, posteriors, model, selection) likelihood.parameters.update(hyper_prior.sample()) likelihood.log_likelihood_ratio() if args.injection_file is not None: injections = pd.read_json(args.injection_file) injection_parameters = dict(injections.iloc[args.injection_index]) else: injection_parameters = None if args.backend == "jax" and args.sampler != "numpyro": from gwpopulation.experimental.jax import JittedLikelihood likelihood = JittedLikelihood(likelihood) result = run_sampler( likelihood=likelihood, priors=hyper_prior, label=args.label, sampler=args.sampler, outdir=os.path.join(args.run_dir, "result"), injection_parameters=injection_parameters, save="hdf5", **get_sampler_kwargs(args), ) result.meta_data["models"] = args.models result.meta_data["vt_models"] = args.vt_models result.meta_data["event_ids"] = event_ids if args.backend == "jax" and args.sampler == "numpyro": from gwpopulation.experimental.jax import JittedLikelihood likelihood = JittedLikelihood(likelihood) logger.info("Computing rate posterior") compute_rate_posterior(posterior=result.posterior, selection=selection) logger.info("Computing n effectives") fill_minimum_n_effective(posterior=result.posterior, likelihood=likelihood) logger.info("Generating extra statistics") result.posterior = generate_extra_statistics( posterior=result.posterior, likelihood=likelihood ) result.posterior = likelihood.conversion_function(result.posterior)[0] result.save_to_file(extension="hdf5", overwrite=True) logger.info("Resampling single event posteriors") model = load_model(args) selection = load_vt(args) likelihood = create_likelihood(args, posteriors, model, selection) likelihood.hyper_prior.parameters = likelihood.parameters resample_single_event_posteriors(likelihood, result, save=True)