Variational Inference with Normalizing Flow
Consider the problem of estimating (in a Bayesian sense) the parameters \(\boldsymbol{z}\in\boldsymbol{\mathcal{Z}}\) of a physics-based or statistical model
from the observations \(\boldsymbol{x}\in\boldsymbol{\mathcal{X}}\) and a known statistical characterization of the error \(\boldsymbol{\varepsilon}\).
We tackle this problem with variational inference and normalizing flow. A normalizing flow (NF) is a nonlinear transformation \(F:\mathbb{R}^{d}\times \boldsymbol{\Lambda} \to \mathbb{R}^{d}\) designed to map an easy-to-sample base distribution \(q_{0}(\boldsymbol{z}_{0})\) into a close approximation \(q_{K}(\boldsymbol{z}_{K})\) of a desired target posterior density \(p(\boldsymbol{z}|\boldsymbol{x})\). This transformation can be determined by composing \(K\) bijections
and evaluating the transformed density through the change of variable formula (see [V+09]).
In the context of variational inference, we seek to determine an optimal set of parameters \(\boldsymbol{\lambda}\in\boldsymbol{\Lambda}\) so that \(q_{K}(\boldsymbol{z}_{K})\approx p(\boldsymbol{z}|\boldsymbol{x})\). Given observations \(\boldsymbol{x}\in\mathcal{\boldsymbol{X}}\), a likelihood function \(l_{\boldsymbol{z}}(\boldsymbol{x})\) (informed by the distribution of the error \(\boldsymbol{\varepsilon}\)) and prior \(p(\boldsymbol{z})\), a NF-based approximation \(q_K(\boldsymbol{z})\) of the posterior distribution \(p(\boldsymbol{z}|\boldsymbol{x})\) can be computed by maximizing the lower bound to the log marginal likelihood \(\log p(\boldsymbol{x})\) (the so-called evidence lower bound or ELBO), or, equivalently, by minimizing a free energy bound (see, e.g., [RM15]).
For computational convenience, normalizing flows transformations are selected to be easily invertible and their Jacobian determinant can be computed with a cost that grows linearly with the problem dimensionality. Approaches in the literature include RealNVP [DSDB16], GLOW [KD18] and autoregressive transformations such as MAF [PPM18] and IAF [KSJ+16].