Physics-Informed Neural Networks (PINNs)¶
Physics-Informed Neural Networks (PINNs) incorporate physical laws, described by differential equations, directly into the neural network training process. This enables learning from both data and physics, reducing data requirements and ensuring physically consistent solutions.
Overview¶
PINNs leverage automatic differentiation to embed PDEs into the loss function:
- Data-driven learning from observed measurements
- Physics enforcement through PDE residual minimization
- Boundary/initial conditions as soft or hard constraints
- No mesh required - operates on collocation points
Survey Reference
This framework implements methodologies from the full PINN survey (arXiv:2601.10222v1).
Theoretical Foundation¶
Problem Formulation¶
Consider a PDE of the form:
with boundary conditions:
A neural network \(u_\theta(x)\) approximates the solution by minimizing:
Loss Components¶
PDE Residual Loss: $\(\mathcal{L}_{pde} = \frac{1}{N_r} \sum_{i=1}^{N_r} \left| \mathcal{L}[u_\theta](x_i) - f(x_i) \right|^2\)$
Boundary Condition Loss: $\(\mathcal{L}_{bc} = \frac{1}{N_b} \sum_{i=1}^{N_b} \left| \mathcal{B}[u_\theta](x_i) - g(x_i) \right|^2\)$
Data Loss: $\(\mathcal{L}_{data} = \frac{1}{N_d} \sum_{i=1}^{N_d} \left| u_\theta(x_i) - u^{obs}_i \right|^2\)$
Multi-Scale PINNs¶
The opifex library provides a specialized MultiScalePINN architecture designed to capture physics phenomena across multiple scales.
opifex.neural.pinns.multi_scale.MultiScalePINN
¶
MultiScalePINN(input_dim: int, output_dim: int, scales: list[int], hidden_dims: list[int], *, activation: Callable[[Array], Array] = gelu, rngs: Rngs)
Bases: Module
Multi-Scale Physics-Informed Neural Network.
This architecture processes input at multiple scale levels to capture multi-scale physics phenomena. Each scale network captures information at different resolutions, and the outputs are combined to form the final prediction.
The architecture is particularly effective for: - Fluid dynamics with multiple scales (boundary layers, turbulence) - Heat transfer with different thermal scales - Electromagnetic phenomena with multiple wavelengths - Quantum mechanics with multi-scale wave functions
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimensionality (spatial coordinates) |
required |
output_dim
|
int
|
Output dimensionality (solution fields) |
required |
scales
|
list[int]
|
List of scale factors for multi-scale processing |
required |
hidden_dims
|
list[int]
|
Hidden layer dimensions for each scale network |
required |
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generators |
required |
Source code in opifex/neural/pinns/multi_scale.py
get_scale_outputs
¶
get_scale_outputs(x: Array) -> list[Array]
Get outputs from individual scale networks.
This is useful for analysis and debugging multi-scale behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input coordinates (batch_size, input_dim) |
required |
Returns:
| Type | Description |
|---|---|
list[Array]
|
List of outputs from each scale network |
Source code in opifex/neural/pinns/multi_scale.py
compute_derivatives
¶
Compute derivatives of the multi-scale solution.
This is essential for physics-informed training where PDE residuals require derivative computations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input coordinates (batch_size, input_dim) |
required |
order
|
int
|
Derivative order (1 for first derivatives, 2 for second) |
1
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary containing derivative tensors |
Source code in opifex/neural/pinns/multi_scale.py
Factory Functions¶
Create a Multi-Scale PINN for heat equation problems.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spatial_dim
|
int
|
Spatial dimensionality (1D, 2D, or 3D) |
required |
scales
|
list[int] | None
|
Scale factors (default: [1, 2, 4] for multi-scale) |
None
|
hidden_dims
|
list[int] | None
|
Hidden layer dimensions (default: [64, 32]) |
None
|
rngs
|
Rngs
|
Random number generators |
required |
Returns:
| Type | Description |
|---|---|
MultiScalePINN
|
Configured Multi-Scale PINN for heat equation |
Create a Multi-Scale PINN for Navier-Stokes equations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spatial_dim
|
int
|
Spatial dimensionality (2D or 3D) |
required |
scales
|
list[int] | None
|
Scale factors (default: [1, 2, 4, 8] for turbulence) |
None
|
hidden_dims
|
list[int] | None
|
Hidden layer dimensions (default: [128, 64, 32]) |
None
|
rngs
|
Rngs
|
Random number generators |
required |
Returns:
| Type | Description |
|---|---|
MultiScalePINN
|
Configured Multi-Scale PINN for Navier-Stokes |
Building Custom PINNs¶
You can build custom PINNs by combining opifex.neural.base.StandardMLP with opifex.core.problems.PDEProblem.
Basic Example¶
import jax
import jax.numpy as jnp
from flax import nnx
from opifex.neural.base import StandardMLP
from opifex.core.problems import create_pde_problem
# 1. Define the PDE (e.g., 1D Poisson equation: u_xx = -f)
def poisson_residual(model, x):
"""Compute PDE residual for Poisson equation."""
def u_scalar(xi):
return model(xi.reshape(1, -1)).squeeze()
# Compute second derivative
u_xx = jax.vmap(lambda xi: jax.hessian(u_scalar)(xi).squeeze())(x)
f = jnp.sin(jnp.pi * x[:, 0])
return u_xx + f
# 2. Create the Neural Network
rngs = nnx.Rngs(0)
model = StandardMLP(
layer_sizes=[1, 64, 64, 1],
activation='tanh',
rngs=rngs
)
# 3. Define loss function
def loss_fn(model, x_interior, x_boundary):
# PDE residual
residual = poisson_residual(model, x_interior)
pde_loss = jnp.mean(residual ** 2)
# Boundary conditions (u(0) = u(1) = 0)
bc_pred = model(x_boundary)
bc_loss = jnp.mean(bc_pred ** 2)
return pde_loss + 10.0 * bc_loss
# 4. Training
import optax
opt = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
# Generate training points
x_interior = jax.random.uniform(jax.random.key(0), (1000, 1))
x_boundary = jnp.array([[0.0], [1.0]])
for step in range(5000):
loss, grads = nnx.value_and_grad(
lambda m: loss_fn(m, x_interior, x_boundary)
)(model)
opt.update(model, grads)
if step % 500 == 0:
print(f"Step {step}: loss = {loss:.4e}")
2D Laplace Equation Example¶
import jax
import jax.numpy as jnp
from flax import nnx
class LaplacePINN(nnx.Module):
"""PINN for solving the Laplace equation."""
def __init__(self, hidden_dims: list[int], rngs: nnx.Rngs):
layers = []
dims = [2, *hidden_dims, 1] # 2D input, scalar output
for i in range(len(dims) - 1):
layers.append(nnx.Linear(dims[i], dims[i+1], rngs=rngs))
self.layers = nnx.List(layers)
def __call__(self, x):
"""Forward pass."""
h = x
for layer in list(self.layers)[:-1]:
h = nnx.tanh(layer(h))
return list(self.layers)[-1](h)
def compute_residual(self, x):
"""Compute Laplace equation residual: u_xx + u_yy = 0."""
def u_scalar(xi):
return self(xi.reshape(1, -1)).squeeze()
def laplacian(xi):
hess = jax.hessian(u_scalar)(xi)
return hess[0, 0] + hess[1, 1] # u_xx + u_yy
return jax.vmap(laplacian)(x)
# Create and train
model = LaplacePINN(hidden_dims=[64, 64, 64], rngs=nnx.Rngs(0))
# Domain: unit square [0, 1]^2
x_interior = jax.random.uniform(jax.random.key(0), (1000, 2))
# Boundary: known Dirichlet conditions
# (simplified - in practice, sample all four boundaries)
x_boundary = jnp.vstack([
jnp.column_stack([jnp.zeros(25), jnp.linspace(0, 1, 25)]),
jnp.column_stack([jnp.ones(25), jnp.linspace(0, 1, 25)]),
])
u_boundary = jnp.sin(jnp.pi * x_boundary[:, 1]) # Example BC
Training Enhancements¶
Opifex provides several techniques to improve PINN training:
Loss Balancing¶
Use GradNorm for automatic multi-task loss balancing:
from opifex.core.physics.gradnorm import GradNormBalancer
balancer = GradNormBalancer(num_losses=3, rngs=nnx.Rngs(0))
losses = jnp.array([pde_loss, bc_loss, data_loss])
weighted_loss = balancer.compute_weighted_loss(losses)
Adaptive Sampling¶
Use RAD sampling to focus on high-residual regions:
from opifex.training.adaptive_sampling import RADSampler
sampler = RADSampler()
residuals = model.compute_residual(all_points)
batch = sampler.sample(all_points, residuals, batch_size=256, key=key)
Second-Order Optimization¶
Use hybrid optimizers for faster convergence:
from opifex.optimization.second_order import HybridOptimizer
optimizer = HybridOptimizer(HybridOptimizerConfig(
first_order_steps=1000,
switch_criterion=SwitchCriterion.LOSS_VARIANCE,
))
Multilevel Training¶
Use multilevel training for hierarchical convergence:
from opifex.training.multilevel import (
CascadeTrainer, MultilevelAdam, create_network_hierarchy, prolongate,
)
hierarchy = create_network_hierarchy(
input_dim=2, output_dim=1,
base_hidden_dims=[64, 64], num_levels=3,
coarsening_factor=0.5, rngs=nnx.Rngs(0),
)
trainer = CascadeTrainer(
hierarchy=hierarchy,
optimizer=MultilevelAdam(learning_rate=1e-3),
prolongate_fn=prolongate,
)
NTK Diagnostics¶
Use NTK analysis to diagnose training issues:
from opifex.core.physics.ntk import NTKSpectralAnalyzer
analyzer = NTKSpectralAnalyzer(model)
diagnostics = analyzer.analyze(x_train, learning_rate=1e-3)
print(f"Condition number: {diagnostics.condition_number}")
Advanced Methods¶
Domain Decomposition¶
For large or complex domains, use domain decomposition methods:
| Method | Description | Best For |
|---|---|---|
| XPINN | Explicit interface conditions | Non-overlapping domains |
| FBPINN | Window function blending | Smooth solutions |
| CPINN | Conservation enforcement | Conservation laws |
| APINN | Learned gating | Unknown optimal decomposition |
from opifex.neural.pinns.domain_decomposition import XPINN, Subdomain, Interface
model = XPINN(
input_dim=2, output_dim=1,
subdomains=subdomains,
interfaces=interfaces,
hidden_dims=[32, 32],
rngs=nnx.Rngs(0),
)
Common PDEs¶
Heat Equation¶
def heat_residual(model, x, t, alpha=0.01):
"""x: spatial coords, t: time."""
xt = jnp.column_stack([x, t])
def u_scalar(xi):
return model(xi.reshape(1, -1)).squeeze()
# Compute derivatives
def compute_derivs(xi):
grad_u = jax.grad(u_scalar)(xi)
hess_u = jax.hessian(u_scalar)(xi)
u_t = grad_u[-1] # Time derivative
laplacian = jnp.sum(jnp.diag(hess_u)[:-1]) # Spatial Laplacian
return u_t - alpha * laplacian
return jax.vmap(compute_derivs)(xt)
Burgers' Equation¶
def burgers_residual(model, x, t, nu=0.01):
xt = jnp.column_stack([x, t])
def u_scalar(xi):
return model(xi.reshape(1, -1)).squeeze()
def compute_derivs(xi):
u = u_scalar(xi)
grad_u = jax.grad(u_scalar)(xi)
hess_u = jax.hessian(u_scalar)(xi)
u_x, u_t = grad_u[0], grad_u[1]
u_xx = hess_u[0, 0]
return u_t + u * u_x - nu * u_xx
return jax.vmap(compute_derivs)(xt)
Navier-Stokes (2D Incompressible)¶
def navier_stokes_residual(model, xy, t, nu=0.01):
"""Model outputs [u, v, p]."""
xyt = jnp.column_stack([xy, t])
def field(xi):
return model(xi.reshape(1, -1)).squeeze() # [u, v, p]
def compute_residuals(xi):
# Get field values and derivatives
uvp = field(xi)
u, v, p = uvp[0], uvp[1], uvp[2]
jac = jax.jacfwd(field)(xi) # Shape: (3, 3) for [u,v,p] x [x,y,t]
u_x, u_y, u_t = jac[0, 0], jac[0, 1], jac[0, 2]
v_x, v_y, v_t = jac[1, 0], jac[1, 1], jac[1, 2]
p_x, p_y = jac[2, 0], jac[2, 1]
hess = jax.hessian(lambda xi: field(xi)[0])(xi)
u_xx, u_yy = hess[0, 0], hess[1, 1]
hess_v = jax.hessian(lambda xi: field(xi)[1])(xi)
v_xx, v_yy = hess_v[0, 0], hess_v[1, 1]
# Momentum equations
res_u = u_t + u*u_x + v*u_y + p_x - nu*(u_xx + u_yy)
res_v = v_t + u*v_x + v*v_y + p_y - nu*(v_xx + v_yy)
# Continuity
res_cont = u_x + v_y
return jnp.array([res_u, res_v, res_cont])
return jax.vmap(compute_residuals)(xyt)
Best Practices¶
Network Architecture¶
- Activation:
tanhfor smooth solutions,gelufor faster training - Depth: 3-5 layers for most problems
- Width: 32-128 neurons per layer
- Input normalization: Scale inputs to [-1, 1] or [0, 1]
Collocation Point Selection¶
- Interior: 1000-10000 points (problem-dependent)
- Boundary: 100-1000 points per boundary segment
- Distribution: Use adaptive sampling for efficiency
Loss Weighting¶
- Start with equal weights
- Use GradNorm for automatic balancing
- Increase BC weights if constraints are violated
- Monitor individual loss components
Training Strategy¶
- Warmup: Use Adam with learning rate warmup
- Main training: Continue with Adam or switch to hybrid
- Fine-tuning: Use L-BFGS for final convergence
See Also¶
- Domain Decomposition PINNs - Large-scale problems
- NTK Analysis - Training diagnostics
- Adaptive Sampling - Efficient collocation
- GradNorm - Loss balancing
- Second-Order Optimization - Fast convergence
- Multilevel Training - Hierarchical training
- Training Guide - General training procedures
- API Reference - Complete API documentation