LINFA Tutorial 2 - Three dimensions

This LINFA tutorial will guide you through using the most common functionalities of LINFA with a hands-on examples.

  • What is LINFA? LINFA is a library for variational inference with normalizing flow and adaptive annealing. LINFA accommodates computationally expensive models and difficult-to-sample posterior distributions with dependent parameters.

  • Why should I use LINFA? Designed as a general inference engine, LINFA allows the user to define custom input transformations, computational models, surrogates, and likelihood functions which will be discussed throughout the tutorial.

Tutorial outline

In this tutorial we will:

  1. Analyze and implement a physics-based model for a ballistic application.

  2. Generate a set of synthetic observations.

  3. Compute the model gradients using PyTorch and verify their correctness through a finite difference approximation.

  4. Perform an inference tasks with LINFA:

    • Case 1: Variational inference with the original model.

    • Case 2: Variational inference with a lightweight neural network surrogate.

After going through this tutorial, you will be able to integrate you favorite physics-based model with LINFA, to perform inference tasks.

In addition, we emphasize two special features available through LINFA:

  • Adaptively trained surrogate models (NoFAS module).

  • Adaptive annealing schedulers (AdaAnn module).

We encourage the user to take advantage of such features, especially when using physics-based models with computationally expensive evaluations.

Additional Resources

Background theory and examples for LINFA

More about LINFA library:

[1]:
! pip install linfa-vi
Requirement already satisfied: linfa-vi in /home/dschiava/.local/lib/python3.8/site-packages (1.1.10)
Requirement already satisfied: matplotlib==3.5.3 in /home/dschiava/.local/lib/python3.8/site-packages (from linfa-vi) (3.5.3)
Requirement already satisfied: torch==1.13.1 in /home/dschiava/.local/lib/python3.8/site-packages (from linfa-vi) (1.13.1)
Requirement already satisfied: tomli; python_version < "3.11" in /home/dschiava/.local/lib/python3.8/site-packages (from linfa-vi) (2.0.1)
Requirement already satisfied: numpy==1.19.5 in /home/dschiava/.local/lib/python3.8/site-packages (from linfa-vi) (1.19.5)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/lib/python3/dist-packages (from matplotlib==3.5.3->linfa-vi) (1.0.1)
Requirement already satisfied: pillow>=6.2.0 in /home/dschiava/.local/lib/python3.8/site-packages (from matplotlib==3.5.3->linfa-vi) (9.1.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/dschiava/.local/lib/python3.8/site-packages (from matplotlib==3.5.3->linfa-vi) (4.33.3)
Requirement already satisfied: pyparsing>=2.2.1 in /usr/lib/python3/dist-packages (from matplotlib==3.5.3->linfa-vi) (2.4.6)
Requirement already satisfied: packaging>=20.0 in /home/dschiava/.local/lib/python3.8/site-packages (from matplotlib==3.5.3->linfa-vi) (21.3)
Requirement already satisfied: python-dateutil>=2.7 in /usr/lib/python3/dist-packages (from matplotlib==3.5.3->linfa-vi) (2.7.3)
Requirement already satisfied: cycler>=0.10 in /usr/lib/python3/dist-packages (from matplotlib==3.5.3->linfa-vi) (0.10.0)
Requirement already satisfied: typing-extensions in /home/dschiava/.local/lib/python3.8/site-packages (from torch==1.13.1->linfa-vi) (3.10.0.2)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66; platform_system == "Linux" in /home/dschiava/.local/lib/python3.8/site-packages (from torch==1.13.1->linfa-vi) (11.10.3.66)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99; platform_system == "Linux" in /home/dschiava/.local/lib/python3.8/site-packages (from torch==1.13.1->linfa-vi) (11.7.99)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == "Linux" in /home/dschiava/.local/lib/python3.8/site-packages (from torch==1.13.1->linfa-vi) (11.7.99)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96; platform_system == "Linux" in /home/dschiava/.local/lib/python3.8/site-packages (from torch==1.13.1->linfa-vi) (8.5.0.96)
Requirement already satisfied: setuptools in /home/dschiava/.local/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66; platform_system == "Linux"->torch==1.13.1->linfa-vi) (58.1.0)
Requirement already satisfied: wheel in /home/dschiava/.local/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66; platform_system == "Linux"->torch==1.13.1->linfa-vi) (0.37.0)
[2]:
## Import libraries ##
import os
import linfa
from linfa.run_experiment import experiment
from linfa.transform import Transformation
from linfa.nofas import Surrogate
import torch
import random
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

Problem definition

  • Our physics-based model phys consists of a simple ballistic model. It computes the outputs:

    • maximum height (m) \(x_{1}\),

    • final landing location (m) of the object \(x_{2}\),

    • total time of flight (s) \(x_{3}\),

    from the inputs:

    • starting position (m) \(z_{1}\),

    • initial velocity (m/s) \(z_{2}\),

    • angle (degrees) \(z_{3}\).

The model is described by the following equations

\[x_{1} = \frac{z_{2}^{2}\,\sin^{2}(z_{3})}{2\,g},\,\, x_{2} = z_{1} + \frac{z_{2}^{2}\,\sin(2\,z_{3})}{g},\,\, x_{3} = \frac{2\,z_{2}\,\sin(z_{3})}{g}.\]

Model identifiability

By considering fixed values for the outputs \((x_{1},x_{2},x_{3}) = (\widetilde{x}_{1},\widetilde{x}_{2},\widetilde{x}_{3})\), we can perform some algebraic manipulation to investigate structural identifiability. For example, if we derive \(z_{2}\) from the equation for \(x_{3}\) and we plug it back in the equation for \(x_{1}\), we get the equation

\[g\,\frac{\widetilde{x}_{3}^{2}}{8} = \widetilde{x}_{1}^{2}.\]

The maximum height and time of flight are, as expected, related by a deterministic condition and therefore only one of these provide an independent information for the solution of the inverse problem.

Due to this dependence, the number of observables is reduced to only two, from three inputs. This results in a non-identifiable inference task. In other words, there is an infinite number of input combinations \((z_{1},z_{2},z_{3})\) corresponding to the outputs \((\widetilde{x}_{1},\widetilde{x}_{2},\widetilde{x}_{3})\).

A graphical explanation for this lack of identifiability can be is shown in the the plot below

8f7e9040a40f4c7b984bfa1ede63f955 Figure: Examples of trajectories resulting in the same landing distance and maximum height (or time of flight).

This picture shows how the final target location at \(x_{2}\) can be reached by multiple initial positions, velocities and angles. The lack of identifiability also translates in the existence of a one-dimensional manifold of inputs that correspond to the same outputs. This manifold can be determined from the following expressions in the form \(z_1(z_{3})\) and \(z_2(z_{3})\)

\[z_{1} = \widetilde{x}_{2} - \frac{g\cdot \widetilde{x}_{3}^{2}}{2}\cdot \left[\frac{\cos(z_{3})}{\sin(z_{3})}\right],\,\, z_{2} = \frac{g\cdot \widetilde{x}_{3}}{2\,\sin(z_{3})}.\]

These two curves are plotted below.

55c58e98d6ba4da08df85784f690e258 Figure: Two-dimensional projections of one-dimensional manifold where all parameters correspond to the same outputs. When performing inference we therefore expect the posterior distribution to be concentrated around such curves.

Implementation as a Python class

  • We first create a new Phys model class, having three member functions:

    • __init__ - A constructor.

    • genDataFile - A member function to create synthetic observations.

    • solve_t - A function to perform forward model evaluations.

Please refer to the comments below for additional implementation details.

[3]:
#### Implementation of the traditional trajectory motion physics problem ####
class Phys:

    ### Define constructor function for Phys class ###
    def __init__(self):
        ## Define input parameters (True value)
        # input[] = [starting_position, initial_velocity, angle] = [1(m), 5(m/s), 60(degs)]
        self.defParam = torch.Tensor([[1.0, 5.0, 60.0]])

        self.gConst = 9.81   # gravitational constant
        self.stdRatio = 0.05 # standard deviation ratio
        self.data = None     # data set of model sample

    ### Define data file generator function ###
    # dataSize (int): size of sample (data)
    # dataFileName (String): name of the sample data file
    # store (Boolean): True if user wish to store the generated data file; False otherwise.
    def genDataFile(self, dataSize = 50, dataFileName="data_phys_3d.txt", store=True):
        def_out = self.solve_t(self.defParam)[0]
        self.data = def_out + self.stdRatio * torch.abs(def_out) * torch.normal(0, 1, size=(dataSize, 3))
        self.data = self.data.t().detach().numpy()
        if store: np.savetxt(dataFileName, self.data)
        return self.data

    ### Define data file generator function ###
    # params (Tensor): input parameters storing starting position, initial velocity, and angle in corresponding order.
    def solve_t(self, params):
        z1, z2, z3 = torch.chunk(params, chunks=3, dim=1) # input parameters
        z3 = z3 * (np.pi / 180)                           # convert unit from degree to radians

        ## Output value calculation
        # ouput[] = [maximum_height, final_location, total_time]
        x = torch.cat(( (z2 * z2 * torch.sin(z3) * torch.sin(z3)) / (2.0 * self.gConst),  # x1: maxHeight
            z1 + ((z2 * z2 * torch.sin(2.0 * z3)) / self.gConst),                         # x2: finalLocation
            (2.0 * z2 * torch.sin(z3)) / self.gConst), 1)                                 # x3: totalTime
        return x

Generation of synthetic data

The genDataFile member function is designed to generate multiple synthetic outputs by adding Gaussian noise around the output corresponding to a default parameter set

\[\boldsymbol{z}^{*} = (1.0, 5.0, 60.0)\]

where the initial angle of the trajectory is measured in degrees. The following code generates 50 synthetic observations and stores them in the data_phys.txt file.

[4]:
## Generate phys sample file ##
# Define model
model = Phys()
# Generate Data
physData = model.genDataFile()

Now that we have our model set up, we go on to our second step and check the computation of its gradient.

Check for Gradient Calculation

  • Prior to applying NoFAS to our Phys model, we check if the model gradient (Jacobian actually since it has multiple outputs) is correctly computed by PyTorch.

  • Specifically, when the surrogate is not enabled, gradient calculation is completed straight through the model, so we want to ensure that this is correct before running an inference task.

  • Here we compute each gradient using (1) Pytorch and (2) a forward Euler finite difference approximation, and compare the results provided by these two approaches.

#### Computing gradients through PyTorch

We define a new class to compute gradients. The class is construced by specifying a model and a transformation and provides member functions to compute the derivatives.

[5]:
#### Implementation of gradient calculation using PyTorch - version 2 ####
class PytorchGrad:

    ### Define constructor function for PytorchGrad2 class ###
    def __init__(self, model, transform):
        # Define input parameters and enable gradient calculation
        self.z = torch.Tensor([[1.0, 5.0, 60.0]])
        self.z.requires_grad = True

        self.in_vals = transform.forward(self.z)

        self.out_val = model.solve_t(self.in_vals)
        self.out1, self.out2, self.out3 = torch.chunk(self.out_val, chunks=3, dim=1)

    # Compute gradients using backward function for y
    def back_x1(self):
        self.out1.backward()
        d1 = self.z.grad
        a, b, c = torch.chunk(d1, chunks=3, dim=1)
        return [a.item(), b.item(), c.item()]

    def back_x2(self):
        self.out2.backward()
        d2 = self.z.grad
        a, b, c = torch.chunk(d2, chunks=3, dim=1)
        return [a.item(), b.item(), c.item()]

    def back_x3(self):
        self.out3.backward()
        d3 = self.z.grad
        a, b, c = torch.chunk(d3, chunks=3, dim=1)
        return [a.item(), b.item(), c.item()]

We then use the class with the Phys model and an identity transformation as shown next.

[6]:
# Define Phys model
model = Phys()

# Set transformation information and define transforamtion
trsf_info = [['identity',0.0,0.0,0.0,0.0],
             ['identity',0,0.0,0.0,0.0],
             ['identity',0,0.0,0.0,0.0]]

transform = Transformation(trsf_info)

# List to store dx/dz values
dx_dz_pytorch = []

# Define PytorchGrad object and calculate gradient
pyGrad = PytorchGrad(model, transform)
dx_dz_pytorch.append(pyGrad.back_x1())

pyGrad = PytorchGrad(model, transform)
dx_dz_pytorch.append(pyGrad.back_x2())

pyGrad = PytorchGrad(model, transform)
dx_dz_pytorch.append(pyGrad.back_x3())

# convert to pandas DataFrame for readability
jacob_mat_2 = pd.DataFrame(dx_dz_pytorch, columns=['dz1', 'dz2', 'dz3'])
jacob_mat_2.index = ['dx1', 'dx2', 'dx3']
jacob_mat_2
[6]:
dz1 dz2 dz3
dx1 0.0 0.382263 0.019260
dx2 1.0 0.882799 -0.044478
dx3 0.0 0.176560 0.008896

Approximating gradients with finite differences

We now apply the forward Euler approximation of the gradient to verify the results above.

[7]:
### Function that manually calculates a derivative ###
def getGrad(f_eps, f, eps):
    return (f_eps - f) / (eps)

### Function that returns a list of gradients ###
def gradList(f_eps1, f_eps2, f_eps3, f, eps):
    return [getGrad(f_eps1, f, eps).item(), getGrad(f_eps2, f, eps).item(), getGrad(f_eps3, f, eps).item()]
[8]:
# List to store dx/dz values
dx_dz = []
dx1_dz = []
dx2_dz = []
dx3_dz = []

# Set up parameters
eps = 1.0
z = torch.Tensor([[1.0, 5.0, 60.0]])
z_eps1 = torch.Tensor([[1.0 + eps, 5.0, 60.0]])
z_eps2 = torch.Tensor([[1.0, 5.0 + eps, 60.0]])
z_eps3 = torch.Tensor([[1.0, 5.0, 60.0 + eps]])

x1_eps1 = model.solve_t(z_eps1)[0,0]
x1_eps2 = model.solve_t(z_eps2)[0,0]
x1_eps3 = model.solve_t(z_eps3)[0,0]
x1_eps = model.solve_t(z)[0,0]

dx1_dz = gradList(x1_eps1, x1_eps2, x1_eps3, x1_eps, eps)
dx_dz.append(dx1_dz)

x2_eps1 = model.solve_t(z_eps1)[0,1]
x2_eps2 = model.solve_t(z_eps2)[0,1]
x2_eps3 = model.solve_t(z_eps3)[0,1]
x2_eps = model.solve_t(z)[0,1]

dx2_dz = gradList(x2_eps1, x2_eps2, x2_eps3, x2_eps, eps)
dx_dz.append(dx2_dz)

x3_eps1 = model.solve_t(z_eps1)[0,2]
x3_eps2 = model.solve_t(z_eps2)[0,2]
x3_eps3 = model.solve_t(z_eps3)[0,2]
x3_eps = model.solve_t(z)[0,2]

dx3_dz = gradList(x3_eps1, x3_eps2, x3_eps3, x3_eps, eps)
dx_dz.append(dx3_dz)

# convert to pandas DataFrame for readability
jacob_mat_3 = pd.DataFrame(dx_dz, columns=['dz1', 'dz2', 'dz3'])
jacob_mat_3.index = ['dx1', 'dx2', 'dx3']
jacob_mat_3
[8]:
dz1 dz2 dz3
dx1 0.0 0.420489 0.019062
dx2 1.0 0.971078 -0.045814
dx3 0.0 0.176560 0.008761

Check the convergence of the finite difference approximation to the gradient

  • Note: if you’d like you can adjust the script below to check convergence for other components.

[9]:
## Focus: dx2_dz3

initial_eps = 15        # Initial change of value (eps)
k = 150                 # Number of iterations
dx2_dz3_list = []       # List to store results
pytorch_grad2 = -0.0445 # Pytorch gradient value

# Calculate for dx2_dz3 as eps decreases
for t in range(1, k):
    update_eps = initial_eps*(1/t)                             # updated eps value
    z_eps3 = torch.Tensor([[1.0, 5.0, 60.0 + update_eps]])     # update z_eps3
    x2_eps3 = model.solve_t(z_eps3)[0,1]                       # update x2_eps3
    dx2_dz3_list.append(getGrad(x2_eps3, x2_eps, update_eps))  # store result to dx2_dz3_list
[10]:
## Plot result to see convergence
plt.style.use('dark_background')

fig, ax = plt.subplots()
ax.plot(range(1,k), dx2_dz3_list, c = "red", linestyle = "solid", label = "FD Approximation")

plt.axhline(y = pytorch_grad2, color = 'blue', linestyle = '-', label = "Pytorch Gradient")
plt.legend(loc="lower right")
plt.title("Gradient Plot for dx2_dz3")
plt.ylabel("Gradient")
plt.xlabel("k-Iterations")
plt.show()
../../_images/content_tutorial_tutorial_linfa_3d_22_0.png

Now that we confirmed that our model successfully computes the gradients, we go on to our third step: Model Evaluation Set Up and Applications

Variational inference with full model

Definition of hyperparameters

The first step is to define all options and hyperparameters for the inference task. Additional detail for each hyperparameter can be found in the documentation or in the definition of the experiment class.

[11]:
# Experiment Setting
exp = experiment()
exp.flow_type        = 'maf'        # str: Type of flow
exp.n_blocks         = 5            # int: Number of layers
exp.hidden_size      = 100          # int: Hidden layer size for MADE in each layer
exp.n_hidden         = 1            # int: Number of hidden layers in each MADE
exp.activation_fn    = 'relu'       # str: Actication function used
exp.input_order      = 'sequential' # str: Input order for create_mask
exp.batch_norm_order = True         # boolean: Order to decide if batch_norm is used
exp.save_interval    = 5000         # int: How often to sample from normalizing flow

exp.input_size    = 3               # int: Dimensionality of input
exp.batch_size    = 250             # int: Number of samples generated
exp.true_data_num = 2               # double: number of true model evaluated
exp.n_iter        = 25001           # int: Number of iterations
exp.lr            = 0.01            # float: Learning rate
exp.lr_decay      = 0.9999          # float: Learning rate decay
exp.log_interval  = 100             # int: How often to show loss stat

exp.run_nofas          = False      # boolean: to run experiment with nofas
exp.annealing          = False      # boolean: to run experiment with annealing
exp.calibrate_interval = 1000       # int: How often to update surrogate model
exp.budget             = 260        # int: Total number of true model evaluation

exp.surr_pre_it  = 20000            # int: Number of pre-training iterations for surrogate model
exp.surr_upd_it  = 6000             # int: Number of iterations for the surrogate model update
exp.surr_folder  = "./"
exp.use_new_surr = True             # boolean: to run experiment with nofas

exp.results_file = 'results.txt'      # str: result text file name
exp.log_file     = 'log.txt'          # str: log text file name
exp.samples_file = 'samples.txt'      # str: sample text file name
exp.seed         = random.randint(0, 10 ** 9)  # int: Random seed used
exp.n_sample     = 5000               # int: Total number of iterations
exp.no_cuda      = True               # boolean: to run experiment with NO cuda

exp.optimizer    = 'RMSprop'          # str: Type of optimizer
exp.lr_scheduler = 'ExponentialLR'    # str: Type of scheduler

exp.device = torch.device('cuda:0' if torch.cuda.is_available() and not exp.no_cuda else 'cpu')

Define the transformation

Now we define the trasformation of parameters and initialize the

[12]:
# Define transformation based on normalization rate
trsf_info = [['identity',0.0,0.0,0.0,0.0],
             ['identity',0.0,0.0,0.0,0.0],
             ['linear',-3,3,40.0,140.0]]
trsf = Transformation(trsf_info)
exp.transform = trsf

Model and surrogate definition

We create an instance of the Phys model and assign None to the surrogate. Note that we have also specified exp.run_nofas = False and exp.annealing = False to switch off both the adaptive surrogate capability and annealing.

[13]:
# Define model
model = Phys()
exp.model = model

# Get data
model.data = np.loadtxt('./data_phys_3d.txt')

# Run experiment without surrogate
exp.surrogate = None

Log-likelihood definiton

[14]:
## Define log density
# x: original, untransformed inputs
# model: our model
# transform: our transformation
def log_density(x, model, surrogate, transform):

    # Compute transformation log Jacobian
    adjust = transform.compute_log_jacob_func(x)

    # Get the absolute values of the standard deviations
    stds = torch.abs(model.solve_t(model.defParam)) * model.stdRatio
    Data = torch.tensor(model.data).to(exp.device)

    # Check for surrogate
    if surrogate:
        modelOut = exp.surrogate.forward(x)
    else:
        modelOut = model.solve_t(transform.forward(x))

    # Eval LL
    ll1 = -0.5 * np.prod(model.data.shape) * np.log(2.0 * np.pi)
    ll2 = (-0.5 * model.data.shape[1] * torch.log(torch.prod(stds))).item()
    ll3 = 0.0
    for i in range(3):
        ll3 += - 0.5 * torch.sum(((modelOut[:, i].unsqueeze(1) - Data[i, :].unsqueeze(0)) / stds[0, i]) ** 2, dim=1)
    negLL = -(ll1 + ll2 + ll3)
    res = -negLL.reshape(x.size(0), 1) + adjust

    return res

Launch inference task

[15]:
## Run
print('')
print('--- TUTORIAL: Ballistic Example - Full model')

# Experiment Setting
exp.name = "phys_full_3d"             # str: Name of experiment
exp.output_dir   = './' + exp.name    # str: output directory location

# Assign logdensity
exp.model_logdensity = lambda x: log_density(x, model, exp.surrogate, trsf)

# Run VI
exp.run()

--- TUTORIAL: Ballistic Example - Full model

--- Running on device: cpu

VI NF (t=1.000): it:     100 | loss: 9.888e+03
VI NF (t=1.000): it:     200 | loss: 4.040e+03
VI NF (t=1.000): it:     300 | loss: 7.229e+02
VI NF (t=1.000): it:     400 | loss: 4.176e+02
VI NF (t=1.000): it:     500 | loss: 2.605e+02
VI NF (t=1.000): it:     600 | loss: 3.275e+02
VI NF (t=1.000): it:     700 | loss: 2.749e+02
VI NF (t=1.000): it:     800 | loss: 4.634e+02
VI NF (t=1.000): it:     900 | loss: 2.928e+02
VI NF (t=1.000): it:    1000 | loss: 1.815e+02
VI NF (t=1.000): it:    1100 | loss: 1.633e+02
VI NF (t=1.000): it:    1200 | loss: 2.868e+02
VI NF (t=1.000): it:    1300 | loss: 1.306e+02
VI NF (t=1.000): it:    1400 | loss: 2.091e+02
VI NF (t=1.000): it:    1500 | loss: 1.541e+02
VI NF (t=1.000): it:    1600 | loss: 1.281e+02
VI NF (t=1.000): it:    1700 | loss: 1.794e+02
VI NF (t=1.000): it:    1800 | loss: 1.748e+02
VI NF (t=1.000): it:    1900 | loss: 1.770e+02
VI NF (t=1.000): it:    2000 | loss: 1.271e+02
VI NF (t=1.000): it:    2100 | loss: 1.270e+02
VI NF (t=1.000): it:    2200 | loss: 8.753e+01
VI NF (t=1.000): it:    2300 | loss: 1.310e+02
VI NF (t=1.000): it:    2400 | loss: 2.871e+02
VI NF (t=1.000): it:    2500 | loss: 1.230e+02
VI NF (t=1.000): it:    2600 | loss: 1.586e+02
VI NF (t=1.000): it:    2700 | loss: 1.758e+02
VI NF (t=1.000): it:    2800 | loss: 1.233e+02
VI NF (t=1.000): it:    2900 | loss: 1.556e+02
VI NF (t=1.000): it:    3000 | loss: 1.827e+02
VI NF (t=1.000): it:    3100 | loss: 7.177e+01
VI NF (t=1.000): it:    3200 | loss: 1.189e+02
VI NF (t=1.000): it:    3300 | loss: 6.993e+01
VI NF (t=1.000): it:    3400 | loss: 6.559e+01
VI NF (t=1.000): it:    3500 | loss: 4.154e+01
VI NF (t=1.000): it:    3600 | loss: 4.389e+01
VI NF (t=1.000): it:    3700 | loss: 3.308e+01
VI NF (t=1.000): it:    3800 | loss: 3.680e+01
VI NF (t=1.000): it:    3900 | loss: 2.924e+01
VI NF (t=1.000): it:    4000 | loss: 2.567e+01
VI NF (t=1.000): it:    4100 | loss: 6.522e+01
VI NF (t=1.000): it:    4200 | loss: 1.913e+01
VI NF (t=1.000): it:    4300 | loss: 1.717e+01
VI NF (t=1.000): it:    4400 | loss: 1.692e+01
VI NF (t=1.000): it:    4500 | loss: 1.646e+01
VI NF (t=1.000): it:    4600 | loss: 1.614e+01
VI NF (t=1.000): it:    4700 | loss: 1.652e+01
VI NF (t=1.000): it:    4800 | loss: 1.680e+01
VI NF (t=1.000): it:    4900 | loss: 1.582e+01
--- Saving results at iteration 5000
VI NF (t=1.000): it:    5000 | loss: 1.567e+01
VI NF (t=1.000): it:    5100 | loss: 1.589e+01
VI NF (t=1.000): it:    5200 | loss: 1.559e+01
VI NF (t=1.000): it:    5300 | loss: 1.617e+01
VI NF (t=1.000): it:    5400 | loss: 1.537e+01
VI NF (t=1.000): it:    5500 | loss: 1.512e+01
VI NF (t=1.000): it:    5600 | loss: 1.540e+01
VI NF (t=1.000): it:    5700 | loss: 1.568e+01
VI NF (t=1.000): it:    5800 | loss: 1.521e+01
VI NF (t=1.000): it:    5900 | loss: 1.521e+01
VI NF (t=1.000): it:    6000 | loss: 1.512e+01
VI NF (t=1.000): it:    6100 | loss: 1.527e+01
VI NF (t=1.000): it:    6200 | loss: 1.507e+01
VI NF (t=1.000): it:    6300 | loss: 2.290e+01
VI NF (t=1.000): it:    6400 | loss: 1.513e+01
VI NF (t=1.000): it:    6500 | loss: 1.496e+01
VI NF (t=1.000): it:    6600 | loss: 1.546e+01
VI NF (t=1.000): it:    6700 | loss: 1.512e+01
VI NF (t=1.000): it:    6800 | loss: 1.495e+01
VI NF (t=1.000): it:    6900 | loss: 1.518e+01
VI NF (t=1.000): it:    7000 | loss: 1.477e+01
VI NF (t=1.000): it:    7100 | loss: 1.569e+01
VI NF (t=1.000): it:    7200 | loss: 1.501e+01
VI NF (t=1.000): it:    7300 | loss: 1.495e+01
VI NF (t=1.000): it:    7400 | loss: 1.473e+01
VI NF (t=1.000): it:    7500 | loss: 1.495e+01
VI NF (t=1.000): it:    7600 | loss: 1.495e+01
VI NF (t=1.000): it:    7700 | loss: 1.517e+01
VI NF (t=1.000): it:    7800 | loss: 1.489e+01
VI NF (t=1.000): it:    7900 | loss: 1.486e+01
VI NF (t=1.000): it:    8000 | loss: 1.480e+01
VI NF (t=1.000): it:    8100 | loss: 1.477e+01
VI NF (t=1.000): it:    8200 | loss: 1.480e+01
VI NF (t=1.000): it:    8300 | loss: 1.514e+01
VI NF (t=1.000): it:    8400 | loss: 1.473e+01
VI NF (t=1.000): it:    8500 | loss: 1.677e+01
VI NF (t=1.000): it:    8600 | loss: 1.479e+01
VI NF (t=1.000): it:    8700 | loss: 1.483e+01
VI NF (t=1.000): it:    8800 | loss: 1.469e+01
VI NF (t=1.000): it:    8900 | loss: 1.482e+01
VI NF (t=1.000): it:    9000 | loss: 1.462e+01
VI NF (t=1.000): it:    9100 | loss: 1.461e+01
VI NF (t=1.000): it:    9200 | loss: 1.473e+01
VI NF (t=1.000): it:    9300 | loss: 1.491e+01
VI NF (t=1.000): it:    9400 | loss: 1.454e+01
VI NF (t=1.000): it:    9500 | loss: 1.477e+01
VI NF (t=1.000): it:    9600 | loss: 1.445e+01
VI NF (t=1.000): it:    9700 | loss: 1.476e+01
VI NF (t=1.000): it:    9800 | loss: 1.480e+01
VI NF (t=1.000): it:    9900 | loss: 1.473e+01
--- Saving results at iteration 10000
VI NF (t=1.000): it:   10000 | loss: 1.469e+01
VI NF (t=1.000): it:   10100 | loss: 1.457e+01
VI NF (t=1.000): it:   10200 | loss: 1.524e+01
VI NF (t=1.000): it:   10300 | loss: 1.455e+01
VI NF (t=1.000): it:   10400 | loss: 1.484e+01
VI NF (t=1.000): it:   10500 | loss: 1.487e+01
VI NF (t=1.000): it:   10600 | loss: 1.469e+01
VI NF (t=1.000): it:   10700 | loss: 1.451e+01
VI NF (t=1.000): it:   10800 | loss: 1.464e+01
VI NF (t=1.000): it:   10900 | loss: 1.451e+01
VI NF (t=1.000): it:   11000 | loss: 1.486e+01
VI NF (t=1.000): it:   11100 | loss: 1.473e+01
VI NF (t=1.000): it:   11200 | loss: 1.453e+01
VI NF (t=1.000): it:   11300 | loss: 1.479e+01
VI NF (t=1.000): it:   11400 | loss: 1.494e+01
VI NF (t=1.000): it:   11500 | loss: 1.492e+01
VI NF (t=1.000): it:   11600 | loss: 1.470e+01
VI NF (t=1.000): it:   11700 | loss: 1.480e+01
VI NF (t=1.000): it:   11800 | loss: 1.456e+01
VI NF (t=1.000): it:   11900 | loss: 1.463e+01
VI NF (t=1.000): it:   12000 | loss: 1.452e+01
VI NF (t=1.000): it:   12100 | loss: 1.452e+01
VI NF (t=1.000): it:   12200 | loss: 1.451e+01
VI NF (t=1.000): it:   12300 | loss: 1.437e+01
VI NF (t=1.000): it:   12400 | loss: 1.461e+01
VI NF (t=1.000): it:   12500 | loss: 1.475e+01
VI NF (t=1.000): it:   12600 | loss: 1.444e+01
VI NF (t=1.000): it:   12700 | loss: 1.497e+01
VI NF (t=1.000): it:   12800 | loss: 1.448e+01
VI NF (t=1.000): it:   12900 | loss: 1.447e+01
VI NF (t=1.000): it:   13000 | loss: 1.449e+01
VI NF (t=1.000): it:   13100 | loss: 1.438e+01
VI NF (t=1.000): it:   13200 | loss: 1.470e+01
VI NF (t=1.000): it:   13300 | loss: 1.456e+01
VI NF (t=1.000): it:   13400 | loss: 1.448e+01
VI NF (t=1.000): it:   13500 | loss: 1.440e+01
VI NF (t=1.000): it:   13600 | loss: 1.488e+01
VI NF (t=1.000): it:   13700 | loss: 1.498e+01
VI NF (t=1.000): it:   13800 | loss: 1.466e+01
VI NF (t=1.000): it:   13900 | loss: 1.448e+01
VI NF (t=1.000): it:   14000 | loss: 1.476e+01
VI NF (t=1.000): it:   14100 | loss: 1.480e+01
VI NF (t=1.000): it:   14200 | loss: 1.451e+01
VI NF (t=1.000): it:   14300 | loss: 1.445e+01
VI NF (t=1.000): it:   14400 | loss: 1.443e+01
VI NF (t=1.000): it:   14500 | loss: 1.450e+01
VI NF (t=1.000): it:   14600 | loss: 1.448e+01
VI NF (t=1.000): it:   14700 | loss: 1.501e+01
VI NF (t=1.000): it:   14800 | loss: 1.516e+01
VI NF (t=1.000): it:   14900 | loss: 1.464e+01
--- Saving results at iteration 15000
VI NF (t=1.000): it:   15000 | loss: 1.438e+01
VI NF (t=1.000): it:   15100 | loss: 1.434e+01
VI NF (t=1.000): it:   15200 | loss: 1.457e+01
VI NF (t=1.000): it:   15300 | loss: 1.491e+01
VI NF (t=1.000): it:   15400 | loss: 1.447e+01
VI NF (t=1.000): it:   15500 | loss: 1.454e+01
VI NF (t=1.000): it:   15600 | loss: 1.523e+01
VI NF (t=1.000): it:   15700 | loss: 1.456e+01
VI NF (t=1.000): it:   15800 | loss: 1.454e+01
VI NF (t=1.000): it:   15900 | loss: 1.494e+01
VI NF (t=1.000): it:   16000 | loss: 1.438e+01
VI NF (t=1.000): it:   16100 | loss: 1.458e+01
VI NF (t=1.000): it:   16200 | loss: 1.423e+01
VI NF (t=1.000): it:   16300 | loss: 1.461e+01
VI NF (t=1.000): it:   16400 | loss: 1.463e+01
VI NF (t=1.000): it:   16500 | loss: 1.455e+01
VI NF (t=1.000): it:   16600 | loss: 1.455e+01
VI NF (t=1.000): it:   16700 | loss: 1.435e+01
VI NF (t=1.000): it:   16800 | loss: 1.452e+01
VI NF (t=1.000): it:   16900 | loss: 1.463e+01
VI NF (t=1.000): it:   17000 | loss: 1.444e+01
VI NF (t=1.000): it:   17100 | loss: 1.443e+01
VI NF (t=1.000): it:   17200 | loss: 1.445e+01
VI NF (t=1.000): it:   17300 | loss: 1.425e+01
VI NF (t=1.000): it:   17400 | loss: 1.432e+01
VI NF (t=1.000): it:   17500 | loss: 1.442e+01
VI NF (t=1.000): it:   17600 | loss: 1.452e+01
VI NF (t=1.000): it:   17700 | loss: 1.448e+01
VI NF (t=1.000): it:   17800 | loss: 1.456e+01
VI NF (t=1.000): it:   17900 | loss: 1.463e+01
VI NF (t=1.000): it:   18000 | loss: 1.443e+01
VI NF (t=1.000): it:   18100 | loss: 1.452e+01
VI NF (t=1.000): it:   18200 | loss: 1.435e+01
VI NF (t=1.000): it:   18300 | loss: 1.463e+01
VI NF (t=1.000): it:   18400 | loss: 1.432e+01
VI NF (t=1.000): it:   18500 | loss: 1.456e+01
VI NF (t=1.000): it:   18600 | loss: 1.468e+01
VI NF (t=1.000): it:   18700 | loss: 1.444e+01
VI NF (t=1.000): it:   18800 | loss: 1.433e+01
VI NF (t=1.000): it:   18900 | loss: 1.454e+01
VI NF (t=1.000): it:   19000 | loss: 1.463e+01
VI NF (t=1.000): it:   19100 | loss: 1.455e+01
VI NF (t=1.000): it:   19200 | loss: 1.444e+01
VI NF (t=1.000): it:   19300 | loss: 1.436e+01
VI NF (t=1.000): it:   19400 | loss: 1.450e+01
VI NF (t=1.000): it:   19500 | loss: 1.479e+01
VI NF (t=1.000): it:   19600 | loss: 1.438e+01
VI NF (t=1.000): it:   19700 | loss: 1.452e+01
VI NF (t=1.000): it:   19800 | loss: 1.516e+01
VI NF (t=1.000): it:   19900 | loss: 1.444e+01
--- Saving results at iteration 20000
VI NF (t=1.000): it:   20000 | loss: 1.429e+01
VI NF (t=1.000): it:   20100 | loss: 1.427e+01
VI NF (t=1.000): it:   20200 | loss: 1.458e+01
VI NF (t=1.000): it:   20300 | loss: 1.437e+01
VI NF (t=1.000): it:   20400 | loss: 1.501e+01
VI NF (t=1.000): it:   20500 | loss: 1.457e+01
VI NF (t=1.000): it:   20600 | loss: 1.435e+01
VI NF (t=1.000): it:   20700 | loss: 1.433e+01
VI NF (t=1.000): it:   20800 | loss: 1.477e+01
VI NF (t=1.000): it:   20900 | loss: 1.476e+01
VI NF (t=1.000): it:   21000 | loss: 1.453e+01
VI NF (t=1.000): it:   21100 | loss: 1.437e+01
VI NF (t=1.000): it:   21200 | loss: 1.454e+01
VI NF (t=1.000): it:   21300 | loss: 1.437e+01
VI NF (t=1.000): it:   21400 | loss: 1.439e+01
VI NF (t=1.000): it:   21500 | loss: 1.479e+01
VI NF (t=1.000): it:   21600 | loss: 1.445e+01
VI NF (t=1.000): it:   21700 | loss: 1.420e+01
VI NF (t=1.000): it:   21800 | loss: 1.458e+01
VI NF (t=1.000): it:   21900 | loss: 1.431e+01
VI NF (t=1.000): it:   22000 | loss: 1.480e+01
VI NF (t=1.000): it:   22100 | loss: 1.436e+01
VI NF (t=1.000): it:   22200 | loss: 1.434e+01
VI NF (t=1.000): it:   22300 | loss: 1.464e+01
VI NF (t=1.000): it:   22400 | loss: 1.441e+01
VI NF (t=1.000): it:   22500 | loss: 1.426e+01
VI NF (t=1.000): it:   22600 | loss: 1.429e+01
VI NF (t=1.000): it:   22700 | loss: 1.435e+01
VI NF (t=1.000): it:   22800 | loss: 1.443e+01
VI NF (t=1.000): it:   22900 | loss: 1.506e+01
VI NF (t=1.000): it:   23000 | loss: 1.455e+01
VI NF (t=1.000): it:   23100 | loss: 1.441e+01
VI NF (t=1.000): it:   23200 | loss: 1.454e+01
VI NF (t=1.000): it:   23300 | loss: 1.448e+01
VI NF (t=1.000): it:   23400 | loss: 1.432e+01
VI NF (t=1.000): it:   23500 | loss: 1.432e+01
VI NF (t=1.000): it:   23600 | loss: 1.425e+01
VI NF (t=1.000): it:   23700 | loss: 1.430e+01
VI NF (t=1.000): it:   23800 | loss: 1.431e+01
VI NF (t=1.000): it:   23900 | loss: 1.424e+01
VI NF (t=1.000): it:   24000 | loss: 1.442e+01
VI NF (t=1.000): it:   24100 | loss: 1.443e+01
VI NF (t=1.000): it:   24200 | loss: 1.438e+01
VI NF (t=1.000): it:   24300 | loss: 1.437e+01
VI NF (t=1.000): it:   24400 | loss: 1.441e+01
VI NF (t=1.000): it:   24500 | loss: 1.435e+01
VI NF (t=1.000): it:   24600 | loss: 1.474e+01
VI NF (t=1.000): it:   24700 | loss: 1.440e+01
VI NF (t=1.000): it:   24800 | loss: 1.421e+01
VI NF (t=1.000): it:   24900 | loss: 1.434e+01
--- Saving results at iteration 25000
VI NF (t=1.000): it:   25000 | loss: 1.427e+01

--- Simulation completed!

Notice that the model evaluation has been successfully completed by checking at the newly created phys_nofasFree folder in the tutorial directory.

Note also that LINFA supports a post processing script to plot the mnain results including the loss profile, two-dimensional slices of the posterior distribution and two-dimensional slices for the predictive posterior.

We can use the command below to generate the result plots

[16]:
import linfa
! python3 -m linfa.plot_res -n phys_full_3d -i 25000 -f "./" -p 'png' -d
Plotting log...
Plotting posterior samples...
Plotting posterior predictive samples...

You can now visualize the results

[17]:
from IPython.display import Image, display
display(Image(filename='./phys_full_3d/log_plot.png',width=300))
../../_images/content_tutorial_tutorial_linfa_3d_37_0.png
[18]:
from IPython.display import Image, display
display(Image(filename='phys_full_3d/data_plot_phys_full_3d_25000_0_1.png',width=300))
display(Image(filename='phys_full_3d/data_plot_phys_full_3d_25000_0_2.png',width=300))
display(Image(filename='phys_full_3d/data_plot_phys_full_3d_25000_1_2.png',width=300))
../../_images/content_tutorial_tutorial_linfa_3d_38_0.png
../../_images/content_tutorial_tutorial_linfa_3d_38_1.png
../../_images/content_tutorial_tutorial_linfa_3d_38_2.png
[19]:
from IPython.display import Image, display
display(Image(filename='phys_full_3d/params_plot_phys_full_3d_25000_0_1.png',width=300))
display(Image(filename='phys_full_3d/params_plot_phys_full_3d_25000_0_2.png',width=300))
display(Image(filename='phys_full_3d/params_plot_phys_full_3d_25000_1_2.png',width=300))
../../_images/content_tutorial_tutorial_linfa_3d_39_0.png
../../_images/content_tutorial_tutorial_linfa_3d_39_1.png
../../_images/content_tutorial_tutorial_linfa_3d_39_2.png

The results look as expected, concentrated around the fibers identified above.

However, even with our simple model, the cost of evaluating the physics-based model and to compute the gradient through it can be significant and might lead to intractable inference tasks.

In such cases, LINFA enables the construction of the adaptively trained surrogate model. By utilizing the surrogate model, gradient computation is executed through the surrogate, reducing the computational burden of such operation.

In addition, LINFA provides an adaptive annealing scheduler (AdaAnn) which allows easier sampling from complicated densities.

Accordingly, we will specifically observe how the adaptively trained surrogate model reduces the computational cost in our last step: Applying our model including the Surrogate model.

AdaAnn: An adaptive annealing scheduler

We start by additing some options to activate the adaptive annealing scheduler and specify the associated hyperparameters.

[73]:
exp.annealing = True
exp.scheduler = 'AdaAnn' # str: type of annealing scheduler used
exp.tol       = 0.01     # float: tolerance for AdaAnn scheduler
exp.t0        = 0.001    # float: initial inverse temperature value
exp.N         = 250      # int: number of sample points during annealing
exp.N_1       = 250      # int: number of sample points at t=1
exp.T_0       = 500      # int: number of parameter updates at initial t0
exp.T         = 10       # int: number of parameter updates during annealing
exp.T_1       = 25000    # int: number of parameter updates at t=1
exp.M         = 1000     # int: number of sample points used to update temperature

NoFAS: Normalizing flow with an adaptively trained surrogate

Before specifying a surrogate model, we define the associated hyperparameters.

[74]:
exp.run_nofas          = True       # boolean: to run experiment with nofas
exp.calibrate_interval = 2000       # int: How often to update surrogate model
exp.budget             = 2000       # int: Total number of true model evaluation
exp.surr_pre_it        = 40000      # int: Number of pre-training iterations for surrogate model
exp.surr_upd_it        = 6000       # int: Number of iterations for the surrogate model update
exp.use_new_surr       = True       # boolean: to run experiment with nofas
exp.surr_folder        = "./"
exp.true_data_num      = 15

In addition we need to define the new surrogate and assign the surrogate so the current instance of th Experiment class knows about it. Note the following hyperparameter choices:

  • The memory_len parameter is set to 100 to use the adaptively collected samples for a larger number of iterations.

  • A parametric architecture can be specified for the default dense neural network surrogate. The parameter dnn_arch=[100,100] is used to generate a dense neural network with two hidden layers having 100 and 100 neurons, respectively. The parameter dnn_activation='silu' is used to specify a silu activation function for all layers except the last.

[75]:
exp.name = "phys_surr_3d"
exp.output_dir   = './' + exp.name

exp.surrogate = Surrogate(exp.name, lambda x: model.solve_t(trsf.forward(x)), exp.input_size, 3,
                            model_folder=exp.surr_folder, limits=torch.Tensor([[0, 6], [4, 5], [-3, 3]]),
                            memory_len=100, dnn_arch=[100,100], dnn_activation='silu', device=exp.device)
surr_filename = exp.surr_folder + exp.name
if exp.use_new_surr or (not os.path.isfile(surr_filename + ".sur")) or (not os.path.isfile(surr_filename + ".npz")):
    print("Warning: Surrogate model files: {0}.npz and {0}.npz could not be found. ".format(surr_filename))
    exp.surrogate.gen_grid(gridnum=6)
    exp.surrogate.pre_train(exp.surr_pre_it, 0.03, 0.9999, 500, store=True)
# Load the surrogate
exp.surrogate.surrogate_load()
Success: Pre-Grid found.
Warning: Surrogate model files: ./phys_surr_3d.npz and ./phys_surr_3d.npz could not be found.

--- Pre-training surrogate model

SUR: PRE: it:       0 | loss: 3.003e+00
SUR: PRE: it:     500 | loss: 2.340e-01
SUR: PRE: it:    1000 | loss: 1.291e-01
SUR: PRE: it:    1500 | loss: 1.237e-01
SUR: PRE: it:    2000 | loss: 6.689e-02
SUR: PRE: it:    2500 | loss: 2.845e-02
SUR: PRE: it:    3000 | loss: 6.229e-02
SUR: PRE: it:    3500 | loss: 3.632e-02
SUR: PRE: it:    4000 | loss: 2.400e-02
SUR: PRE: it:    4500 | loss: 6.203e-03
SUR: PRE: it:    5000 | loss: 5.400e-03
SUR: PRE: it:    5500 | loss: 3.647e-02
SUR: PRE: it:    6000 | loss: 4.926e-03
SUR: PRE: it:    6500 | loss: 1.528e-02
SUR: PRE: it:    7000 | loss: 1.344e-02
SUR: PRE: it:    7500 | loss: 3.059e-02
SUR: PRE: it:    8000 | loss: 1.215e-02
SUR: PRE: it:    8500 | loss: 2.109e-02
SUR: PRE: it:    9000 | loss: 2.994e-03
SUR: PRE: it:    9500 | loss: 8.915e-03
SUR: PRE: it:   10000 | loss: 4.484e-03
SUR: PRE: it:   10500 | loss: 1.377e-02
SUR: PRE: it:   11000 | loss: 6.786e-03
SUR: PRE: it:   11500 | loss: 4.807e-03
SUR: PRE: it:   12000 | loss: 6.118e-03
SUR: PRE: it:   12500 | loss: 1.832e-03
SUR: PRE: it:   13000 | loss: 6.654e-03
SUR: PRE: it:   13500 | loss: 7.077e-03
SUR: PRE: it:   14000 | loss: 2.848e-03
SUR: PRE: it:   14500 | loss: 2.907e-03
SUR: PRE: it:   15000 | loss: 1.074e-03
SUR: PRE: it:   15500 | loss: 3.613e-03
SUR: PRE: it:   16000 | loss: 3.777e-03
SUR: PRE: it:   16500 | loss: 2.822e-03
SUR: PRE: it:   17000 | loss: 4.747e-03
SUR: PRE: it:   17500 | loss: 1.166e-03
SUR: PRE: it:   18000 | loss: 1.258e-03
SUR: PRE: it:   18500 | loss: 1.388e-03
SUR: PRE: it:   19000 | loss: 1.203e-03
SUR: PRE: it:   19500 | loss: 1.731e-03
SUR: PRE: it:   20000 | loss: 6.677e-04
SUR: PRE: it:   20500 | loss: 1.497e-03
SUR: PRE: it:   21000 | loss: 1.465e-03
SUR: PRE: it:   21500 | loss: 1.827e-03
SUR: PRE: it:   22000 | loss: 6.093e-04
SUR: PRE: it:   22500 | loss: 7.512e-04
SUR: PRE: it:   23000 | loss: 1.724e-03
SUR: PRE: it:   23500 | loss: 7.349e-04
SUR: PRE: it:   24000 | loss: 7.053e-04
SUR: PRE: it:   24500 | loss: 9.477e-04
SUR: PRE: it:   25000 | loss: 1.028e-03
SUR: PRE: it:   25500 | loss: 1.077e-03
SUR: PRE: it:   26000 | loss: 5.348e-04
SUR: PRE: it:   26500 | loss: 6.456e-04
SUR: PRE: it:   27000 | loss: 4.906e-04
SUR: PRE: it:   27500 | loss: 5.697e-04
SUR: PRE: it:   28000 | loss: 9.931e-04
SUR: PRE: it:   28500 | loss: 6.074e-04
SUR: PRE: it:   29000 | loss: 6.776e-04
SUR: PRE: it:   29500 | loss: 4.846e-04
SUR: PRE: it:   30000 | loss: 5.515e-04
SUR: PRE: it:   30500 | loss: 5.223e-04
SUR: PRE: it:   31000 | loss: 4.591e-04
SUR: PRE: it:   31500 | loss: 6.977e-04
SUR: PRE: it:   32000 | loss: 5.778e-04
SUR: PRE: it:   32500 | loss: 6.757e-04
SUR: PRE: it:   33000 | loss: 4.428e-04
SUR: PRE: it:   33500 | loss: 4.748e-04
SUR: PRE: it:   34000 | loss: 5.916e-04
SUR: PRE: it:   34500 | loss: 4.548e-04
SUR: PRE: it:   35000 | loss: 4.060e-04
SUR: PRE: it:   35500 | loss: 4.320e-04
SUR: PRE: it:   36000 | loss: 4.388e-04
SUR: PRE: it:   36500 | loss: 4.324e-04
SUR: PRE: it:   37000 | loss: 4.112e-04
SUR: PRE: it:   37500 | loss: 3.938e-04
SUR: PRE: it:   38000 | loss: 3.992e-04
SUR: PRE: it:   38500 | loss: 4.018e-04
SUR: PRE: it:   39000 | loss: 3.898e-04
SUR: PRE: it:   39500 | loss: 3.881e-04

--- Surrogate model pre-train complete

Success: [limits] loaded.
Success: [pre_grid] loaded.
Success: [grid_record] loaded.
[76]:
## Run
print('')
print('--- TUTORIAL: Ballistic Example - with NN surrogate and annealing')

# Assign logdensity
exp.model_logdensity = lambda x: log_density(x, model, exp.surrogate, trsf)

# Run VI
exp.run()

--- TUTORIAL: Ballistic Example - with NN surrogate and annealing

--- Running on device: cpu

VI NF (t=0.001): it:     100 | loss: 6.280e+00
VI NF (t=0.001): it:     200 | loss: 7.855e+00
VI NF (t=0.001): it:     300 | loss: 2.544e+00
VI NF (t=0.001): it:     400 | loss: 2.022e+00
VI NF (t=0.001): it:     500 | loss: 2.399e+00
VI NF (t=0.001): it:     600 | loss: 1.822e+00
VI NF (t=0.001): it:     700 | loss: 2.274e+00
VI NF (t=0.001): it:     800 | loss: 1.991e+00
VI NF (t=0.001): it:     900 | loss: 2.080e+00
VI NF (t=0.001): it:    1000 | loss: 2.283e+00
VI NF (t=0.001): it:    1100 | loss: 2.372e+00
VI NF (t=0.002): it:    1200 | loss: 2.342e+00
VI NF (t=0.002): it:    1300 | loss: 1.970e+00
VI NF (t=0.002): it:    1400 | loss: 2.086e+00
VI NF (t=0.002): it:    1500 | loss: 2.699e+00
VI NF (t=0.002): it:    1600 | loss: 2.418e+00
VI NF (t=0.002): it:    1700 | loss: 2.189e+00
VI NF (t=0.002): it:    1800 | loss: 2.209e+00
VI NF (t=0.002): it:    1900 | loss: 2.476e+00

--- Updating surrogate model

Std before inflation -> Std after inflation
1.382e+00 -> 1.382e+00
1.234e+00 -> 1.234e+00
1.221e+00 -> 1.221e+00

SUR: UPD: it:       0 | loss: 8.511e+00
SUR: UPD: it:     500 | loss: 4.022e-02
SUR: UPD: it:    1000 | loss: 8.212e-03
SUR: UPD: it:    1500 | loss: 6.960e-03
SUR: UPD: it:    2000 | loss: 1.507e-03
SUR: UPD: it:    2500 | loss: 1.597e-03
SUR: UPD: it:    3000 | loss: 9.604e-04
SUR: UPD: it:    3500 | loss: 8.631e-04
SUR: UPD: it:    4000 | loss: 8.193e-04
SUR: UPD: it:    4500 | loss: 7.685e-04
SUR: UPD: it:    5000 | loss: 7.177e-04
SUR: UPD: it:    5500 | loss: 6.888e-04

--- Surrogate model updated

VI NF (t=0.002): it:    2000 | loss: 2.739e+01
VI NF (t=0.002): it:    2100 | loss: 4.456e+00
VI NF (t=0.002): it:    2200 | loss: 3.160e+00
VI NF (t=0.003): it:    2300 | loss: 3.300e+00
VI NF (t=0.003): it:    2400 | loss: 3.334e+00
VI NF (t=0.003): it:    2500 | loss: 3.465e+00
VI NF (t=0.003): it:    2600 | loss: 3.199e+00
VI NF (t=0.003): it:    2700 | loss: 3.433e+00
VI NF (t=0.004): it:    2800 | loss: 3.548e+00
VI NF (t=0.004): it:    2900 | loss: 3.411e+00
VI NF (t=0.004): it:    3000 | loss: 3.381e+00
VI NF (t=0.005): it:    3100 | loss: 3.594e+00
VI NF (t=0.005): it:    3200 | loss: 3.623e+00
VI NF (t=0.005): it:    3300 | loss: 3.679e+00
VI NF (t=0.006): it:    3400 | loss: 3.812e+00
VI NF (t=0.006): it:    3500 | loss: 3.962e+00
VI NF (t=0.007): it:    3600 | loss: 4.088e+00
VI NF (t=0.007): it:    3700 | loss: 4.389e+00
VI NF (t=0.008): it:    3800 | loss: 4.197e+00
VI NF (t=0.008): it:    3900 | loss: 4.100e+00

--- Updating surrogate model

Std before inflation -> Std after inflation
2.019e+00 -> 2.019e+00
4.892e-01 -> 4.892e-01
1.294e+00 -> 1.294e+00

SUR: UPD: it:       0 | loss: 2.554e-01
SUR: UPD: it:     500 | loss: 5.976e-02
SUR: UPD: it:    1000 | loss: 9.066e-03
SUR: UPD: it:    1500 | loss: 5.433e-03
SUR: UPD: it:    2000 | loss: 3.665e-03
SUR: UPD: it:    2500 | loss: 3.088e-03
SUR: UPD: it:    3000 | loss: 2.456e-03
SUR: UPD: it:    3500 | loss: 1.683e-03
SUR: UPD: it:    4000 | loss: 1.429e-03
SUR: UPD: it:    4500 | loss: 1.279e-03
SUR: UPD: it:    5000 | loss: 1.188e-03
SUR: UPD: it:    5500 | loss: 1.141e-03

--- Surrogate model updated

VI NF (t=0.008): it:    4000 | loss: 8.947e+00
VI NF (t=0.009): it:    4100 | loss: 4.386e+00
VI NF (t=0.010): it:    4200 | loss: 4.437e+00
VI NF (t=0.010): it:    4300 | loss: 4.502e+00
VI NF (t=0.011): it:    4400 | loss: 4.569e+00
VI NF (t=0.012): it:    4500 | loss: 4.578e+00
VI NF (t=0.013): it:    4600 | loss: 4.666e+00
VI NF (t=0.015): it:    4700 | loss: 4.643e+00
VI NF (t=0.016): it:    4800 | loss: 4.929e+00
VI NF (t=0.017): it:    4900 | loss: 5.354e+00
--- Saving results at iteration 5000
VI NF (t=0.019): it:    5000 | loss: 5.007e+00
VI NF (t=0.020): it:    5100 | loss: 5.176e+00
VI NF (t=0.022): it:    5200 | loss: 5.177e+00
VI NF (t=0.023): it:    5300 | loss: 5.391e+00
VI NF (t=0.025): it:    5400 | loss: 5.523e+00
VI NF (t=0.028): it:    5500 | loss: 5.659e+00
VI NF (t=0.030): it:    5600 | loss: 5.629e+00
VI NF (t=0.032): it:    5700 | loss: 5.759e+00
VI NF (t=0.035): it:    5800 | loss: 5.938e+00
VI NF (t=0.038): it:    5900 | loss: 6.737e+00

--- Updating surrogate model

Std before inflation -> Std after inflation
1.042e+00 -> 1.042e+00
2.302e-01 -> 2.302e-01
8.577e-01 -> 8.577e-01

SUR: UPD: it:       0 | loss: 4.513e-02
SUR: UPD: it:     500 | loss: 2.741e-02
SUR: UPD: it:    1000 | loss: 2.510e-02
SUR: UPD: it:    1500 | loss: 9.926e-03
SUR: UPD: it:    2000 | loss: 8.245e-03
SUR: UPD: it:    2500 | loss: 5.986e-03
SUR: UPD: it:    3000 | loss: 5.046e-03
SUR: UPD: it:    3500 | loss: 4.360e-03
SUR: UPD: it:    4000 | loss: 4.052e-03
SUR: UPD: it:    4500 | loss: 3.875e-03
SUR: UPD: it:    5000 | loss: 3.784e-03
SUR: UPD: it:    5500 | loss: 3.425e-03

--- Surrogate model updated

VI NF (t=0.041): it:    6000 | loss: 6.748e+00
VI NF (t=0.044): it:    6100 | loss: 7.170e+00
VI NF (t=0.047): it:    6200 | loss: 6.787e+00
VI NF (t=0.050): it:    6300 | loss: 7.110e+00
VI NF (t=0.054): it:    6400 | loss: 7.226e+00
VI NF (t=0.058): it:    6500 | loss: 6.807e+00
VI NF (t=0.062): it:    6600 | loss: 6.888e+00
VI NF (t=0.066): it:    6700 | loss: 7.650e+00
VI NF (t=0.071): it:    6800 | loss: 7.605e+00
VI NF (t=0.076): it:    6900 | loss: 7.290e+00
VI NF (t=0.081): it:    7000 | loss: 7.314e+00
VI NF (t=0.086): it:    7100 | loss: 6.986e+00
VI NF (t=0.092): it:    7200 | loss: 7.425e+00
VI NF (t=0.098): it:    7300 | loss: 7.790e+00
VI NF (t=0.104): it:    7400 | loss: 7.841e+00
VI NF (t=0.111): it:    7500 | loss: 8.551e+00
VI NF (t=0.118): it:    7600 | loss: 8.344e+00
VI NF (t=0.126): it:    7700 | loss: 7.926e+00
VI NF (t=0.133): it:    7800 | loss: 8.191e+00
VI NF (t=0.141): it:    7900 | loss: 8.124e+00

--- Updating surrogate model

Std before inflation -> Std after inflation
1.110e+00 -> 1.110e+00
3.225e-01 -> 3.225e-01
8.923e-01 -> 8.923e-01

SUR: UPD: it:       0 | loss: 1.185e-02
SUR: UPD: it:     500 | loss: 3.751e-02
SUR: UPD: it:    1000 | loss: 1.469e-02
SUR: UPD: it:    1500 | loss: 1.219e-02
SUR: UPD: it:    2000 | loss: 9.400e-03
SUR: UPD: it:    2500 | loss: 7.327e-03
SUR: UPD: it:    3000 | loss: 6.174e-03
SUR: UPD: it:    3500 | loss: 5.771e-03
SUR: UPD: it:    4000 | loss: 5.432e-03
SUR: UPD: it:    4500 | loss: 5.144e-03
SUR: UPD: it:    5000 | loss: 4.914e-03
SUR: UPD: it:    5500 | loss: 4.784e-03

--- Surrogate model updated

VI NF (t=0.150): it:    8000 | loss: 9.014e+00
VI NF (t=0.160): it:    8100 | loss: 8.494e+00
VI NF (t=0.173): it:    8200 | loss: 8.526e+00
VI NF (t=0.186): it:    8300 | loss: 8.841e+00
VI NF (t=0.200): it:    8400 | loss: 9.537e+00
VI NF (t=0.213): it:    8500 | loss: 8.967e+00
VI NF (t=0.228): it:    8600 | loss: 9.638e+00
VI NF (t=0.241): it:    8700 | loss: 9.399e+00
VI NF (t=0.255): it:    8800 | loss: 9.370e+00
VI NF (t=0.272): it:    8900 | loss: 9.607e+00
VI NF (t=0.287): it:    9000 | loss: 9.764e+00
VI NF (t=0.308): it:    9100 | loss: 1.013e+01
VI NF (t=0.329): it:    9200 | loss: 1.014e+01
VI NF (t=0.346): it:    9300 | loss: 1.032e+01
VI NF (t=0.368): it:    9400 | loss: 1.089e+01
VI NF (t=0.388): it:    9500 | loss: 1.074e+01
VI NF (t=0.414): it:    9600 | loss: 1.057e+01
VI NF (t=0.440): it:    9700 | loss: 1.111e+01
VI NF (t=0.466): it:    9800 | loss: 1.147e+01
VI NF (t=0.496): it:    9900 | loss: 1.135e+01
--- Saving results at iteration 10000

--- Updating surrogate model

Std before inflation -> Std after inflation
1.002e+00 -> 1.002e+00
2.492e-01 -> 2.492e-01
8.356e-01 -> 8.356e-01

SUR: UPD: it:       0 | loss: 1.952e-02
SUR: UPD: it:     500 | loss: 3.685e-02
SUR: UPD: it:    1000 | loss: 4.391e-02
SUR: UPD: it:    1500 | loss: 1.828e-02
SUR: UPD: it:    2000 | loss: 1.372e-02
SUR: UPD: it:    2500 | loss: 1.123e-02
SUR: UPD: it:    3000 | loss: 1.047e-02
SUR: UPD: it:    3500 | loss: 9.877e-03
SUR: UPD: it:    4000 | loss: 9.245e-03
SUR: UPD: it:    4500 | loss: 8.520e-03
SUR: UPD: it:    5000 | loss: 8.276e-03
SUR: UPD: it:    5500 | loss: 8.089e-03

--- Surrogate model updated

VI NF (t=0.526): it:   10000 | loss: 1.302e+01
VI NF (t=0.556): it:   10100 | loss: 1.194e+01
VI NF (t=0.590): it:   10200 | loss: 1.188e+01
VI NF (t=0.622): it:   10300 | loss: 1.386e+01
VI NF (t=0.656): it:   10400 | loss: 1.259e+01
VI NF (t=0.696): it:   10500 | loss: 1.328e+01
VI NF (t=0.739): it:   10600 | loss: 1.384e+01
VI NF (t=0.786): it:   10700 | loss: 1.478e+01
VI NF (t=0.824): it:   10800 | loss: 1.376e+01
VI NF (t=0.879): it:   10900 | loss: 1.457e+01
VI NF (t=0.925): it:   11000 | loss: 1.514e+01
VI NF (t=0.968): it:   11100 | loss: 1.488e+01
VI NF (t=1.000): it:   11200 | loss: 1.529e+01
VI NF (t=1.000): it:   11300 | loss: 1.482e+01
VI NF (t=1.000): it:   11400 | loss: 1.533e+01
VI NF (t=1.000): it:   11500 | loss: 1.539e+01
VI NF (t=1.000): it:   11600 | loss: 1.552e+01
VI NF (t=1.000): it:   11700 | loss: 1.478e+01
VI NF (t=1.000): it:   11800 | loss: 1.485e+01
VI NF (t=1.000): it:   11900 | loss: 1.477e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
7.592e-01 -> 7.592e-01
1.634e-01 -> 1.634e-01
6.678e-01 -> 6.678e-01

SUR: UPD: it:       0 | loss: 1.055e-02
SUR: UPD: it:     500 | loss: 4.042e-02
SUR: UPD: it:    1000 | loss: 1.554e-02
SUR: UPD: it:    1500 | loss: 9.839e-03
SUR: UPD: it:    2000 | loss: 8.273e-03
SUR: UPD: it:    2500 | loss: 6.075e-03
SUR: UPD: it:    3000 | loss: 5.406e-03
SUR: UPD: it:    3500 | loss: 4.990e-03
SUR: UPD: it:    4000 | loss: 4.826e-03
SUR: UPD: it:    4500 | loss: 4.693e-03
SUR: UPD: it:    5000 | loss: 4.594e-03
SUR: UPD: it:    5500 | loss: 4.512e-03

--- Surrogate model updated

VI NF (t=1.000): it:   12000 | loss: 1.655e+01
VI NF (t=1.000): it:   12100 | loss: 1.499e+01
VI NF (t=1.000): it:   12200 | loss: 1.533e+01
VI NF (t=1.000): it:   12300 | loss: 1.540e+01
VI NF (t=1.000): it:   12400 | loss: 1.526e+01
VI NF (t=1.000): it:   12500 | loss: 1.503e+01
VI NF (t=1.000): it:   12600 | loss: 1.489e+01
VI NF (t=1.000): it:   12700 | loss: 1.505e+01
VI NF (t=1.000): it:   12800 | loss: 1.513e+01
VI NF (t=1.000): it:   12900 | loss: 1.511e+01
VI NF (t=1.000): it:   13000 | loss: 1.467e+01
VI NF (t=1.000): it:   13100 | loss: 1.516e+01
VI NF (t=1.000): it:   13200 | loss: 1.529e+01
VI NF (t=1.000): it:   13300 | loss: 1.480e+01
VI NF (t=1.000): it:   13400 | loss: 1.510e+01
VI NF (t=1.000): it:   13500 | loss: 1.493e+01
VI NF (t=1.000): it:   13600 | loss: 1.490e+01
VI NF (t=1.000): it:   13700 | loss: 1.582e+01
VI NF (t=1.000): it:   13800 | loss: 1.479e+01
VI NF (t=1.000): it:   13900 | loss: 1.471e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
8.356e-01 -> 8.356e-01
1.591e-01 -> 1.591e-01
7.226e-01 -> 7.226e-01

SUR: UPD: it:       0 | loss: 5.213e-03
SUR: UPD: it:     500 | loss: 4.175e-02
SUR: UPD: it:    1000 | loss: 1.323e-02
SUR: UPD: it:    1500 | loss: 1.298e-02
SUR: UPD: it:    2000 | loss: 8.124e-03
SUR: UPD: it:    2500 | loss: 5.622e-03
SUR: UPD: it:    3000 | loss: 5.102e-03
SUR: UPD: it:    3500 | loss: 4.973e-03
SUR: UPD: it:    4000 | loss: 4.886e-03
SUR: UPD: it:    4500 | loss: 4.781e-03
SUR: UPD: it:    5000 | loss: 4.731e-03
SUR: UPD: it:    5500 | loss: 4.692e-03

--- Surrogate model updated

VI NF (t=1.000): it:   14000 | loss: 1.474e+01
VI NF (t=1.000): it:   14100 | loss: 1.547e+01
VI NF (t=1.000): it:   14200 | loss: 1.506e+01
VI NF (t=1.000): it:   14300 | loss: 1.505e+01
VI NF (t=1.000): it:   14400 | loss: 1.505e+01
VI NF (t=1.000): it:   14500 | loss: 1.476e+01
VI NF (t=1.000): it:   14600 | loss: 1.471e+01
VI NF (t=1.000): it:   14700 | loss: 1.492e+01
VI NF (t=1.000): it:   14800 | loss: 1.486e+01
VI NF (t=1.000): it:   14900 | loss: 1.460e+01
--- Saving results at iteration 15000
VI NF (t=1.000): it:   15000 | loss: 1.498e+01
VI NF (t=1.000): it:   15100 | loss: 1.471e+01
VI NF (t=1.000): it:   15200 | loss: 1.479e+01
VI NF (t=1.000): it:   15300 | loss: 1.525e+01
VI NF (t=1.000): it:   15400 | loss: 1.502e+01
VI NF (t=1.000): it:   15500 | loss: 1.469e+01
VI NF (t=1.000): it:   15600 | loss: 1.497e+01
VI NF (t=1.000): it:   15700 | loss: 1.476e+01
VI NF (t=1.000): it:   15800 | loss: 1.509e+01
VI NF (t=1.000): it:   15900 | loss: 1.481e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
1.021e+00 -> 1.021e+00
2.114e-01 -> 2.114e-01
8.731e-01 -> 8.731e-01

SUR: UPD: it:       0 | loss: 5.265e-03
SUR: UPD: it:     500 | loss: 9.765e-02
SUR: UPD: it:    1000 | loss: 2.880e-02
SUR: UPD: it:    1500 | loss: 1.977e-02
SUR: UPD: it:    2000 | loss: 7.858e-03
SUR: UPD: it:    2500 | loss: 6.165e-03
SUR: UPD: it:    3000 | loss: 5.518e-03
SUR: UPD: it:    3500 | loss: 5.372e-03
SUR: UPD: it:    4000 | loss: 5.272e-03
SUR: UPD: it:    4500 | loss: 5.220e-03
SUR: UPD: it:    5000 | loss: 5.185e-03
SUR: UPD: it:    5500 | loss: 5.157e-03

--- Surrogate model updated

VI NF (t=1.000): it:   16000 | loss: 1.480e+01
VI NF (t=1.000): it:   16100 | loss: 1.461e+01
VI NF (t=1.000): it:   16200 | loss: 1.462e+01
VI NF (t=1.000): it:   16300 | loss: 1.520e+01
VI NF (t=1.000): it:   16400 | loss: 1.488e+01
VI NF (t=1.000): it:   16500 | loss: 1.495e+01
VI NF (t=1.000): it:   16600 | loss: 1.460e+01
VI NF (t=1.000): it:   16700 | loss: 1.477e+01
VI NF (t=1.000): it:   16800 | loss: 1.465e+01
VI NF (t=1.000): it:   16900 | loss: 1.530e+01
VI NF (t=1.000): it:   17000 | loss: 1.481e+01
VI NF (t=1.000): it:   17100 | loss: 1.503e+01
VI NF (t=1.000): it:   17200 | loss: 1.491e+01
VI NF (t=1.000): it:   17300 | loss: 1.461e+01
VI NF (t=1.000): it:   17400 | loss: 1.475e+01
VI NF (t=1.000): it:   17500 | loss: 1.480e+01
VI NF (t=1.000): it:   17600 | loss: 1.467e+01
VI NF (t=1.000): it:   17700 | loss: 1.496e+01
VI NF (t=1.000): it:   17800 | loss: 1.480e+01
VI NF (t=1.000): it:   17900 | loss: 1.478e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
7.920e-01 -> 7.920e-01
1.486e-01 -> 1.486e-01
6.749e-01 -> 6.749e-01

SUR: UPD: it:       0 | loss: 5.455e-03
SUR: UPD: it:     500 | loss: 5.221e-02
SUR: UPD: it:    1000 | loss: 1.615e-02
SUR: UPD: it:    1500 | loss: 1.359e-02
SUR: UPD: it:    2000 | loss: 1.158e-02
SUR: UPD: it:    2500 | loss: 6.820e-03
SUR: UPD: it:    3000 | loss: 5.785e-03
SUR: UPD: it:    3500 | loss: 5.638e-03
SUR: UPD: it:    4000 | loss: 5.521e-03
SUR: UPD: it:    4500 | loss: 5.369e-03
SUR: UPD: it:    5000 | loss: 5.291e-03
SUR: UPD: it:    5500 | loss: 5.234e-03

--- Surrogate model updated

VI NF (t=1.000): it:   18000 | loss: 1.520e+01
VI NF (t=1.000): it:   18100 | loss: 1.501e+01
VI NF (t=1.000): it:   18200 | loss: 1.488e+01
VI NF (t=1.000): it:   18300 | loss: 1.473e+01
VI NF (t=1.000): it:   18400 | loss: 1.489e+01
VI NF (t=1.000): it:   18500 | loss: 1.466e+01
VI NF (t=1.000): it:   18600 | loss: 1.477e+01
VI NF (t=1.000): it:   18700 | loss: 1.506e+01
VI NF (t=1.000): it:   18800 | loss: 1.486e+01
VI NF (t=1.000): it:   18900 | loss: 1.451e+01
VI NF (t=1.000): it:   19000 | loss: 1.470e+01
VI NF (t=1.000): it:   19100 | loss: 1.480e+01
VI NF (t=1.000): it:   19200 | loss: 1.473e+01
VI NF (t=1.000): it:   19300 | loss: 1.467e+01
VI NF (t=1.000): it:   19400 | loss: 1.461e+01
VI NF (t=1.000): it:   19500 | loss: 1.452e+01
VI NF (t=1.000): it:   19600 | loss: 1.493e+01
VI NF (t=1.000): it:   19700 | loss: 1.457e+01
VI NF (t=1.000): it:   19800 | loss: 1.483e+01
VI NF (t=1.000): it:   19900 | loss: 1.471e+01
--- Saving results at iteration 20000

--- Updating surrogate model

Std before inflation -> Std after inflation
7.021e-01 -> 7.021e-01
1.436e-01 -> 1.436e-01
6.034e-01 -> 6.034e-01

SUR: UPD: it:       0 | loss: 5.580e-03
SUR: UPD: it:     500 | loss: 5.468e-02
SUR: UPD: it:    1000 | loss: 2.160e-02
SUR: UPD: it:    1500 | loss: 7.003e-03
SUR: UPD: it:    2000 | loss: 6.942e-03
SUR: UPD: it:    2500 | loss: 5.077e-03
SUR: UPD: it:    3000 | loss: 4.563e-03
SUR: UPD: it:    3500 | loss: 4.358e-03
SUR: UPD: it:    4000 | loss: 4.254e-03
SUR: UPD: it:    4500 | loss: 4.091e-03
SUR: UPD: it:    5000 | loss: 4.030e-03
SUR: UPD: it:    5500 | loss: 3.989e-03

--- Surrogate model updated

VI NF (t=1.000): it:   20000 | loss: 1.506e+01
VI NF (t=1.000): it:   20100 | loss: 1.449e+01
VI NF (t=1.000): it:   20200 | loss: 1.475e+01
VI NF (t=1.000): it:   20300 | loss: 1.457e+01
VI NF (t=1.000): it:   20400 | loss: 1.459e+01
VI NF (t=1.000): it:   20500 | loss: 1.429e+01
VI NF (t=1.000): it:   20600 | loss: 1.447e+01
VI NF (t=1.000): it:   20700 | loss: 1.446e+01
VI NF (t=1.000): it:   20800 | loss: 1.519e+01
VI NF (t=1.000): it:   20900 | loss: 1.438e+01
VI NF (t=1.000): it:   21000 | loss: 1.441e+01
VI NF (t=1.000): it:   21100 | loss: 1.445e+01
VI NF (t=1.000): it:   21200 | loss: 1.451e+01
VI NF (t=1.000): it:   21300 | loss: 1.435e+01
VI NF (t=1.000): it:   21400 | loss: 1.436e+01
VI NF (t=1.000): it:   21500 | loss: 1.440e+01
VI NF (t=1.000): it:   21600 | loss: 1.467e+01
VI NF (t=1.000): it:   21700 | loss: 1.444e+01
VI NF (t=1.000): it:   21800 | loss: 1.508e+01
VI NF (t=1.000): it:   21900 | loss: 1.440e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
9.966e-01 -> 9.966e-01
2.030e-01 -> 2.030e-01
8.306e-01 -> 8.306e-01

SUR: UPD: it:       0 | loss: 4.235e-03
SUR: UPD: it:     500 | loss: 2.240e-02
SUR: UPD: it:    1000 | loss: 1.279e-02
SUR: UPD: it:    1500 | loss: 5.627e-03
SUR: UPD: it:    2000 | loss: 7.259e-03
SUR: UPD: it:    2500 | loss: 4.469e-03
SUR: UPD: it:    3000 | loss: 4.221e-03
SUR: UPD: it:    3500 | loss: 4.128e-03
SUR: UPD: it:    4000 | loss: 4.043e-03
SUR: UPD: it:    4500 | loss: 3.994e-03
SUR: UPD: it:    5000 | loss: 3.958e-03
SUR: UPD: it:    5500 | loss: 3.918e-03

--- Surrogate model updated

VI NF (t=1.000): it:   22000 | loss: 1.447e+01
VI NF (t=1.000): it:   22100 | loss: 1.448e+01
VI NF (t=1.000): it:   22200 | loss: 1.424e+01
VI NF (t=1.000): it:   22300 | loss: 1.452e+01
VI NF (t=1.000): it:   22400 | loss: 1.463e+01
VI NF (t=1.000): it:   22500 | loss: 1.459e+01
VI NF (t=1.000): it:   22600 | loss: 1.450e+01
VI NF (t=1.000): it:   22700 | loss: 1.427e+01
VI NF (t=1.000): it:   22800 | loss: 1.440e+01
VI NF (t=1.000): it:   22900 | loss: 1.440e+01
VI NF (t=1.000): it:   23000 | loss: 1.421e+01
VI NF (t=1.000): it:   23100 | loss: 1.510e+01
VI NF (t=1.000): it:   23200 | loss: 1.433e+01
VI NF (t=1.000): it:   23300 | loss: 1.422e+01
VI NF (t=1.000): it:   23400 | loss: 1.432e+01
VI NF (t=1.000): it:   23500 | loss: 1.436e+01
VI NF (t=1.000): it:   23600 | loss: 1.450e+01
VI NF (t=1.000): it:   23700 | loss: 1.444e+01
VI NF (t=1.000): it:   23800 | loss: 1.457e+01
VI NF (t=1.000): it:   23900 | loss: 1.472e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
9.116e-01 -> 9.116e-01
1.567e-01 -> 1.567e-01
7.682e-01 -> 7.682e-01

SUR: UPD: it:       0 | loss: 4.040e-03
SUR: UPD: it:     500 | loss: 3.107e-02
SUR: UPD: it:    1000 | loss: 1.386e-02
SUR: UPD: it:    1500 | loss: 1.236e-02
SUR: UPD: it:    2000 | loss: 5.685e-03
SUR: UPD: it:    2500 | loss: 4.995e-03
SUR: UPD: it:    3000 | loss: 4.473e-03
SUR: UPD: it:    3500 | loss: 4.348e-03
SUR: UPD: it:    4000 | loss: 4.203e-03
SUR: UPD: it:    4500 | loss: 4.146e-03
SUR: UPD: it:    5000 | loss: 4.108e-03
SUR: UPD: it:    5500 | loss: 4.072e-03

--- Surrogate model updated

VI NF (t=1.000): it:   24000 | loss: 1.438e+01
VI NF (t=1.000): it:   24100 | loss: 1.460e+01
VI NF (t=1.000): it:   24200 | loss: 1.442e+01
VI NF (t=1.000): it:   24300 | loss: 1.485e+01
VI NF (t=1.000): it:   24400 | loss: 1.480e+01
VI NF (t=1.000): it:   24500 | loss: 1.449e+01
VI NF (t=1.000): it:   24600 | loss: 1.427e+01
VI NF (t=1.000): it:   24700 | loss: 1.452e+01
VI NF (t=1.000): it:   24800 | loss: 1.428e+01
VI NF (t=1.000): it:   24900 | loss: 1.488e+01
--- Saving results at iteration 25000
VI NF (t=1.000): it:   25000 | loss: 1.454e+01
VI NF (t=1.000): it:   25100 | loss: 1.427e+01
VI NF (t=1.000): it:   25200 | loss: 1.442e+01
VI NF (t=1.000): it:   25300 | loss: 1.471e+01
VI NF (t=1.000): it:   25400 | loss: 1.429e+01
VI NF (t=1.000): it:   25500 | loss: 1.435e+01
VI NF (t=1.000): it:   25600 | loss: 1.427e+01
VI NF (t=1.000): it:   25700 | loss: 1.416e+01
VI NF (t=1.000): it:   25800 | loss: 1.433e+01
VI NF (t=1.000): it:   25900 | loss: 1.449e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
9.151e-01 -> 9.151e-01
1.991e-01 -> 1.991e-01
7.685e-01 -> 7.685e-01

SUR: UPD: it:       0 | loss: 4.379e-03
SUR: UPD: it:     500 | loss: 1.899e-02
SUR: UPD: it:    1000 | loss: 1.188e-02
SUR: UPD: it:    1500 | loss: 1.609e-02
SUR: UPD: it:    2000 | loss: 6.333e-03
SUR: UPD: it:    2500 | loss: 5.559e-03
SUR: UPD: it:    3000 | loss: 5.315e-03
SUR: UPD: it:    3500 | loss: 5.064e-03
SUR: UPD: it:    4000 | loss: 4.760e-03
SUR: UPD: it:    4500 | loss: 4.654e-03
SUR: UPD: it:    5000 | loss: 4.542e-03
SUR: UPD: it:    5500 | loss: 4.472e-03

--- Surrogate model updated

VI NF (t=1.000): it:   26000 | loss: 1.506e+01
VI NF (t=1.000): it:   26100 | loss: 1.427e+01
VI NF (t=1.000): it:   26200 | loss: 1.439e+01
VI NF (t=1.000): it:   26300 | loss: 1.413e+01
VI NF (t=1.000): it:   26400 | loss: 1.418e+01
VI NF (t=1.000): it:   26500 | loss: 1.428e+01
VI NF (t=1.000): it:   26600 | loss: 1.426e+01
VI NF (t=1.000): it:   26700 | loss: 1.443e+01
VI NF (t=1.000): it:   26800 | loss: 1.485e+01
VI NF (t=1.000): it:   26900 | loss: 1.425e+01
VI NF (t=1.000): it:   27000 | loss: 1.436e+01
VI NF (t=1.000): it:   27100 | loss: 1.432e+01
VI NF (t=1.000): it:   27200 | loss: 1.427e+01
VI NF (t=1.000): it:   27300 | loss: 1.436e+01
VI NF (t=1.000): it:   27400 | loss: 1.433e+01
VI NF (t=1.000): it:   27500 | loss: 1.410e+01
VI NF (t=1.000): it:   27600 | loss: 1.420e+01
VI NF (t=1.000): it:   27700 | loss: 1.424e+01
VI NF (t=1.000): it:   27800 | loss: 1.422e+01
VI NF (t=1.000): it:   27900 | loss: 1.439e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
1.127e+00 -> 1.127e+00
2.378e-01 -> 2.378e-01
9.542e-01 -> 9.542e-01

SUR: UPD: it:       0 | loss: 4.859e-03
SUR: UPD: it:     500 | loss: 5.057e-02
SUR: UPD: it:    1000 | loss: 3.723e-02
SUR: UPD: it:    1500 | loss: 7.700e-03
SUR: UPD: it:    2000 | loss: 7.524e-03
SUR: UPD: it:    2500 | loss: 5.830e-03
SUR: UPD: it:    3000 | loss: 5.132e-03
SUR: UPD: it:    3500 | loss: 4.714e-03
SUR: UPD: it:    4000 | loss: 4.449e-03
SUR: UPD: it:    4500 | loss: 4.357e-03
SUR: UPD: it:    5000 | loss: 4.285e-03
SUR: UPD: it:    5500 | loss: 4.239e-03

--- Surrogate model updated

VI NF (t=1.000): it:   28000 | loss: 1.476e+01
VI NF (t=1.000): it:   28100 | loss: 1.424e+01
VI NF (t=1.000): it:   28200 | loss: 1.447e+01
VI NF (t=1.000): it:   28300 | loss: 1.483e+01
VI NF (t=1.000): it:   28400 | loss: 1.455e+01
VI NF (t=1.000): it:   28500 | loss: 1.441e+01
VI NF (t=1.000): it:   28600 | loss: 1.437e+01
VI NF (t=1.000): it:   28700 | loss: 1.442e+01
VI NF (t=1.000): it:   28800 | loss: 1.445e+01
VI NF (t=1.000): it:   28900 | loss: 1.456e+01
VI NF (t=1.000): it:   29000 | loss: 1.415e+01
VI NF (t=1.000): it:   29100 | loss: 1.451e+01
VI NF (t=1.000): it:   29200 | loss: 1.441e+01
VI NF (t=1.000): it:   29300 | loss: 1.428e+01
VI NF (t=1.000): it:   29400 | loss: 1.484e+01
VI NF (t=1.000): it:   29500 | loss: 1.430e+01
VI NF (t=1.000): it:   29600 | loss: 1.434e+01
VI NF (t=1.000): it:   29700 | loss: 1.422e+01
VI NF (t=1.000): it:   29800 | loss: 1.420e+01
VI NF (t=1.000): it:   29900 | loss: 1.434e+01
--- Saving results at iteration 30000

--- Updating surrogate model

Std before inflation -> Std after inflation
1.039e+00 -> 1.039e+00
1.933e-01 -> 1.933e-01
8.661e-01 -> 8.661e-01

SUR: UPD: it:       0 | loss: 4.299e-03
SUR: UPD: it:     500 | loss: 4.884e-02
SUR: UPD: it:    1000 | loss: 1.782e-02
SUR: UPD: it:    1500 | loss: 7.554e-03
SUR: UPD: it:    2000 | loss: 6.015e-03
SUR: UPD: it:    2500 | loss: 5.871e-03
SUR: UPD: it:    3000 | loss: 4.793e-03
SUR: UPD: it:    3500 | loss: 4.311e-03
SUR: UPD: it:    4000 | loss: 4.257e-03
SUR: UPD: it:    4500 | loss: 4.165e-03
SUR: UPD: it:    5000 | loss: 4.105e-03
SUR: UPD: it:    5500 | loss: 4.059e-03

--- Surrogate model updated

VI NF (t=1.000): it:   30000 | loss: 1.469e+01
VI NF (t=1.000): it:   30100 | loss: 1.420e+01
VI NF (t=1.000): it:   30200 | loss: 1.419e+01
VI NF (t=1.000): it:   30300 | loss: 1.440e+01
VI NF (t=1.000): it:   30400 | loss: 1.422e+01
VI NF (t=1.000): it:   30500 | loss: 1.422e+01
VI NF (t=1.000): it:   30600 | loss: 1.440e+01
VI NF (t=1.000): it:   30700 | loss: 1.407e+01
VI NF (t=1.000): it:   30800 | loss: 1.430e+01
VI NF (t=1.000): it:   30900 | loss: 1.439e+01
VI NF (t=1.000): it:   31000 | loss: 1.416e+01
VI NF (t=1.000): it:   31100 | loss: 1.434e+01
VI NF (t=1.000): it:   31200 | loss: 1.424e+01
VI NF (t=1.000): it:   31300 | loss: 1.438e+01
VI NF (t=1.000): it:   31400 | loss: 1.445e+01
VI NF (t=1.000): it:   31500 | loss: 1.422e+01
VI NF (t=1.000): it:   31600 | loss: 1.431e+01
VI NF (t=1.000): it:   31700 | loss: 1.426e+01
VI NF (t=1.000): it:   31800 | loss: 1.424e+01
VI NF (t=1.000): it:   31900 | loss: 1.472e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
1.178e+00 -> 1.178e+00
1.709e-01 -> 1.709e-01
1.001e+00 -> 1.001e+00

SUR: UPD: it:       0 | loss: 4.096e-03
SUR: UPD: it:     500 | loss: 1.765e-02
SUR: UPD: it:    1000 | loss: 1.049e-02
SUR: UPD: it:    1500 | loss: 8.594e-03
SUR: UPD: it:    2000 | loss: 5.135e-03
SUR: UPD: it:    2500 | loss: 3.869e-03
SUR: UPD: it:    3000 | loss: 3.777e-03
SUR: UPD: it:    3500 | loss: 3.582e-03
SUR: UPD: it:    4000 | loss: 3.485e-03
SUR: UPD: it:    4500 | loss: 3.382e-03
SUR: UPD: it:    5000 | loss: 3.325e-03
SUR: UPD: it:    5500 | loss: 3.285e-03

--- Surrogate model updated

VI NF (t=1.000): it:   32000 | loss: 1.432e+01
VI NF (t=1.000): it:   32100 | loss: 1.432e+01
VI NF (t=1.000): it:   32200 | loss: 1.431e+01
VI NF (t=1.000): it:   32300 | loss: 1.416e+01
VI NF (t=1.000): it:   32400 | loss: 1.449e+01
VI NF (t=1.000): it:   32500 | loss: 1.431e+01
VI NF (t=1.000): it:   32600 | loss: 1.424e+01
VI NF (t=1.000): it:   32700 | loss: 1.428e+01
VI NF (t=1.000): it:   32800 | loss: 1.434e+01
VI NF (t=1.000): it:   32900 | loss: 1.417e+01
VI NF (t=1.000): it:   33000 | loss: 1.411e+01
VI NF (t=1.000): it:   33100 | loss: 1.403e+01
VI NF (t=1.000): it:   33200 | loss: 1.426e+01
VI NF (t=1.000): it:   33300 | loss: 1.474e+01
VI NF (t=1.000): it:   33400 | loss: 1.425e+01
VI NF (t=1.000): it:   33500 | loss: 1.408e+01
VI NF (t=1.000): it:   33600 | loss: 1.409e+01
VI NF (t=1.000): it:   33700 | loss: 1.430e+01
VI NF (t=1.000): it:   33800 | loss: 1.415e+01
VI NF (t=1.000): it:   33900 | loss: 1.434e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
1.422e+00 -> 1.422e+00
2.405e-01 -> 2.405e-01
1.184e+00 -> 1.184e+00

SUR: UPD: it:       0 | loss: 3.881e-03
SUR: UPD: it:     500 | loss: 2.671e-02
SUR: UPD: it:    1000 | loss: 1.024e-02
SUR: UPD: it:    1500 | loss: 1.053e-02
SUR: UPD: it:    2000 | loss: 4.400e-03
SUR: UPD: it:    2500 | loss: 5.529e-03
SUR: UPD: it:    3000 | loss: 3.854e-03
SUR: UPD: it:    3500 | loss: 3.495e-03
SUR: UPD: it:    4000 | loss: 3.393e-03
SUR: UPD: it:    4500 | loss: 3.330e-03
SUR: UPD: it:    5000 | loss: 3.300e-03
SUR: UPD: it:    5500 | loss: 3.260e-03

--- Surrogate model updated

VI NF (t=1.000): it:   34000 | loss: 1.463e+01
VI NF (t=1.000): it:   34100 | loss: 1.421e+01
VI NF (t=1.000): it:   34200 | loss: 1.422e+01
VI NF (t=1.000): it:   34300 | loss: 1.430e+01
VI NF (t=1.000): it:   34400 | loss: 1.412e+01
VI NF (t=1.000): it:   34500 | loss: 1.415e+01
VI NF (t=1.000): it:   34600 | loss: 1.426e+01
VI NF (t=1.000): it:   34700 | loss: 1.432e+01
VI NF (t=1.000): it:   34800 | loss: 1.421e+01
VI NF (t=1.000): it:   34900 | loss: 1.407e+01
--- Saving results at iteration 35000
VI NF (t=1.000): it:   35000 | loss: 1.433e+01
VI NF (t=1.000): it:   35100 | loss: 1.419e+01
VI NF (t=1.000): it:   35200 | loss: 1.401e+01
VI NF (t=1.000): it:   35300 | loss: 1.422e+01
VI NF (t=1.000): it:   35400 | loss: 1.439e+01
VI NF (t=1.000): it:   35500 | loss: 1.423e+01
VI NF (t=1.000): it:   35600 | loss: 1.439e+01
VI NF (t=1.000): it:   35700 | loss: 1.421e+01
VI NF (t=1.000): it:   35800 | loss: 1.411e+01
VI NF (t=1.000): it:   35900 | loss: 1.417e+01

--- Updating surrogate model

Std before inflation -> Std after inflation
9.419e-01 -> 9.419e-01
2.319e-01 -> 2.319e-01
7.815e-01 -> 7.815e-01

SUR: UPD: it:       0 | loss: 3.299e-03
SUR: UPD: it:     500 | loss: 3.264e-02
SUR: UPD: it:    1000 | loss: 2.505e-02
SUR: UPD: it:    1500 | loss: 1.314e-02
SUR: UPD: it:    2000 | loss: 4.835e-03
SUR: UPD: it:    2500 | loss: 4.117e-03
SUR: UPD: it:    3000 | loss: 3.915e-03
SUR: UPD: it:    3500 | loss: 3.822e-03
SUR: UPD: it:    4000 | loss: 3.669e-03
SUR: UPD: it:    4500 | loss: 3.588e-03
SUR: UPD: it:    5000 | loss: 3.563e-03
SUR: UPD: it:    5500 | loss: 3.535e-03

--- Surrogate model updated

VI NF (t=1.000): it:   36000 | loss: 1.419e+01
VI NF (t=1.000): it:   36100 | loss: 1.420e+01

--- Simulation completed!

We generate plots with the same command we used for full model inference, but a different folder name.

[77]:
import linfa
! python3 -m linfa.plot_res -n phys_surr_3d -i 30000 -f "./" -p 'png' -d
Plotting log...
Plotting posterior samples...
Plotting posterior predictive samples...

You can now visualize the results

[78]:
from IPython.display import Image, display
display(Image(filename='phys_surr_3d/log_plot.png',width=300))
../../_images/content_tutorial_tutorial_linfa_3d_51_0.png
[79]:
from IPython.display import Image, display
display(Image(filename='phys_surr_3d/data_plot_phys_surr_3d_30000_0_1.png',width=300))
display(Image(filename='phys_surr_3d/data_plot_phys_surr_3d_30000_0_2.png',width=300))
display(Image(filename='phys_surr_3d/data_plot_phys_surr_3d_30000_1_2.png',width=300))
../../_images/content_tutorial_tutorial_linfa_3d_52_0.png
../../_images/content_tutorial_tutorial_linfa_3d_52_1.png
../../_images/content_tutorial_tutorial_linfa_3d_52_2.png
[80]:
from IPython.display import Image, display
display(Image(filename='phys_surr_3d/params_plot_phys_surr_3d_30000_0_1.png',width=300))
display(Image(filename='phys_surr_3d/params_plot_phys_surr_3d_30000_0_2.png',width=300))
display(Image(filename='phys_surr_3d/params_plot_phys_surr_3d_30000_1_2.png',width=300))
../../_images/content_tutorial_tutorial_linfa_3d_53_0.png
../../_images/content_tutorial_tutorial_linfa_3d_53_1.png
../../_images/content_tutorial_tutorial_linfa_3d_53_2.png