Skip to content

Physics-Informed Neural Networks (PINNs)

Physics-Informed Neural Networks (PINNs) incorporate physical laws, described by differential equations, directly into the neural network training process. This enables learning from both data and physics, reducing data requirements and ensuring physically consistent solutions.

Overview

PINNs leverage automatic differentiation to embed PDEs into the loss function:

  • Data-driven learning from observed measurements
  • Physics enforcement through PDE residual minimization
  • Boundary/initial conditions as soft or hard constraints
  • No mesh required - operates on collocation points

Survey Reference

This framework implements methodologies from the full PINN survey (arXiv:2601.10222v1).

Theoretical Foundation

Problem Formulation

Consider a PDE of the form:

\[\mathcal{L}[u](x) = f(x), \quad x \in \Omega\]

with boundary conditions:

\[\mathcal{B}[u](x) = g(x), \quad x \in \partial\Omega\]

A neural network \(u_\theta(x)\) approximates the solution by minimizing:

\[\mathcal{L}_{total} = \lambda_{pde} \mathcal{L}_{pde} + \lambda_{bc} \mathcal{L}_{bc} + \lambda_{data} \mathcal{L}_{data}\]

Loss Components

PDE Residual Loss: $\(\mathcal{L}_{pde} = \frac{1}{N_r} \sum_{i=1}^{N_r} \left| \mathcal{L}[u_\theta](x_i) - f(x_i) \right|^2\)$

Boundary Condition Loss: $\(\mathcal{L}_{bc} = \frac{1}{N_b} \sum_{i=1}^{N_b} \left| \mathcal{B}[u_\theta](x_i) - g(x_i) \right|^2\)$

Data Loss: $\(\mathcal{L}_{data} = \frac{1}{N_d} \sum_{i=1}^{N_d} \left| u_\theta(x_i) - u^{obs}_i \right|^2\)$

Multi-Scale PINNs

The opifex library provides a specialized MultiScalePINN architecture designed to capture physics phenomena across multiple scales.

opifex.neural.pinns.multi_scale.MultiScalePINN

MultiScalePINN(input_dim: int, output_dim: int, scales: list[int], hidden_dims: list[int], *, activation: Callable[[Array], Array] = gelu, rngs: Rngs)

Bases: Module

Multi-Scale Physics-Informed Neural Network.

This architecture processes input at multiple scale levels to capture multi-scale physics phenomena. Each scale network captures information at different resolutions, and the outputs are combined to form the final prediction.

The architecture is particularly effective for: - Fluid dynamics with multiple scales (boundary layers, turbulence) - Heat transfer with different thermal scales - Electromagnetic phenomena with multiple wavelengths - Quantum mechanics with multi-scale wave functions

Parameters:

Name Type Description Default
input_dim int

Input dimensionality (spatial coordinates)

required
output_dim int

Output dimensionality (solution fields)

required
scales list[int]

List of scale factors for multi-scale processing

required
hidden_dims list[int]

Hidden layer dimensions for each scale network

required
activation Callable[[Array], Array]

Activation function

gelu
rngs Rngs

Random number generators

required
Source code in opifex/neural/pinns/multi_scale.py
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    scales: list[int],
    hidden_dims: list[int],
    *,
    activation: Callable[[Array], Array] = nnx.gelu,
    rngs: nnx.Rngs,
):
    """Initialize Multi-Scale PINN.

    Args:
        input_dim: Input dimensionality (spatial coordinates)
        output_dim: Output dimensionality (solution fields)
        scales: List of scale factors for multi-scale processing
        hidden_dims: Hidden layer dimensions for each scale network
        activation: Activation function
        rngs: Random number generators
    """
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.scales = scales
    self.num_scales = len(scales)
    self.activation = activation

    # Create scale-specific networks
    scale_networks_temp = []
    for _i, _scale in enumerate(scales):
        # Each scale network processes scaled input coordinates
        layers = []

        # Input layer
        layers.append(
            nnx.Linear(
                in_features=input_dim,
                out_features=hidden_dims[0],
                rngs=rngs,
            )
        )

        # Hidden layers
        for j in range(len(hidden_dims) - 1):
            layers.append(
                nnx.Linear(
                    in_features=hidden_dims[j],
                    out_features=hidden_dims[j + 1],
                    rngs=rngs,
                )
            )

        # Output layer for this scale
        layers.append(
            nnx.Linear(
                in_features=hidden_dims[-1],
                out_features=output_dim,
                rngs=rngs,
            )
        )

        scale_network = nnx.Sequential(*layers)
        scale_networks_temp.append(scale_network)
        self.scale_networks = nnx.List(scale_networks_temp)

    # Combination weights for multi-scale fusion
    self.scale_weights = nnx.Linear(
        in_features=self.num_scales * output_dim,
        out_features=output_dim,
        rngs=rngs,
    )

get_scale_outputs

get_scale_outputs(x: Array) -> list[Array]

Get outputs from individual scale networks.

This is useful for analysis and debugging multi-scale behavior.

Parameters:

Name Type Description Default
x Array

Input coordinates (batch_size, input_dim)

required

Returns:

Type Description
list[Array]

List of outputs from each scale network

Source code in opifex/neural/pinns/multi_scale.py
def get_scale_outputs(self, x: Array) -> list[Array]:
    """Get outputs from individual scale networks.

    This is useful for analysis and debugging multi-scale behavior.

    Args:
        x: Input coordinates (batch_size, input_dim)

    Returns:
        List of outputs from each scale network
    """
    scale_outputs = []

    for scale, network in zip(self.scales, self.scale_networks, strict=False):
        scaled_x = x * scale

        h = scaled_x
        for layer in network.layers[:-1]:
            h = layer(h)
            h = self.activation(h)

        scale_output = network.layers[-1](h)
        scale_outputs.append(scale_output)

    return scale_outputs

compute_derivatives

compute_derivatives(x: Array, order: int = 1) -> dict[str, Array]

Compute derivatives of the multi-scale solution.

This is essential for physics-informed training where PDE residuals require derivative computations.

Parameters:

Name Type Description Default
x Array

Input coordinates (batch_size, input_dim)

required
order int

Derivative order (1 for first derivatives, 2 for second)

1

Returns:

Type Description
dict[str, Array]

Dictionary containing derivative tensors

Source code in opifex/neural/pinns/multi_scale.py
def compute_derivatives(self, x: Array, order: int = 1) -> dict[str, Array]:
    """Compute derivatives of the multi-scale solution.

    This is essential for physics-informed training where PDE residuals
    require derivative computations.

    Args:
        x: Input coordinates (batch_size, input_dim)
        order: Derivative order (1 for first derivatives, 2 for second)

    Returns:
        Dictionary containing derivative tensors
    """

    def solution_fn(coords):
        return self(coords.reshape(1, -1)).squeeze()

    derivatives = {}

    if order >= 1:
        # First derivatives
        grad_fn = jax.grad(solution_fn)
        if x.ndim == 2:
            # Batch processing
            derivatives["grad"] = jax.vmap(grad_fn)(x)
        else:
            derivatives["grad"] = grad_fn(x)

    if order >= 2:
        # Second derivatives (Laplacian)
        def laplacian_fn(coords):
            grad_fn = jax.grad(solution_fn)
            hessian_fn = jax.jacfwd(grad_fn)
            hessian = hessian_fn(coords)
            return jnp.trace(hessian)  # Laplacian = trace of Hessian

        if x.ndim == 2:
            derivatives["laplacian"] = jax.vmap(laplacian_fn)(x)
        else:
            derivatives["laplacian"] = laplacian_fn(x)

    return derivatives

Factory Functions

Create a Multi-Scale PINN for heat equation problems.

Parameters:

Name Type Description Default
spatial_dim int

Spatial dimensionality (1D, 2D, or 3D)

required
scales list[int] | None

Scale factors (default: [1, 2, 4] for multi-scale)

None
hidden_dims list[int] | None

Hidden layer dimensions (default: [64, 32])

None
rngs Rngs

Random number generators

required

Returns:

Type Description
MultiScalePINN

Configured Multi-Scale PINN for heat equation

Create a Multi-Scale PINN for Navier-Stokes equations.

Parameters:

Name Type Description Default
spatial_dim int

Spatial dimensionality (2D or 3D)

required
scales list[int] | None

Scale factors (default: [1, 2, 4, 8] for turbulence)

None
hidden_dims list[int] | None

Hidden layer dimensions (default: [128, 64, 32])

None
rngs Rngs

Random number generators

required

Returns:

Type Description
MultiScalePINN

Configured Multi-Scale PINN for Navier-Stokes

Building Custom PINNs

You can build custom PINNs by combining opifex.neural.base.StandardMLP with opifex.core.problems.PDEProblem.

Basic Example

import jax
import jax.numpy as jnp
from flax import nnx
from opifex.neural.base import StandardMLP
from opifex.core.problems import create_pde_problem

# 1. Define the PDE (e.g., 1D Poisson equation: u_xx = -f)
def poisson_residual(model, x):
    """Compute PDE residual for Poisson equation."""
    def u_scalar(xi):
        return model(xi.reshape(1, -1)).squeeze()

    # Compute second derivative
    u_xx = jax.vmap(lambda xi: jax.hessian(u_scalar)(xi).squeeze())(x)
    f = jnp.sin(jnp.pi * x[:, 0])
    return u_xx + f

# 2. Create the Neural Network
rngs = nnx.Rngs(0)
model = StandardMLP(
    layer_sizes=[1, 64, 64, 1],
    activation='tanh',
    rngs=rngs
)

# 3. Define loss function
def loss_fn(model, x_interior, x_boundary):
    # PDE residual
    residual = poisson_residual(model, x_interior)
    pde_loss = jnp.mean(residual ** 2)

    # Boundary conditions (u(0) = u(1) = 0)
    bc_pred = model(x_boundary)
    bc_loss = jnp.mean(bc_pred ** 2)

    return pde_loss + 10.0 * bc_loss

# 4. Training
import optax

opt = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

# Generate training points
x_interior = jax.random.uniform(jax.random.key(0), (1000, 1))
x_boundary = jnp.array([[0.0], [1.0]])

for step in range(5000):
    loss, grads = nnx.value_and_grad(
        lambda m: loss_fn(m, x_interior, x_boundary)
    )(model)
    opt.update(model, grads)

    if step % 500 == 0:
        print(f"Step {step}: loss = {loss:.4e}")

2D Laplace Equation Example

import jax
import jax.numpy as jnp
from flax import nnx

class LaplacePINN(nnx.Module):
    """PINN for solving the Laplace equation."""

    def __init__(self, hidden_dims: list[int], rngs: nnx.Rngs):
        layers = []
        dims = [2, *hidden_dims, 1]  # 2D input, scalar output
        for i in range(len(dims) - 1):
            layers.append(nnx.Linear(dims[i], dims[i+1], rngs=rngs))
        self.layers = nnx.List(layers)

    def __call__(self, x):
        """Forward pass."""
        h = x
        for layer in list(self.layers)[:-1]:
            h = nnx.tanh(layer(h))
        return list(self.layers)[-1](h)

    def compute_residual(self, x):
        """Compute Laplace equation residual: u_xx + u_yy = 0."""
        def u_scalar(xi):
            return self(xi.reshape(1, -1)).squeeze()

        def laplacian(xi):
            hess = jax.hessian(u_scalar)(xi)
            return hess[0, 0] + hess[1, 1]  # u_xx + u_yy

        return jax.vmap(laplacian)(x)

# Create and train
model = LaplacePINN(hidden_dims=[64, 64, 64], rngs=nnx.Rngs(0))

# Domain: unit square [0, 1]^2
x_interior = jax.random.uniform(jax.random.key(0), (1000, 2))

# Boundary: known Dirichlet conditions
# (simplified - in practice, sample all four boundaries)
x_boundary = jnp.vstack([
    jnp.column_stack([jnp.zeros(25), jnp.linspace(0, 1, 25)]),
    jnp.column_stack([jnp.ones(25), jnp.linspace(0, 1, 25)]),
])
u_boundary = jnp.sin(jnp.pi * x_boundary[:, 1])  # Example BC

Training Enhancements

Opifex provides several techniques to improve PINN training:

Loss Balancing

Use GradNorm for automatic multi-task loss balancing:

from opifex.core.physics.gradnorm import GradNormBalancer

balancer = GradNormBalancer(num_losses=3, rngs=nnx.Rngs(0))
losses = jnp.array([pde_loss, bc_loss, data_loss])
weighted_loss = balancer.compute_weighted_loss(losses)

Adaptive Sampling

Use RAD sampling to focus on high-residual regions:

from opifex.training.adaptive_sampling import RADSampler

sampler = RADSampler()
residuals = model.compute_residual(all_points)
batch = sampler.sample(all_points, residuals, batch_size=256, key=key)

Second-Order Optimization

Use hybrid optimizers for faster convergence:

from opifex.optimization.second_order import HybridOptimizer

optimizer = HybridOptimizer(HybridOptimizerConfig(
    first_order_steps=1000,
    switch_criterion=SwitchCriterion.LOSS_VARIANCE,
))

Multilevel Training

Use multilevel training for hierarchical convergence:

from opifex.training.multilevel import (
    CascadeTrainer, MultilevelAdam, create_network_hierarchy, prolongate,
)

hierarchy = create_network_hierarchy(
    input_dim=2, output_dim=1,
    base_hidden_dims=[64, 64], num_levels=3,
    coarsening_factor=0.5, rngs=nnx.Rngs(0),
)
trainer = CascadeTrainer(
    hierarchy=hierarchy,
    optimizer=MultilevelAdam(learning_rate=1e-3),
    prolongate_fn=prolongate,
)

NTK Diagnostics

Use NTK analysis to diagnose training issues:

from opifex.core.physics.ntk import NTKSpectralAnalyzer

analyzer = NTKSpectralAnalyzer(model)
diagnostics = analyzer.analyze(x_train, learning_rate=1e-3)
print(f"Condition number: {diagnostics.condition_number}")

Advanced Methods

Domain Decomposition

For large or complex domains, use domain decomposition methods:

Method Description Best For
XPINN Explicit interface conditions Non-overlapping domains
FBPINN Window function blending Smooth solutions
CPINN Conservation enforcement Conservation laws
APINN Learned gating Unknown optimal decomposition
from opifex.neural.pinns.domain_decomposition import XPINN, Subdomain, Interface

model = XPINN(
    input_dim=2, output_dim=1,
    subdomains=subdomains,
    interfaces=interfaces,
    hidden_dims=[32, 32],
    rngs=nnx.Rngs(0),
)

Common PDEs

Heat Equation

\[\frac{\partial u}{\partial t} = \alpha \nabla^2 u\]
def heat_residual(model, x, t, alpha=0.01):
    """x: spatial coords, t: time."""
    xt = jnp.column_stack([x, t])

    def u_scalar(xi):
        return model(xi.reshape(1, -1)).squeeze()

    # Compute derivatives
    def compute_derivs(xi):
        grad_u = jax.grad(u_scalar)(xi)
        hess_u = jax.hessian(u_scalar)(xi)
        u_t = grad_u[-1]  # Time derivative
        laplacian = jnp.sum(jnp.diag(hess_u)[:-1])  # Spatial Laplacian
        return u_t - alpha * laplacian

    return jax.vmap(compute_derivs)(xt)

Burgers' Equation

\[\frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} = \nu \frac{\partial^2 u}{\partial x^2}\]
def burgers_residual(model, x, t, nu=0.01):
    xt = jnp.column_stack([x, t])

    def u_scalar(xi):
        return model(xi.reshape(1, -1)).squeeze()

    def compute_derivs(xi):
        u = u_scalar(xi)
        grad_u = jax.grad(u_scalar)(xi)
        hess_u = jax.hessian(u_scalar)(xi)
        u_x, u_t = grad_u[0], grad_u[1]
        u_xx = hess_u[0, 0]
        return u_t + u * u_x - nu * u_xx

    return jax.vmap(compute_derivs)(xt)
\[\frac{\partial \vec{u}}{\partial t} + (\vec{u} \cdot \nabla)\vec{u} = -\nabla p + \nu \nabla^2 \vec{u}$$ $$\nabla \cdot \vec{u} = 0\]
def navier_stokes_residual(model, xy, t, nu=0.01):
    """Model outputs [u, v, p]."""
    xyt = jnp.column_stack([xy, t])

    def field(xi):
        return model(xi.reshape(1, -1)).squeeze()  # [u, v, p]

    def compute_residuals(xi):
        # Get field values and derivatives
        uvp = field(xi)
        u, v, p = uvp[0], uvp[1], uvp[2]

        jac = jax.jacfwd(field)(xi)  # Shape: (3, 3) for [u,v,p] x [x,y,t]
        u_x, u_y, u_t = jac[0, 0], jac[0, 1], jac[0, 2]
        v_x, v_y, v_t = jac[1, 0], jac[1, 1], jac[1, 2]
        p_x, p_y = jac[2, 0], jac[2, 1]

        hess = jax.hessian(lambda xi: field(xi)[0])(xi)
        u_xx, u_yy = hess[0, 0], hess[1, 1]

        hess_v = jax.hessian(lambda xi: field(xi)[1])(xi)
        v_xx, v_yy = hess_v[0, 0], hess_v[1, 1]

        # Momentum equations
        res_u = u_t + u*u_x + v*u_y + p_x - nu*(u_xx + u_yy)
        res_v = v_t + u*v_x + v*v_y + p_y - nu*(v_xx + v_yy)

        # Continuity
        res_cont = u_x + v_y

        return jnp.array([res_u, res_v, res_cont])

    return jax.vmap(compute_residuals)(xyt)

Best Practices

Network Architecture

  • Activation: tanh for smooth solutions, gelu for faster training
  • Depth: 3-5 layers for most problems
  • Width: 32-128 neurons per layer
  • Input normalization: Scale inputs to [-1, 1] or [0, 1]

Collocation Point Selection

  • Interior: 1000-10000 points (problem-dependent)
  • Boundary: 100-1000 points per boundary segment
  • Distribution: Use adaptive sampling for efficiency

Loss Weighting

  • Start with equal weights
  • Use GradNorm for automatic balancing
  • Increase BC weights if constraints are violated
  • Monitor individual loss components

Training Strategy

  1. Warmup: Use Adam with learning rate warmup
  2. Main training: Continue with Adam or switch to hybrid
  3. Fine-tuning: Use L-BFGS for final convergence

See Also