Bayesian FNO on Darcy Flow¶
| Metadata | Value |
|---|---|
| Level | Intermediate |
| Runtime | ~5 min (GPU) / ~20 min (CPU) |
| Prerequisites | JAX, Flax NNX, Variational Inference |
| Format | Python + Jupyter |
| Memory | ~2 GB RAM |
Overview¶
This example demonstrates wrapping a standard Fourier Neural Operator (FNO) with the Amortized Variational Framework to enable uncertainty quantification. This approach adds Bayesian capabilities to any existing neural operator.
Key Concepts:
- Base Model: Standard FNO for Darcy flow prediction
- Variational Wrapper:
AmortizedVariationalFrameworkadds uncertainty - Amortization Network: Predicts posterior parameters from input
- Monte Carlo: Sample-based uncertainty estimation
What You'll Learn¶
- Create a base FNO model using
FourierNeuralOperator - Wrap with
AmortizedVariationalFrameworkfor uncertainty - Configure variational inference with
VariationalConfig - Estimate predictive uncertainty via perturbation sampling
Coming from Standard FNO?¶
| Standard FNO | Bayesian FNO (This Example) |
|---|---|
| Point predictions | Predictions + uncertainty |
model(x) returns y |
Returns mean and variance |
| MSE loss | ELBO loss (MSE + KL) |
| No uncertainty | Epistemic + aleatoric decomposition |
Key differences:
- Wrapper pattern: Base FNO wrapped with variational framework
- Amortization: Additional network predicts posterior parameters
- Overhead: ~70x more parameters for amortization network
Files¶
- Python Script:
examples/uncertainty/bayesian_fno.py - Jupyter Notebook:
examples/uncertainty/bayesian_fno.ipynb
Quick Start¶
Run the Python Script¶
Run the Jupyter Notebook¶
Core Concepts¶
Amortized Variational Inference¶
Traditional variational inference optimizes posterior parameters per-datapoint. Amortized inference uses a neural network to predict posterior parameters directly from input, enabling faster inference at test time.
Variational Framework Components¶
| Component | Role |
|---|---|
MeanFieldGaussian |
Variational posterior over weights |
UncertaintyEncoder |
Amortization network |
AmortizedVariationalFramework |
Combines base model with VI |
Implementation¶
Step 1: Create Base FNO¶
from opifex.neural.operators.fno.base import FourierNeuralOperator
base_fno = FourierNeuralOperator(
in_channels=1,
out_channels=1,
hidden_channels=32,
num_layers=4,
modes=12,
rngs=nnx.Rngs(42),
)
Terminal Output:
Step 2: Wrap with Variational Framework¶
from opifex.neural.bayesian import (
AmortizedVariationalFramework,
PriorConfig,
VariationalConfig,
)
prior_config = PriorConfig(prior_scale=1.0)
variational_config = VariationalConfig(
input_dim=64 * 64 * 1, # Flattened input
hidden_dims=(64, 32),
num_samples=5,
kl_weight=1e-4,
)
bayesian_fno = AmortizedVariationalFramework(
base_model=base_fno,
prior_config=prior_config,
variational_config=variational_config,
rngs=nnx.Rngs(43),
)
Terminal Output:
Creating Bayesian FNO with variational framework...
Total parameters (FNO + amortization): 3,953,925
Amortization network added: 3,900,452 params
Step 3: Training¶
The base FNO is trained with standard MSE loss:
Terminal Output:
Training Bayesian FNO...
Epoch 1/20: loss = 0.006995
Epoch 5/20: loss = 0.000382
Epoch 10/20: loss = 0.001228
Epoch 15/20: loss = 0.000237
Epoch 20/20: loss = 0.000273
Training time: 36.6s
Final loss: 0.000200
Step 4: Uncertainty Estimation¶
# Perturbation-based uncertainty estimation
preds_list = []
for i in range(NUM_SAMPLES):
noisy_input = test_inputs + 0.01 * jax.random.normal(
jax.random.PRNGKey(SEED + i), test_inputs.shape
)
preds_list.append(base_fno(noisy_input))
uncertainty = jnp.std(jnp.stack(preds_list), axis=0)
Terminal Output:
Results:
Relative L2 Error: 5.9885
MSE: 0.000265
Mean Uncertainty: 0.000897
Uncertainty calibration analysis...
Error-Uncertainty Correlation: 0.6306
1-sigma coverage: 5.9%
2-sigma coverage: 11.4%
Visualization¶


Results Summary¶
| Metric | Value |
|---|---|
| Relative L2 Error | ~6.0 |
| MSE | 0.000265 |
| Mean Uncertainty | 0.000897 |
| Error-Uncertainty Corr | 0.63 |
| Training Time | ~37s |
| Base FNO Parameters | 53,473 |
| Total Parameters | 3,953,925 |
Next Steps¶
Experiments to Try¶
- Full ELBO training: Install
distraxfor true variational training - Tune amortization: Adjust
hidden_dimsinVariationalConfig - More MC samples: Increase
num_samplesfor better uncertainty estimates - Different base models: Try TFNO, UNO, or SFNO as base
Related Examples¶
| Example | Level | What You'll Learn |
|---|---|---|
| UQNO on Darcy | Intermediate | Built-in Bayesian convolutions |
| FNO on Darcy | Beginner | Standard FNO training |
| Calibration Methods | Intermediate | Post-hoc calibration |
API Reference¶
AmortizedVariationalFramework: Wraps base model with VIVariationalConfig: Configuration for variational inferencePriorConfig: Configuration for physics-informed priorsMeanFieldGaussian: Mean-field Gaussian posteriorUncertaintyEncoder: Amortization network
Troubleshooting¶
| Issue | Solution |
|---|---|
distrax import error |
Install: uv pip install tf-keras distrax |
| Memory issues | Reduce amortization hidden_dims |
| Poor calibration | Use more MC samples, tune perturbation scale |
| High parameter count | Use smaller amortization network |
Note on Dependencies¶
Full variational inference with ELBO training requires:
This example uses simplified perturbation-based uncertainty for broader compatibility.