Experiment Class

class run_experiment.experiment[source]

Bases: object

Defines an instance of variational inference

This class is the core class of the LINFA library and defines all the default hyperparameter values and and functions used for inference.

M

Number of Monte Carlo samples use to compute the denominator of the AdaAnn formula

Type

int

N

Number of batch samples generated for $t<1$ at each iteration

Type

int

N_1

number of batch samples generated for $t=1$ at each iteration

Type

int

T

Number of parameter updates for each temperature for $t<1$

Type

int

T_0

Number of parameter updates at the initial inverse temperature $t_0$

Type

int

T_1

Number of parameter updates at $t=1$

Type

int

activation_fn

Actication function used (either ‘relu’,’tanh’ or ‘sigmoid’)

Type

str

annealing

Flag to activate an annealing scheduler

Type

bool

batch_norm_order

Uses decide if batch_norm is used

Type

bool

batch_size

Number of batch samples generated at every iteration from the base distribution

Type

int

budget

Maximum number of allowed evaluations of the true model

Type

int

calibrate_interval

How often the surrogate model is updated

Type

int

flow_type

Type of flow (‘maf’ or ‘realnvp’)

Type

str

hidden_size

Hidden layer size for MADE in each layer

Type

int

input_order

Input order for create_mask (either ‘sequential’ or ‘random’)

Type

str

input_size

Number of input parameters

Type

int

linear_step

Fixed step size for the Linear annealing scheduler

Type

double

log_file

File name where the log profile stats are written

Type

str

log_interval

How often the loss statistics are printed

Type

int

lr

Learning rate

Type

double

lr_decay

Learning rate decay

Type

double

lr_scheduler

type of lr scheduler used (either ‘StepLR’ or ‘ExponentialLR’)

Type

str

lr_step

Number of steps for StepLR learning rate scheduler

Type

int

n_blocks

Number of layers

Type

int

n_hidden

Number of hidden layers in each MADE

Type

int

n_iter

Total number of iterations

Type

int

n_sample

Number of batch samples used to print results at save_interval

Type

int

no_cuda

Flag to use CPU

Type

bool

optimizer

Type of optimizer used (either ‘Adam’ or ‘RMSprop’)

Type

str

output_dir

Name of the output folder

Type

str

run()[source]

Runs instance of inference inference problem

run_nofas

Activate NoFAS and the use of a surrogate model

Type

bool

save_interval

Save interval for all results

Type

int

scheduler

Type of annealing scheduler (either ‘AdaAnn’ or ‘Linear’)

Type

str

seed

Random seed

Type

int

store_nf_interval

Save interval for normalizing flow parameters

Type

int

store_surr_interval

Save interval for surrogate model (None for no save)

Type

int

surr_folder

Folder where the surrogate model is stored

Type

str

surr_pre_it

Number of pre-training iterations for surrogate model

Type

int

surr_upd_it

Number of iterations for the surrogate model update

Type

int

surrogate_type

Type of surrogate model (‘surrogate’ or ‘discrepancy’)

Type

str

t0

Initial value for the inverse temperature

Type

double

tol

KL tolerance for AdaAnn scheduler

Type

double

train(nf, optimizer, iteration, log, sampling=True, t=1)[source]

Parameter update for normalizing flow and surrogate model

This is the function where the ELBO loss function is evaluated, the results are saved and the surrogate model is updated.

Parameters
  • nf (instance of normalizing flow) – the normalizing flow architecture used for variational inference

  • optimizer (instance of PyTorch optimizer) – the selected PyTorch optimizer

  • iteration (int) – current iteration number

  • log (list of lists) – stores a log of [iteration, annealing temperature, loss value]

  • sampling (bool) – Flag indicating the sampling stage

  • t (double) – current inverse temperature for annealing scheduler

Returns

None

true_data_num

Number of true model evaluated at each surrogate update

Type

double

use_new_surr

Start by pre-training a new surrogate and ignore existing surrogates

Type

bool