Source code for dessn.framework.fitter

import logging
import os
import pickle
import socket
from collections import OrderedDict

import numpy as np
import sys

import shutil

from dessn.utility.doJob import write_jobscript_slurm


[docs]class Fitter(object): def __init__(self, temp_dir): self.models = [] self.simulations = [] self.num_cosmologies = 30 self.num_walkers = 10 self.num_cpu = self.num_cosmologies * self.num_walkers self.logger = logging.getLogger(__name__) self.temp_dir = temp_dir self.max_steps = 3000 if not os.path.exists(temp_dir): os.makedirs(temp_dir)
[docs] def set_models(self, *models): self.models = models return self
[docs] def set_max_steps(self, max_steps): self.max_steps = max_steps
[docs] def set_simulations(self, *simulations): self.simulations = simulations return self
[docs] def set_num_cosmologies(self, num_cosmologies): self.num_cosmologies = num_cosmologies return self
[docs] def set_num_cpu(self, num_cpu=None): if num_cpu is None: self.num_cpu = self.num_cosmologies * self.num_walkers else: self.num_cpu = num_cpu
[docs] def set_num_walkers(self, num_walkers): self.num_walkers = num_walkers return self
[docs] def get_num_jobs(self): num_jobs = len(self.models) * len(self.simulations) * self.num_cosmologies * self.num_walkers return num_jobs
[docs] def get_indexes_from_index(self, index): num_simulations = len(self.simulations) num_cosmo = self.num_cosmologies num_walkers = self.num_walkers num_per_model_sim = num_cosmo * num_walkers num_per_model = num_simulations * num_per_model_sim model_index = index // num_per_model index -= model_index * num_per_model sim_index = index // num_per_model_sim index -= sim_index * num_per_model_sim cosmo_index = index // num_walkers walker_index = index % num_walkers return model_index, sim_index, cosmo_index, walker_index
[docs] def run_fit(self, model_index, simulation_index, cosmo_index, walker_index, num_cores=1): model = self.models[model_index] sim = self.simulations[simulation_index] out_file = self.temp_dir + "/stan_%d_%d_%d_%d.pkl" % (model_index, simulation_index, cosmo_index, walker_index) if num_cores == 1: w, n = 1000, self.max_steps else: w, n = 500, 1000 data = model.get_data(sim, cosmo_index) self.logger.info("Running Stan job, saving to %s" % out_file) import pystan sm = pystan.StanModel(file=model.get_stan_file(), model_name="Cosmology") fit = sm.sampling(data=data, iter=n, warmup=w, chains=num_cores, init=model.get_init_wrapped(**data)) self.logger.info("Stan finished sampling") # Get parameters params = [p for p in model.get_parameters() if p in fit.sim["pars_oi"]] print("SAVING parameters:") print(params) if "weight" in fit.sim["pars_oi"]: self.logger.debug("Found weight to save") params.append("weight") if "posterior" in fit.sim["pars_oi"]: self.logger.debug("Found posterior to save") params.append("posterior") dictionary = fit.extract(pars=params) # Turn log scale parameters into normal scale to see them easier for key in list(dictionary.keys()): if key.find("log_") == 0: dictionary[key[4:]] = np.exp(dictionary[key]) del dictionary[key] # Correct the chains if there is a weight function dictionary = model.correct_chain(dictionary, sim, data) with open(out_file, 'wb') as output: pickle.dump(dictionary, output) self.logger.info("Saved chain to %s" % out_file)
[docs] def is_laptop(self): return "science" in socket.gethostname()
[docs] def fit(self, file): num_jobs = self.get_num_jobs() num_models = len(self.models) num_simulations = len(self.simulations) self.logger.info("With %d models, %d simulations, %d cosmologies and %d walkers, have %d jobs" % (num_models, num_simulations, self.num_cosmologies, self.num_walkers, num_jobs)) if self.is_laptop(): self.logger.info("Running Stan locally with 4 cores.") self.run_fit(0, 0, 0, 0, num_cores=4) else: if len(sys.argv) == 1: h = socket.gethostname() partition = "regular" if "edison" in h else "smp" if os.path.exists(self.temp_dir): self.logger.info("Deleting %s" % self.temp_dir) shutil.rmtree(self.temp_dir) filename = write_jobscript_slurm(file, name=os.path.basename(file), num_tasks=self.get_num_jobs(), num_cpu=self.num_cpu, delete=True, partition=partition) self.logger.info("Running batch job at %s" % filename) os.system("sbatch %s" % filename) else: index = int(sys.argv[1]) mi, si, ci, wi = self.get_indexes_from_index(index) self.logger.info("Running model %d, sim %d, cosmology %d, walker number %d" % (mi, si, ci, wi)) self.run_fit(mi, si, ci, wi)
[docs] def load_file(self, filename): with open(filename, 'rb') as output: chain = pickle.load(output) self.logger.debug("Loaded pickle from %s" % filename) return chain
[docs] def get_result_from_chain(self, chain, simulation_index, model_index, cosmo_index, convert_names=True, max_deviation=2.5): sims = self.simulations[simulation_index] if not type(sims) == list: sims = [sims] truth_list = [s.get_truth_values_dict() for s in sims] truth = {k: [t[k] for t in truth_list] for k in truth_list[0].keys()} for k in truth: if isinstance(truth[k][0], np.ndarray): truth[k] = np.concatenate([a.flatten() for a in truth[k]]) mapping = self.models[model_index].get_labels() stan_weight = chain.get("weight") # if stan_weight is not None: # stan_weight -= np.mean(stan_weight) new_weight = chain.get("new_weight") if new_weight is not None: new_weight -= max_deviation * np.std(new_weight) new_weight[new_weight > 0] = 0 new_weight = np.exp(new_weight) posterior = chain.get("posterior") parameters = list(mapping.keys()) if convert_names: truth = {mapping[k]: truth.get(k) for k in mapping if k in truth.keys()} temp_list = [] for p in parameters: try: vals = chain.get(p) if vals is None: continue label = mapping.get(p) if convert_names else p if r"%d" in label: if len(vals.shape) > 2: vals = vals.reshape((vals.shape[0], -1)) num_d = 1 if len(vals.shape) < 2 else vals.shape[1] for i in range(num_d): if len(vals.shape) < 2: temp_list.append([mapping[p] % i, vals]) else: temp_list.append([mapping[p] % i, vals[:, i]]) if truth.get(mapping[p]) is not None: if len(truth[mapping[p]]) <= i: self.logger.warning("Truth values don't line up for %s %d" % (p, i)) truth[mapping[p] % i] = 0 else: truth[mapping[p] % i] = truth[mapping[p]][i] if truth.get(mapping[p]) is not None: del truth[mapping[p]] else: try: truth[mapping[p]] = truth[mapping[p]][0] except KeyError: pass if convert_names: temp_list.append([mapping[p], vals]) else: temp_list.append([p, vals]) except KeyError: self.logger.warning("Key error on %s" % p) _, sys_labels = self.models[model_index].get_systematic_labels(self.simulations[simulation_index]) for i in range(len(temp_list)): label, res = temp_list[i] if "\delta \mathcal{Z}" in label: n = int(label.split("\delta \mathcal{Z}_{")[1].split("}")[0]) if n >= len(sys_labels): temp_list[i][0] = '$\\delta [ Unknown ]$' else: temp_list[i][0] = sys_labels[n] result = OrderedDict(temp_list) return self.models[model_index], self.simulations[simulation_index], cosmo_index, result, truth, new_weight, stan_weight, posterior
[docs] def load(self, split_models=True, split_sims=True, split_cosmo=False, convert_names=True, max_deviation=2.5, squeeze=True): files = [f for f in os.listdir(self.temp_dir) if f.endswith(".pkl")] files.sort(key=lambda s: list(map(int, s.replace(".", "_").split('_')[1:-1]))) print(files) filenames = [self.temp_dir + "/" + f for f in files] model_indexes = [int(f.split("_")[1]) for f in files] sim_indexes = [int(f.split("_")[2]) for f in files] cosmo_indexes = [int(f.split("_")[3]) for f in files] chains = [self.load_file(f) for f in filenames] results = [] prev_model, prev_sim, prev_cosmo = 0, 0, 0 stacked = None for c, mi, si, ci in zip(chains, model_indexes, sim_indexes, cosmo_indexes): if (prev_cosmo != ci and split_cosmo) or (prev_model != mi and split_models) or (prev_sim != si and split_sims): if stacked is not None: results.append(self.get_result_from_chain(stacked, prev_sim, prev_model, prev_cosmo, convert_names=convert_names, max_deviation=max_deviation)) stacked = None prev_model = mi prev_sim = si prev_cosmo = ci if stacked is None: stacked = c else: for key in list(c.keys()): stacked[key] = np.concatenate((stacked[key], c[key])) results.append(self.get_result_from_chain(stacked, si, mi, ci, convert_names=convert_names, max_deviation=max_deviation)) if squeeze and len(results) == 1: return results[0] return results