Source code for queens.iterators.adaptive_sampling

#
# SPDX-License-Identifier: LGPL-3.0-or-later
# Copyright (c) 2024-2025, QUEENS contributors.
#
# This file is part of QUEENS.
#
# QUEENS is free software: you can redistribute it and/or modify it under the terms of the GNU
# Lesser General Public License as published by the Free Software Foundation, either version 3 of
# the License, or (at your option) any later version. QUEENS is distributed in the hope that it will
# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You
# should have received a copy of the GNU Lesser General Public License along with QUEENS. If not,
# see <https://www.gnu.org/licenses/>.
#
"""Adaptive sampling iterator."""

import logging
import pickle
import types

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit

from queens.iterators._iterator import Iterator
from queens.iterators.sequential_monte_carlo_chopin import SequentialMonteCarloChopin
from queens.utils.io import load_result

_logger = logging.getLogger(__name__)
jax.config.update("jax_enable_x64", True)


[docs] class AdaptiveSampling(Iterator): """Adaptive sampling iterator. Attributes: likelihood_model (Model): Likelihood model (Only Gaussian Likelihood supported) solving_iterator (Iterator): Iterator to solve inverse problem (only SequentialMonteCarloChopin works out of the box) num_new_samples (int): Number of new training samples in each adaptive step num_steps (int): Number of adaptive sampling steps seed (int, opt): Seed for random number generation restart_file (str, opt): Result file path for restarts cs_div_criterion (float): Cauchy-Schwarz divergence stopping criterion threshold x_train (np.ndarray): Training input samples x_train_new (np.ndarray): Newly drawn training samples y_train (np.ndarray): Training likelihood output samples model_outputs (np.ndarray): Training model output samples """ def __init__( self, model, parameters, global_settings, likelihood_model, initial_train_samples, solving_iterator, num_new_samples, num_steps, seed=41, restart_file=None, cs_div_criterion=0.01, ): """Initialise AdaptiveSampling. Args: model (Model): Model to be evaluated by iterator. parameters (Parameters): Parameters object. global_settings (GlobalSettings): settings of the QUEENS experiment including its name and the output directory. likelihood_model (Model): Likelihood model (Only Gaussian Likelihood supported). initial_train_samples (np.ndarray): Initial training samples for surrogate model. solving_iterator (Iterator): Iterator to solve inverse problem (only SequentialMonteCarloChopin works out of the box). num_new_samples (int): Number of new training samples in each adaptive step. num_steps (int): Number of adaptive sampling steps. seed (int, opt): Seed for random number generation. restart_file (str, opt): Result file path for restarts. cs_div_criterion (float): Cauchy-Schwarz divergence stopping criterion threshold. """ super().__init__(model, parameters, global_settings) self.seed = seed self.likelihood_model = likelihood_model self.solving_iterator = solving_iterator self.num_new_samples = num_new_samples self.num_steps = num_steps self.restart_file = restart_file self.cs_div_criterion = cs_div_criterion self.x_train_new = initial_train_samples self.x_train = np.empty((0, self.parameters.num_parameters)) self.y_train = np.empty((0, 1)) self.model_outputs = np.empty((0, self.likelihood_model.y_obs.size))
[docs] def pre_run(self): """Pre run.""" np.random.seed(self.seed) if self.restart_file: results = load_result(self.restart_file) self.x_train = results["x_train"][-1] self.model_outputs = results["model_outputs"][-1] self.y_train = results["y_train"][-1] self.x_train_new = results["x_train_new"][-1]
[docs] def core_run(self): """Core run.""" for i in range(self.num_steps): _logger.info("Step: %i / %i", i + 1, self.num_steps) self.x_train = np.concatenate([self.x_train, self.x_train_new], axis=0) self.y_train = self.eval_log_likelihood().reshape(-1, 1) _logger.info("Number of solver evaluations: %i", self.x_train.shape[0]) self.model.initialize(self.x_train, self.y_train, self.likelihood_model.y_obs.size) random_state = np.random.get_state() self.solving_iterator.pre_run() # We don't want that the random seed is set here. np.random.set_state(random_state) def _m(self_, _, xp): x_train_ml = self.x_train[np.argmax(self.y_train[:, 0])] epn = xp.shared["exponents"][-1] target = self_.current_target(epn) particles = np.lib.recfunctions.structured_to_unstructured(xp.theta) if not (particles == x_train_ml).all(-1).any(): for j, par in enumerate(xp.theta.dtype.names): xp.theta[par][0] = x_train_ml[j] target(xp) return self_.move(xp, target) if isinstance(self.solving_iterator, SequentialMonteCarloChopin): self.solving_iterator.smc_obj.fk.M = types.MethodType( _m, self.solving_iterator.smc_obj.fk ) self.solving_iterator.core_run() particles, weights, log_posterior = self.solving_iterator.get_particles_and_weights() self.x_train_new = self.choose_new_samples(particles, weights) cs_div = self.write_results(particles, weights, log_posterior, i) if cs_div < self.cs_div_criterion: break
[docs] def eval_log_likelihood(self): """Evaluate log likelihood. Returns: log_likelihood (np.ndarray): Log likelihood """ model_output = self.likelihood_model.forward_model.evaluate(self.x_train_new)["result"] self.model_outputs = np.concatenate([self.model_outputs, model_output], axis=0) if self.likelihood_model.noise_type.startswith("MAP"): self.likelihood_model.update_covariance(model_output) log_likelihood = self.likelihood_model.normal_distribution.logpdf(self.model_outputs) log_likelihood -= self.likelihood_model.normal_distribution.logpdf_const return log_likelihood
[docs] def choose_new_samples(self, particles, weights): """Choose new training samples. Choose new training samples from approximated posterior distribution. Args: particles (np.ndarray): Unique particles of approximated posterior. weights (np.ndarray): Unique non-zero particle weights of approximated posterior. Returns: x_train_new (np.ndarray): New training samples """ # Filter particles, that are present in training sample set indices = (particles[:, np.newaxis] == self.x_train).all(-1).any(-1) particles = particles[~indices] weights = weights[~indices] weights /= np.sum(weights) if len(weights) == 0: _logger.warning( "Adaptive sampling of new training samples failed. " "Drawing new training samples from prior..." ) return self.parameters.draw_samples(self.num_new_samples) num_adaptive_samples = min(len(weights), self.num_new_samples) indices = np.random.choice( np.arange(len(weights)), num_adaptive_samples, p=weights, replace=False ) x_train_new = particles[indices] if num_adaptive_samples < self.num_new_samples: num_prior_samples = self.num_new_samples - num_adaptive_samples _logger.warning( "Adaptive sampling of new training samples partly failed. " "Drawing %i new training samples from prior...", num_prior_samples, ) prior_samples = self.parameters.draw_samples(num_prior_samples) x_train_new = np.concatenate([x_train_new, prior_samples], axis=0) return x_train_new
[docs] def write_results(self, particles, weights, log_posterior, iteration): """Write results to output file and calculate cs_div. Args: particles (np.ndarray): Particles of approximated posterior weights (np.ndarray): Particle weights of approximated posterior log_posterior (np.ndarray): Log posterior value of particles iteration (int): Iteration count Returns: cs_div (float): Maximum Cauchy-Schwarz divergence between marginals of the current and previous step """ result_file = self.global_settings.result_file(".pickle") if iteration == 0 and not self.restart_file: results = { "x_train": [], "model_outputs": [], "y_train": [], "x_train_new": [], "particles": [], "weights": [], "log_posterior": [], "cs_div": [], } cs_div = np.nan else: results = load_result(result_file) particles_prev = results["particles"][-1] weights_prev = results["weights"][-1] samples_prev = particles_prev[ np.random.choice(np.arange(weights_prev.size), 5_000, p=weights_prev) ] samples_curr = particles[np.random.choice(np.arange(weights.size), 5_000, p=weights)] cs_div = float(cauchy_schwarz_divergence(samples_prev, samples_curr)) _logger.info("Cauchy Schwarz divergence: %.2e", cs_div) results["x_train"].append(self.x_train) results["model_outputs"].append(self.model_outputs) results["y_train"].append(self.y_train) results["x_train_new"].append(self.x_train_new) results["particles"].append(particles) results["weights"].append(weights) results["log_posterior"].append(log_posterior) results["cs_div"].append(cs_div) with open(result_file, "wb") as handle: pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL) return cs_div
[docs] def post_run(self): """Post run."""
[docs] @jit def cauchy_schwarz_divergence(samples_1, samples_2): """Maximum Cauchy-Schwarz divergence between marginals of two sample sets. Args: samples_1 (np.ndarray): Sample set 1 samples_2 (np.ndarray): Sample set 2 Returns: cs_div_max (np.ndarray): Maximum Cauchy-Schwarz divergence between marginals of two sample sets. """ n_1 = samples_1.shape[0] n_2 = samples_2.shape[0] factor_1 = n_1 ** (-1.0 / 5) factor_2 = n_2 ** (-1.0 / 5) var_1 = jnp.var(samples_1, axis=0) * factor_1**2 var_2 = jnp.var(samples_2, axis=0) * factor_2**2 def normalizing_factor(variance): return (2 * jnp.pi * variance) ** (-1 / 2) def normal(x_1, x_2, variance): d = x_1[:, jnp.newaxis, :] - x_2[jnp.newaxis, :, :] norm = d**2 / variance return normalizing_factor(variance), -0.5 * norm z_1_2 = normal(samples_1, samples_2, var_1 + var_2) z_1_1 = normal(samples_1, samples_1, var_1 + var_1) z_1_1 = z_1_1[0] * jnp.exp(z_1_1[1]) z_2_2 = normal(samples_2, samples_2, var_2 + var_2) z_2_2 = z_2_2[0] * jnp.exp(z_2_2[1]) max_1_2 = jnp.max(z_1_2[1], axis=(0, 1)) term_1 = ( -jnp.log(jnp.sum(1 / n_1 * 1 / n_2 * z_1_2[0] * jnp.exp(z_1_2[1] - max_1_2), axis=(0, 1))) - max_1_2 ) term_2 = 0.5 * jnp.log( 1 / n_1 * normalizing_factor(var_1) + jnp.sum( 2 / n_1**2 * z_1_1 * jnp.tri(*z_1_1.shape[:2], k=-1)[:, :, jnp.newaxis], axis=(0, 1) ) ) term_3 = 0.5 * jnp.log( 1 / n_2 * normalizing_factor(var_2) + jnp.sum( 2 / n_2**2 * z_2_2 * jnp.tri(*z_1_1.shape[:2], k=-1)[:, :, jnp.newaxis], axis=(0, 1) ) ) cs_div_max = jnp.max(term_1 + term_2 + term_3) return cs_div_max