Skip to content

Neural Network API Reference

The opifex.neural package provides the building blocks for scientific machine learning models, built on top of Flax NNX.

Base Architectures

Standard MLP

opifex.neural.base.StandardMLP

StandardMLP(layer_sizes: list[int], activation: str = 'gelu', dropout_rate: float = 0.0, use_bias: bool = True, apply_final_dropout: bool = False, *, dtype: Any | None = None, param_dtype: Any = float32, rngs: Rngs, kernel_init: Callable = xavier_uniform(), bias_init: Callable = zeros)

Bases: Module

Modern Multi-Layer Perceptron implementation using FLAX NNX.

Fully compliant with Flax NNX best practices including: - Proper RNG handling with keyword-only rngs parameter - Modern activation functions (GELU default, configurable) - Efficient dropout strategies with deterministic control - Custom initialization strategies following NNX patterns - Automatic differentiation with JAX - Performance-optimized state management

Attributes:

Name Type Description
layer_sizes

List of layer sizes including input and output dimensions

activation

Name of the activation function to use

dropout_rate

Dropout probability (0.0 means no dropout)

use_bias

Whether to include bias terms in linear layers

apply_final_dropout

Whether to apply dropout after the final layer

layers

Sequence of linear transformation layers

activation_fn

The actual activation function

dropout Dropout | None

Dropout layer (None if dropout_rate is 0)

Parameters:

Name Type Description Default
layer_sizes list[int]

List of layer sizes, e.g., [input_dim, hidden1, hidden2, output_dim]

required
activation str

Activation function name ('gelu', 'tanh', 'relu', 'sigmoid', 'silu') Default is 'gelu' for modern neural networks

'gelu'
dropout_rate float

Dropout probability for regularization (0.0 = no dropout)

0.0
use_bias bool

Whether to use bias in linear projections

True
apply_final_dropout bool

Whether to apply dropout after final layer (useful for some transformer-style architectures)

False
dtype Any | None

Computation dtype for NNX linear layers. None preserves the Flax default promotion behavior.

None
param_dtype Any

Parameter storage dtype for NNX linear layers.

float32
rngs Rngs

FLAX NNX random number generator state (keyword-only)

required
kernel_init Callable

Kernel initialization function (callable)

xavier_uniform()
bias_init Callable

Bias initialization function (callable)

zeros
Source code in opifex/neural/base.py
def __init__(
    self,
    layer_sizes: list[int],
    activation: str = "gelu",
    dropout_rate: float = 0.0,
    use_bias: bool = True,
    apply_final_dropout: bool = False,
    *,
    dtype: Any | None = None,
    param_dtype: Any = jnp.float32,
    rngs: nnx.Rngs,
    kernel_init: Callable = nnx.initializers.xavier_uniform(),
    bias_init: Callable = nnx.initializers.zeros,
):
    """Initialize the StandardMLP following modern NNX patterns.

    Args:
        layer_sizes: List of layer sizes, e.g.,
            [input_dim, hidden1, hidden2, output_dim]
        activation: Activation function name
            ('gelu', 'tanh', 'relu', 'sigmoid', 'silu')
            Default is 'gelu' for modern neural networks
        dropout_rate: Dropout probability for regularization
            (0.0 = no dropout)
        use_bias: Whether to use bias in linear projections
        apply_final_dropout: Whether to apply dropout after final layer
            (useful for some transformer-style architectures)
        dtype: Computation dtype for NNX linear layers. ``None`` preserves
            the Flax default promotion behavior.
        param_dtype: Parameter storage dtype for NNX linear layers.
        rngs: FLAX NNX random number generator state (keyword-only)
        kernel_init: Kernel initialization function (callable)
        bias_init: Bias initialization function (callable)
    """
    super().__init__()

    # Store configuration
    self.layer_sizes = layer_sizes
    self.activation = activation
    self.dropout_rate = dropout_rate
    self.use_bias = use_bias
    self.apply_final_dropout = apply_final_dropout
    self.dtype = dtype
    self.param_dtype = param_dtype

    # Validate layer sizes
    if len(layer_sizes) < 2:
        raise ValueError("layer_sizes must have at least 2 elements (input and output)")

    # Create layers following NNX patterns (use nnx.List for Flax 0.12.0+)
    layers = []
    for i in range(len(layer_sizes) - 1):
        layer = nnx.Linear(
            in_features=layer_sizes[i],
            out_features=layer_sizes[i + 1],
            use_bias=use_bias,
            kernel_init=kernel_init,
            bias_init=bias_init,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        layers.append(layer)
    self.layers = nnx.List(layers)

    # Set activation function using the activation library
    self.activation_fn = get_activation(activation)

    # Initialize dropout if needed - pass rngs directly
    if dropout_rate > 0.0:
        self.dropout: nnx.Dropout | None = nnx.Dropout(rate=dropout_rate, rngs=rngs)
    else:
        self.dropout = None

Quantum MLP

opifex.neural.base.QuantumMLP

QuantumMLP(layer_sizes: list[int], activation: str = 'tanh', enforce_symmetry: bool = True, dropout_rate: float = 0.0, use_bias: bool = True, apply_final_dropout: bool = False, symmetry_type: str = 'permutation', *, rngs: Rngs, kernel_init: Callable = xavier_uniform(), bias_init: Callable = zeros)

Bases: Module

Modern Quantum-aware Multi-Layer Perceptron for molecular and quantum systems.

Fully compliant with Flax NNX best practices while providing quantum-specific features: - Proper RNG handling with keyword-only rngs parameter - Symmetry enforcement for molecular systems - Specialized initialization for quantum properties - Physics-informed constraints with numerical stability - Modern dropout strategies with deterministic control - Quantum-specific energy and force computation methods

Attributes:

Name Type Description
layer_sizes

List of layer sizes including input and output dimensions

activation

Activation function name

enforce_symmetry

Whether to enforce permutation symmetry

dropout_rate

Dropout probability for regularization

use_bias

Whether to use bias in linear layers

apply_final_dropout

Whether to apply dropout after the final layer

layers

Sequence of linear layers

activation_fn

Activation function

dropout Dropout | None

Dropout layer (if dropout_rate > 0)

Parameters:

Name Type Description Default
layer_sizes list[int]

List of layer sizes for the network architecture

required
activation str

Activation function name ('gelu', 'tanh', 'relu', 'sigmoid', 'silu') Default is 'tanh' for quantum neural networks

'tanh'
enforce_symmetry bool

Whether to enforce molecular symmetries

True
dropout_rate float

Dropout probability for regularization (0.0 = no dropout)

0.0
use_bias bool

Whether to use bias in linear projections

True
apply_final_dropout bool

Whether to apply dropout after final layer (useful for quantum transformer-style architectures)

False
symmetry_type str

Type of symmetry to enforce ('permutation', 'rotation', 'both')

'permutation'
rngs Rngs

FLAX NNX random number generator state (keyword-only)

required
kernel_init Callable

Kernel initialization function (callable, quantum-aware)

xavier_uniform()
bias_init Callable

Bias initialization function (callable, quantum-aware)

zeros
Source code in opifex/neural/base.py
def __init__(
    self,
    layer_sizes: list[int],
    activation: str = "tanh",
    enforce_symmetry: bool = True,
    dropout_rate: float = 0.0,
    use_bias: bool = True,
    apply_final_dropout: bool = False,
    symmetry_type: str = "permutation",
    *,
    rngs: nnx.Rngs,
    kernel_init: Callable = nnx.initializers.xavier_uniform(),
    bias_init: Callable = nnx.initializers.zeros,
):
    """Initialize Quantum MLP following modern NNX patterns.

    Args:
        layer_sizes: List of layer sizes for the network architecture
        activation: Activation function name
            ('gelu', 'tanh', 'relu', 'sigmoid', 'silu')
            Default is 'tanh' for quantum neural networks
        enforce_symmetry: Whether to enforce molecular symmetries
        dropout_rate: Dropout probability for regularization
            (0.0 = no dropout)
        use_bias: Whether to use bias in linear projections
        apply_final_dropout: Whether to apply dropout after final layer
            (useful for quantum transformer-style architectures)
        symmetry_type: Type of symmetry to enforce
            ('permutation', 'rotation', 'both')
        rngs: FLAX NNX random number generator state (keyword-only)
        kernel_init: Kernel initialization function
            (callable, quantum-aware)
        bias_init: Bias initialization function
            (callable, quantum-aware)
    """
    super().__init__()

    # Store configuration
    self.layer_sizes = layer_sizes
    self.activation = activation
    self.enforce_symmetry = enforce_symmetry
    self.symmetry_type = symmetry_type
    self.dropout_rate = dropout_rate
    self.use_bias = use_bias
    self.apply_final_dropout = apply_final_dropout

    # Validate layer sizes
    if len(layer_sizes) < 2:
        raise ValueError("layer_sizes must have at least 2 elements (input and output)")

    # Apply quantum-aware initialization scaling if needed
    quantum_kernel_init = self._apply_quantum_scaling(kernel_init)
    quantum_bias_init = self._apply_quantum_scaling(bias_init)

    # Create layers with quantum-aware initialization (use nnx.List)
    layers = []
    for i in range(len(layer_sizes) - 1):
        layer = nnx.Linear(
            in_features=layer_sizes[i],
            out_features=layer_sizes[i + 1],
            use_bias=use_bias,
            kernel_init=quantum_kernel_init,
            bias_init=quantum_bias_init,
            rngs=rngs,
        )
        layers.append(layer)
    self.layers = nnx.List(layers)

    # Set activation function optimized for quantum calculations
    self.activation_fn = get_activation(activation)

    # Initialize dropout if needed - pass rngs directly
    if dropout_rate > 0.0:
        self.dropout: nnx.Dropout | None = nnx.Dropout(rate=dropout_rate, rngs=rngs)
    else:
        self.dropout = None

    # Setup symmetry constraints if needed
    self._setup_symmetry_constraints()

compute_energy

compute_energy(positions: Array, *, deterministic: bool = True) -> Array

Compute energy for given atomic positions.

Parameters:

Name Type Description Default
positions Array

Atomic positions array of shape (batch, n_atoms, 3) or (n_atoms, 3) or flattened (n_atoms*3,)

required
deterministic bool

Whether to use deterministic mode (True for inference)

True

Returns:

Type Description
Array

Energy array with shape (batch_size, 1) for consistency

Source code in opifex/neural/base.py
def compute_energy(
    self,
    positions: jax.Array,
    *,
    deterministic: bool = True,
) -> jax.Array:
    """Compute energy for given atomic positions.

    Args:
        positions: Atomic positions array of shape (batch, n_atoms, 3)
            or (n_atoms, 3) or flattened (n_atoms*3,)
        deterministic: Whether to use deterministic mode
            (True for inference)

    Returns:
        Energy array with shape (batch_size, 1) for consistency
    """
    # Handle flattened 1D input (common in tests)
    if positions.ndim == 1:
        # Assume 3D coordinates, create batch with single item
        if positions.shape[0] % 3 != 0:
            raise ValueError(
                f"Flattened positions length {positions.shape[0]} must be divisible by 3"
            )
        flat_positions = positions[None, :]  # Add batch dimension
    # Handle both batched and single inputs for 2D/3D
    elif positions.ndim == 2:
        # Single molecule: shape (n_atoms, 3)
        flat_positions = positions.flatten()[None, :]  # Add batch dim after flatten
    elif positions.ndim == 3:
        # Batched molecules: shape (batch, n_atoms, 3)
        # Flatten each batch item separately
        batch_size = positions.shape[0]
        flat_positions = positions.reshape(batch_size, -1)  # (batch, n_atoms*3)
    else:
        raise ValueError(f"Expected positions with 1, 2 or 3 dimensions, got {positions.ndim}")

    # Forward pass to get energy
    energy = self(flat_positions, deterministic=deterministic)

    # For 1D input, return scalar for API consistency with test expectations
    if positions.ndim == 1:
        return energy.squeeze()  # Return scalar energy

    # Ensure energy has shape (batch_size, 1) for API consistency
    if energy.ndim == 2 and energy.shape[1] == 1:
        return energy  # Already correct shape
    # Reshape to (batch_size, 1)
    return energy.reshape(-1, 1)

compute_forces

compute_forces(positions: Array, *, deterministic: bool = True) -> Array

Compute forces as negative gradient of energy.

Parameters:

Name Type Description Default
positions Array

Atomic positions array of shape (batch, n_atoms, 3) or (n_atoms, 3) or flattened (n_atoms*3,)

required
deterministic bool

Whether to use deterministic mode (True for inference)

True

Returns:

Type Description
Array

Forces array of shape (batch, n_atoms, 3) or (n_atoms, 3) or (n_atoms*3,)

Array

matching the input shape

Source code in opifex/neural/base.py
def compute_forces(
    self,
    positions: jax.Array,
    *,
    deterministic: bool = True,
) -> jax.Array:
    """Compute forces as negative gradient of energy.

    Args:
        positions: Atomic positions array of shape (batch, n_atoms, 3)
            or (n_atoms, 3) or flattened (n_atoms*3,)
        deterministic: Whether to use deterministic mode
            (True for inference)

    Returns:
        Forces array of shape (batch, n_atoms, 3) or (n_atoms, 3) or (n_atoms*3,)
        matching the input shape
    """
    # Handle flattened 1D input (common in tests)
    if positions.ndim == 1:
        # Assume 3D coordinates, reshape to (n_atoms, 3)
        if positions.shape[0] % 3 != 0:
            raise ValueError(
                f"Flattened positions length {positions.shape[0]} must be divisible by 3"
            )
        n_atoms = positions.shape[0] // 3
        positions_reshaped = positions.reshape(n_atoms, 3)

        def energy_fn_1d(pos):
            return self._compute_energy_scalar(pos, deterministic=deterministic)

        # Compute forces for reshaped positions
        forces_reshaped = -jax.grad(energy_fn_1d)(positions_reshaped)

        # Return in original flattened shape
        return forces_reshaped.flatten()

    def energy_fn_2d3d(pos):
        return self._compute_energy_scalar(pos, deterministic=deterministic)

    if positions.ndim == 2:
        # Single molecule case
        return -jax.grad(energy_fn_2d3d)(positions)
    if positions.ndim == 3:
        # Batched case - use vmap to handle batch dimension
        batched_grad = jax.vmap(jax.grad(energy_fn_2d3d))
        return -batched_grad(positions)
    raise ValueError(f"Expected positions with 1, 2 or 3 dimensions, got {positions.ndim}")

compute_energy_and_forces

compute_energy_and_forces(positions: Array, *, deterministic: bool = True) -> tuple[Array, Array]

Efficiently compute both energy and forces.

Parameters:

Name Type Description Default
positions Array

Atomic positions array of shape (n_atoms, 3) or flattened (n_atoms*3,)

required
deterministic bool

Whether to use deterministic mode (True for inference)

True

Returns:

Type Description
Array

Tuple of (energy, forces) where energy is scalar and

Array

forces has shape (n_atoms, 3) or (n_atoms*3,) matching input

Source code in opifex/neural/base.py
def compute_energy_and_forces(
    self,
    positions: jax.Array,
    *,
    deterministic: bool = True,
) -> tuple[jax.Array, jax.Array]:
    """Efficiently compute both energy and forces.

    Args:
        positions: Atomic positions array of shape (n_atoms, 3) or
            flattened (n_atoms*3,)
        deterministic: Whether to use deterministic mode
            (True for inference)

    Returns:
        Tuple of (energy, forces) where energy is scalar and
        forces has shape (n_atoms, 3) or (n_atoms*3,) matching input
    """
    # Handle flattened 1D input (common in tests)
    if positions.ndim == 1:
        # Assume 3D coordinates, reshape to (n_atoms, 3)
        if positions.shape[0] % 3 != 0:
            raise ValueError(
                f"Flattened positions length {positions.shape[0]} must be divisible by 3"
            )
        n_atoms = positions.shape[0] // 3
        positions_reshaped = positions.reshape(n_atoms, 3)

        def energy_fn_1d(pos):
            return self._compute_energy_scalar(pos, deterministic=deterministic)

        # Use value_and_grad for efficiency
        energy, grad_energy = jax.value_and_grad(energy_fn_1d)(positions_reshaped)
        forces = -grad_energy  # Forces are negative gradient

        # Return forces in original flattened shape
        return energy, forces.flatten()

    def energy_fn_2d(pos):
        return self._compute_energy_scalar(pos, deterministic=deterministic)

    # Use value_and_grad for efficiency
    energy, grad_energy = jax.value_and_grad(energy_fn_2d)(positions)
    forces = -grad_energy  # Forces are negative gradient
    return energy, forces

Neural Quantum

opifex.neural.quantum

Neural quantum chemistry modules for scientific machine learning.

NeuralDFT

NeuralDFT(*, grid_size: int = 1000, convergence_threshold: float = 1e-08, max_scf_iterations: int = 100, xc_functional_type: str = 'neural', mixing_strategy: str = 'neural', use_neural_scf: bool = True, chemical_accuracy_target: float = 0.043, enable_high_precision: bool = True, rngs: Rngs)

Bases: Module

Neural Density Functional Theory Framework.

Integrates neural exchange-correlation functionals with neural-enhanced SCF solvers for efficient DFT calculations. Designed for chemical accuracy in quantum molecular systems with proper handling of high-precision calculations and quantum constraints.

Fully compliant with modern Flax NNX patterns.

Parameters:

Name Type Description Default
grid_size int

Size of electron density grid

1000
convergence_threshold float

SCF convergence threshold in Hartree

1e-08
max_scf_iterations int

Maximum SCF iterations

100
xc_functional_type str

Type of XC functional ("neural", "lda", "pbe")

'neural'
mixing_strategy str

Density mixing strategy ("neural", "diis", "simple")

'neural'
use_neural_scf bool

Whether to use neural SCF solver enhancements

True
chemical_accuracy_target float

Target accuracy in Hartree

0.043
enable_high_precision bool

Whether to use float64 for critical calculations

True
rngs Rngs

Random number generators (keyword-only)

required

compute_energy

compute_energy(molecular_system: Any, *, density: Array | None = None, deterministic: bool = True) -> DFTResult

Compute total energy using neural DFT with enhanced precision.

Parameters:

Name Type Description Default
molecular_system Any

Molecular system to compute

required
density Array | None

Optional initial density guess

None
deterministic bool

Whether to use deterministic mode

True

Returns:

Type Description
DFTResult

DFT calculation result with precision diagnostics

predict_chemical_accuracy

predict_chemical_accuracy(molecular_system: Any, reference_energy: float | None = None) -> dict[str, Any]

Predict chemical accuracy with enhanced diagnostics.

NeuralSCFSolver

NeuralSCFSolver(convergence_threshold: float = 1e-08, max_iterations: int = 100, mixing_strategy: str = 'neural', grid_size: int = 1000, chemical_accuracy_target: float = 1e-06, *, rngs: Rngs)

Bases: Module

Neural-enhanced self-consistent field solver with full

convergence analysis.

Implements neural acceleration of SCF convergence through: 1. Intelligent density mixing using neural networks 2. Advanced convergence prediction with chemical accuracy assessment 3. Stability monitoring and adaptive recovery mechanisms 4. High-precision numerical methods for quantum accuracy

Parameters:

Name Type Description Default
convergence_threshold float

Energy convergence threshold

1e-08
max_iterations int

Maximum number of SCF iterations

100
mixing_strategy str

Density mixing strategy ("neural" or "linear")

'neural'
grid_size int

Size of density grid for molecular calculations

1000
chemical_accuracy_target float

Target accuracy for chemical predictions

1e-06
rngs Rngs

Random number generators for neural components

required

solve_scf

solve_scf(molecular_system: MolecularSystem, initial_density: Array, hamiltonian_fn: Callable | None = None, *, deterministic: bool = False) -> SCFResult

Solve SCF equations with neural acceleration and full analysis.

Parameters:

Name Type Description Default
molecular_system MolecularSystem

Molecular system to solve

required
initial_density Array

Initial electron density guess

required
hamiltonian_fn Callable | None

Custom Hamiltonian function (optional)

None
deterministic bool

Whether to use deterministic computation

False

Returns:

Type Description
SCFResult

Full SCF result with convergence analysis

predict_convergence_iterations

predict_convergence_iterations(molecular_system: MolecularSystem, initial_density: Array, *, deterministic: bool = False) -> int

Predict number of iterations required for convergence.

Parameters:

Name Type Description Default
molecular_system MolecularSystem

Molecular system to analyze

required
initial_density Array

Initial density guess

required
deterministic bool

Whether to use deterministic computation

False

Returns:

Type Description
int

Predicted number of iterations for convergence

NeuralXCFunctional

NeuralXCFunctional(hidden_sizes: Sequence[int] = (128, 128, 64), activation: Callable = gelu, use_attention: bool = True, num_attention_heads: int = 8, use_advanced_features: bool = True, dropout_rate: float = 0.0, *, rngs: Rngs)

Bases: Module

Neural exchange-correlation functional for DFT calculations.

Implements a modern neural XC functional with attention mechanisms for capturing non-local correlations, enhanced physics constraints, and chemical accuracy optimization.

Parameters:

Name Type Description Default
hidden_sizes Sequence[int]

Sequence of hidden layer sizes

(128, 128, 64)
activation Callable

Activation function to use

gelu
use_attention bool

Whether to use attention mechanism for non-local correlations

True
num_attention_heads int

Number of attention heads

8
use_advanced_features bool

Whether to include advanced physics features

True
dropout_rate float

Dropout rate for regularization

0.0
rngs Rngs

Random number generators

required

compute_functional_derivative

compute_functional_derivative(density: Array, gradients: Array, *, deterministic: bool = False) -> Array

Compute functional derivative of XC energy with respect to density.

Parameters:

Name Type Description Default
density Array

Electron density

required
gradients Array

Density gradients

required
deterministic bool

Whether to use deterministic computation

False

Returns:

Type Description
Array

Functional derivative ∂E_xc/∂ρ with enhanced numerical stability

assess_chemical_accuracy

assess_chemical_accuracy(density: Array, gradients: Array, reference_energy: Array | None = None, *, deterministic: bool = False) -> dict[str, float]

Assess chemical accuracy of XC functional predictions.

Parameters:

Name Type Description Default
density Array

Electron density

required
gradients Array

Density gradients

required
reference_energy Array | None

Reference XC energy for comparison (optional)

None
deterministic bool

Whether to use deterministic computation

False

Returns:

Type Description
dict[str, float]

Dictionary containing accuracy metrics

Neural Operators

opifex.neural.operators

Opifex Neural Operators: Full Operator Learning Library

This module provides the most complete collection of neural operators for scientific machine learning, including all major variants from the neuraloperator repository and advanced architectures.

The library includes:

  • Fourier Neural Operators (FNO, TFNO, U-FNO, SFNO, Local FNO, AM-FNO)
  • Deep Operator Networks (DeepONet and variants)
  • Specialized operators (GINO, MGNO, UQNO, LNO, WNO, GNO)
  • Physics-informed operators (PINO)
  • Graph-based operators
  • Uncertainty quantification operators

All operators are built with JAX/FLAX NNX for high performance and support automatic differentiation, just-in-time compilation, and multi-device parallelization.

AdaptiveDeepONet

AdaptiveDeepONet(branch_input_dim: int, trunk_input_dim: int, base_latent_dim: int, *, num_resolution_levels: int = 3, adaptive_latent_scaling: bool = True, use_residual_connections: bool = True, activation: str = 'tanh', rngs: Rngs)

Bases: Module

Adaptive DeepONet with dynamic architecture adjustment.

This variant can adapt its architecture based on problem complexity and provides multiple resolution levels for different accuracy requirements.

Parameters:

Name Type Description Default
branch_input_dim int

Branch network input dimension

required
trunk_input_dim int

Trunk network input dimension

required
base_latent_dim int

Base latent dimension (scaled for different levels)

required
num_resolution_levels int

Number of resolution levels

3
adaptive_latent_scaling bool

Whether to scale latent dimensions adaptively

True
use_residual_connections bool

Whether to use residual connections

True
activation str

Activation function name

'tanh'
rngs Rngs

Random number generators

required

DeepONet

DeepONet(branch_sizes: list[int], trunk_sizes: list[int], *, activation: str = 'gelu', output_activation: str | None = None, use_bias: bool = True, rngs: Rngs)

Bases: Module

Deep Operator Network for learning function-to-function mappings.

DeepONet learns to approximate nonlinear operators G that map functions to functions: G: u → G(u), where u and G(u) are functions.

The architecture consists of: - Branch network: Processes input function u evaluated at sensors - Trunk network: Processes evaluation locations y - Dot product combination of branch and trunk outputs

Fully compliant with modern Flax NNX patterns.

Parameters:

Name Type Description Default
branch_sizes list[int]

Layer sizes for branch network [input_sensors, hidden1, hidden2, ..., output_dim]

required
trunk_sizes list[int]

Layer sizes for trunk network [location_dim, hidden1, hidden2, ..., output_dim] Note: output_dim should match branch output_dim

required
activation str

Activation function name for hidden layers

'gelu'
output_activation str | None

Optional activation for final output (None means no activation on output)

None
use_bias bool

Whether to use bias in linear layers

True
rngs Rngs

Random number generators (keyword-only)

required

get_branch_output

get_branch_output(branch_input: Array, *, deterministic: bool = True) -> Array

Get branch network output for analysis purposes.

Parameters:

Name Type Description Default
branch_input Array

Function values at sensor locations

required
deterministic bool

Whether to use deterministic mode

True

Returns:

Type Description
Array

Branch network output

get_trunk_output

get_trunk_output(trunk_input: Array, *, deterministic: bool = True) -> Array

Get trunk network output for analysis purposes.

Parameters:

Name Type Description Default
trunk_input Array

Evaluation locations

required
deterministic bool

Whether to use deterministic mode

True

Returns:

Type Description
Array

Trunk network output

FourierEnhancedDeepONet

FourierEnhancedDeepONet(branch_sizes: list[int], trunk_sizes: list[int], *, fourier_modes: int = 16, use_spectral_branch: bool = True, use_spectral_trunk: bool = False, activation: str = 'tanh', rngs: Rngs)

Bases: Module

Fourier-Enhanced DeepONet combining spectral and operator learning.

This variant integrates Fourier Neural Operator concepts into DeepONet architecture for improved performance on problems with spectral structure.

Parameters:

Name Type Description Default
branch_sizes list[int]

Branch network layer sizes [input, hidden..., output]

required
trunk_sizes list[int]

Trunk network layer sizes [input, hidden..., output]

required
fourier_modes int

Number of Fourier modes for spectral layers

16
use_spectral_branch bool

Whether to use spectral convolution in branch

True
use_spectral_trunk bool

Whether to use spectral convolution in trunk

False
activation str

Activation function name

'tanh'
rngs Rngs

Random number generators

required

MultiPhysicsDeepONet

MultiPhysicsDeepONet(branch_input_dim: int, trunk_input_dim: int, branch_hidden_dims: list[int], trunk_hidden_dims: list[int], latent_dim: int, *, num_physics_systems: int = 1, use_attention: bool = True, attention_heads: int = 8, physics_constraints: list[str] | None = None, sensor_optimization: bool = False, num_sensors: int | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Bases: Module

Enhanced DeepONet with multi-physics support and attention mechanisms.

Extends the basic DeepONet architecture with physics-aware attention, multi-physics coupling, and sensor optimization for improved operator learning.

Parameters:

Name Type Description Default
branch_input_dim int

Branch network input dimension

required
trunk_input_dim int

Trunk network input dimension

required
branch_hidden_dims list[int]

Branch network hidden dimensions

required
trunk_hidden_dims list[int]

Trunk network hidden dimensions

required
latent_dim int

Latent dimension for inner product

required
num_physics_systems int

Number of physics systems to handle

1
use_attention bool

Whether to use physics-aware attention

True
attention_heads int

Number of attention heads

8
physics_constraints list[str] | None

List of physics constraints to enforce

None
sensor_optimization bool

Whether to use sensor optimization

False
num_sensors int | None

Number of sensors (required if sensor_optimization=True)

None
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

branch_nets property

branch_nets: list[Module]

Get branch networks from all physics operators.

get_sensor_positions

get_sensor_positions() -> Array | None

Get current sensor positions if sensor optimization is enabled.

set_physics_constraints

set_physics_constraints(constraints: list[str]) -> None

Update physics constraints for attention mechanism.

AmortizedFourierNeuralOperator

AmortizedFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int = 32, modes: Sequence[int] = (16, 16), num_layers: int = 4, kernel_hidden_dim: int = 128, kernel_layers: int = 3, max_frequency: float = 10.0, activation: Callable = gelu, use_layer_norm: bool = False, use_kernel_regularization: bool = True, *, rngs: Rngs)

Bases: Module

Amortized Fourier Neural Operator with neural kernel parameterization.

get_regularization_loss

get_regularization_loss(x: Array) -> Array

Compute regularization loss on demand.

get_kernel_analysis

get_kernel_analysis(freq_range: tuple[float, float], num_points: int = 100) -> dict[str, Array]

Analyze learned kernel functions.

AmortizedSpectralConvolution

AmortizedSpectralConvolution(in_channels: int, out_channels: int, modes: Sequence[int], kernel_hidden_dim: int = 128, kernel_layers: int = 3, max_frequency: float = 10.0, use_kernel_regularization: bool = True, *, rngs: Rngs)

Bases: Module

Amortized spectral convolution with neural kernel parameterization.

KernelNetwork

KernelNetwork(freq_dim: int, output_dim: int, hidden_dim: int = 128, num_layers: int = 3, activation: Callable = gelu, use_frequency_encoding: bool = True, max_frequency: float = 10.0, *, rngs: Rngs)

Bases: Module

Neural network to parameterize Fourier kernels.

FourierLayer

FourierLayer(in_channels: int, out_channels: int, modes: int, *, activation: Callable[[Array], Array] = gelu, spatial_dims: int = 2, rngs: Rngs)

Bases: Module

Fourier layer combining spectral convolution with activation.

This layer performs: 1. FFT to transform input to spectral domain 2. Spectral convolution 3. IFFT to transform back to spatial domain 4. Linear transformation and activation with proper residual connection

Fully compliant with modern Flax NNX patterns.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
modes int

Number of Fourier modes

required
activation Callable[[Array], Array]

Activation function

gelu
spatial_dims int

Number of spatial dimensions (1, 2, or 3). Controls which spectral weights are allocated — avoids dead parameters.

2
rngs Rngs

Random number generators (keyword-only)

required

FourierNeuralOperator

FourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, modes: int, num_layers: int, *, activation: Callable[[Array], Array] = gelu, factorization_type: str | None = None, factorization_rank: int | None = None, use_mixed_precision: bool = False, domain_padding: int = 0, spatial_dims: int = 2, rngs: Rngs)

Bases: Module

Fourier Neural Operator for learning solution operators of PDEs.

Implements the complete FNO architecture with optional tensor factorization and mixed precision training capabilities. Fully compliant with modern Flax NNX patterns.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
hidden_channels int

Number of hidden channels

required
modes int

Number of Fourier modes

required
num_layers int

Number of Fourier layers

required
activation Callable[[Array], Array]

Activation function

gelu
factorization_type str | None

Optional tensor factorization ('tucker', 'cp')

None
factorization_rank int | None

Rank for tensor factorization

None
use_mixed_precision bool

Whether to use mixed precision

False
domain_padding int

Pixels to pad spatial dims (reduces Gibbs phenomenon for non-periodic problems). Set to 2 for Darcy flow.

0
spatial_dims int

Number of spatial dimensions (1, 2, or 3). Determines which spectral weights are allocated per layer.

2
rngs Rngs

Random number generators (keyword-only)

required

count_parameters

count_parameters() -> int

Count total number of trainable parameters in the model.

FactorizedFourierLayer

FactorizedFourierLayer(in_channels: int, out_channels: int, modes: int, factorization_type: str, factorization_rank: int, *, activation: Callable[[Array], Array] = gelu, rngs: Rngs)

Bases: Module

Fourier layer with tensor factorization for parameter reduction.

Implements Tucker or CP factorization of the spectral convolution weights to achieve significant parameter reduction (up to 95%) while maintaining performance.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
modes int

Number of Fourier modes

required
factorization_type str

Type of factorization ("tucker" or "cp")

required
factorization_rank int

Rank for factorization

required
activation Callable[[Array], Array]

Activation function

gelu
rngs Rngs

Random number generators

required

get_parameter_count

get_parameter_count() -> dict[str, int | float]

Get parameter count breakdown for analysis.

LocalFourierLayer

LocalFourierLayer(in_channels: int, out_channels: int, modes: Sequence[int], kernel_size: int = 3, activation: Callable = gelu, mixing_weight: float = 0.5, *, rngs: Rngs)

Bases: Module

Fourier layer with local convolution for capturing short-range interactions.

Combines global spectral convolution with local spatial convolution for full feature extraction.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
modes Sequence[int]

Fourier modes for spectral convolution

required
kernel_size int

Kernel size for local convolution

3
activation Callable

Activation function

gelu
mixing_weight float

Weight for combining spectral and local branches

0.5
rngs Rngs

Random number generator state

required

get_mixing_analysis

get_mixing_analysis(x: Array) -> tuple[Array, Array, Array]

Analyze global vs local contributions for this layer.

Parameters:

Name Type Description Default
x Array

Input tensor (batch, in_channels, *spatial).

required

Returns:

Type Description
Array

Tuple of (global_features, local_features, mixing_weights)

Array

where mixing_weights is a scalar array of the spectral weight.

LocalFourierNeuralOperator

LocalFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, modes: Sequence[int], num_layers: int = 4, kernel_size: int = 3, use_adaptive_mixing: bool = True, use_residual_connections: bool = True, activation: Callable = gelu, *, rngs: Rngs)

Bases: Module

Local Fourier Neural Operator combining global and local operations.

This operator is designed for problems that require both: - Long-range dependencies (captured by Fourier operations) - Local features and fine details (captured by convolutions)

Examples include: - Turbulent flows with both large-scale structures and small eddies - Wave propagation with local scattering and global modes - Multi-physics problems with different characteristic scales

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
hidden_channels int

Hidden layer width

required
modes Sequence[int]

Fourier modes for global operations

required
num_layers int

Number of Local Fourier layers

4
kernel_size int

Kernel size for local convolutions

3
use_adaptive_mixing bool

Whether to use adaptive feature mixing

True
use_residual_connections bool

Whether to use residual connections

True
activation Callable

Activation function

gelu
rngs Rngs

Random number generator state

required

analyze_global_local_contributions

analyze_global_local_contributions(x: Array) -> dict[str, list[Array]]

Analyze global vs local contributions at each layer.

Returns:

Type Description
dict[str, list[Array]]

Dictionary with global and local feature maps

MultiScaleFourierNeuralOperator

MultiScaleFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, modes_per_scale: list[int], num_layers_per_scale: list[int], *, spatial_dims: int = 2, activation: Callable[[Array], Array] = gelu, use_cross_scale_attention: bool = True, attention_heads: int = 8, dropout_rate: float = 0.0, use_gradient_checkpointing: bool = True, rngs: Rngs)

Bases: Module

Multi-Scale Fourier Neural Operator for hierarchical resolution handling.

This operator learns operators across multiple scales simultaneously, enabling efficient handling of multi-scale physics problems like turbulence, multi-phase flows, and hierarchical material structures.

Features: - Hierarchical spectral convolutions at different resolution levels - Adaptive scale selection based on input characteristics - Cross-scale information exchange through attention mechanisms - Memory-efficient implementation with gradient checkpointing

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
hidden_channels int

Hidden channel dimension

required
modes_per_scale list[int]

List of Fourier modes for each scale

required
num_layers_per_scale list[int]

List of layer counts for each scale

required
spatial_dims int

Number of spatial dimensions (1 or 2)

2
activation Callable[[Array], Array]

Activation function

gelu
use_cross_scale_attention bool

Whether to use cross-scale attention

True
attention_heads int

Number of attention heads

8
dropout_rate float

Dropout rate for regularization

0.0
use_gradient_checkpointing bool

Whether to use gradient checkpointing

True
rngs Rngs

Random number generators

required

SphericalFourierNeuralOperator

SphericalFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, lmax: int, mmax: int | None = None, num_layers: int = 4, activation: Callable = gelu, use_real_sht: bool = False, *, rngs: Rngs)

Bases: Module

Spherical Fourier Neural Operator for data on spherical domains.

Uses spherical harmonic transforms instead of regular FFTs, making it ideal for: - Global atmospheric modeling - Ocean circulation - Planetary science - Any data naturally defined on spheres

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
hidden_channels int

Hidden layer width

required
lmax int

Maximum spherical harmonic degree

required
mmax int | None

Maximum azimuthal order (if None, uses lmax)

None
num_layers int

Number of SFNO layers

4
activation Callable

Activation function

gelu
use_real_sht bool

Whether to use real spherical harmonics

False
rngs Rngs

Random number generator state

required

get_spherical_modes

get_spherical_modes(x: Array) -> Array

Get spherical harmonic coefficients for analysis.

Parameters:

Name Type Description Default
x Array

Input tensor on sphere

required

Returns:

Type Description
Array

Spherical harmonic coefficients

compute_power_spectrum

compute_power_spectrum(x: Array) -> Array

Compute spherical harmonic power spectrum.

Parameters:

Name Type Description Default
x Array

Input tensor on sphere

required

Returns:

Type Description
Array

Power spectrum as function of spherical harmonic degree l

SphericalHarmonicConvolution

SphericalHarmonicConvolution(in_channels: int, out_channels: int, lmax: int, mmax: int | None = None, *, rngs: Rngs)

Bases: Module

Spherical harmonic convolution for spherical domains.

Operates in spherical harmonic space analogous to how standard FNO operates in Fourier space, but adapted for spherical geometry.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
lmax int

Maximum spherical harmonic degree (controls resolution)

required
mmax int | None

Maximum azimuthal order (if None, uses lmax)

None
rngs Rngs

Random number generator state

required

TensorizedFourierNeuralOperator

TensorizedFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), num_layers: int = 4, factorization: Literal['tucker', 'cp', 'tt'] = 'tucker', rank: float = 0.1, *, rngs: Rngs)

Bases: Module

Simplified Tensorized FNO with stable implementations.

TensorizedSpectralConvolution

TensorizedSpectralConvolution(in_channels: int, out_channels: int, modes: Sequence[int], decomposition_type: Literal['tucker', 'cp', 'tt'] = 'tucker', rank: float = 0.1, *, rngs: Rngs)

Bases: Module

Simplified tensorized spectral convolution for stability.

get_compression_stats

get_compression_stats() -> dict[str, float]

Get compression statistics.

UFNODecoderBlock

UFNODecoderBlock(in_channels: int, skip_channels: int, out_channels: int, modes: Sequence[int], upsample_factor: int = 2, activation: Callable = gelu, *, rngs: Rngs)

Bases: Module

Clean U-FNO decoder block with standardized tensor operations.

Performs: upsampling + skip fusion + spectral convolution

UFNOEncoderBlock

UFNOEncoderBlock(in_channels: int, out_channels: int, modes: Sequence[int], downsample_factor: int = 2, activation: Callable = gelu, *, rngs: Rngs)

Bases: Module

Clean U-FNO encoder block with standardized tensor operations.

Performs: spectral convolution + skip connection + downsampling

UFourierNeuralOperator

UFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, modes: Sequence[int], num_levels: int = 3, downsample_factor: int = 2, activation: Callable = gelu, *, rngs: Rngs)

Bases: Module

U-Net style Fourier Neural Operator with clean, standardized architecture.

Features: - Consistent tensor dimension handling - Standardized spectral operations - Clean encoder-decoder structure - Proper channel management throughout

GraphNeuralOperator

GraphNeuralOperator(node_dim: int, hidden_dim: int, num_layers: int, *, edge_dim: int = 0, activation: Callable[[Array], Array] = gelu, rngs: Rngs)

Bases: Module

Graph Neural Operator for learning operators on irregular domains.

Implements message passing neural networks with geometric awareness for learning operators on graph-structured data. Suitable for irregular meshes, molecular systems, and other graph-based scientific computing applications.

Parameters:

Name Type Description Default
node_dim int

Dimension of node features

required
hidden_dim int

Hidden dimension for message passing

required
num_layers int

Number of message passing layers

required
edge_dim int

Dimension of edge features (0 for no edge features)

0
activation Callable[[Array], Array]

Activation function

gelu
rngs Rngs

Random number generators

required

MessagePassingLayer

MessagePassingLayer(node_dim: int, edge_dim: int, hidden_dim: int, *, activation: Callable[[Array], Array] = gelu, rngs: Rngs)

Bases: Module

Message passing layer for graph neural networks.

Implements the message passing paradigm: 1. Compute messages between connected nodes 2. Aggregate messages at each node 3. Update node features based on aggregated messages

Parameters:

Name Type Description Default
node_dim int

Dimension of node features

required
edge_dim int

Dimension of edge features

required
hidden_dim int

Hidden dimension for message computation

required
activation Callable[[Array], Array]

Activation function

gelu
rngs Rngs

Random number generators

required

PhysicsAwareAttention

PhysicsAwareAttention(embed_dim: int, num_heads: int, *, physics_constraints: list[str] | None = None, dropout_rate: float = 0.0, rngs: Rngs)

Bases: Module

Physics-aware attention mechanism with constraint enforcement.

Integrates physics constraints into the attention mechanism to ensure physically meaningful attention patterns.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension

required
num_heads int

Number of attention heads

required
physics_constraints list[str] | None

List of physics constraints to enforce

None
dropout_rate float

Dropout rate for attention weights

0.0
rngs Rngs

Random number generators

required

PhysicsCrossAttention

PhysicsCrossAttention(embed_dim: int, num_heads: int, physics_constraints: list[str], num_physics_systems: int, *, conservation_weight: float = 0.1, adaptive_weighting: bool = True, cross_system_coupling: bool = True, dropout_rate: float = 0.0, rngs: Rngs)

Bases: Module

Physics-Cross-Attention mechanism for enhanced multi-physics coupling.

Implements cross-attention between different physics systems with conservation law enforcement and adaptive weighting based on physics constraints.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension

required
num_heads int

Number of attention heads

required
physics_constraints list[str]

List of physics constraints to enforce

required
num_physics_systems int

Number of different physics systems

required
conservation_weight float

Weight for conservation law enforcement

0.1
adaptive_weighting bool

Whether to use adaptive constraint weighting

True
cross_system_coupling bool

Whether to enable cross-system coupling

True
dropout_rate float

Dropout rate for attention weights

0.0
rngs Rngs

Random number generators

required

forward_with_conservation

forward_with_conservation(x: Array, *, physics_info: Array | None = None, training: bool = False) -> tuple[Array, Array]

Forward pass with conservation loss computation.

Parameters:

Name Type Description Default
x Array

Input tensor

required
physics_info Array | None

Physics constraint information

None
training bool

Whether in training mode

False

Returns:

Type Description
tuple[Array, Array]

Tuple of (output, conservation_loss)

PhysicsInformedOperator

PhysicsInformedOperator(layer_sizes: list[int], physics_type: str = 'pde', *, activation: str = 'gelu', physics_weight: float = 1.0, data_weight: float = 1.0, use_bias: bool = True, rngs: Rngs)

Bases: Module

Physics-Informed Neural Operator with embedded physical constraints.

This operator combines standard neural operator architectures with physics-based constraints and differential operators to ensure physically consistent solutions.

Fully compliant with modern Flax NNX patterns.

Parameters:

Name Type Description Default
layer_sizes list[int]

Layer sizes for the neural network [input_dim, hidden1, hidden2, ..., output_dim]

required
physics_type str

Type of physics constraint ('pde', 'conservation', 'symmetry')

'pde'
activation str

Activation function name

'gelu'
physics_weight float

Weight for physics loss component

1.0
data_weight float

Weight for data loss component

1.0
use_bias bool

Whether to use bias in linear layers

True
rngs Rngs

Random number generators (keyword-only)

required

compute_physics_loss

compute_physics_loss(coordinates: Array, *, deterministic: bool = True) -> Array

Compute physics-based loss components.

Parameters:

Name Type Description Default
coordinates Array

Space-time coordinates

required
deterministic bool

Whether to use deterministic mode

True

Returns:

Type Description
Array

Physics loss value

compute_total_loss

compute_total_loss(coordinates: Array, target_solution: Array | None = None, *, deterministic: bool = True) -> dict[str, Array]

Compute total loss combining data and physics components.

Parameters:

Name Type Description Default
coordinates Array

Space-time coordinates

required
target_solution Array | None

Target solution (optional, for supervised learning)

None
deterministic bool

Whether to use deterministic mode

True

Returns:

Type Description
dict[str, Array]

Dictionary containing individual loss components and total loss

GeometryAttention

GeometryAttention(feature_dim: int, geometry_dim: int, num_heads: int = 8, use_distance_attention: bool = True, *, rngs: Rngs)

Bases: Module

Geometry-aware attention mechanism.

Computes attention weights based on both feature similarity and geometric relationships with proper dimension handling.

Parameters:

Name Type Description Default
feature_dim int

Dimension of feature vectors

required
geometry_dim int

Dimension of geometry embeddings

required
num_heads int

Number of attention heads

8
use_distance_attention bool

Whether to include distance-based attention

True
rngs Rngs

Random number generator state

required

GeometryEncoder

GeometryEncoder(coord_dim: int, hidden_dim: int, output_dim: int, use_positional_encoding: bool = True, max_position: float = 10000.0, *, rngs: Rngs)

Bases: Module

Encoder for geometric coordinates with positional encoding.

Transforms coordinate information into rich geometric embeddings suitable for neural operator processing.

Parameters:

Name Type Description Default
coord_dim int

Dimension of input coordinates

required
hidden_dim int

Hidden layer dimension

required
output_dim int

Output embedding dimension

required
use_positional_encoding bool

Whether to use sinusoidal positional encoding

True
max_position float

Maximum position for encoding

10000.0
rngs Rngs

Random number generator state

required

GeometryInformedNeuralOperator

GeometryInformedNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), num_layers: int = 4, geometry_dim: int = 32, coord_dim: int = 2, use_geometry_attention: bool = True, use_spectral_conv: bool = True, *, rngs: Rngs)

Bases: Module

Complete Geometry-Informed Neural Operator.

Advanced neural operator that incorporates geometric information throughout the network for improved performance on spatially complex problems.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
hidden_channels int

Hidden channel dimension

64
modes Sequence[int]

Fourier modes for spectral convolution

(16, 16)
num_layers int

Number of GINO blocks

4
geometry_dim int

Dimension of geometry embeddings

32
coord_dim int

Coordinate dimension

2
use_geometry_attention bool

Whether to use geometry attention

True
use_spectral_conv bool

Whether to use spectral convolution

True
rngs Rngs

Random number generator state

required

GINOBlock

GINOBlock(in_channels: int, out_channels: int, modes: Sequence[int], geometry_dim: int, coord_dim: int = 2, use_geometry_attention: bool = True, use_spectral_conv: bool = True, activation: Callable[[Array], Array] = gelu, *, rngs: Rngs)

Bases: Module

Single GINO block with spectral convolution and geometry attention.

Combines spectral convolutions with geometry-aware processing for enhanced spatial understanding.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
modes Sequence[int]

Fourier modes for spectral convolution

required
geometry_dim int

Dimension of geometry embeddings

required
coord_dim int

Dimension of coordinates

2
use_geometry_attention bool

Whether to use geometry attention

True
use_spectral_conv bool

Whether to use spectral convolution

True
activation Callable[[Array], Array]

Activation function

gelu
rngs Rngs

Random number generator state

required

LatentNeuralOperator

LatentNeuralOperator(in_channels: int, out_channels: int, latent_dim: int, num_latent_tokens: int, *, num_attention_heads: int = 8, num_encoder_layers: int = 4, num_decoder_layers: int = 4, physics_constraints: list[str] | None = None, dropout_rate: float = 0.0, activation: Callable[[Array], Array] = gelu, rngs: Rngs)

Bases: Module

Latent Neural Operator with attention-based latent representations.

This operator learns compact latent representations of function spaces using attention mechanisms, enabling efficient learning of complex operator mappings with reduced computational overhead.

Features: - Learnable latent space for function representation - Multi-head attention for function-to-latent and latent-to-function mappings - Physics-aware attention constraints - Efficient inference through latent space operations

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
latent_dim int

Dimension of latent space

required
num_latent_tokens int

Number of latent tokens

required
num_attention_heads int

Number of attention heads

8
num_encoder_layers int

Number of encoder layers

4
num_decoder_layers int

Number of decoder layers

4
physics_constraints list[str] | None

List of physics constraints

None
dropout_rate float

Dropout rate

0.0
activation Callable[[Array], Array]

Activation function

gelu
rngs Rngs

Random number generators

required

MGNOLayer

MGNOLayer(channels: int, max_multipole_order: int = 4, use_local_messages: bool = True, dropout_rate: float = 0.1, *, rngs: Rngs)

Bases: Module

MGNO layer with numerical stability and robust message passing.

Combines multipole expansion with local graph neural network operations for handling both long-range and short-range interactions.

Parameters:

Name Type Description Default
channels int

Number of feature channels

required
max_multipole_order int

Maximum multipole expansion order

4
use_local_messages bool

Whether to use local message passing

True
dropout_rate float

Dropout rate for regularization

0.1
rngs Rngs

Random number generator state

required

MultipoleExpansion

MultipoleExpansion(channels: int, max_order: int = 4, epsilon: float = 1e-08, stabilization_factor: float = 0.1, *, rngs: Rngs)

Bases: Module

Numerically stable multipole expansion layer.

Computes multipole moments with proper numerical stability to prevent overflow and NaN generation in hierarchical computations.

Parameters:

Name Type Description Default
channels int

Number of feature channels

required
max_order int

Maximum multipole order

4
epsilon float

Small constant for numerical stability

1e-08
stabilization_factor float

Factor for moment normalization

0.1
rngs Rngs

Random number generator state

required

MultipoleGraphNeuralOperator

MultipoleGraphNeuralOperator(in_features: int, out_features: int, hidden_features: int = 64, num_layers: int = 3, max_degree: int = 4, use_local_messages: bool = True, dropout_rate: float = 0.1, *, rngs: Rngs)

Bases: Module

Complete Multipole Graph Neural Operator with numerical stability.

Neural operator for systems with long-range interactions such as molecular dynamics, N-body simulations, and plasma physics.

Parameters:

Name Type Description Default
in_features int

Number of input feature channels

required
out_features int

Number of output feature channels

required
hidden_features int

Hidden layer width

64
num_layers int

Number of MGNO layers

3
max_degree int

Maximum multipole expansion order

4
use_local_messages bool

Whether to use local message passing

True
dropout_rate float

Dropout rate for regularization

0.1
rngs Rngs

Random number generator state

required

OperatorNetwork

OperatorNetwork(operator_type: str, config: dict[str, Any], *, rngs: Rngs)

Bases: Module

Unified interface for different operator network types.

This class provides a common interface for different neural operator architectures (FNO, DeepONet, etc.) to enable easy experimentation and comparison.

Parameters:

Name Type Description Default
operator_type str

Type of operator ('fno', 'deeponet', 'fourier_deeponet', 'adaptive_deeponet', etc.)

required
config dict[str, Any]

Configuration dictionary for the operator

required
rngs Rngs

Random number generators

required

BayesianLinear

BayesianLinear(in_features: int, out_features: int, prior_std: float = 1.0, *, rngs: Rngs)

Bases: Module

Bayesian linear layer with weight uncertainty.

Implements variational Bayesian linear layer where weights are distributions rather than point estimates.

Parameters:

Name Type Description Default
in_features int

Number of input features

required
out_features int

Number of output features

required
prior_std float

Standard deviation of weight prior

1.0
rngs Rngs

Random number generator state

required

kl_divergence

kl_divergence() -> Array

Compute KL divergence between posterior and prior.

Returns:

Type Description
Array

KL divergence scalar

BayesianSpectralConvolution

BayesianSpectralConvolution(in_channels: int, out_channels: int, modes: Sequence[int], prior_std: float = 1.0, *, rngs: Rngs)

Bases: Module

Bayesian spectral convolution with proper shape handling.

Implements spectral convolution in Fourier domain with Bayesian weights for uncertainty quantification.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
modes Sequence[int]

Fourier modes for each spatial dimension

required
prior_std float

Standard deviation of weight prior

1.0
rngs Rngs

Random number generator state

required

kl_divergence

kl_divergence() -> Array

Compute KL divergence for weight distributions.

UncertaintyQuantificationNeuralOperator

UncertaintyQuantificationNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), num_layers: int = 4, use_epistemic: bool = True, use_aleatoric: bool = True, ensemble_size: int = 10, *, rngs: Rngs)

Bases: Module

Complete Uncertainty Quantification Neural Operator.

Neural operator with built-in uncertainty quantification for safety-critical applications and robust predictions.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
hidden_channels int

Hidden layer width

64
modes Sequence[int]

Fourier modes for spectral convolution

(16, 16)
num_layers int

Number of UQNO layers

4
use_epistemic bool

Whether to use epistemic uncertainty

True
use_aleatoric bool

Whether to use aleatoric uncertainty

True
ensemble_size int

Size for Monte Carlo sampling

10
rngs Rngs

Random number generator state

required

predict_with_uncertainty

predict_with_uncertainty(x: Array, num_samples: int = 10, key: Array | None = None) -> dict[str, Array]

Predict with Monte Carlo uncertainty estimation.

Parameters:

Name Type Description Default
x Array

Input tensor

required
num_samples int

Number of Monte Carlo samples

10
key Array | None

Random key for sampling

None

Returns:

Type Description
dict[str, Array]

Dictionary with prediction statistics

kl_divergence

kl_divergence() -> Array

Compute total KL divergence for all Bayesian layers.

UQNOLayer

UQNOLayer(in_channels: int, out_channels: int, modes: Sequence[int], use_skip_connection: bool = True, *, rngs: Rngs)

Bases: Module

UQNO layer with proper shape handling for skip connections.

Combines Bayesian spectral convolution with local operations and proper channel dimension handling.

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
modes Sequence[int]

Fourier modes for spectral convolution

required
use_skip_connection bool

Whether to use skip connections

True
rngs Rngs

Random number generator state

required

kl_divergence

kl_divergence() -> Array

Get KL divergence from spectral convolution.

WaveletNeuralOperator

WaveletNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, num_levels: int, *, wavelet_type: str = 'db4', mode: str = 'symmetric', activation: Callable[[Array], Array] = gelu, use_learnable_wavelets: bool = False, rngs: Rngs)

Bases: Module

Wavelet Neural Operator for multi-scale wavelet-based learning.

This operator uses wavelet transforms to capture multi-scale features in the input functions, enabling efficient learning of operators with multi-scale characteristics like turbulence and material heterogeneity.

Features: - Discrete Wavelet Transform (DWT) for multi-scale decomposition - Learnable wavelet coefficients processing - Multi-resolution reconstruction - Adaptive wavelet basis selection

Parameters:

Name Type Description Default
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
hidden_channels int

Hidden channel dimension

required
num_levels int

Number of wavelet decomposition levels

required
wavelet_type str

Type of wavelet (e.g., 'db4', 'haar')

'db4'
mode str

Boundary condition mode

'symmetric'
activation Callable[[Array], Array]

Activation function

gelu
use_learnable_wavelets bool

Whether to use learnable wavelet bases

False
rngs Rngs

Random number generators

required

create_high_frequency_amfno

create_high_frequency_amfno(in_channels: int, out_channels: int, modes: Sequence[int] = (128, 128), **kwargs) -> AmortizedFourierNeuralOperator

Create AM-FNO optimized for high-frequency problems.

create_shock_amfno

create_shock_amfno(in_channels: int = 3, out_channels: int = 3, modes: Sequence[int] = (96, 96), **kwargs) -> AmortizedFourierNeuralOperator

Create AM-FNO for problems with shocks/discontinuities.

create_wave_amfno

create_wave_amfno(in_channels: int = 2, out_channels: int = 2, modes: Sequence[int] = (64, 64), **kwargs) -> AmortizedFourierNeuralOperator

Create AM-FNO for wave propagation problems.

create_multiphysics_local_fno

create_multiphysics_local_fno(in_channels: int = 5, out_channels: int = 5, modes: Sequence[int] = (24, 24), **kwargs) -> LocalFourierNeuralOperator

Create Local FNO for multi-physics problems.

create_turbulence_local_fno

create_turbulence_local_fno(in_channels: int = 3, out_channels: int = 3, modes: Sequence[int] = (32, 32), **kwargs) -> LocalFourierNeuralOperator

Create Local FNO optimized for turbulent flow modeling.

create_wave_local_fno

create_wave_local_fno(in_channels: int = 2, out_channels: int = 2, modes: Sequence[int] = (64, 64), **kwargs) -> LocalFourierNeuralOperator

Create Local FNO for wave propagation with scattering.

create_climate_sfno

create_climate_sfno(in_channels: int = 5, out_channels: int = 5, lmax: int = 32, **kwargs) -> SphericalFourierNeuralOperator

Create SFNO optimized for global climate modeling.

create_ocean_sfno

create_ocean_sfno(in_channels: int = 4, out_channels: int = 4, lmax: int = 48, **kwargs) -> SphericalFourierNeuralOperator

Create SFNO for global ocean circulation modeling.

create_planetary_sfno

create_planetary_sfno(in_channels: int = 3, out_channels: int = 3, lmax: int = 16, **kwargs) -> SphericalFourierNeuralOperator

Create SFNO for planetary-scale phenomena.

create_weather_sfno

create_weather_sfno(in_channels: int = 7, out_channels: int = 7, lmax: int = 64, **kwargs) -> SphericalFourierNeuralOperator

Create SFNO for high-resolution weather prediction.

create_cp_fno

create_cp_fno(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), rank: float = 0.1, num_layers: int = 4, *, rngs: Rngs) -> TensorizedFourierNeuralOperator

Create CP factorized FNO.

create_tt_fno

create_tt_fno(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), rank: float = 0.1, num_layers: int = 4, *, rngs: Rngs) -> TensorizedFourierNeuralOperator

Create Tensor Train factorized FNO.

create_tucker_fno

create_tucker_fno(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), rank: float = 0.1, num_layers: int = 4, *, rngs: Rngs) -> TensorizedFourierNeuralOperator

Create Tucker factorized FNO.

create_deep_ufno

create_deep_ufno(in_channels: int, out_channels: int, hidden_channels: int = 32, modes: Sequence[int] = (32, 32), **kwargs) -> UFourierNeuralOperator

Create deep U-FNO (5 levels) for complex multi-scale problems.

create_shallow_ufno

create_shallow_ufno(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), **kwargs) -> UFourierNeuralOperator

Create shallow U-FNO (2 levels) for simple multi-scale problems.

create_turbulence_ufno

create_turbulence_ufno(in_channels: int = 4, out_channels: int = 3, **kwargs) -> UFourierNeuralOperator

Create U-FNO optimized for turbulent flow modeling.

create_3d_gino

create_3d_gino(in_channels: int, out_channels: int, *, rngs: Rngs) -> GeometryInformedNeuralOperator

Create GINO optimized for 3D problems.

create_adaptive_mesh_gino

create_adaptive_mesh_gino(in_channels: int, out_channels: int, *, rngs: Rngs) -> GeometryInformedNeuralOperator

Create GINO for adaptive mesh refinement.

create_cad_gino

create_cad_gino(in_channels: int, out_channels: int, *, rngs: Rngs) -> GeometryInformedNeuralOperator

Create GINO optimized for CAD geometries.

create_multiscale_gino

create_multiscale_gino(in_channels: int, out_channels: int, *, rngs: Rngs) -> GeometryInformedNeuralOperator

Create GINO for multiscale problems.

create_molecular_mgno

create_molecular_mgno(in_features: int, out_features: int, *, rngs: Rngs) -> MultipoleGraphNeuralOperator

Create MGNO optimized for molecular dynamics simulations.

create_nbody_mgno

create_nbody_mgno(in_features: int, out_features: int, *, rngs: Rngs) -> MultipoleGraphNeuralOperator

Create MGNO for N-body gravitational simulations.

create_plasma_mgno

create_plasma_mgno(in_features: int, out_features: int, *, rngs: Rngs) -> MultipoleGraphNeuralOperator

Create MGNO for plasma physics simulations.

create_bayesian_inverse_uqno

create_bayesian_inverse_uqno(in_channels: int, out_channels: int, *, rngs: Rngs) -> UncertaintyQuantificationNeuralOperator

Create UQNO for Bayesian inverse problems.

create_robust_design_uqno

create_robust_design_uqno(in_channels: int, out_channels: int, *, rngs: Rngs) -> UncertaintyQuantificationNeuralOperator

Create UQNO for robust engineering design.

create_safety_critical_uqno

create_safety_critical_uqno(in_channels: int, out_channels: int, *, rngs: Rngs) -> UncertaintyQuantificationNeuralOperator

Create UQNO for safety-critical applications.

create_operator

create_operator(operator_type: str, **kwargs: Any) -> Any

Factory function to create any operator by name.

Parameters:

Name Type Description Default
operator_type str

Type of operator to create

required
**kwargs Any

Arguments for operator initialization

{}

Returns:

Type Description
Any

Initialized operator instance

Raises:

Type Description
ValueError

If operator_type is not recognized

Example

Create a Tensorized FNO

tfno = create_operator("TFNO", ... in_channels=3, out_channels=1, ... hidden_channels=64, modes=(16, 16), ... factorization="tucker", rank=0.1, ... rngs=rngs)

recommend_operator

recommend_operator(application: str) -> dict[str, Any]

Recommend the best operator for a specific application.

Parameters:

Name Type Description Default
application str

Application domain

required

Returns:

Type Description
dict[str, Any]

Dictionary with recommendations

Example

rec = recommend_operator("turbulent_flow") print(f"Recommended: {rec['primary']}") print(f"Reason: {rec['reason']}")

list_operators

list_operators(category: str | None = None) -> dict[str, Sequence[str]]

List available operators by category.

Parameters:

Name Type Description Default
category str | None

Optional category filter

None

Returns:

Type Description
dict[str, Sequence[str]]

Dictionary of operators by category

get_operator_info

get_operator_info(operator_type: str) -> dict[str, Any]

Get detailed information about a specific operator.

Parameters:

Name Type Description Default
operator_type str

Type of operator

required

Returns:

Type Description
dict[str, Any]

Dictionary with operator information

Bayesian Networks

opifex.neural.bayesian

Bayesian neural network components with uncertainty quantification.

BlackJAXIntegration

BlackJAXIntegration(base_model: Module, sampler_type: str = 'nuts', num_warmup: int = 1000, num_samples: int = 1000, step_size: float = 0.001, *, rngs: Rngs)

Bases: Module

BlackJAX MCMC sampling integration for Bayesian neural networks.

Provides MCMC sampling capabilities for full Bayesian inference on neural network parameters, supporting multiple sampling algorithms (NUTS, HMC, MALA).

Parameters:

Name Type Description Default
base_model Module

Neural network model for Bayesian inference

required
sampler_type str

MCMC sampler type ('nuts', 'hmc', 'mala')

'nuts'
num_warmup int

Number of warmup steps for sampler adaptation

1000
num_samples int

Number of posterior samples to generate

1000
step_size float

Initial step size for MCMC sampling

0.001
rngs Rngs

Random number generators

required

sample_posterior

sample_posterior(x_data: Array, y_data: Array, *, rngs: Rngs | None = None) -> Array

Sample from posterior distribution using MCMC.

Parameters:

Name Type Description Default
x_data Array

Input training data

required
y_data Array

Target training data

required
rngs Rngs | None

Random number generators

None

Returns:

Type Description
Array

Posterior samples as array of shape (num_samples, num_params)

posterior_predictive

posterior_predictive(x_test: Array, posterior_samples: Array) -> Array

Generate posterior predictive samples.

Parameters:

Name Type Description Default
x_test Array

Test input data

required
posterior_samples Array

Posterior parameter samples

required

Returns:

Type Description
Array

Predictive samples of shape (num_samples, num_test, output_dim)

compute_posterior_statistics

compute_posterior_statistics(posterior_samples: Array) -> dict[str, Any]

Compute posterior statistics from samples.

Parameters:

Name Type Description Default
posterior_samples Array

Posterior parameter samples

required

Returns:

Type Description
dict[str, Any]

Dictionary with mean, std, and credible intervals

integrate_with_variational_framework

integrate_with_variational_framework(variational_framework: AmortizedVariationalFramework, x_data: Array, y_data: Array, *, rngs: Rngs) -> dict[str, Any]

Integrate BlackJAX sampling with variational framework.

Parameters:

Name Type Description Default
variational_framework AmortizedVariationalFramework

Variational framework instance

required
x_data Array

Training input data

required
y_data Array

Training target data

required
rngs Rngs

Random number generators

required

Returns:

Type Description
dict[str, Any]

Dictionary with MCMC samples and variational comparison

CalibrationTools

CalibrationTools(*, rngs: Rngs)

Bases: Module

Enhanced tools for uncertainty calibration assessment and improvement.

Parameters:

Name Type Description Default
rngs Rngs

Random number generators

required

assess_calibration

assess_calibration(predictions: Array, uncertainties: Array, true_values: Array, num_bins: int = 10) -> dict[str, float | dict[str, Array]]

Assess calibration quality of uncertainty estimates.

Parameters:

Name Type Description Default
predictions Array

Model predictions

required
uncertainties Array

Predicted uncertainties

required
true_values Array

Ground truth values

required
num_bins int

Number of bins for reliability diagram

10

Returns:

Type Description
dict[str, float | dict[str, Array]]

Dictionary with calibration metrics

compute_reliability_diagram

compute_reliability_diagram(confidences: Array, accuracies: Array, num_bins: int = 10) -> dict[str, Array]

Compute reliability diagram data.

Parameters:

Name Type Description Default
confidences Array

Predicted confidence values

required
accuracies Array

Binary accuracy indicators

required
num_bins int

Number of bins for the diagram

10

Returns:

Type Description
dict[str, Array]

Dictionary with binned confidence and accuracy data

platt_scaling

platt_scaling(logits: Array, labels: Array, validation_logits: Array) -> tuple[float, float]

Apply Platt scaling for probability calibration.

Parameters:

Name Type Description Default
logits Array

Training logits for fitting scaling parameters

required
labels Array

Training labels

required
validation_logits Array

Validation logits to calibrate

required

Returns:

Type Description
tuple[float, float]

Tuple of (slope, intercept) scaling parameters

isotonic_regression_calibration

isotonic_regression_calibration(confidences: Array, accuracies: Array) -> Array

Apply isotonic regression for calibration.

Parameters:

Name Type Description Default
confidences Array

Predicted confidence values

required
accuracies Array

Binary accuracy indicators

required

Returns:

Type Description
Array

Calibrated confidence values

ConformalPrediction

ConformalPrediction(alpha: float = 0.1, *, rngs: Rngs)

Bases: Module

Conformal prediction for calibrated uncertainty intervals.

Provides prediction intervals with finite-sample coverage guarantees based on conformal prediction theory.

Parameters:

Name Type Description Default
alpha float

Miscoverage level (1-alpha is the target coverage)

0.1
rngs Rngs

Random number generators

required

calibrate

calibrate(predictions: Array, true_values: Array) -> None

Calibrate conformal prediction using calibration set.

Parameters:

Name Type Description Default
predictions Array

Model predictions on calibration set

required
true_values Array

True values for calibration set

required

predict_intervals

predict_intervals(predictions: Array) -> tuple[Array, Array]

Compute conformal prediction intervals.

Parameters:

Name Type Description Default
predictions Array

Model predictions for test set

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (lower_bounds, upper_bounds) for prediction intervals

compute_coverage

compute_coverage(lower_bounds: Array, upper_bounds: Array, true_values: Array) -> float

Compute empirical coverage of prediction intervals.

Parameters:

Name Type Description Default
lower_bounds Array

Lower bounds of prediction intervals

required
upper_bounds Array

Upper bounds of prediction intervals

required
true_values Array

True values

required

Returns:

Type Description
float

Empirical coverage rate

IsotonicRegression

IsotonicRegression(n_bins: int = 100, *, rngs: Rngs)

Bases: Module

Isotonic regression for calibration.

Non-parametric calibration method that learns a monotonic mapping from confidence scores to calibrated probabilities.

Parameters:

Name Type Description Default
n_bins int

Number of bins for isotonic regression

100
rngs Rngs

Random number generators

required

fit

fit(confidences: Array, labels: Array) -> None

Fit isotonic regression using pool adjacent violators algorithm.

Parameters:

Name Type Description Default
confidences Array

Training confidence scores

required
labels Array

Binary labels (0 or 1)

required

PlattScaling

PlattScaling(*, rngs: Rngs)

Bases: Module

Platt scaling for probabilistic calibration.

Applies a sigmoid function to logits to improve calibration of binary classification problems.

Parameters:

Name Type Description Default
rngs Rngs

Random number generators

required

fit

fit(logits: Array, labels: Array, max_iterations: int = 100) -> None

Fit Platt scaling parameters using maximum likelihood.

Parameters:

Name Type Description Default
logits Array

Training logits

required
labels Array

Binary labels (0 or 1)

required
max_iterations int

Maximum number of optimization iterations

100

TemperatureScaling

TemperatureScaling(physics_constraints: Sequence[str] = (), adaptive: bool = False, learning_rate: float = 0.01, constraint_strength: float = 1.0, *, rngs: Rngs)

Bases: Module

Temperature scaling for uncertainty calibration.

Applies learnable temperature scaling to improve calibration of probabilistic predictions while respecting physics constraints.

Parameters:

Name Type Description Default
physics_constraints Sequence[str]

List of physics constraints to enforce

()
adaptive bool

Whether to use adaptive temperature learning

False
learning_rate float

Learning rate for temperature optimization

0.01
constraint_strength float

Strength of physics constraint enforcement

1.0
rngs Rngs

Random number generators

required

apply_physics_aware_calibration

apply_physics_aware_calibration(predictions: Array, inputs: Array) -> tuple[Array, float]

Apply physics-aware temperature scaling with constraint enforcement.

Parameters:

Name Type Description Default
predictions Array

Model predictions to calibrate

required
inputs Array

Input data for constraint evaluation

required

Returns:

Type Description
tuple[Array, float]

Tuple of (calibrated_predictions, physics_constraint_penalty)

optimize_temperature

optimize_temperature(logits: Array, labels: Array) -> float

Optimize temperature parameter for calibration.

Parameters:

Name Type Description Default
logits Array

Model logits for validation data

required
labels Array

True labels for validation data

required

Returns:

Type Description
float

Optimized temperature value

optimize_temperature_with_physics_constraints

optimize_temperature_with_physics_constraints(predictions: Array, targets: Array, inputs: Array) -> float

Optimize temperature parameter with physics constraint awareness.

Parameters:

Name Type Description Default
predictions Array

Model predictions

required
targets Array

Target values

required
inputs Array

Input data for constraint evaluation

required

Returns:

Type Description
float

Optimized temperature value

adaptive_temperature_scaling

adaptive_temperature_scaling(predictions: Array, uncertainties: Array, true_values: Array) -> Array

Apply adaptive temperature scaling based on uncertainty quality.

Parameters:

Name Type Description Default
predictions Array

Model predictions

required
uncertainties Array

Predicted uncertainties

required
true_values Array

Ground truth values

required

Returns:

Type Description
Array

Adaptively calibrated temperatures

ConformalConfig dataclass

ConformalConfig(alpha: float = 0.1)

Configuration for conformal prediction.

Attributes:

Name Type Description
alpha float

Miscoverage level. The target coverage probability is 1 - alpha. Must be in the open interval (0, 1).

ConformalPredictor

ConformalPredictor(model: Module, config: ConformalConfig | None = None)

Split conformal prediction for calibrated prediction intervals.

Wraps any point predictor (PINN, neural operator, etc.) and provides calibrated prediction intervals without distributional assumptions.

The predictor must be calibrated on a held-out calibration set before prediction intervals can be computed.

Attributes:

Name Type Description
model

The wrapped NNX module used for point predictions.

config

Conformal prediction configuration.

Parameters:

Name Type Description Default
model Module

Any Flax NNX module that maps inputs to predictions. Must implement __call__(x) -> jax.Array.

required
config ConformalConfig | None

Conformal prediction configuration. If None, uses default ConformalConfig(alpha=0.1).

None

calibrate

calibrate(x_cal: Array, y_cal: Array) -> None

Compute nonconformity scores on a calibration set.

Runs the wrapped model on x_cal, computes absolute residuals against y_cal, and stores the conformal quantile.

Parameters:

Name Type Description Default
x_cal Array

Calibration inputs with shape (n, ...).

required
y_cal Array

Calibration targets with shape (n, ...).

required

predict_with_intervals

predict_with_intervals(x: Array) -> tuple[Array, Array, Array]

Return point predictions with calibrated prediction intervals.

Parameters:

Name Type Description Default
x Array

Input array with shape (n, ...).

required

Returns:

Type Description
Array

A tuple of (predictions, lower_bounds, upper_bounds) where each

Array

array has the same shape as the model output.

Raises:

Type Description
RuntimeError

If calibrate() has not been called yet.

ConservationLawPriors

ConservationLawPriors(conservation_laws: Sequence[str] = ('energy', 'momentum', 'mass'), uncertainty_scale: float = 0.1, prior_strength: float = 1.0, adaptive_weighting: bool = True, *, rngs: Rngs)

Bases: Module

Conservation law priors for uncertainty estimation.

This class implements physics-aware priors that incorporate conservation laws directly into uncertainty quantification, enabling physically consistent uncertainty estimates.

Parameters:

Name Type Description Default
conservation_laws Sequence[str]

List of conservation laws to enforce

('energy', 'momentum', 'mass')
uncertainty_scale float

Scale factor for uncertainty estimates

0.1
prior_strength float

Strength of physics constraints in prior

1.0
adaptive_weighting bool

Whether to use adaptive constraint weighting

True
rngs Rngs

Random number generators

required

compute_physics_aware_uncertainty

compute_physics_aware_uncertainty(predictions: Array, model_uncertainty: Array, physics_state: Array) -> Array

Compute physics-aware uncertainty estimates.

Parameters:

Name Type Description Default
predictions Array

Model predictions

required
model_uncertainty Array

Basic model uncertainty

required
physics_state Array

Physical state variables for constraint evaluation

required

Returns:

Type Description
Array

Physics-aware uncertainty estimates

sample_physics_constrained_params

sample_physics_constrained_params(base_params: Array, constraint_strength: float = 1.0) -> Array

Sample parameters that satisfy physics constraints.

Parameters:

Name Type Description Default
base_params Array

Base parameter samples

required
constraint_strength float

Strength of constraint enforcement

1.0

Returns:

Type Description
Array

Physics-constrained parameter samples

DomainSpecificPriors

DomainSpecificPriors(domain: str = 'quantum_chemistry', parameter_ranges: dict[str, tuple[float, float]] | None = None, distribution_types: dict[str, str] | None = None, correlation_structure: str = 'independent', *, rngs: Rngs)

Bases: Module

Domain-specific prior distributions for scientific computing.

Provides specialized priors for different scientific domains including quantum mechanics, molecular dynamics, fluid dynamics, and materials science.

Parameters:

Name Type Description Default
domain str

Scientific domain (quantum_chemistry, molecular_dynamics, etc.)

'quantum_chemistry'
parameter_ranges dict[str, tuple[float, float]] | None

Custom parameter ranges for specific parameters

None
distribution_types dict[str, str] | None

Distribution types for each parameter

None
correlation_structure str

Correlation structure between parameters

'independent'
rngs Rngs

Random number generators

required

sample_domain_priors

sample_domain_priors(sample_shape: tuple[int, ...], parameter_type: str) -> Array

Sample from domain-specific priors.

Parameters:

Name Type Description Default
sample_shape tuple[int, ...]

Shape of samples to generate

required
parameter_type str

Type of parameter to sample

required

Returns:

Type Description
Array

Samples from domain-specific prior distribution

evaluate_prior_log_prob

evaluate_prior_log_prob(values: Array, parameter_type: str) -> Array

Evaluate log probability under domain-specific prior.

Parameters:

Name Type Description Default
values Array

Parameter values to evaluate

required
parameter_type str

Type of parameter

required

Returns:

Type Description
Array

Log probability under domain prior

HierarchicalBayesianFramework

HierarchicalBayesianFramework(hierarchy_levels: int = 3, level_dimensions: Sequence[int] = (64, 32, 16), uncertainty_propagation: str = 'multiplicative', correlation_structure: str = 'exchangeable', *, rngs: Rngs)

Bases: Module

Hierarchical Bayesian framework for multi-level uncertainty estimation.

Implements hierarchical models that can capture uncertainty at multiple scales and levels, suitable for complex scientific computing applications.

Parameters:

Name Type Description Default
hierarchy_levels int

Number of hierarchy levels

3
level_dimensions Sequence[int]

Dimensions for each hierarchy level

(64, 32, 16)
uncertainty_propagation str

How uncertainty propagates between levels

'multiplicative'
correlation_structure str

Correlation structure between levels

'exchangeable'
rngs Rngs

Random number generators

required

sample_hierarchical_parameters

sample_hierarchical_parameters(sample_shape: tuple[int, ...], level: int = 0) -> Array

Sample parameters from hierarchical model at specified level.

Parameters:

Name Type Description Default
sample_shape tuple[int, ...]

Shape of samples to generate

required
level int

Hierarchy level to sample from

0

Returns:

Type Description
Array

Hierarchical parameter samples

propagate_uncertainty_hierarchically

propagate_uncertainty_hierarchically(base_uncertainty: Array, target_level: int) -> Array

Propagate uncertainty through hierarchy levels.

Parameters:

Name Type Description Default
base_uncertainty Array

Base uncertainty estimates

required
target_level int

Target hierarchy level

required

Returns:

Type Description
Array

Hierarchically propagated uncertainty

compute_hierarchical_log_prob

compute_hierarchical_log_prob(values: Array, level: int) -> Array

Compute log probability under hierarchical model.

Parameters:

Name Type Description Default
values Array

Parameter values to evaluate

required
level int

Hierarchy level

required

Returns:

Type Description
Array

Log probability under hierarchical model

adaptive_hierarchy_weighting

adaptive_hierarchy_weighting(observed_data: Array, predictions: Array) -> Array

Adaptively weight hierarchy levels based on data fit.

Parameters:

Name Type Description Default
observed_data Array

Observed data for adaptation

required
predictions Array

Model predictions at different levels

required

Returns:

Type Description
Array

Adaptive weights for hierarchy levels

PhysicsAwareUncertaintyPropagation

PhysicsAwareUncertaintyPropagation(conservation_laws: Sequence[str] = ('energy', 'momentum'), constraint_tolerance: float = 1e-06, uncertainty_inflation: float = 1.1, correlation_aware: bool = True, *, rngs: Rngs)

Bases: Module

Physics-aware uncertainty propagation for scientific computing.

Propagates uncertainty through physics-informed models while respecting conservation laws and physical constraints.

Parameters:

Name Type Description Default
conservation_laws Sequence[str]

Conservation laws to respect during propagation

('energy', 'momentum')
constraint_tolerance float

Tolerance for constraint violations

1e-06
uncertainty_inflation float

Factor to inflate uncertainty for safety

1.1
correlation_aware bool

Whether to account for parameter correlations

True
rngs Rngs

Random number generators

required

propagate_with_physics_constraints

propagate_with_physics_constraints(input_uncertainty: Array, model_jacobian: Array, physics_state: Array) -> Array

Propagate uncertainty while respecting physics constraints.

Parameters:

Name Type Description Default
input_uncertainty Array

Input uncertainty estimates

required
model_jacobian Array

Jacobian of the model wrt inputs

required
physics_state Array

Current physics state for constraint evaluation

required

Returns:

Type Description
Array

Physics-constrained uncertainty propagation

compute_physics_informed_confidence

compute_physics_informed_confidence(predictions: Array, uncertainties: Array, physics_state: Array) -> Array

Compute physics-informed confidence intervals.

Parameters:

Name Type Description Default
predictions Array

Model predictions

required
uncertainties Array

Uncertainty estimates

required
physics_state Array

Physics state for constraint evaluation

required

Returns:

Type Description
Array

Physics-informed confidence measures

uncertainty_aware_constraint_projection

uncertainty_aware_constraint_projection(parameters: Array, uncertainties: Array) -> tuple[Array, Array]

Project parameters to satisfy constraints while accounting for uncertainty.

Parameters:

Name Type Description Default
parameters Array

Parameter values to project

required
uncertainties Array

Parameter uncertainties

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (projected_parameters, adjusted_uncertainties)

PhysicsInformedPriors

PhysicsInformedPriors(conservation_laws: Sequence[str] = (), boundary_conditions: Sequence[str] = (), constraint_weights: Array | None = None, penalty_weight: float = 1.0, *, rngs: Rngs)

Bases: Module

Physics-informed prior constraints for Bayesian models.

Enforces conservation laws, boundary conditions, and other physical constraints through learnable constraint weights and penalty functions.

Parameters:

Name Type Description Default
conservation_laws Sequence[str]

List of conservation laws to enforce

()
boundary_conditions Sequence[str]

List of boundary conditions to enforce

()
constraint_weights Array | None

Optional custom weights for constraints

None
penalty_weight float

Weight for constraint violation penalties

1.0
rngs Rngs

Random number generators

required

apply_constraints

apply_constraints(params: Array) -> Array

Apply physics constraints to sampled parameters.

Parameters:

Name Type Description Default
params Array

Unconstrained parameter samples

required

Returns:

Type Description
Array

Constrained parameters that satisfy physics laws

compute_violation_penalty

compute_violation_penalty(params: Array) -> float

Compute penalty for physics constraint violations.

Parameters:

Name Type Description Default
params Array

Parameter values to evaluate

required

Returns:

Type Description
float

Violation penalty (higher = more violation)

check_physical_plausibility

check_physical_plausibility(params: Array) -> float

Check physical plausibility of parameters.

Parameters:

Name Type Description Default
params Array

Parameter values to check

required

Returns:

Type Description
float

Plausibility score between 0 (implausible) and 1 (plausible)

AdvancedAleatoricUncertainty

Advanced aleatoric uncertainty estimation methods.

distributional_uncertainty staticmethod

distributional_uncertainty(distribution_params: dict[str, Float[Array, 'batch ...']], distribution_type: str = 'gaussian') -> Float[Array, 'batch output']

Compute aleatoric uncertainty from distributional outputs.

AdvancedEpistemicUncertainty

Advanced epistemic uncertainty estimation methods.

compute_ensemble_disagreement staticmethod

compute_ensemble_disagreement(ensemble_predictions: Float[Array, 'models batch output'], aggregation_method: str = 'variance') -> Float[Array, 'batch output']

Compute epistemic uncertainty from ensemble disagreement.

compute_predictive_diversity staticmethod

compute_predictive_diversity(ensemble_predictions: Float[Array, 'models batch output'], diversity_metric: str = 'pairwise_distance') -> Float[Array, 'batch output']

Compute predictive diversity as a measure of epistemic uncertainty.

AdvancedUncertaintyAggregator

Advanced uncertainty aggregation with multiple sources and weighting.

weighted_uncertainty_aggregation staticmethod

weighted_uncertainty_aggregation(uncertainty_sources: list[Float[Array, 'batch output']], weights: Float[Array, 'sources'] | None = None, aggregation_method: str = 'weighted_variance') -> Float[Array, 'batch output']

Aggregate uncertainties from multiple sources with optional weighting.

adaptive_weighting staticmethod

adaptive_weighting(uncertainty_sources: list[Float[Array, 'batch output']], reliability_scores: list[Float[Array, 'batch']] | None = None, adaptation_method: str = 'reliability_based') -> Float[Array, 'sources batch']

Compute adaptive weights for uncertainty sources based on reliability.

uncertainty_quality_assessment staticmethod

uncertainty_quality_assessment(predictions: Float[Array, 'batch output'], uncertainties: Float[Array, 'batch output'], true_values: Float[Array, 'batch output'] | None = None) -> dict[str, float]

Assess the quality of uncertainty estimates.

AleatoricUncertainty

Aleatoric (data) uncertainty estimation.

homoscedastic_uncertainty staticmethod

homoscedastic_uncertainty(_predictions: Float[Array, 'batch output'], log_variance: Float[Array, 'batch output']) -> Float[Array, 'batch output']

Compute homoscedastic (constant) aleatoric uncertainty.

heteroscedastic_uncertainty staticmethod

heteroscedastic_uncertainty(input_dependent_variance: Float[Array, 'batch output']) -> Float[Array, 'batch output']

Compute heteroscedastic (input-dependent) aleatoric uncertainty.

predictive_variance staticmethod

predictive_variance(predictions: Float[Array, 'samples batch output'], individual_variances: Float[Array, 'samples batch output']) -> Float[Array, 'batch output']

Compute total predictive variance including aleatoric component.

noise_estimation staticmethod

noise_estimation(residuals: Float[Array, 'batch output'], predictions: Float[Array, 'batch output']) -> Float[Array, 'batch output']

Estimate aleatoric uncertainty from residuals.

CalibrationAssessment

Enhanced uncertainty calibration assessment tools.

expected_calibration_error staticmethod

expected_calibration_error(confidences: Float[Array, 'n_samples'], accuracies: Float[Array, 'n_samples'], n_bins: int = 10) -> float

Compute Expected Calibration Error (ECE).

maximum_calibration_error staticmethod

maximum_calibration_error(confidences: Float[Array, 'n_samples'], accuracies: Float[Array, 'n_samples'], n_bins: int = 10) -> float

Compute Maximum Calibration Error (MCE).

reliability_diagram_data staticmethod

reliability_diagram_data(confidences: Float[Array, 'n_samples'], accuracies: Float[Array, 'n_samples'], n_bins: int = 10) -> dict[str, Array]

Compute reliability diagram data for visualization.

assess_calibration

assess_calibration(confidences: Float[Array, 'n_samples'], accuracies: Float[Array, 'n_samples'], n_bins: int = 10) -> CalibrationMetrics

Assess overall calibration with multiple metrics.

CalibrationMetrics dataclass

CalibrationMetrics(expected_calibration_error: float, maximum_calibration_error: float, reliability_diagram: dict[str, Array], confidence_histogram: Array, accuracy_histogram: Array)

Uncertainty calibration assessment metrics.

DistributionalAleatoricUncertainty

Distributional modeling of aleatoric uncertainty.

sample_gaussian

sample_gaussian(mean: Float[Array, 'batch output'], log_std: Float[Array, 'batch output'], num_samples: int) -> Float[Array, 'samples batch output']

Sample from Gaussian distributional output.

Parameters:

Name Type Description Default
mean Float[Array, 'batch output']

Mean predictions

required
log_std Float[Array, 'batch output']

Log standard deviation predictions

required
num_samples int

Number of samples to draw

required

Returns:

Type Description
Float[Array, 'samples batch output']

Samples from the distributional output

compute_gaussian_uncertainty

compute_gaussian_uncertainty(mean: Float[Array, 'batch output'], log_std: Float[Array, 'batch output']) -> Float[Array, 'batch output']

Compute uncertainty from Gaussian distributional parameters.

Parameters:

Name Type Description Default
mean Float[Array, 'batch output']

Mean predictions

required
log_std Float[Array, 'batch output']

Log standard deviation predictions

required

Returns:

Type Description
Float[Array, 'batch output']

Aleatoric uncertainty (variance)

compute_mixture_uncertainty

compute_mixture_uncertainty(mixture_weights: Float[Array, 'batch components'], means: Float[Array, 'batch components output'], log_stds: Float[Array, 'batch components output']) -> Float[Array, 'batch output']

Compute uncertainty from mixture of Gaussians.

Parameters:

Name Type Description Default
mixture_weights Float[Array, 'batch components']

Mixture component weights

required
means Float[Array, 'batch components output']

Component means

required
log_stds Float[Array, 'batch components output']

Component log standard deviations

required

Returns:

Type Description
Float[Array, 'batch output']

Total uncertainty from mixture model

EnhancedUncertaintyComponents dataclass

EnhancedUncertaintyComponents(epistemic_ensemble: Float[Array, 'batch output'], aleatoric_distributional: Float[Array, 'batch output'], total_uncertainty: Float[Array, 'batch output'], uncertainty_breakdown: dict[str, Float[Array, 'batch output']], epistemic_dropout: Float[Array, 'batch output'] | None = None)

Enhanced uncertainty components with multiple sources.

EnhancedUncertaintyQuantifier

EnhancedUncertaintyQuantifier(ensemble_size: int = 5, distributional_output: bool = True, multi_source_aggregation: bool = True, confidence_level: float = 0.95)

Enhanced uncertainty quantifier with multiple decomposition methods.

Parameters:

Name Type Description Default
ensemble_size int

Number of models in ensemble

5
distributional_output bool

Whether to use distributional outputs

True
multi_source_aggregation bool

Whether to aggregate multiple uncertainty sources

True
confidence_level float

Confidence level for intervals

0.95

enhanced_decompose_uncertainty

enhanced_decompose_uncertainty(ensemble_predictions: Float[Array, 'models batch output'], distributional_std: Float[Array, 'batch output'] | None = None, inputs: Float[Array, 'batch input_dim'] | None = None, dropout_predictions: Float[Array, 'samples batch output'] | None = None) -> EnhancedUncertaintyComponents

Enhanced uncertainty decomposition with multiple sources.

Parameters:

Name Type Description Default
ensemble_predictions Float[Array, 'models batch output']

Predictions from ensemble models

required
distributional_std Float[Array, 'batch output'] | None

Standard deviation from distributional output

None
inputs Float[Array, 'batch input_dim'] | None

Input data for context-dependent uncertainty

None
dropout_predictions Float[Array, 'samples batch output'] | None

Predictions with dropout for additional epistemic uncertainty

None

Returns:

Type Description
EnhancedUncertaintyComponents

Enhanced uncertainty components with detailed breakdown

EnsembleEpistemicUncertainty

EnsembleEpistemicUncertainty(num_models: int)

Ensemble-based epistemic uncertainty estimation.

Parameters:

Name Type Description Default
num_models int

Number of models in the ensemble

required

add_model

add_model(model: Any) -> None

Add a model to the ensemble.

Parameters:

Name Type Description Default
model Any

Neural network model to add to ensemble

required

aggregate_predictions

aggregate_predictions(ensemble_predictions: Float[Array, 'models batch output'], method: str = 'mean') -> Float[Array, 'batch output']

Aggregate predictions from ensemble models.

Parameters:

Name Type Description Default
ensemble_predictions Float[Array, 'models batch output']

Predictions from all ensemble models

required
method str

Aggregation method ("mean", "median", "weighted_mean")

'mean'

Returns:

Type Description
Float[Array, 'batch output']

Aggregated predictions

compute_epistemic_uncertainty

compute_epistemic_uncertainty(ensemble_predictions: Float[Array, 'models batch output']) -> Float[Array, 'batch output']

Compute epistemic uncertainty from ensemble predictions.

Parameters:

Name Type Description Default
ensemble_predictions Float[Array, 'models batch output']

Predictions from all ensemble models

required

Returns:

Type Description
Float[Array, 'batch output']

Epistemic uncertainty (variance across models)

compute_prediction_disagreement

compute_prediction_disagreement(ensemble_predictions: Float[Array, 'models batch output']) -> Float[Array, 'batch output']

Compute prediction disagreement metric.

Parameters:

Name Type Description Default
ensemble_predictions Float[Array, 'models batch output']

Predictions from all ensemble models

required

Returns:

Type Description
Float[Array, 'batch output']

Disagreement metric (pairwise prediction variance)

EpistemicUncertainty

Epistemic (model) uncertainty estimation.

compute_variance staticmethod

compute_variance(predictions: Float[Array, 'samples batch output']) -> Float[Array, 'batch output']

Compute epistemic uncertainty as variance across model samples.

compute_entropy staticmethod

compute_entropy(predictions: Float[Array, 'samples batch classes']) -> Float[Array, 'batch classes']

Compute predictive entropy for classification tasks.

compute_mutual_information staticmethod

compute_mutual_information(predictions: Float[Array, 'samples batch classes']) -> Float[Array, 'batch classes']

Compute mutual information between predictions and model parameters.

compute_variance_of_expected staticmethod

compute_variance_of_expected(predictions: Float[Array, 'samples batch output']) -> Float[Array, 'batch output']

Compute variance of expected predictions (pure epistemic uncertainty).

MultiSourceUncertaintyAggregator

Aggregation of uncertainty from multiple sources.

aggregate_uncertainties

aggregate_uncertainties(epistemic_sources: list[Float[Array, 'batch output']], aleatoric_sources: list[Float[Array, 'batch output']], method: str = 'variance_sum', epistemic_weights: Array | None = None, aleatoric_weights: Array | None = None) -> Float[Array, 'batch output']

Aggregate uncertainties from multiple sources.

Parameters:

Name Type Description Default
epistemic_sources list[Float[Array, 'batch output']]

List of epistemic uncertainty estimates

required
aleatoric_sources list[Float[Array, 'batch output']]

List of aleatoric uncertainty estimates

required
method str

Aggregation method ("variance_sum", "weighted_sum", "max")

'variance_sum'
epistemic_weights Array | None

Weights for epistemic sources

None
aleatoric_weights Array | None

Weights for aleatoric sources

None

Returns:

Type Description
Float[Array, 'batch output']

Total aggregated uncertainty

compute_uncertainty_breakdown

compute_uncertainty_breakdown(epistemic_sources: list[Float[Array, 'batch output']], aleatoric_sources: list[Float[Array, 'batch output']], source_names: list[str] | None = None) -> dict[str, Float[Array, 'batch output']]

Compute detailed uncertainty breakdown by source.

Parameters:

Name Type Description Default
epistemic_sources list[Float[Array, 'batch output']]

List of epistemic uncertainty estimates

required
aleatoric_sources list[Float[Array, 'batch output']]

List of aleatoric uncertainty estimates

required
source_names list[str] | None

Names for uncertainty sources

None

Returns:

Type Description
dict[str, Float[Array, 'batch output']]

Dictionary mapping source names to uncertainty values

UncertaintyComponents dataclass

UncertaintyComponents(epistemic: Float[Array, ...], aleatoric: Float[Array, ...], total: Float[Array, ...])

Decomposed uncertainty components.

UncertaintyIntegrationResults dataclass

UncertaintyIntegrationResults(predictions: Float[Array, 'batch output'], uncertainty_components: UncertaintyComponents, calibration_metrics: CalibrationMetrics, confidence_intervals: tuple[Float[Array, 'batch output'], Float[Array, 'batch output']], prediction_intervals: tuple[Float[Array, 'batch output'], Float[Array, 'batch output']])

Results from uncertainty propagation through model pipeline.

UncertaintyQuantifier

UncertaintyQuantifier(num_samples: int = 100, confidence_level: float = 0.95)

Enhanced uncertainty quantification interface with integration capabilities.

decompose_uncertainty

decompose_uncertainty(predictions: Float[Array, 'samples batch output'], aleatoric_variance: Float[Array, 'samples batch output'] | None = None) -> UncertaintyComponents

Decompose total uncertainty into epistemic and aleatoric components.

enhanced_uncertainty_decomposition

enhanced_uncertainty_decomposition(predictions: Float[Array, 'samples batch output'], true_values: Float[Array, 'batch output'] | None = None, inputs: Float[Array, 'batch input_dim'] | None = None) -> UncertaintyComponents

Enhanced uncertainty decomposition with additional context.

compute_confidence_intervals

compute_confidence_intervals(predictions: Float[Array, 'samples batch output'], confidence_level: float | None = None) -> tuple[Float[Array, 'batch output'], Float[Array, 'batch output']]

Compute confidence intervals from prediction samples.

compute_prediction_intervals

compute_prediction_intervals(mean_predictions: Float[Array, 'batch output'], total_variance: Float[Array, 'batch output'], confidence_level: float | None = None) -> tuple[Float[Array, 'batch output'], Float[Array, 'batch output']]

Compute prediction intervals using Gaussian assumption.

propagate_uncertainty

propagate_uncertainty(predictions: Float[Array, 'samples batch output'], inputs: Float[Array, 'batch input_dim'], true_values: Float[Array, 'batch output'] | None = None) -> UncertaintyIntegrationResults

Propagate uncertainty through the entire prediction pipeline.

AmortizedVariationalFramework

AmortizedVariationalFramework(base_model: Module, prior_config: PriorConfig, variational_config: VariationalConfig, *, rngs: Rngs)

Bases: Module

Variational framework with amortized uncertainty estimation.

This framework combines a base neural network model with variational Bayesian inference capabilities, enabling uncertainty quantification through amortized variational inference.

Parameters:

Name Type Description Default
base_model Module

Base neural network model to augment with uncertainty.

required
prior_config PriorConfig

Configuration for physics-informed priors.

required
variational_config VariationalConfig

Configuration for variational inference.

required
rngs Rngs

Random number generator state.

required

predict_with_uncertainty

predict_with_uncertainty(x: Float[Array, 'batch input_dim'], num_samples: int | None = None, *, rngs: Rngs) -> tuple[Float[Array, 'batch output_dim'], Float[Array, 'batch output_dim']]

Forward pass with uncertainty quantification.

Parameters:

Name Type Description Default
x Float[Array, 'batch input_dim']

Input tensor of shape (batch_size, input_dim).

required
num_samples int | None

Number of Monte Carlo samples for uncertainty estimation.

None
rngs Rngs

Random number generator state.

required

Returns:

Type Description
tuple[Float[Array, 'batch output_dim'], Float[Array, 'batch output_dim']]

Tuple of (mean_prediction, uncertainty) both of shape (batch_size, output_dim).

compute_elbo

compute_elbo(x: Float[Array, 'batch input_dim'], y: Float[Array, 'batch output_dim'], num_samples: int | None = None, *, rngs: Rngs) -> Float[Array, '']

Compute Evidence Lower BOund (ELBO).

Parameters:

Name Type Description Default
x Float[Array, 'batch input_dim']

Input tensor of shape (batch_size, input_dim).

required
y Float[Array, 'batch output_dim']

Target tensor of shape (batch_size, output_dim).

required
num_samples int | None

Number of Monte Carlo samples for ELBO estimation.

None
rngs Rngs

Random number generator state.

required

Returns:

Type Description
Float[Array, '']

ELBO scalar value (higher is better).

sample_predictive_distribution

sample_predictive_distribution(x: Float[Array, 'batch input_dim'], num_samples: int | None = None, *, rngs: Rngs) -> Float[Array, 'samples batch output_dim']

Sample from predictive distribution.

Parameters:

Name Type Description Default
x Float[Array, 'batch input_dim']

Input tensor of shape (batch_size, input_dim).

required
num_samples int | None

Number of predictive samples to generate.

None
rngs Rngs

Random number generator state.

required

Returns:

Type Description
Float[Array, 'samples batch output_dim']

Predictive samples of shape (num_samples, batch_size, output_dim).

MeanFieldGaussian

MeanFieldGaussian(num_params: int, *, rngs: Rngs)

Bases: Module

Mean-field Gaussian variational posterior.

This class implements a factorized Gaussian posterior distribution for variational inference in neural networks.

Parameters:

Name Type Description Default
num_params int

Number of parameters in the posterior.

required
rngs Rngs

Random number generator state.

required

sample

sample(num_samples: int, *, rngs: Rngs) -> Float[Array, 'samples params']

Sample from variational posterior.

Parameters:

Name Type Description Default
num_samples int

Number of samples to draw.

required
rngs Rngs

Random number generator state.

required

Returns:

Type Description
Float[Array, 'samples params']

Array of shape (num_samples, num_params) containing parameter samples.

log_prob

log_prob(samples: Float[Array, 'samples params']) -> Float[Array, samples]

Compute log probability of samples.

Parameters:

Name Type Description Default
samples Float[Array, 'samples params']

Parameter samples of shape (num_samples, num_params).

required

Returns:

Type Description
Float[Array, samples]

Log probabilities for each sample of shape (num_samples,).

kl_divergence

kl_divergence(prior_mean: float = 0.0, prior_std: float = 1.0) -> Float[Array, '']

Compute KL divergence from prior.

Parameters:

Name Type Description Default
prior_mean float

Mean of the prior distribution.

0.0
prior_std float

Standard deviation of the prior distribution.

1.0

Returns:

Type Description
Float[Array, '']

KL divergence scalar value.

PriorConfig dataclass

PriorConfig(conservation_laws: Sequence[str] = (), boundary_conditions: Sequence[str] = (), physics_constraints: Sequence[str] = (), prior_scale: float = 1.0)

Configuration for physics-informed priors.

Attributes:

Name Type Description
conservation_laws Sequence[str]

List of conservation laws to enforce (e.g., ['energy', 'momentum']).

boundary_conditions Sequence[str]

List of boundary conditions to incorporate.

physics_constraints Sequence[str]

List of physics constraints to respect.

prior_scale float

Scale parameter for the prior distribution.

UncertaintyEncoder

UncertaintyEncoder(input_dim: int, hidden_dims: Sequence[int], output_dim: int, *, rngs: Rngs)

Bases: Module

Neural network for amortized uncertainty estimation.

This encoder network predicts the parameters of the variational posterior directly from input data, enabling amortized variational inference.

Parameters:

Name Type Description Default
input_dim int

Dimensionality of input features.

required
hidden_dims Sequence[int]

Sequence of hidden layer dimensions.

required
output_dim int

Dimensionality of output (typically 2 * num_params for mean and log_std).

required
rngs Rngs

Random number generator state.

required

VariationalConfig dataclass

VariationalConfig(input_dim: int, hidden_dims: Sequence[int] = (64, 32), num_samples: int = 10, kl_weight: float = 1.0, temperature: float = 1.0)

Configuration for variational inference.

Attributes:

Name Type Description
input_dim int

Dimensionality of input features.

hidden_dims Sequence[int]

Tuple of hidden layer dimensions for the encoder.

num_samples int

Number of samples to draw during inference.

kl_weight float

Weight for the KL divergence term in ELBO.

temperature float

Temperature parameter for variational distribution.

Domain Decomposition PINNs

Domain decomposition methods for physics-informed neural networks, enabling efficient training on complex geometries.

Base Classes

opifex.neural.pinns.domain_decomposition.base

Base classes for Domain Decomposition PINNs.

This module provides the foundational classes for domain decomposition approaches to physics-informed neural networks.

Key Classes
  • Subdomain: Represents a subdomain region in the computational domain
  • Interface: Represents the interface between adjacent subdomains
  • DomainDecompositionPINN: Abstract base class for DD-PINN variants
Design Principles
  • Each subdomain has its own neural network
  • Interfaces enforce continuity and flux matching
  • Window functions provide smooth blending (for FBPINN variants)
References
  • Survey Section 8.3: Domain Decomposition Methods

Subdomain dataclass

Subdomain(id: int, bounds: Float[Array, 'dim 2'], overlap: float = 0.0)

Representation of a subdomain in the computational domain.

A subdomain is a rectangular region defined by its bounds in each spatial dimension.

Attributes:

Name Type Description
id int

Unique identifier for this subdomain

bounds Float[Array, 'dim 2']

Array of shape (dim, 2) with [min, max] for each dimension

overlap float

Optional overlap with neighboring subdomains (for Schwarz methods)

center property

center: Float[Array, ' dim']

Compute the center of the subdomain.

volume property

volume: Float[Array, '']

Compute the volume (area in 2D, length in 1D) of the subdomain.

contains

contains(x: Float[Array, ' dim']) -> Array

Check if a point is inside this subdomain.

Parameters:

Name Type Description Default
x Float[Array, ' dim']

Point coordinates of shape (dim,)

required

Returns:

Type Description
Array

Boolean array (scalar) indicating if point is inside subdomain

Interface dataclass

Interface(subdomain_ids: tuple[int, int], points: Float[Array, 'num_points dim'], normal: Float[Array, ' dim'])

Representation of an interface between two subdomains.

The interface stores sample points for enforcing continuity conditions between adjacent subdomains.

Attributes:

Name Type Description
subdomain_ids tuple[int, int]

Tuple of (left_id, right_id) for adjacent subdomains

points Float[Array, 'num_points dim']

Sample points on the interface, shape (num_points, dim)

normal Float[Array, ' dim']

Outward normal vector from first subdomain, shape (dim,)

DomainDecompositionPINN

DomainDecompositionPINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence[Interface], hidden_dims: Sequence[int], *, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Bases: Module

Base class for Domain Decomposition PINNs.

This class provides the infrastructure for training separate networks on subdomains with interface coupling conditions.

Attributes:

Name Type Description
input_dim

Input spatial dimension

output_dim

Output dimension (solution fields)

subdomains

List of subdomain definitions

interfaces

List of interface definitions

networks

List of subdomain networks

Parameters:

Name Type Description Default
input_dim int

Input spatial dimension

required
output_dim int

Output dimension

required
subdomains Sequence[Subdomain]

List of subdomain definitions

required
interfaces Sequence[Interface]

List of interface definitions

required
hidden_dims Sequence[int]

Hidden layer dimensions (shared across subdomains)

required
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

get_subdomain_outputs

get_subdomain_outputs(x: Float[Array, ...]) -> list[Float[Array, 'batch out']]

Get outputs from all subdomain networks.

Parameters:

Name Type Description Default
x Float[Array, ...]

Input coordinates

required

Returns:

Type Description
list[Float[Array, 'batch out']]

List of outputs from each subdomain network

compute_interface_residual

compute_interface_residual() -> Float[Array, '']

Compute interface continuity residual.

Enforces u_left = u_right at interface points.

Returns:

Type Description
Float[Array, '']

Scalar interface residual (MSE of discontinuity)

compute_flux_residual

compute_flux_residual(derivative_fn: Callable[[Module, Float[Array, ...]], Float[Array, ...]]) -> Float[Array, '']

Compute interface flux continuity residual.

Enforces (du/dn)_left = (du/dn)_right at interface points.

Parameters:

Name Type Description Default
derivative_fn Callable[[Module, Float[Array, ...]], Float[Array, ...]]

Function to compute gradient of network output

required

Returns:

Type Description
Float[Array, '']

Scalar flux residual

SubdomainNetwork

SubdomainNetwork(input_dim: int, output_dim: int, hidden_dims: Sequence[int], *, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Bases: Module

Neural network for a single subdomain.

A simple MLP that processes inputs for a specific subdomain.

Parameters:

Name Type Description Default
input_dim int

Input dimension

required
output_dim int

Output dimension

required
hidden_dims Sequence[int]

List of hidden layer dimensions

required
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

uniform_partition

uniform_partition(bounds: Float[Array, 'dim 2'], num_partitions: tuple[int, ...], interface_points: int = 10) -> tuple[list[Subdomain], list[Interface]]

Create uniform partition of a rectangular domain.

Parameters:

Name Type Description Default
bounds Float[Array, 'dim 2']

Domain bounds, shape (dim, 2) with [min, max] for each dimension

required
num_partitions tuple[int, ...]

Number of partitions in each dimension

required
interface_points int

Number of sample points per interface

10

Returns:

Type Description
tuple[list[Subdomain], list[Interface]]

Tuple of (subdomains, interfaces)

XPINN (Extended PINN)

opifex.neural.pinns.domain_decomposition.xpinn

Extended Physics-Informed Neural Network (XPINN).

XPINN extends the PINN framework to handle domain decomposition with explicit interface conditions for continuity and flux matching.

Key Features
  • Separate networks for each subdomain
  • Interface continuity conditions (u_left = u_right)
  • Flux continuity conditions (du/dn_left = du/dn_right)
  • Weighted loss combination for interface enforcement
References
  • Jagtap & Karniadakis (2020): Extended Physics-Informed Neural Networks
  • Survey Section 8.3.1: XPINNs
  • GitHub: https://github.com/AmeyaJagtap/XPINNs

XPINN

XPINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence[Interface], hidden_dims: Sequence[int], *, config: XPINNConfig | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Bases: DomainDecompositionPINN

Extended Physics-Informed Neural Network.

XPINN decomposes the computational domain into non-overlapping subdomains, training a separate neural network for each subdomain. Interface conditions enforce solution continuity and flux matching between adjacent subdomains.

The total loss includes
  • Data loss (if available)
  • PDE residual loss (per subdomain)
  • Interface continuity loss: ||u_left - u_right||²
  • Interface flux loss: ||∂u/∂n_left - ∂u/∂n_right||²

Attributes:

Name Type Description
config

XPINN configuration with loss weights

input_dim

Spatial dimension

output_dim

Solution dimension

subdomains

List of subdomain definitions

interfaces

List of interface definitions

networks

List of subdomain networks

Example

subdomains = [ ... Subdomain(id=0, bounds=jnp.array([[0.0, 0.5]])), ... Subdomain(id=1, bounds=jnp.array([[0.5, 1.0]])), ... ] interfaces = [ ... Interface(subdomain_ids=(0, 1), points=jnp.array([[0.5]]), ... normal=jnp.array([1.0])) ... ] model = XPINN( ... input_dim=1, output_dim=1, ... subdomains=subdomains, interfaces=interfaces, ... hidden_dims=[32, 32], rngs=nnx.Rngs(0) ... )

Parameters:

Name Type Description Default
input_dim int

Spatial dimension

required
output_dim int

Solution dimension

required
subdomains Sequence[Subdomain]

List of subdomain definitions

required
interfaces Sequence[Interface]

List of interface definitions

required
hidden_dims Sequence[int]

Hidden layer dimensions for subdomain networks

required
config XPINNConfig | None

XPINN configuration. Uses defaults if None.

None
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

compute_continuity_loss

compute_continuity_loss() -> Float[Array, '']

Compute interface continuity loss.

Delegates to base class compute_interface_residual (DRY).

Returns:

Type Description
Float[Array, '']

Scalar continuity loss (MSE of discontinuity)

compute_flux_loss

compute_flux_loss() -> Float[Array, '']

Compute interface flux continuity loss.

Enforces ∂u/∂n_left = ∂u/∂n_right at all interface points, where n is the interface normal direction.

Returns:

Type Description
Float[Array, '']

Scalar flux loss (MSE of flux discontinuity)

compute_interface_loss

compute_interface_loss() -> Float[Array, '']

Compute total weighted interface loss.

Combines continuity and flux losses with configured weights.

Returns:

Type Description
Float[Array, '']

Scalar total interface loss

compute_subdomain_residual

compute_subdomain_residual(subdomain_id: int, residual_fn: Callable[[Callable[[Float[Array, ...]], Float[Array, 'batch out']], Float[Array, ...]], Float[Array, ' batch']], collocation_points: Float[Array, ...]) -> Float[Array, '']

Compute PDE residual for a specific subdomain.

Parameters:

Name Type Description Default
subdomain_id int

ID of the subdomain

required
residual_fn Callable[[Callable[[Float[Array, ...]], Float[Array, 'batch out']], Float[Array, ...]], Float[Array, ' batch']]

Function that computes PDE residual given network and points

required
collocation_points Float[Array, ...]

Points where to evaluate residual

required

Returns:

Type Description
Float[Array, '']

Scalar residual loss for this subdomain

compute_total_residual

compute_total_residual(residual_fn: Callable[[Callable[[Float[Array, ...]], Float[Array, 'batch out']], Float[Array, ...]], Float[Array, ' batch']], collocation_points_per_subdomain: Sequence[Float[Array, ...]]) -> Float[Array, '']

Compute total PDE residual across all subdomains.

Parameters:

Name Type Description Default
residual_fn Callable[[Callable[[Float[Array, ...]], Float[Array, 'batch out']], Float[Array, ...]], Float[Array, ' batch']]

Function that computes PDE residual

required
collocation_points_per_subdomain Sequence[Float[Array, ...]]

Collocation points for each subdomain

required

Returns:

Type Description
Float[Array, '']

Scalar total residual loss

XPINNConfig dataclass

XPINNConfig(continuity_weight: float = 1.0, flux_weight: float = 1.0, residual_weight: float = 1.0, average_residual_weight: float = 0.0)

Configuration for XPINN training.

Attributes:

Name Type Description
continuity_weight float

Weight for interface continuity loss (u_left = u_right)

flux_weight float

Weight for interface flux continuity loss (du/dn matching)

residual_weight float

Weight for PDE residual loss in each subdomain

average_residual_weight float

Weight for residual averaging at interfaces

FBPINN (Finite Basis PINN)

opifex.neural.pinns.domain_decomposition.fbpinn

Finite Basis Physics-Informed Neural Network (FBPINN).

FBPINN uses smooth window functions to create a partition of unity, enabling smooth blending of subdomain solutions without explicit interface conditions.

Key Features
  • Smooth window functions (cosine, Gaussian)
  • Partition of unity through normalization
  • No explicit interface conditions needed
  • Naturally handles overlapping subdomains
References
  • Moseley et al. (2023): Finite Basis Physics-Informed Neural Networks
  • Survey Section 8.3.2: FBPINNs
  • GitHub: https://github.com/benmoseley/FBPINNs

FBPINN

FBPINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence, hidden_dims: Sequence[int], *, config: FBPINNConfig | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Bases: DomainDecompositionPINN

Finite Basis Physics-Informed Neural Network.

FBPINN decomposes the computational domain into overlapping subdomains, using smooth window functions to blend subdomain network outputs. This creates a partition of unity that ensures smooth global solutions.

The output is computed as

u(x) = Σᵢ wᵢ(x) * uᵢ(x) / Σⱼ wⱼ(x)

where wᵢ(x) is the window function for subdomain i and uᵢ(x) is the network output for subdomain i.

Attributes:

Name Type Description
config

FBPINN configuration

windows

List of window functions for each subdomain

Example

subdomains = [ ... Subdomain(id=0, bounds=jnp.array([[0.0, 0.6]])), ... Subdomain(id=1, bounds=jnp.array([[0.4, 1.0]])), ... ] model = FBPINN( ... input_dim=1, output_dim=1, ... subdomains=subdomains, interfaces=[], ... hidden_dims=[32, 32], rngs=nnx.Rngs(0) ... )

Parameters:

Name Type Description Default
input_dim int

Spatial dimension

required
output_dim int

Solution dimension

required
subdomains Sequence[Subdomain]

List of subdomain definitions (should overlap)

required
interfaces Sequence

List of interface definitions (optional for FBPINN)

required
hidden_dims Sequence[int]

Hidden layer dimensions for subdomain networks

required
config FBPINNConfig | None

FBPINN configuration. Uses defaults if None.

None
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

compute_window_weights

compute_window_weights(x: Float[Array, ...]) -> Float[Array, 'batch num_subdomains']

Compute window weights for all subdomains.

Parameters:

Name Type Description Default
x Float[Array, ...]

Input coordinates

required

Returns:

Type Description
Float[Array, 'batch num_subdomains']

Window weights, shape (batch, num_subdomains)

FBPINNConfig dataclass

FBPINNConfig(window_type: Literal['cosine', 'gaussian'] = 'cosine', normalize_windows: bool = True, overlap_factor: float = 0.2, gaussian_sigma: float = 0.25)

Configuration for FBPINN training.

Attributes:

Name Type Description
window_type Literal['cosine', 'gaussian']

Type of window function ("cosine" or "gaussian")

normalize_windows bool

Whether to normalize window weights to sum to 1

overlap_factor float

Factor controlling subdomain overlap (for auto-partitioning)

gaussian_sigma float

Sigma parameter for Gaussian windows

WindowFunction

WindowFunction(subdomain: Subdomain)

Bases: ABC

Abstract base class for window functions.

Window functions define the influence region of each subdomain network. They should be smooth, have compact support within the subdomain, and enable partition of unity when combined.

Parameters:

Name Type Description Default
subdomain Subdomain

The subdomain this window is associated with

required

CosineWindow

CosineWindow(subdomain: Subdomain)

Bases: WindowFunction

Cosine-based window function.

w(x) = 0.5 * (1 + cos(π * r)) for r < 1, else 0

where r is the normalized distance from the subdomain center, scaled by the subdomain half-width.

This creates a smooth bump function that is 1 at the center and 0 at the boundary.

GaussianWindow

GaussianWindow(subdomain: Subdomain, sigma: float = 0.25)

Bases: WindowFunction

Gaussian-based window function.

w(x) = exp(-||x - center||² / (2 * σ²))

where σ controls the width of the Gaussian.

Parameters:

Name Type Description Default
subdomain Subdomain

The subdomain this window is associated with

required
sigma float

Standard deviation of the Gaussian (relative to subdomain size)

0.25

CPINN (Conservative PINN)

opifex.neural.pinns.domain_decomposition.cpinn

Conservative Physics-Informed Neural Network (cPINN).

cPINN extends XPINN with explicit flux conservation at interfaces, enforcing strong conservation properties required for conservation laws.

Key Features
  • Explicit flux computation at interfaces
  • Strong conservation enforcement
  • Weighted combination of continuity and flux losses
References
  • Jagtap et al. (2020): Conservative physics-informed neural networks
  • Survey Section 8.3.2: Conservative PINNs

CPINN

CPINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence[Interface], hidden_dims: Sequence[int], *, config: CPINNConfig | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Bases: DomainDecompositionPINN

Conservative Physics-Informed Neural Network.

cPINN enforces strong conservation at subdomain interfaces by explicitly computing and matching fluxes across boundaries.

The total interface loss includes
  • Continuity loss: ||u_left - u_right||²
  • Flux conservation loss: ||F_left · n - F_right · n||²

where F = ∇u is the flux (gradient) of the solution.

Attributes:

Name Type Description
config

cPINN configuration with loss weights

input_dim

Spatial dimension

output_dim

Solution dimension

subdomains

List of subdomain definitions

interfaces

List of interface definitions

networks

List of subdomain networks

Example

subdomains = [ ... Subdomain(id=0, bounds=jnp.array([[0.0, 0.5]])), ... Subdomain(id=1, bounds=jnp.array([[0.5, 1.0]])), ... ] interfaces = [ ... Interface(subdomain_ids=(0, 1), points=jnp.array([[0.5]]), ... normal=jnp.array([1.0])) ... ] model = CPINN( ... input_dim=1, output_dim=1, ... subdomains=subdomains, interfaces=interfaces, ... hidden_dims=[32, 32], rngs=nnx.Rngs(0) ... )

Parameters:

Name Type Description Default
input_dim int

Spatial dimension

required
output_dim int

Solution dimension

required
subdomains Sequence[Subdomain]

List of subdomain definitions

required
interfaces Sequence[Interface]

List of interface definitions

required
hidden_dims Sequence[int]

Hidden layer dimensions for subdomain networks

required
config CPINNConfig | None

cPINN configuration. Uses defaults if None.

None
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

compute_continuity_loss

compute_continuity_loss() -> Float[Array, '']

Compute interface continuity loss.

Delegates to base class compute_interface_residual (DRY).

Returns:

Type Description
Float[Array, '']

Scalar continuity loss (MSE of discontinuity)

compute_flux_conservation_loss

compute_flux_conservation_loss() -> Float[Array, '']

Compute flux conservation loss at interfaces.

Enforces F_left · n = F_right · n at all interface points, where F = ∇u is the flux.

Returns:

Type Description
Float[Array, '']

Scalar flux conservation loss

compute_interface_loss

compute_interface_loss() -> Float[Array, '']

Compute total weighted interface loss.

Combines continuity and flux conservation losses with configured weights.

Returns:

Type Description
Float[Array, '']

Scalar total interface loss

CPINNConfig dataclass

CPINNConfig(flux_weight: float = 1.0, continuity_weight: float = 1.0, conservation_weight: float = 0.1)

Configuration for cPINN training.

Attributes:

Name Type Description
flux_weight float

Weight for flux conservation loss at interfaces

continuity_weight float

Weight for solution continuity loss

conservation_weight float

Weight for global conservation enforcement

APINN (Augmented PINN)

opifex.neural.pinns.domain_decomposition.apinn

Augmented Physics-Informed Neural Network (APINN).

APINN uses a learnable gating network to smoothly blend subdomain solutions, allowing the model to learn optimal subdomain selection.

Key Features
  • Learnable gating network for subdomain weighting
  • Temperature-controlled softmax for soft/hard selection
  • Differentiable blending for end-to-end training
References
  • Survey Section 8.3.3: Augmented PINNs

APINN

APINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence[Interface], hidden_dims: Sequence[int], *, config: APINNConfig | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Bases: DomainDecompositionPINN

Augmented Physics-Informed Neural Network.

APINN uses a learnable gating network to determine how to blend solutions from different subdomains. Unlike FBPINN which uses fixed window functions, APINN learns the optimal blending.

The output is computed as

u(x) = Σᵢ gᵢ(x) * uᵢ(x)

where gᵢ(x) are the learned gating weights (sum to 1) and uᵢ(x) are the subdomain network outputs.

Attributes:

Name Type Description
config

APINN configuration

gating_network

Network that produces blending weights

input_dim

Spatial dimension

output_dim

Solution dimension

subdomains

List of subdomain definitions

interfaces

List of interface definitions

networks

List of subdomain networks

Example

subdomains = [ ... Subdomain(id=0, bounds=jnp.array([[0.0, 0.5]])), ... Subdomain(id=1, bounds=jnp.array([[0.5, 1.0]])), ... ] interfaces = [ ... Interface(subdomain_ids=(0, 1), points=jnp.array([[0.5]]), ... normal=jnp.array([1.0])) ... ] model = APINN( ... input_dim=1, output_dim=1, ... subdomains=subdomains, interfaces=interfaces, ... hidden_dims=[32, 32], rngs=nnx.Rngs(0) ... )

Parameters:

Name Type Description Default
input_dim int

Spatial dimension

required
output_dim int

Solution dimension

required
subdomains Sequence[Subdomain]

List of subdomain definitions

required
interfaces Sequence[Interface]

List of interface definitions

required
hidden_dims Sequence[int]

Hidden layer dimensions for subdomain networks

required
config APINNConfig | None

APINN configuration. Uses defaults if None.

None
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

get_gating_weights

get_gating_weights(x: Float[Array, 'batch dim']) -> Float[Array, 'batch num_subdomains']

Get gating weights for given points.

Parameters:

Name Type Description Default
x Float[Array, 'batch dim']

Input coordinates

required

Returns:

Type Description
Float[Array, 'batch num_subdomains']

Gating weights for each subdomain

compute_interface_loss

compute_interface_loss() -> Float[Array, '']

Compute weighted interface continuity loss.

Delegates continuity computation to base class compute_interface_residual and applies the configured continuity weight (DRY).

Returns:

Type Description
Float[Array, '']

Scalar interface loss

APINNConfig dataclass

APINNConfig(temperature: float = 1.0, gating_hidden_dims: list[int] = (lambda: [16, 16])(), continuity_weight: float = 1.0)

Configuration for APINN training.

Attributes:

Name Type Description
temperature float

Softmax temperature for gating. Lower values give sharper (more discrete) weights, higher values give smoother (more uniform) weights.

gating_hidden_dims list[int]

Hidden dimensions for the gating network

continuity_weight float

Weight for interface continuity loss

GatingNetwork

GatingNetwork(input_dim: int, num_subdomains: int, hidden_dims: Sequence[int], *, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Bases: Module

Gating network for subdomain selection.

This network takes spatial coordinates and outputs weights for blending subdomain solutions.

Attributes:

Name Type Description
layers

List of linear layers

activation

Activation function

Parameters:

Name Type Description Default
input_dim int

Input spatial dimension

required
num_subdomains int

Number of subdomains to gate

required
hidden_dims Sequence[int]

Hidden layer dimensions

required
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

For usage examples and best practices, see the Domain Decomposition PINNs Guide.

Activations

opifex.neural.activations

Activation functions optimized for scientific neural networks.

This module provides a full collection of activation functions specifically optimized for scientific machine learning applications. All functions are fully compatible with Flax NNX patterns and JAX transformations.

MODERNIZATION APPLIED: - Full Flax NNX compliance with proper type annotations - Enhanced activation function selection with error handling - Optimized implementations for scientific computing - Support for both standard and specialized activation patterns

get_activation

get_activation(name: str | Callable) -> Any

Get activation function by name or return function if already callable.

Parameters:

Name Type Description Default
name str | Callable

Name of the activation function (case-insensitive) or callable function

required

Returns:

Type Description
Any

JAX activation function or callable

Raises:

Type Description
ValueError

If activation function is not found

list_activations

list_activations() -> list[str]

List all available activation functions.

Returns:

Type Description
list[str]

List of activation function names

Examples:

>>> activations = list_activations()
>>> print(f"Available activations: {', '.join(activations)}")

register_activation

register_activation(name: str, func: Callable) -> None

Register a custom activation function.

Parameters:

Name Type Description Default
name str

Name of the activation function

required
func Callable

The activation function (should accept and return JAX arrays)

required

Examples:

>>> def my_activation(x):
...     return x ** 3
>>> register_activation("cubic", my_activation)
>>> cubic_fn = get_activation("cubic")

mish

mish(x: Array) -> Array

Mish activation function: x * tanh(softplus(x)).

Mish is a self-gated activation function that has shown excellent performance in deep networks. It's smooth and non-monotonic.

Mathematical definition: f(x) = x * tanh(ln(1 + exp(x)))

Parameters:

Name Type Description Default
x Array

Input array

required

Returns:

Type Description
Array

Output array with Mish activation applied

Note

This implementation uses softplus(x) = ln(1 + exp(x)) for numerical stability.

snake_activation

snake_activation(x: Array, a: float = 1.0) -> Array

Snake activation function: x + sin²(αx)/α.

Snake activation has been shown to work well for certain scientific applications, particularly those involving periodic patterns.

Mathematical definition: f(x) = x + (1/α) * sin²(αx)

Parameters:

Name Type Description Default
x Array

Input array

required
a float

Frequency parameter (default: 1.0)

1.0

Returns:

Type Description
Array

Output array with Snake activation applied

Note

The frequency parameter α controls the oscillation frequency. Higher values create more frequent oscillations.

gaussian_activation

gaussian_activation(x: Array, sigma: float = 1.0) -> Array

Gaussian activation function: exp(-x²/(2σ²)).

Gaussian activation can be useful for radial basis function networks and certain scientific applications where localized responses are desired.

Mathematical definition: f(x) = exp(-x²/(2σ²))

Parameters:

Name Type Description Default
x Array

Input array

required
sigma float

Standard deviation parameter (default: 1.0)

1.0

Returns:

Type Description
Array

Output array with Gaussian activation applied

Note

The σ parameter controls the width of the Gaussian. Smaller values create sharper peaks.

normalized_tanh

normalized_tanh(x: Array) -> Array

Normalized tanh activation: 1.7159 * tanh(2x/3).

This is a normalized version of tanh that has unit variance for normalized inputs, which can help with training stability.

Parameters:

Name Type Description Default
x Array

Input array

required

Returns:

Type Description
Array

Output array with normalized tanh applied

soft_exponential

soft_exponential(x: Array, alpha: float = 0.0) -> Array

Soft exponential activation function.

This is a parameterized activation that interpolates between different behaviors based on the alpha parameter.

Mathematical definition: - If α < 0: -ln(1 - α(x + α)) / α - If α = 0: x - If α > 0: (exp(αx) - 1) / α + α

Parameters:

Name Type Description Default
x Array

Input array

required
alpha float

Shape parameter

0.0

Returns:

Type Description
Array

Output array with soft exponential applied

get_derivative_activation

get_derivative_activation(name: str) -> Any

Get the derivative of an activation function.

This is useful for implementations that need explicit derivatives rather than relying on automatic differentiation.

Parameters:

Name Type Description Default
name str

Name of the activation function

required

Returns:

Type Description
Any

Derivative function of the specified activation

Raises:

Type Description
ValueError

If activation name is not recognized or derivative not available