Neural Network API Reference¶
The opifex.neural package provides the building blocks for scientific machine learning models, built on top of Flax NNX.
Base Architectures¶
Standard MLP¶
opifex.neural.base.StandardMLP
¶
StandardMLP(layer_sizes: list[int], activation: str = 'gelu', dropout_rate: float = 0.0, use_bias: bool = True, apply_final_dropout: bool = False, *, dtype: Any | None = None, param_dtype: Any = float32, rngs: Rngs, kernel_init: Callable = xavier_uniform(), bias_init: Callable = zeros)
Bases: Module
Modern Multi-Layer Perceptron implementation using FLAX NNX.
Fully compliant with Flax NNX best practices including: - Proper RNG handling with keyword-only rngs parameter - Modern activation functions (GELU default, configurable) - Efficient dropout strategies with deterministic control - Custom initialization strategies following NNX patterns - Automatic differentiation with JAX - Performance-optimized state management
Attributes:
| Name | Type | Description |
|---|---|---|
layer_sizes |
List of layer sizes including input and output dimensions |
|
activation |
Name of the activation function to use |
|
dropout_rate |
Dropout probability (0.0 means no dropout) |
|
use_bias |
Whether to include bias terms in linear layers |
|
apply_final_dropout |
Whether to apply dropout after the final layer |
|
layers |
Sequence of linear transformation layers |
|
activation_fn |
The actual activation function |
|
dropout |
Dropout | None
|
Dropout layer (None if dropout_rate is 0) |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_sizes
|
list[int]
|
List of layer sizes, e.g., [input_dim, hidden1, hidden2, output_dim] |
required |
activation
|
str
|
Activation function name ('gelu', 'tanh', 'relu', 'sigmoid', 'silu') Default is 'gelu' for modern neural networks |
'gelu'
|
dropout_rate
|
float
|
Dropout probability for regularization (0.0 = no dropout) |
0.0
|
use_bias
|
bool
|
Whether to use bias in linear projections |
True
|
apply_final_dropout
|
bool
|
Whether to apply dropout after final layer (useful for some transformer-style architectures) |
False
|
dtype
|
Any | None
|
Computation dtype for NNX linear layers. |
None
|
param_dtype
|
Any
|
Parameter storage dtype for NNX linear layers. |
float32
|
rngs
|
Rngs
|
FLAX NNX random number generator state (keyword-only) |
required |
kernel_init
|
Callable
|
Kernel initialization function (callable) |
xavier_uniform()
|
bias_init
|
Callable
|
Bias initialization function (callable) |
zeros
|
Source code in opifex/neural/base.py
Quantum MLP¶
opifex.neural.base.QuantumMLP
¶
QuantumMLP(layer_sizes: list[int], activation: str = 'tanh', enforce_symmetry: bool = True, dropout_rate: float = 0.0, use_bias: bool = True, apply_final_dropout: bool = False, symmetry_type: str = 'permutation', *, rngs: Rngs, kernel_init: Callable = xavier_uniform(), bias_init: Callable = zeros)
Bases: Module
Modern Quantum-aware Multi-Layer Perceptron for molecular and quantum systems.
Fully compliant with Flax NNX best practices while providing quantum-specific features: - Proper RNG handling with keyword-only rngs parameter - Symmetry enforcement for molecular systems - Specialized initialization for quantum properties - Physics-informed constraints with numerical stability - Modern dropout strategies with deterministic control - Quantum-specific energy and force computation methods
Attributes:
| Name | Type | Description |
|---|---|---|
layer_sizes |
List of layer sizes including input and output dimensions |
|
activation |
Activation function name |
|
enforce_symmetry |
Whether to enforce permutation symmetry |
|
dropout_rate |
Dropout probability for regularization |
|
use_bias |
Whether to use bias in linear layers |
|
apply_final_dropout |
Whether to apply dropout after the final layer |
|
layers |
Sequence of linear layers |
|
activation_fn |
Activation function |
|
dropout |
Dropout | None
|
Dropout layer (if dropout_rate > 0) |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_sizes
|
list[int]
|
List of layer sizes for the network architecture |
required |
activation
|
str
|
Activation function name ('gelu', 'tanh', 'relu', 'sigmoid', 'silu') Default is 'tanh' for quantum neural networks |
'tanh'
|
enforce_symmetry
|
bool
|
Whether to enforce molecular symmetries |
True
|
dropout_rate
|
float
|
Dropout probability for regularization (0.0 = no dropout) |
0.0
|
use_bias
|
bool
|
Whether to use bias in linear projections |
True
|
apply_final_dropout
|
bool
|
Whether to apply dropout after final layer (useful for quantum transformer-style architectures) |
False
|
symmetry_type
|
str
|
Type of symmetry to enforce ('permutation', 'rotation', 'both') |
'permutation'
|
rngs
|
Rngs
|
FLAX NNX random number generator state (keyword-only) |
required |
kernel_init
|
Callable
|
Kernel initialization function (callable, quantum-aware) |
xavier_uniform()
|
bias_init
|
Callable
|
Bias initialization function (callable, quantum-aware) |
zeros
|
Source code in opifex/neural/base.py
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | |
compute_energy
¶
Compute energy for given atomic positions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
positions
|
Array
|
Atomic positions array of shape (batch, n_atoms, 3) or (n_atoms, 3) or flattened (n_atoms*3,) |
required |
deterministic
|
bool
|
Whether to use deterministic mode (True for inference) |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Energy array with shape (batch_size, 1) for consistency |
Source code in opifex/neural/base.py
compute_forces
¶
Compute forces as negative gradient of energy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
positions
|
Array
|
Atomic positions array of shape (batch, n_atoms, 3) or (n_atoms, 3) or flattened (n_atoms*3,) |
required |
deterministic
|
bool
|
Whether to use deterministic mode (True for inference) |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Forces array of shape (batch, n_atoms, 3) or (n_atoms, 3) or (n_atoms*3,) |
Array
|
matching the input shape |
Source code in opifex/neural/base.py
compute_energy_and_forces
¶
Efficiently compute both energy and forces.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
positions
|
Array
|
Atomic positions array of shape (n_atoms, 3) or flattened (n_atoms*3,) |
required |
deterministic
|
bool
|
Whether to use deterministic mode (True for inference) |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Tuple of (energy, forces) where energy is scalar and |
Array
|
forces has shape (n_atoms, 3) or (n_atoms*3,) matching input |
Source code in opifex/neural/base.py
Neural Quantum¶
opifex.neural.quantum
¶
Neural quantum chemistry modules for scientific machine learning.
NeuralDFT
¶
NeuralDFT(*, grid_size: int = 1000, convergence_threshold: float = 1e-08, max_scf_iterations: int = 100, xc_functional_type: str = 'neural', mixing_strategy: str = 'neural', use_neural_scf: bool = True, chemical_accuracy_target: float = 0.043, enable_high_precision: bool = True, rngs: Rngs)
Bases: Module
Neural Density Functional Theory Framework.
Integrates neural exchange-correlation functionals with neural-enhanced SCF solvers for efficient DFT calculations. Designed for chemical accuracy in quantum molecular systems with proper handling of high-precision calculations and quantum constraints.
Fully compliant with modern Flax NNX patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
grid_size
|
int
|
Size of electron density grid |
1000
|
convergence_threshold
|
float
|
SCF convergence threshold in Hartree |
1e-08
|
max_scf_iterations
|
int
|
Maximum SCF iterations |
100
|
xc_functional_type
|
str
|
Type of XC functional ("neural", "lda", "pbe") |
'neural'
|
mixing_strategy
|
str
|
Density mixing strategy ("neural", "diis", "simple") |
'neural'
|
use_neural_scf
|
bool
|
Whether to use neural SCF solver enhancements |
True
|
chemical_accuracy_target
|
float
|
Target accuracy in Hartree |
0.043
|
enable_high_precision
|
bool
|
Whether to use float64 for critical calculations |
True
|
rngs
|
Rngs
|
Random number generators (keyword-only) |
required |
compute_energy
¶
compute_energy(molecular_system: Any, *, density: Array | None = None, deterministic: bool = True) -> DFTResult
Compute total energy using neural DFT with enhanced precision.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
molecular_system
|
Any
|
Molecular system to compute |
required |
density
|
Array | None
|
Optional initial density guess |
None
|
deterministic
|
bool
|
Whether to use deterministic mode |
True
|
Returns:
| Type | Description |
|---|---|
DFTResult
|
DFT calculation result with precision diagnostics |
NeuralSCFSolver
¶
NeuralSCFSolver(convergence_threshold: float = 1e-08, max_iterations: int = 100, mixing_strategy: str = 'neural', grid_size: int = 1000, chemical_accuracy_target: float = 1e-06, *, rngs: Rngs)
Bases: Module
Neural-enhanced self-consistent field solver with full
convergence analysis.
Implements neural acceleration of SCF convergence through: 1. Intelligent density mixing using neural networks 2. Advanced convergence prediction with chemical accuracy assessment 3. Stability monitoring and adaptive recovery mechanisms 4. High-precision numerical methods for quantum accuracy
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
convergence_threshold
|
float
|
Energy convergence threshold |
1e-08
|
max_iterations
|
int
|
Maximum number of SCF iterations |
100
|
mixing_strategy
|
str
|
Density mixing strategy ("neural" or "linear") |
'neural'
|
grid_size
|
int
|
Size of density grid for molecular calculations |
1000
|
chemical_accuracy_target
|
float
|
Target accuracy for chemical predictions |
1e-06
|
rngs
|
Rngs
|
Random number generators for neural components |
required |
solve_scf
¶
solve_scf(molecular_system: MolecularSystem, initial_density: Array, hamiltonian_fn: Callable | None = None, *, deterministic: bool = False) -> SCFResult
Solve SCF equations with neural acceleration and full analysis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
molecular_system
|
MolecularSystem
|
Molecular system to solve |
required |
initial_density
|
Array
|
Initial electron density guess |
required |
hamiltonian_fn
|
Callable | None
|
Custom Hamiltonian function (optional) |
None
|
deterministic
|
bool
|
Whether to use deterministic computation |
False
|
Returns:
| Type | Description |
|---|---|
SCFResult
|
Full SCF result with convergence analysis |
predict_convergence_iterations
¶
predict_convergence_iterations(molecular_system: MolecularSystem, initial_density: Array, *, deterministic: bool = False) -> int
Predict number of iterations required for convergence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
molecular_system
|
MolecularSystem
|
Molecular system to analyze |
required |
initial_density
|
Array
|
Initial density guess |
required |
deterministic
|
bool
|
Whether to use deterministic computation |
False
|
Returns:
| Type | Description |
|---|---|
int
|
Predicted number of iterations for convergence |
NeuralXCFunctional
¶
NeuralXCFunctional(hidden_sizes: Sequence[int] = (128, 128, 64), activation: Callable = gelu, use_attention: bool = True, num_attention_heads: int = 8, use_advanced_features: bool = True, dropout_rate: float = 0.0, *, rngs: Rngs)
Bases: Module
Neural exchange-correlation functional for DFT calculations.
Implements a modern neural XC functional with attention mechanisms for capturing non-local correlations, enhanced physics constraints, and chemical accuracy optimization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_sizes
|
Sequence[int]
|
Sequence of hidden layer sizes |
(128, 128, 64)
|
activation
|
Callable
|
Activation function to use |
gelu
|
use_attention
|
bool
|
Whether to use attention mechanism for non-local correlations |
True
|
num_attention_heads
|
int
|
Number of attention heads |
8
|
use_advanced_features
|
bool
|
Whether to include advanced physics features |
True
|
dropout_rate
|
float
|
Dropout rate for regularization |
0.0
|
rngs
|
Rngs
|
Random number generators |
required |
compute_functional_derivative
¶
compute_functional_derivative(density: Array, gradients: Array, *, deterministic: bool = False) -> Array
Compute functional derivative of XC energy with respect to density.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
density
|
Array
|
Electron density |
required |
gradients
|
Array
|
Density gradients |
required |
deterministic
|
bool
|
Whether to use deterministic computation |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
Functional derivative ∂E_xc/∂ρ with enhanced numerical stability |
assess_chemical_accuracy
¶
assess_chemical_accuracy(density: Array, gradients: Array, reference_energy: Array | None = None, *, deterministic: bool = False) -> dict[str, float]
Assess chemical accuracy of XC functional predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
density
|
Array
|
Electron density |
required |
gradients
|
Array
|
Density gradients |
required |
reference_energy
|
Array | None
|
Reference XC energy for comparison (optional) |
None
|
deterministic
|
bool
|
Whether to use deterministic computation |
False
|
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dictionary containing accuracy metrics |
Neural Operators¶
opifex.neural.operators
¶
Opifex Neural Operators: Full Operator Learning Library
This module provides the most complete collection of neural operators for scientific machine learning, including all major variants from the neuraloperator repository and advanced architectures.
The library includes:
- Fourier Neural Operators (FNO, TFNO, U-FNO, SFNO, Local FNO, AM-FNO)
- Deep Operator Networks (DeepONet and variants)
- Specialized operators (GINO, MGNO, UQNO, LNO, WNO, GNO)
- Physics-informed operators (PINO)
- Graph-based operators
- Uncertainty quantification operators
All operators are built with JAX/FLAX NNX for high performance and support automatic differentiation, just-in-time compilation, and multi-device parallelization.
AdaptiveDeepONet
¶
AdaptiveDeepONet(branch_input_dim: int, trunk_input_dim: int, base_latent_dim: int, *, num_resolution_levels: int = 3, adaptive_latent_scaling: bool = True, use_residual_connections: bool = True, activation: str = 'tanh', rngs: Rngs)
Bases: Module
Adaptive DeepONet with dynamic architecture adjustment.
This variant can adapt its architecture based on problem complexity and provides multiple resolution levels for different accuracy requirements.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
branch_input_dim
|
int
|
Branch network input dimension |
required |
trunk_input_dim
|
int
|
Trunk network input dimension |
required |
base_latent_dim
|
int
|
Base latent dimension (scaled for different levels) |
required |
num_resolution_levels
|
int
|
Number of resolution levels |
3
|
adaptive_latent_scaling
|
bool
|
Whether to scale latent dimensions adaptively |
True
|
use_residual_connections
|
bool
|
Whether to use residual connections |
True
|
activation
|
str
|
Activation function name |
'tanh'
|
rngs
|
Rngs
|
Random number generators |
required |
DeepONet
¶
DeepONet(branch_sizes: list[int], trunk_sizes: list[int], *, activation: str = 'gelu', output_activation: str | None = None, use_bias: bool = True, rngs: Rngs)
Bases: Module
Deep Operator Network for learning function-to-function mappings.
DeepONet learns to approximate nonlinear operators G that map functions to functions: G: u → G(u), where u and G(u) are functions.
The architecture consists of: - Branch network: Processes input function u evaluated at sensors - Trunk network: Processes evaluation locations y - Dot product combination of branch and trunk outputs
Fully compliant with modern Flax NNX patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
branch_sizes
|
list[int]
|
Layer sizes for branch network [input_sensors, hidden1, hidden2, ..., output_dim] |
required |
trunk_sizes
|
list[int]
|
Layer sizes for trunk network [location_dim, hidden1, hidden2, ..., output_dim] Note: output_dim should match branch output_dim |
required |
activation
|
str
|
Activation function name for hidden layers |
'gelu'
|
output_activation
|
str | None
|
Optional activation for final output (None means no activation on output) |
None
|
use_bias
|
bool
|
Whether to use bias in linear layers |
True
|
rngs
|
Rngs
|
Random number generators (keyword-only) |
required |
FourierEnhancedDeepONet
¶
FourierEnhancedDeepONet(branch_sizes: list[int], trunk_sizes: list[int], *, fourier_modes: int = 16, use_spectral_branch: bool = True, use_spectral_trunk: bool = False, activation: str = 'tanh', rngs: Rngs)
Bases: Module
Fourier-Enhanced DeepONet combining spectral and operator learning.
This variant integrates Fourier Neural Operator concepts into DeepONet architecture for improved performance on problems with spectral structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
branch_sizes
|
list[int]
|
Branch network layer sizes [input, hidden..., output] |
required |
trunk_sizes
|
list[int]
|
Trunk network layer sizes [input, hidden..., output] |
required |
fourier_modes
|
int
|
Number of Fourier modes for spectral layers |
16
|
use_spectral_branch
|
bool
|
Whether to use spectral convolution in branch |
True
|
use_spectral_trunk
|
bool
|
Whether to use spectral convolution in trunk |
False
|
activation
|
str
|
Activation function name |
'tanh'
|
rngs
|
Rngs
|
Random number generators |
required |
MultiPhysicsDeepONet
¶
MultiPhysicsDeepONet(branch_input_dim: int, trunk_input_dim: int, branch_hidden_dims: list[int], trunk_hidden_dims: list[int], latent_dim: int, *, num_physics_systems: int = 1, use_attention: bool = True, attention_heads: int = 8, physics_constraints: list[str] | None = None, sensor_optimization: bool = False, num_sensors: int | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Bases: Module
Enhanced DeepONet with multi-physics support and attention mechanisms.
Extends the basic DeepONet architecture with physics-aware attention, multi-physics coupling, and sensor optimization for improved operator learning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
branch_input_dim
|
int
|
Branch network input dimension |
required |
trunk_input_dim
|
int
|
Trunk network input dimension |
required |
branch_hidden_dims
|
list[int]
|
Branch network hidden dimensions |
required |
trunk_hidden_dims
|
list[int]
|
Trunk network hidden dimensions |
required |
latent_dim
|
int
|
Latent dimension for inner product |
required |
num_physics_systems
|
int
|
Number of physics systems to handle |
1
|
use_attention
|
bool
|
Whether to use physics-aware attention |
True
|
attention_heads
|
int
|
Number of attention heads |
8
|
physics_constraints
|
list[str] | None
|
List of physics constraints to enforce |
None
|
sensor_optimization
|
bool
|
Whether to use sensor optimization |
False
|
num_sensors
|
int | None
|
Number of sensors (required if sensor_optimization=True) |
None
|
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
AmortizedFourierNeuralOperator
¶
AmortizedFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int = 32, modes: Sequence[int] = (16, 16), num_layers: int = 4, kernel_hidden_dim: int = 128, kernel_layers: int = 3, max_frequency: float = 10.0, activation: Callable = gelu, use_layer_norm: bool = False, use_kernel_regularization: bool = True, *, rngs: Rngs)
Bases: Module
Amortized Fourier Neural Operator with neural kernel parameterization.
AmortizedSpectralConvolution
¶
AmortizedSpectralConvolution(in_channels: int, out_channels: int, modes: Sequence[int], kernel_hidden_dim: int = 128, kernel_layers: int = 3, max_frequency: float = 10.0, use_kernel_regularization: bool = True, *, rngs: Rngs)
Bases: Module
Amortized spectral convolution with neural kernel parameterization.
KernelNetwork
¶
KernelNetwork(freq_dim: int, output_dim: int, hidden_dim: int = 128, num_layers: int = 3, activation: Callable = gelu, use_frequency_encoding: bool = True, max_frequency: float = 10.0, *, rngs: Rngs)
Bases: Module
Neural network to parameterize Fourier kernels.
FourierLayer
¶
FourierLayer(in_channels: int, out_channels: int, modes: int, *, activation: Callable[[Array], Array] = gelu, spatial_dims: int = 2, rngs: Rngs)
Bases: Module
Fourier layer combining spectral convolution with activation.
This layer performs: 1. FFT to transform input to spectral domain 2. Spectral convolution 3. IFFT to transform back to spatial domain 4. Linear transformation and activation with proper residual connection
Fully compliant with modern Flax NNX patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
modes
|
int
|
Number of Fourier modes |
required |
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
spatial_dims
|
int
|
Number of spatial dimensions (1, 2, or 3). Controls which spectral weights are allocated — avoids dead parameters. |
2
|
rngs
|
Rngs
|
Random number generators (keyword-only) |
required |
FourierNeuralOperator
¶
FourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, modes: int, num_layers: int, *, activation: Callable[[Array], Array] = gelu, factorization_type: str | None = None, factorization_rank: int | None = None, use_mixed_precision: bool = False, domain_padding: int = 0, spatial_dims: int = 2, rngs: Rngs)
Bases: Module
Fourier Neural Operator for learning solution operators of PDEs.
Implements the complete FNO architecture with optional tensor factorization and mixed precision training capabilities. Fully compliant with modern Flax NNX patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
hidden_channels
|
int
|
Number of hidden channels |
required |
modes
|
int
|
Number of Fourier modes |
required |
num_layers
|
int
|
Number of Fourier layers |
required |
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
factorization_type
|
str | None
|
Optional tensor factorization ('tucker', 'cp') |
None
|
factorization_rank
|
int | None
|
Rank for tensor factorization |
None
|
use_mixed_precision
|
bool
|
Whether to use mixed precision |
False
|
domain_padding
|
int
|
Pixels to pad spatial dims (reduces Gibbs phenomenon for non-periodic problems). Set to 2 for Darcy flow. |
0
|
spatial_dims
|
int
|
Number of spatial dimensions (1, 2, or 3). Determines which spectral weights are allocated per layer. |
2
|
rngs
|
Rngs
|
Random number generators (keyword-only) |
required |
FactorizedFourierLayer
¶
FactorizedFourierLayer(in_channels: int, out_channels: int, modes: int, factorization_type: str, factorization_rank: int, *, activation: Callable[[Array], Array] = gelu, rngs: Rngs)
Bases: Module
Fourier layer with tensor factorization for parameter reduction.
Implements Tucker or CP factorization of the spectral convolution weights to achieve significant parameter reduction (up to 95%) while maintaining performance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
modes
|
int
|
Number of Fourier modes |
required |
factorization_type
|
str
|
Type of factorization ("tucker" or "cp") |
required |
factorization_rank
|
int
|
Rank for factorization |
required |
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generators |
required |
LocalFourierLayer
¶
LocalFourierLayer(in_channels: int, out_channels: int, modes: Sequence[int], kernel_size: int = 3, activation: Callable = gelu, mixing_weight: float = 0.5, *, rngs: Rngs)
Bases: Module
Fourier layer with local convolution for capturing short-range interactions.
Combines global spectral convolution with local spatial convolution for full feature extraction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
modes
|
Sequence[int]
|
Fourier modes for spectral convolution |
required |
kernel_size
|
int
|
Kernel size for local convolution |
3
|
activation
|
Callable
|
Activation function |
gelu
|
mixing_weight
|
float
|
Weight for combining spectral and local branches |
0.5
|
rngs
|
Rngs
|
Random number generator state |
required |
get_mixing_analysis
¶
get_mixing_analysis(x: Array) -> tuple[Array, Array, Array]
Analyze global vs local contributions for this layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor (batch, in_channels, *spatial). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Tuple of (global_features, local_features, mixing_weights) |
Array
|
where mixing_weights is a scalar array of the spectral weight. |
LocalFourierNeuralOperator
¶
LocalFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, modes: Sequence[int], num_layers: int = 4, kernel_size: int = 3, use_adaptive_mixing: bool = True, use_residual_connections: bool = True, activation: Callable = gelu, *, rngs: Rngs)
Bases: Module
Local Fourier Neural Operator combining global and local operations.
This operator is designed for problems that require both: - Long-range dependencies (captured by Fourier operations) - Local features and fine details (captured by convolutions)
Examples include: - Turbulent flows with both large-scale structures and small eddies - Wave propagation with local scattering and global modes - Multi-physics problems with different characteristic scales
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
hidden_channels
|
int
|
Hidden layer width |
required |
modes
|
Sequence[int]
|
Fourier modes for global operations |
required |
num_layers
|
int
|
Number of Local Fourier layers |
4
|
kernel_size
|
int
|
Kernel size for local convolutions |
3
|
use_adaptive_mixing
|
bool
|
Whether to use adaptive feature mixing |
True
|
use_residual_connections
|
bool
|
Whether to use residual connections |
True
|
activation
|
Callable
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generator state |
required |
MultiScaleFourierNeuralOperator
¶
MultiScaleFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, modes_per_scale: list[int], num_layers_per_scale: list[int], *, spatial_dims: int = 2, activation: Callable[[Array], Array] = gelu, use_cross_scale_attention: bool = True, attention_heads: int = 8, dropout_rate: float = 0.0, use_gradient_checkpointing: bool = True, rngs: Rngs)
Bases: Module
Multi-Scale Fourier Neural Operator for hierarchical resolution handling.
This operator learns operators across multiple scales simultaneously, enabling efficient handling of multi-scale physics problems like turbulence, multi-phase flows, and hierarchical material structures.
Features: - Hierarchical spectral convolutions at different resolution levels - Adaptive scale selection based on input characteristics - Cross-scale information exchange through attention mechanisms - Memory-efficient implementation with gradient checkpointing
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
hidden_channels
|
int
|
Hidden channel dimension |
required |
modes_per_scale
|
list[int]
|
List of Fourier modes for each scale |
required |
num_layers_per_scale
|
list[int]
|
List of layer counts for each scale |
required |
spatial_dims
|
int
|
Number of spatial dimensions (1 or 2) |
2
|
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
use_cross_scale_attention
|
bool
|
Whether to use cross-scale attention |
True
|
attention_heads
|
int
|
Number of attention heads |
8
|
dropout_rate
|
float
|
Dropout rate for regularization |
0.0
|
use_gradient_checkpointing
|
bool
|
Whether to use gradient checkpointing |
True
|
rngs
|
Rngs
|
Random number generators |
required |
SphericalFourierNeuralOperator
¶
SphericalFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, lmax: int, mmax: int | None = None, num_layers: int = 4, activation: Callable = gelu, use_real_sht: bool = False, *, rngs: Rngs)
Bases: Module
Spherical Fourier Neural Operator for data on spherical domains.
Uses spherical harmonic transforms instead of regular FFTs, making it ideal for: - Global atmospheric modeling - Ocean circulation - Planetary science - Any data naturally defined on spheres
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
hidden_channels
|
int
|
Hidden layer width |
required |
lmax
|
int
|
Maximum spherical harmonic degree |
required |
mmax
|
int | None
|
Maximum azimuthal order (if None, uses lmax) |
None
|
num_layers
|
int
|
Number of SFNO layers |
4
|
activation
|
Callable
|
Activation function |
gelu
|
use_real_sht
|
bool
|
Whether to use real spherical harmonics |
False
|
rngs
|
Rngs
|
Random number generator state |
required |
get_spherical_modes
¶
Get spherical harmonic coefficients for analysis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor on sphere |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Spherical harmonic coefficients |
compute_power_spectrum
¶
Compute spherical harmonic power spectrum.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor on sphere |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Power spectrum as function of spherical harmonic degree l |
SphericalHarmonicConvolution
¶
SphericalHarmonicConvolution(in_channels: int, out_channels: int, lmax: int, mmax: int | None = None, *, rngs: Rngs)
Bases: Module
Spherical harmonic convolution for spherical domains.
Operates in spherical harmonic space analogous to how standard FNO operates in Fourier space, but adapted for spherical geometry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
lmax
|
int
|
Maximum spherical harmonic degree (controls resolution) |
required |
mmax
|
int | None
|
Maximum azimuthal order (if None, uses lmax) |
None
|
rngs
|
Rngs
|
Random number generator state |
required |
TensorizedFourierNeuralOperator
¶
TensorizedFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), num_layers: int = 4, factorization: Literal['tucker', 'cp', 'tt'] = 'tucker', rank: float = 0.1, *, rngs: Rngs)
Bases: Module
Simplified Tensorized FNO with stable implementations.
TensorizedSpectralConvolution
¶
UFNODecoderBlock
¶
UFNODecoderBlock(in_channels: int, skip_channels: int, out_channels: int, modes: Sequence[int], upsample_factor: int = 2, activation: Callable = gelu, *, rngs: Rngs)
Bases: Module
Clean U-FNO decoder block with standardized tensor operations.
Performs: upsampling + skip fusion + spectral convolution
UFNOEncoderBlock
¶
UFNOEncoderBlock(in_channels: int, out_channels: int, modes: Sequence[int], downsample_factor: int = 2, activation: Callable = gelu, *, rngs: Rngs)
Bases: Module
Clean U-FNO encoder block with standardized tensor operations.
Performs: spectral convolution + skip connection + downsampling
UFourierNeuralOperator
¶
UFourierNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, modes: Sequence[int], num_levels: int = 3, downsample_factor: int = 2, activation: Callable = gelu, *, rngs: Rngs)
Bases: Module
U-Net style Fourier Neural Operator with clean, standardized architecture.
Features: - Consistent tensor dimension handling - Standardized spectral operations - Clean encoder-decoder structure - Proper channel management throughout
GraphNeuralOperator
¶
GraphNeuralOperator(node_dim: int, hidden_dim: int, num_layers: int, *, edge_dim: int = 0, activation: Callable[[Array], Array] = gelu, rngs: Rngs)
Bases: Module
Graph Neural Operator for learning operators on irregular domains.
Implements message passing neural networks with geometric awareness for learning operators on graph-structured data. Suitable for irregular meshes, molecular systems, and other graph-based scientific computing applications.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_dim
|
int
|
Dimension of node features |
required |
hidden_dim
|
int
|
Hidden dimension for message passing |
required |
num_layers
|
int
|
Number of message passing layers |
required |
edge_dim
|
int
|
Dimension of edge features (0 for no edge features) |
0
|
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generators |
required |
MessagePassingLayer
¶
MessagePassingLayer(node_dim: int, edge_dim: int, hidden_dim: int, *, activation: Callable[[Array], Array] = gelu, rngs: Rngs)
Bases: Module
Message passing layer for graph neural networks.
Implements the message passing paradigm: 1. Compute messages between connected nodes 2. Aggregate messages at each node 3. Update node features based on aggregated messages
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
node_dim
|
int
|
Dimension of node features |
required |
edge_dim
|
int
|
Dimension of edge features |
required |
hidden_dim
|
int
|
Hidden dimension for message computation |
required |
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generators |
required |
PhysicsAwareAttention
¶
PhysicsAwareAttention(embed_dim: int, num_heads: int, *, physics_constraints: list[str] | None = None, dropout_rate: float = 0.0, rngs: Rngs)
Bases: Module
Physics-aware attention mechanism with constraint enforcement.
Integrates physics constraints into the attention mechanism to ensure physically meaningful attention patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embed_dim
|
int
|
Embedding dimension |
required |
num_heads
|
int
|
Number of attention heads |
required |
physics_constraints
|
list[str] | None
|
List of physics constraints to enforce |
None
|
dropout_rate
|
float
|
Dropout rate for attention weights |
0.0
|
rngs
|
Rngs
|
Random number generators |
required |
PhysicsCrossAttention
¶
PhysicsCrossAttention(embed_dim: int, num_heads: int, physics_constraints: list[str], num_physics_systems: int, *, conservation_weight: float = 0.1, adaptive_weighting: bool = True, cross_system_coupling: bool = True, dropout_rate: float = 0.0, rngs: Rngs)
Bases: Module
Physics-Cross-Attention mechanism for enhanced multi-physics coupling.
Implements cross-attention between different physics systems with conservation law enforcement and adaptive weighting based on physics constraints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embed_dim
|
int
|
Embedding dimension |
required |
num_heads
|
int
|
Number of attention heads |
required |
physics_constraints
|
list[str]
|
List of physics constraints to enforce |
required |
num_physics_systems
|
int
|
Number of different physics systems |
required |
conservation_weight
|
float
|
Weight for conservation law enforcement |
0.1
|
adaptive_weighting
|
bool
|
Whether to use adaptive constraint weighting |
True
|
cross_system_coupling
|
bool
|
Whether to enable cross-system coupling |
True
|
dropout_rate
|
float
|
Dropout rate for attention weights |
0.0
|
rngs
|
Rngs
|
Random number generators |
required |
forward_with_conservation
¶
forward_with_conservation(x: Array, *, physics_info: Array | None = None, training: bool = False) -> tuple[Array, Array]
Forward pass with conservation loss computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor |
required |
physics_info
|
Array | None
|
Physics constraint information |
None
|
training
|
bool
|
Whether in training mode |
False
|
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Tuple of (output, conservation_loss) |
PhysicsInformedOperator
¶
PhysicsInformedOperator(layer_sizes: list[int], physics_type: str = 'pde', *, activation: str = 'gelu', physics_weight: float = 1.0, data_weight: float = 1.0, use_bias: bool = True, rngs: Rngs)
Bases: Module
Physics-Informed Neural Operator with embedded physical constraints.
This operator combines standard neural operator architectures with physics-based constraints and differential operators to ensure physically consistent solutions.
Fully compliant with modern Flax NNX patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_sizes
|
list[int]
|
Layer sizes for the neural network [input_dim, hidden1, hidden2, ..., output_dim] |
required |
physics_type
|
str
|
Type of physics constraint ('pde', 'conservation', 'symmetry') |
'pde'
|
activation
|
str
|
Activation function name |
'gelu'
|
physics_weight
|
float
|
Weight for physics loss component |
1.0
|
data_weight
|
float
|
Weight for data loss component |
1.0
|
use_bias
|
bool
|
Whether to use bias in linear layers |
True
|
rngs
|
Rngs
|
Random number generators (keyword-only) |
required |
compute_physics_loss
¶
compute_total_loss
¶
compute_total_loss(coordinates: Array, target_solution: Array | None = None, *, deterministic: bool = True) -> dict[str, Array]
Compute total loss combining data and physics components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coordinates
|
Array
|
Space-time coordinates |
required |
target_solution
|
Array | None
|
Target solution (optional, for supervised learning) |
None
|
deterministic
|
bool
|
Whether to use deterministic mode |
True
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary containing individual loss components and total loss |
GeometryAttention
¶
GeometryAttention(feature_dim: int, geometry_dim: int, num_heads: int = 8, use_distance_attention: bool = True, *, rngs: Rngs)
Bases: Module
Geometry-aware attention mechanism.
Computes attention weights based on both feature similarity and geometric relationships with proper dimension handling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int
|
Dimension of feature vectors |
required |
geometry_dim
|
int
|
Dimension of geometry embeddings |
required |
num_heads
|
int
|
Number of attention heads |
8
|
use_distance_attention
|
bool
|
Whether to include distance-based attention |
True
|
rngs
|
Rngs
|
Random number generator state |
required |
GeometryEncoder
¶
GeometryEncoder(coord_dim: int, hidden_dim: int, output_dim: int, use_positional_encoding: bool = True, max_position: float = 10000.0, *, rngs: Rngs)
Bases: Module
Encoder for geometric coordinates with positional encoding.
Transforms coordinate information into rich geometric embeddings suitable for neural operator processing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coord_dim
|
int
|
Dimension of input coordinates |
required |
hidden_dim
|
int
|
Hidden layer dimension |
required |
output_dim
|
int
|
Output embedding dimension |
required |
use_positional_encoding
|
bool
|
Whether to use sinusoidal positional encoding |
True
|
max_position
|
float
|
Maximum position for encoding |
10000.0
|
rngs
|
Rngs
|
Random number generator state |
required |
GeometryInformedNeuralOperator
¶
GeometryInformedNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), num_layers: int = 4, geometry_dim: int = 32, coord_dim: int = 2, use_geometry_attention: bool = True, use_spectral_conv: bool = True, *, rngs: Rngs)
Bases: Module
Complete Geometry-Informed Neural Operator.
Advanced neural operator that incorporates geometric information throughout the network for improved performance on spatially complex problems.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
hidden_channels
|
int
|
Hidden channel dimension |
64
|
modes
|
Sequence[int]
|
Fourier modes for spectral convolution |
(16, 16)
|
num_layers
|
int
|
Number of GINO blocks |
4
|
geometry_dim
|
int
|
Dimension of geometry embeddings |
32
|
coord_dim
|
int
|
Coordinate dimension |
2
|
use_geometry_attention
|
bool
|
Whether to use geometry attention |
True
|
use_spectral_conv
|
bool
|
Whether to use spectral convolution |
True
|
rngs
|
Rngs
|
Random number generator state |
required |
GINOBlock
¶
GINOBlock(in_channels: int, out_channels: int, modes: Sequence[int], geometry_dim: int, coord_dim: int = 2, use_geometry_attention: bool = True, use_spectral_conv: bool = True, activation: Callable[[Array], Array] = gelu, *, rngs: Rngs)
Bases: Module
Single GINO block with spectral convolution and geometry attention.
Combines spectral convolutions with geometry-aware processing for enhanced spatial understanding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
modes
|
Sequence[int]
|
Fourier modes for spectral convolution |
required |
geometry_dim
|
int
|
Dimension of geometry embeddings |
required |
coord_dim
|
int
|
Dimension of coordinates |
2
|
use_geometry_attention
|
bool
|
Whether to use geometry attention |
True
|
use_spectral_conv
|
bool
|
Whether to use spectral convolution |
True
|
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generator state |
required |
LatentNeuralOperator
¶
LatentNeuralOperator(in_channels: int, out_channels: int, latent_dim: int, num_latent_tokens: int, *, num_attention_heads: int = 8, num_encoder_layers: int = 4, num_decoder_layers: int = 4, physics_constraints: list[str] | None = None, dropout_rate: float = 0.0, activation: Callable[[Array], Array] = gelu, rngs: Rngs)
Bases: Module
Latent Neural Operator with attention-based latent representations.
This operator learns compact latent representations of function spaces using attention mechanisms, enabling efficient learning of complex operator mappings with reduced computational overhead.
Features: - Learnable latent space for function representation - Multi-head attention for function-to-latent and latent-to-function mappings - Physics-aware attention constraints - Efficient inference through latent space operations
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
latent_dim
|
int
|
Dimension of latent space |
required |
num_latent_tokens
|
int
|
Number of latent tokens |
required |
num_attention_heads
|
int
|
Number of attention heads |
8
|
num_encoder_layers
|
int
|
Number of encoder layers |
4
|
num_decoder_layers
|
int
|
Number of decoder layers |
4
|
physics_constraints
|
list[str] | None
|
List of physics constraints |
None
|
dropout_rate
|
float
|
Dropout rate |
0.0
|
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generators |
required |
MGNOLayer
¶
MGNOLayer(channels: int, max_multipole_order: int = 4, use_local_messages: bool = True, dropout_rate: float = 0.1, *, rngs: Rngs)
Bases: Module
MGNO layer with numerical stability and robust message passing.
Combines multipole expansion with local graph neural network operations for handling both long-range and short-range interactions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channels
|
int
|
Number of feature channels |
required |
max_multipole_order
|
int
|
Maximum multipole expansion order |
4
|
use_local_messages
|
bool
|
Whether to use local message passing |
True
|
dropout_rate
|
float
|
Dropout rate for regularization |
0.1
|
rngs
|
Rngs
|
Random number generator state |
required |
MultipoleExpansion
¶
MultipoleExpansion(channels: int, max_order: int = 4, epsilon: float = 1e-08, stabilization_factor: float = 0.1, *, rngs: Rngs)
Bases: Module
Numerically stable multipole expansion layer.
Computes multipole moments with proper numerical stability to prevent overflow and NaN generation in hierarchical computations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channels
|
int
|
Number of feature channels |
required |
max_order
|
int
|
Maximum multipole order |
4
|
epsilon
|
float
|
Small constant for numerical stability |
1e-08
|
stabilization_factor
|
float
|
Factor for moment normalization |
0.1
|
rngs
|
Rngs
|
Random number generator state |
required |
MultipoleGraphNeuralOperator
¶
MultipoleGraphNeuralOperator(in_features: int, out_features: int, hidden_features: int = 64, num_layers: int = 3, max_degree: int = 4, use_local_messages: bool = True, dropout_rate: float = 0.1, *, rngs: Rngs)
Bases: Module
Complete Multipole Graph Neural Operator with numerical stability.
Neural operator for systems with long-range interactions such as molecular dynamics, N-body simulations, and plasma physics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Number of input feature channels |
required |
out_features
|
int
|
Number of output feature channels |
required |
hidden_features
|
int
|
Hidden layer width |
64
|
num_layers
|
int
|
Number of MGNO layers |
3
|
max_degree
|
int
|
Maximum multipole expansion order |
4
|
use_local_messages
|
bool
|
Whether to use local message passing |
True
|
dropout_rate
|
float
|
Dropout rate for regularization |
0.1
|
rngs
|
Rngs
|
Random number generator state |
required |
OperatorNetwork
¶
Bases: Module
Unified interface for different operator network types.
This class provides a common interface for different neural operator architectures (FNO, DeepONet, etc.) to enable easy experimentation and comparison.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator_type
|
str
|
Type of operator ('fno', 'deeponet', 'fourier_deeponet', 'adaptive_deeponet', etc.) |
required |
config
|
dict[str, Any]
|
Configuration dictionary for the operator |
required |
rngs
|
Rngs
|
Random number generators |
required |
BayesianLinear
¶
Bases: Module
Bayesian linear layer with weight uncertainty.
Implements variational Bayesian linear layer where weights are distributions rather than point estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Number of input features |
required |
out_features
|
int
|
Number of output features |
required |
prior_std
|
float
|
Standard deviation of weight prior |
1.0
|
rngs
|
Rngs
|
Random number generator state |
required |
BayesianSpectralConvolution
¶
BayesianSpectralConvolution(in_channels: int, out_channels: int, modes: Sequence[int], prior_std: float = 1.0, *, rngs: Rngs)
Bases: Module
Bayesian spectral convolution with proper shape handling.
Implements spectral convolution in Fourier domain with Bayesian weights for uncertainty quantification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
modes
|
Sequence[int]
|
Fourier modes for each spatial dimension |
required |
prior_std
|
float
|
Standard deviation of weight prior |
1.0
|
rngs
|
Rngs
|
Random number generator state |
required |
UncertaintyQuantificationNeuralOperator
¶
UncertaintyQuantificationNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), num_layers: int = 4, use_epistemic: bool = True, use_aleatoric: bool = True, ensemble_size: int = 10, *, rngs: Rngs)
Bases: Module
Complete Uncertainty Quantification Neural Operator.
Neural operator with built-in uncertainty quantification for safety-critical applications and robust predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
hidden_channels
|
int
|
Hidden layer width |
64
|
modes
|
Sequence[int]
|
Fourier modes for spectral convolution |
(16, 16)
|
num_layers
|
int
|
Number of UQNO layers |
4
|
use_epistemic
|
bool
|
Whether to use epistemic uncertainty |
True
|
use_aleatoric
|
bool
|
Whether to use aleatoric uncertainty |
True
|
ensemble_size
|
int
|
Size for Monte Carlo sampling |
10
|
rngs
|
Rngs
|
Random number generator state |
required |
UQNOLayer
¶
UQNOLayer(in_channels: int, out_channels: int, modes: Sequence[int], use_skip_connection: bool = True, *, rngs: Rngs)
Bases: Module
UQNO layer with proper shape handling for skip connections.
Combines Bayesian spectral convolution with local operations and proper channel dimension handling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
modes
|
Sequence[int]
|
Fourier modes for spectral convolution |
required |
use_skip_connection
|
bool
|
Whether to use skip connections |
True
|
rngs
|
Rngs
|
Random number generator state |
required |
WaveletNeuralOperator
¶
WaveletNeuralOperator(in_channels: int, out_channels: int, hidden_channels: int, num_levels: int, *, wavelet_type: str = 'db4', mode: str = 'symmetric', activation: Callable[[Array], Array] = gelu, use_learnable_wavelets: bool = False, rngs: Rngs)
Bases: Module
Wavelet Neural Operator for multi-scale wavelet-based learning.
This operator uses wavelet transforms to capture multi-scale features in the input functions, enabling efficient learning of operators with multi-scale characteristics like turbulence and material heterogeneity.
Features: - Discrete Wavelet Transform (DWT) for multi-scale decomposition - Learnable wavelet coefficients processing - Multi-resolution reconstruction - Adaptive wavelet basis selection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
hidden_channels
|
int
|
Hidden channel dimension |
required |
num_levels
|
int
|
Number of wavelet decomposition levels |
required |
wavelet_type
|
str
|
Type of wavelet (e.g., 'db4', 'haar') |
'db4'
|
mode
|
str
|
Boundary condition mode |
'symmetric'
|
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
use_learnable_wavelets
|
bool
|
Whether to use learnable wavelet bases |
False
|
rngs
|
Rngs
|
Random number generators |
required |
create_high_frequency_amfno
¶
create_high_frequency_amfno(in_channels: int, out_channels: int, modes: Sequence[int] = (128, 128), **kwargs) -> AmortizedFourierNeuralOperator
Create AM-FNO optimized for high-frequency problems.
create_shock_amfno
¶
create_shock_amfno(in_channels: int = 3, out_channels: int = 3, modes: Sequence[int] = (96, 96), **kwargs) -> AmortizedFourierNeuralOperator
Create AM-FNO for problems with shocks/discontinuities.
create_wave_amfno
¶
create_wave_amfno(in_channels: int = 2, out_channels: int = 2, modes: Sequence[int] = (64, 64), **kwargs) -> AmortizedFourierNeuralOperator
Create AM-FNO for wave propagation problems.
create_multiphysics_local_fno
¶
create_multiphysics_local_fno(in_channels: int = 5, out_channels: int = 5, modes: Sequence[int] = (24, 24), **kwargs) -> LocalFourierNeuralOperator
Create Local FNO for multi-physics problems.
create_turbulence_local_fno
¶
create_turbulence_local_fno(in_channels: int = 3, out_channels: int = 3, modes: Sequence[int] = (32, 32), **kwargs) -> LocalFourierNeuralOperator
Create Local FNO optimized for turbulent flow modeling.
create_wave_local_fno
¶
create_wave_local_fno(in_channels: int = 2, out_channels: int = 2, modes: Sequence[int] = (64, 64), **kwargs) -> LocalFourierNeuralOperator
Create Local FNO for wave propagation with scattering.
create_climate_sfno
¶
create_climate_sfno(in_channels: int = 5, out_channels: int = 5, lmax: int = 32, **kwargs) -> SphericalFourierNeuralOperator
Create SFNO optimized for global climate modeling.
create_ocean_sfno
¶
create_ocean_sfno(in_channels: int = 4, out_channels: int = 4, lmax: int = 48, **kwargs) -> SphericalFourierNeuralOperator
Create SFNO for global ocean circulation modeling.
create_planetary_sfno
¶
create_planetary_sfno(in_channels: int = 3, out_channels: int = 3, lmax: int = 16, **kwargs) -> SphericalFourierNeuralOperator
Create SFNO for planetary-scale phenomena.
create_weather_sfno
¶
create_weather_sfno(in_channels: int = 7, out_channels: int = 7, lmax: int = 64, **kwargs) -> SphericalFourierNeuralOperator
Create SFNO for high-resolution weather prediction.
create_cp_fno
¶
create_cp_fno(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), rank: float = 0.1, num_layers: int = 4, *, rngs: Rngs) -> TensorizedFourierNeuralOperator
Create CP factorized FNO.
create_tt_fno
¶
create_tt_fno(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), rank: float = 0.1, num_layers: int = 4, *, rngs: Rngs) -> TensorizedFourierNeuralOperator
Create Tensor Train factorized FNO.
create_tucker_fno
¶
create_tucker_fno(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), rank: float = 0.1, num_layers: int = 4, *, rngs: Rngs) -> TensorizedFourierNeuralOperator
Create Tucker factorized FNO.
create_deep_ufno
¶
create_deep_ufno(in_channels: int, out_channels: int, hidden_channels: int = 32, modes: Sequence[int] = (32, 32), **kwargs) -> UFourierNeuralOperator
Create deep U-FNO (5 levels) for complex multi-scale problems.
create_shallow_ufno
¶
create_shallow_ufno(in_channels: int, out_channels: int, hidden_channels: int = 64, modes: Sequence[int] = (16, 16), **kwargs) -> UFourierNeuralOperator
Create shallow U-FNO (2 levels) for simple multi-scale problems.
create_turbulence_ufno
¶
create_turbulence_ufno(in_channels: int = 4, out_channels: int = 3, **kwargs) -> UFourierNeuralOperator
Create U-FNO optimized for turbulent flow modeling.
create_3d_gino
¶
create_3d_gino(in_channels: int, out_channels: int, *, rngs: Rngs) -> GeometryInformedNeuralOperator
Create GINO optimized for 3D problems.
create_adaptive_mesh_gino
¶
create_adaptive_mesh_gino(in_channels: int, out_channels: int, *, rngs: Rngs) -> GeometryInformedNeuralOperator
Create GINO for adaptive mesh refinement.
create_cad_gino
¶
create_cad_gino(in_channels: int, out_channels: int, *, rngs: Rngs) -> GeometryInformedNeuralOperator
Create GINO optimized for CAD geometries.
create_multiscale_gino
¶
create_multiscale_gino(in_channels: int, out_channels: int, *, rngs: Rngs) -> GeometryInformedNeuralOperator
Create GINO for multiscale problems.
create_molecular_mgno
¶
create_molecular_mgno(in_features: int, out_features: int, *, rngs: Rngs) -> MultipoleGraphNeuralOperator
Create MGNO optimized for molecular dynamics simulations.
create_nbody_mgno
¶
create_nbody_mgno(in_features: int, out_features: int, *, rngs: Rngs) -> MultipoleGraphNeuralOperator
Create MGNO for N-body gravitational simulations.
create_plasma_mgno
¶
create_plasma_mgno(in_features: int, out_features: int, *, rngs: Rngs) -> MultipoleGraphNeuralOperator
Create MGNO for plasma physics simulations.
create_bayesian_inverse_uqno
¶
create_bayesian_inverse_uqno(in_channels: int, out_channels: int, *, rngs: Rngs) -> UncertaintyQuantificationNeuralOperator
Create UQNO for Bayesian inverse problems.
create_robust_design_uqno
¶
create_robust_design_uqno(in_channels: int, out_channels: int, *, rngs: Rngs) -> UncertaintyQuantificationNeuralOperator
Create UQNO for robust engineering design.
create_safety_critical_uqno
¶
create_safety_critical_uqno(in_channels: int, out_channels: int, *, rngs: Rngs) -> UncertaintyQuantificationNeuralOperator
Create UQNO for safety-critical applications.
create_operator
¶
Factory function to create any operator by name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator_type
|
str
|
Type of operator to create |
required |
**kwargs
|
Any
|
Arguments for operator initialization |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Initialized operator instance |
Raises:
| Type | Description |
|---|---|
ValueError
|
If operator_type is not recognized |
Example
Create a Tensorized FNO¶
tfno = create_operator("TFNO", ... in_channels=3, out_channels=1, ... hidden_channels=64, modes=(16, 16), ... factorization="tucker", rank=0.1, ... rngs=rngs)
recommend_operator
¶
Recommend the best operator for a specific application.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
application
|
str
|
Application domain |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary with recommendations |
Example
rec = recommend_operator("turbulent_flow") print(f"Recommended: {rec['primary']}") print(f"Reason: {rec['reason']}")
list_operators
¶
Bayesian Networks¶
opifex.neural.bayesian
¶
Bayesian neural network components with uncertainty quantification.
BlackJAXIntegration
¶
BlackJAXIntegration(base_model: Module, sampler_type: str = 'nuts', num_warmup: int = 1000, num_samples: int = 1000, step_size: float = 0.001, *, rngs: Rngs)
Bases: Module
BlackJAX MCMC sampling integration for Bayesian neural networks.
Provides MCMC sampling capabilities for full Bayesian inference on neural network parameters, supporting multiple sampling algorithms (NUTS, HMC, MALA).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_model
|
Module
|
Neural network model for Bayesian inference |
required |
sampler_type
|
str
|
MCMC sampler type ('nuts', 'hmc', 'mala') |
'nuts'
|
num_warmup
|
int
|
Number of warmup steps for sampler adaptation |
1000
|
num_samples
|
int
|
Number of posterior samples to generate |
1000
|
step_size
|
float
|
Initial step size for MCMC sampling |
0.001
|
rngs
|
Rngs
|
Random number generators |
required |
sample_posterior
¶
Sample from posterior distribution using MCMC.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_data
|
Array
|
Input training data |
required |
y_data
|
Array
|
Target training data |
required |
rngs
|
Rngs | None
|
Random number generators |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Posterior samples as array of shape (num_samples, num_params) |
posterior_predictive
¶
compute_posterior_statistics
¶
integrate_with_variational_framework
¶
integrate_with_variational_framework(variational_framework: AmortizedVariationalFramework, x_data: Array, y_data: Array, *, rngs: Rngs) -> dict[str, Any]
Integrate BlackJAX sampling with variational framework.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variational_framework
|
AmortizedVariationalFramework
|
Variational framework instance |
required |
x_data
|
Array
|
Training input data |
required |
y_data
|
Array
|
Training target data |
required |
rngs
|
Rngs
|
Random number generators |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary with MCMC samples and variational comparison |
CalibrationTools
¶
Bases: Module
Enhanced tools for uncertainty calibration assessment and improvement.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rngs
|
Rngs
|
Random number generators |
required |
assess_calibration
¶
assess_calibration(predictions: Array, uncertainties: Array, true_values: Array, num_bins: int = 10) -> dict[str, float | dict[str, Array]]
Assess calibration quality of uncertainty estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Array
|
Model predictions |
required |
uncertainties
|
Array
|
Predicted uncertainties |
required |
true_values
|
Array
|
Ground truth values |
required |
num_bins
|
int
|
Number of bins for reliability diagram |
10
|
Returns:
| Type | Description |
|---|---|
dict[str, float | dict[str, Array]]
|
Dictionary with calibration metrics |
compute_reliability_diagram
¶
compute_reliability_diagram(confidences: Array, accuracies: Array, num_bins: int = 10) -> dict[str, Array]
Compute reliability diagram data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
confidences
|
Array
|
Predicted confidence values |
required |
accuracies
|
Array
|
Binary accuracy indicators |
required |
num_bins
|
int
|
Number of bins for the diagram |
10
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary with binned confidence and accuracy data |
platt_scaling
¶
Apply Platt scaling for probability calibration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Array
|
Training logits for fitting scaling parameters |
required |
labels
|
Array
|
Training labels |
required |
validation_logits
|
Array
|
Validation logits to calibrate |
required |
Returns:
| Type | Description |
|---|---|
tuple[float, float]
|
Tuple of (slope, intercept) scaling parameters |
isotonic_regression_calibration
¶
ConformalPrediction
¶
ConformalPrediction(alpha: float = 0.1, *, rngs: Rngs)
Bases: Module
Conformal prediction for calibrated uncertainty intervals.
Provides prediction intervals with finite-sample coverage guarantees based on conformal prediction theory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha
|
float
|
Miscoverage level (1-alpha is the target coverage) |
0.1
|
rngs
|
Rngs
|
Random number generators |
required |
calibrate
¶
predict_intervals
¶
compute_coverage
¶
Compute empirical coverage of prediction intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lower_bounds
|
Array
|
Lower bounds of prediction intervals |
required |
upper_bounds
|
Array
|
Upper bounds of prediction intervals |
required |
true_values
|
Array
|
True values |
required |
Returns:
| Type | Description |
|---|---|
float
|
Empirical coverage rate |
IsotonicRegression
¶
IsotonicRegression(n_bins: int = 100, *, rngs: Rngs)
Bases: Module
Isotonic regression for calibration.
Non-parametric calibration method that learns a monotonic mapping from confidence scores to calibrated probabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_bins
|
int
|
Number of bins for isotonic regression |
100
|
rngs
|
Rngs
|
Random number generators |
required |
PlattScaling
¶
Bases: Module
Platt scaling for probabilistic calibration.
Applies a sigmoid function to logits to improve calibration of binary classification problems.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rngs
|
Rngs
|
Random number generators |
required |
fit
¶
TemperatureScaling
¶
TemperatureScaling(physics_constraints: Sequence[str] = (), adaptive: bool = False, learning_rate: float = 0.01, constraint_strength: float = 1.0, *, rngs: Rngs)
Bases: Module
Temperature scaling for uncertainty calibration.
Applies learnable temperature scaling to improve calibration of probabilistic predictions while respecting physics constraints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
physics_constraints
|
Sequence[str]
|
List of physics constraints to enforce |
()
|
adaptive
|
bool
|
Whether to use adaptive temperature learning |
False
|
learning_rate
|
float
|
Learning rate for temperature optimization |
0.01
|
constraint_strength
|
float
|
Strength of physics constraint enforcement |
1.0
|
rngs
|
Rngs
|
Random number generators |
required |
apply_physics_aware_calibration
¶
Apply physics-aware temperature scaling with constraint enforcement.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Array
|
Model predictions to calibrate |
required |
inputs
|
Array
|
Input data for constraint evaluation |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, float]
|
Tuple of (calibrated_predictions, physics_constraint_penalty) |
optimize_temperature
¶
optimize_temperature_with_physics_constraints
¶
optimize_temperature_with_physics_constraints(predictions: Array, targets: Array, inputs: Array) -> float
Optimize temperature parameter with physics constraint awareness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Array
|
Model predictions |
required |
targets
|
Array
|
Target values |
required |
inputs
|
Array
|
Input data for constraint evaluation |
required |
Returns:
| Type | Description |
|---|---|
float
|
Optimized temperature value |
adaptive_temperature_scaling
¶
Apply adaptive temperature scaling based on uncertainty quality.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Array
|
Model predictions |
required |
uncertainties
|
Array
|
Predicted uncertainties |
required |
true_values
|
Array
|
Ground truth values |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Adaptively calibrated temperatures |
ConformalConfig
dataclass
¶
ConformalConfig(alpha: float = 0.1)
Configuration for conformal prediction.
Attributes:
| Name | Type | Description |
|---|---|---|
alpha |
float
|
Miscoverage level. The target coverage probability is 1 - alpha. Must be in the open interval (0, 1). |
ConformalPredictor
¶
ConformalPredictor(model: Module, config: ConformalConfig | None = None)
Split conformal prediction for calibrated prediction intervals.
Wraps any point predictor (PINN, neural operator, etc.) and provides calibrated prediction intervals without distributional assumptions.
The predictor must be calibrated on a held-out calibration set before prediction intervals can be computed.
Attributes:
| Name | Type | Description |
|---|---|---|
model |
The wrapped NNX module used for point predictions. |
|
config |
Conformal prediction configuration. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Any Flax NNX module that maps inputs to predictions.
Must implement |
required |
config
|
ConformalConfig | None
|
Conformal prediction configuration. If |
None
|
calibrate
¶
Compute nonconformity scores on a calibration set.
Runs the wrapped model on x_cal, computes absolute residuals
against y_cal, and stores the conformal quantile.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_cal
|
Array
|
Calibration inputs with shape |
required |
y_cal
|
Array
|
Calibration targets with shape |
required |
predict_with_intervals
¶
Return point predictions with calibrated prediction intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array with shape |
required |
Returns:
| Type | Description |
|---|---|
Array
|
A tuple of |
Array
|
array has the same shape as the model output. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If |
ConservationLawPriors
¶
ConservationLawPriors(conservation_laws: Sequence[str] = ('energy', 'momentum', 'mass'), uncertainty_scale: float = 0.1, prior_strength: float = 1.0, adaptive_weighting: bool = True, *, rngs: Rngs)
Bases: Module
Conservation law priors for uncertainty estimation.
This class implements physics-aware priors that incorporate conservation laws directly into uncertainty quantification, enabling physically consistent uncertainty estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
conservation_laws
|
Sequence[str]
|
List of conservation laws to enforce |
('energy', 'momentum', 'mass')
|
uncertainty_scale
|
float
|
Scale factor for uncertainty estimates |
0.1
|
prior_strength
|
float
|
Strength of physics constraints in prior |
1.0
|
adaptive_weighting
|
bool
|
Whether to use adaptive constraint weighting |
True
|
rngs
|
Rngs
|
Random number generators |
required |
compute_physics_aware_uncertainty
¶
compute_physics_aware_uncertainty(predictions: Array, model_uncertainty: Array, physics_state: Array) -> Array
Compute physics-aware uncertainty estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Array
|
Model predictions |
required |
model_uncertainty
|
Array
|
Basic model uncertainty |
required |
physics_state
|
Array
|
Physical state variables for constraint evaluation |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Physics-aware uncertainty estimates |
sample_physics_constrained_params
¶
DomainSpecificPriors
¶
DomainSpecificPriors(domain: str = 'quantum_chemistry', parameter_ranges: dict[str, tuple[float, float]] | None = None, distribution_types: dict[str, str] | None = None, correlation_structure: str = 'independent', *, rngs: Rngs)
Bases: Module
Domain-specific prior distributions for scientific computing.
Provides specialized priors for different scientific domains including quantum mechanics, molecular dynamics, fluid dynamics, and materials science.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
domain
|
str
|
Scientific domain (quantum_chemistry, molecular_dynamics, etc.) |
'quantum_chemistry'
|
parameter_ranges
|
dict[str, tuple[float, float]] | None
|
Custom parameter ranges for specific parameters |
None
|
distribution_types
|
dict[str, str] | None
|
Distribution types for each parameter |
None
|
correlation_structure
|
str
|
Correlation structure between parameters |
'independent'
|
rngs
|
Rngs
|
Random number generators |
required |
HierarchicalBayesianFramework
¶
HierarchicalBayesianFramework(hierarchy_levels: int = 3, level_dimensions: Sequence[int] = (64, 32, 16), uncertainty_propagation: str = 'multiplicative', correlation_structure: str = 'exchangeable', *, rngs: Rngs)
Bases: Module
Hierarchical Bayesian framework for multi-level uncertainty estimation.
Implements hierarchical models that can capture uncertainty at multiple scales and levels, suitable for complex scientific computing applications.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hierarchy_levels
|
int
|
Number of hierarchy levels |
3
|
level_dimensions
|
Sequence[int]
|
Dimensions for each hierarchy level |
(64, 32, 16)
|
uncertainty_propagation
|
str
|
How uncertainty propagates between levels |
'multiplicative'
|
correlation_structure
|
str
|
Correlation structure between levels |
'exchangeable'
|
rngs
|
Rngs
|
Random number generators |
required |
PhysicsAwareUncertaintyPropagation
¶
PhysicsAwareUncertaintyPropagation(conservation_laws: Sequence[str] = ('energy', 'momentum'), constraint_tolerance: float = 1e-06, uncertainty_inflation: float = 1.1, correlation_aware: bool = True, *, rngs: Rngs)
Bases: Module
Physics-aware uncertainty propagation for scientific computing.
Propagates uncertainty through physics-informed models while respecting conservation laws and physical constraints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
conservation_laws
|
Sequence[str]
|
Conservation laws to respect during propagation |
('energy', 'momentum')
|
constraint_tolerance
|
float
|
Tolerance for constraint violations |
1e-06
|
uncertainty_inflation
|
float
|
Factor to inflate uncertainty for safety |
1.1
|
correlation_aware
|
bool
|
Whether to account for parameter correlations |
True
|
rngs
|
Rngs
|
Random number generators |
required |
propagate_with_physics_constraints
¶
propagate_with_physics_constraints(input_uncertainty: Array, model_jacobian: Array, physics_state: Array) -> Array
Propagate uncertainty while respecting physics constraints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_uncertainty
|
Array
|
Input uncertainty estimates |
required |
model_jacobian
|
Array
|
Jacobian of the model wrt inputs |
required |
physics_state
|
Array
|
Current physics state for constraint evaluation |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Physics-constrained uncertainty propagation |
compute_physics_informed_confidence
¶
compute_physics_informed_confidence(predictions: Array, uncertainties: Array, physics_state: Array) -> Array
Compute physics-informed confidence intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Array
|
Model predictions |
required |
uncertainties
|
Array
|
Uncertainty estimates |
required |
physics_state
|
Array
|
Physics state for constraint evaluation |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Physics-informed confidence measures |
uncertainty_aware_constraint_projection
¶
uncertainty_aware_constraint_projection(parameters: Array, uncertainties: Array) -> tuple[Array, Array]
Project parameters to satisfy constraints while accounting for uncertainty.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameters
|
Array
|
Parameter values to project |
required |
uncertainties
|
Array
|
Parameter uncertainties |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Tuple of (projected_parameters, adjusted_uncertainties) |
PhysicsInformedPriors
¶
PhysicsInformedPriors(conservation_laws: Sequence[str] = (), boundary_conditions: Sequence[str] = (), constraint_weights: Array | None = None, penalty_weight: float = 1.0, *, rngs: Rngs)
Bases: Module
Physics-informed prior constraints for Bayesian models.
Enforces conservation laws, boundary conditions, and other physical constraints through learnable constraint weights and penalty functions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
conservation_laws
|
Sequence[str]
|
List of conservation laws to enforce |
()
|
boundary_conditions
|
Sequence[str]
|
List of boundary conditions to enforce |
()
|
constraint_weights
|
Array | None
|
Optional custom weights for constraints |
None
|
penalty_weight
|
float
|
Weight for constraint violation penalties |
1.0
|
rngs
|
Rngs
|
Random number generators |
required |
AdvancedAleatoricUncertainty
¶
Advanced aleatoric uncertainty estimation methods.
AdvancedEpistemicUncertainty
¶
Advanced epistemic uncertainty estimation methods.
AdvancedUncertaintyAggregator
¶
Advanced uncertainty aggregation with multiple sources and weighting.
weighted_uncertainty_aggregation
staticmethod
¶
weighted_uncertainty_aggregation(uncertainty_sources: list[Float[Array, 'batch output']], weights: Float[Array, 'sources'] | None = None, aggregation_method: str = 'weighted_variance') -> Float[Array, 'batch output']
Aggregate uncertainties from multiple sources with optional weighting.
adaptive_weighting
staticmethod
¶
adaptive_weighting(uncertainty_sources: list[Float[Array, 'batch output']], reliability_scores: list[Float[Array, 'batch']] | None = None, adaptation_method: str = 'reliability_based') -> Float[Array, 'sources batch']
Compute adaptive weights for uncertainty sources based on reliability.
AleatoricUncertainty
¶
Aleatoric (data) uncertainty estimation.
homoscedastic_uncertainty
staticmethod
¶
homoscedastic_uncertainty(_predictions: Float[Array, 'batch output'], log_variance: Float[Array, 'batch output']) -> Float[Array, 'batch output']
Compute homoscedastic (constant) aleatoric uncertainty.
heteroscedastic_uncertainty
staticmethod
¶
heteroscedastic_uncertainty(input_dependent_variance: Float[Array, 'batch output']) -> Float[Array, 'batch output']
Compute heteroscedastic (input-dependent) aleatoric uncertainty.
predictive_variance
staticmethod
¶
predictive_variance(predictions: Float[Array, 'samples batch output'], individual_variances: Float[Array, 'samples batch output']) -> Float[Array, 'batch output']
Compute total predictive variance including aleatoric component.
noise_estimation
staticmethod
¶
noise_estimation(residuals: Float[Array, 'batch output'], predictions: Float[Array, 'batch output']) -> Float[Array, 'batch output']
Estimate aleatoric uncertainty from residuals.
CalibrationAssessment
¶
Enhanced uncertainty calibration assessment tools.
expected_calibration_error
staticmethod
¶
expected_calibration_error(confidences: Float[Array, 'n_samples'], accuracies: Float[Array, 'n_samples'], n_bins: int = 10) -> float
Compute Expected Calibration Error (ECE).
maximum_calibration_error
staticmethod
¶
maximum_calibration_error(confidences: Float[Array, 'n_samples'], accuracies: Float[Array, 'n_samples'], n_bins: int = 10) -> float
Compute Maximum Calibration Error (MCE).
reliability_diagram_data
staticmethod
¶
reliability_diagram_data(confidences: Float[Array, 'n_samples'], accuracies: Float[Array, 'n_samples'], n_bins: int = 10) -> dict[str, Array]
Compute reliability diagram data for visualization.
assess_calibration
¶
assess_calibration(confidences: Float[Array, 'n_samples'], accuracies: Float[Array, 'n_samples'], n_bins: int = 10) -> CalibrationMetrics
Assess overall calibration with multiple metrics.
CalibrationMetrics
dataclass
¶
CalibrationMetrics(expected_calibration_error: float, maximum_calibration_error: float, reliability_diagram: dict[str, Array], confidence_histogram: Array, accuracy_histogram: Array)
Uncertainty calibration assessment metrics.
DistributionalAleatoricUncertainty
¶
Distributional modeling of aleatoric uncertainty.
sample_gaussian
¶
sample_gaussian(mean: Float[Array, 'batch output'], log_std: Float[Array, 'batch output'], num_samples: int) -> Float[Array, 'samples batch output']
Sample from Gaussian distributional output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Float[Array, 'batch output']
|
Mean predictions |
required |
log_std
|
Float[Array, 'batch output']
|
Log standard deviation predictions |
required |
num_samples
|
int
|
Number of samples to draw |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'samples batch output']
|
Samples from the distributional output |
compute_gaussian_uncertainty
¶
compute_gaussian_uncertainty(mean: Float[Array, 'batch output'], log_std: Float[Array, 'batch output']) -> Float[Array, 'batch output']
Compute uncertainty from Gaussian distributional parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Float[Array, 'batch output']
|
Mean predictions |
required |
log_std
|
Float[Array, 'batch output']
|
Log standard deviation predictions |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch output']
|
Aleatoric uncertainty (variance) |
compute_mixture_uncertainty
¶
compute_mixture_uncertainty(mixture_weights: Float[Array, 'batch components'], means: Float[Array, 'batch components output'], log_stds: Float[Array, 'batch components output']) -> Float[Array, 'batch output']
Compute uncertainty from mixture of Gaussians.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mixture_weights
|
Float[Array, 'batch components']
|
Mixture component weights |
required |
means
|
Float[Array, 'batch components output']
|
Component means |
required |
log_stds
|
Float[Array, 'batch components output']
|
Component log standard deviations |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch output']
|
Total uncertainty from mixture model |
EnhancedUncertaintyComponents
dataclass
¶
EnhancedUncertaintyComponents(epistemic_ensemble: Float[Array, 'batch output'], aleatoric_distributional: Float[Array, 'batch output'], total_uncertainty: Float[Array, 'batch output'], uncertainty_breakdown: dict[str, Float[Array, 'batch output']], epistemic_dropout: Float[Array, 'batch output'] | None = None)
Enhanced uncertainty components with multiple sources.
EnhancedUncertaintyQuantifier
¶
EnhancedUncertaintyQuantifier(ensemble_size: int = 5, distributional_output: bool = True, multi_source_aggregation: bool = True, confidence_level: float = 0.95)
Enhanced uncertainty quantifier with multiple decomposition methods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ensemble_size
|
int
|
Number of models in ensemble |
5
|
distributional_output
|
bool
|
Whether to use distributional outputs |
True
|
multi_source_aggregation
|
bool
|
Whether to aggregate multiple uncertainty sources |
True
|
confidence_level
|
float
|
Confidence level for intervals |
0.95
|
enhanced_decompose_uncertainty
¶
enhanced_decompose_uncertainty(ensemble_predictions: Float[Array, 'models batch output'], distributional_std: Float[Array, 'batch output'] | None = None, inputs: Float[Array, 'batch input_dim'] | None = None, dropout_predictions: Float[Array, 'samples batch output'] | None = None) -> EnhancedUncertaintyComponents
Enhanced uncertainty decomposition with multiple sources.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ensemble_predictions
|
Float[Array, 'models batch output']
|
Predictions from ensemble models |
required |
distributional_std
|
Float[Array, 'batch output'] | None
|
Standard deviation from distributional output |
None
|
inputs
|
Float[Array, 'batch input_dim'] | None
|
Input data for context-dependent uncertainty |
None
|
dropout_predictions
|
Float[Array, 'samples batch output'] | None
|
Predictions with dropout for additional epistemic uncertainty |
None
|
Returns:
| Type | Description |
|---|---|
EnhancedUncertaintyComponents
|
Enhanced uncertainty components with detailed breakdown |
EnsembleEpistemicUncertainty
¶
EnsembleEpistemicUncertainty(num_models: int)
Ensemble-based epistemic uncertainty estimation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_models
|
int
|
Number of models in the ensemble |
required |
add_model
¶
add_model(model: Any) -> None
Add a model to the ensemble.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
Neural network model to add to ensemble |
required |
aggregate_predictions
¶
aggregate_predictions(ensemble_predictions: Float[Array, 'models batch output'], method: str = 'mean') -> Float[Array, 'batch output']
Aggregate predictions from ensemble models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ensemble_predictions
|
Float[Array, 'models batch output']
|
Predictions from all ensemble models |
required |
method
|
str
|
Aggregation method ("mean", "median", "weighted_mean") |
'mean'
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch output']
|
Aggregated predictions |
compute_epistemic_uncertainty
¶
compute_epistemic_uncertainty(ensemble_predictions: Float[Array, 'models batch output']) -> Float[Array, 'batch output']
Compute epistemic uncertainty from ensemble predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ensemble_predictions
|
Float[Array, 'models batch output']
|
Predictions from all ensemble models |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch output']
|
Epistemic uncertainty (variance across models) |
compute_prediction_disagreement
¶
compute_prediction_disagreement(ensemble_predictions: Float[Array, 'models batch output']) -> Float[Array, 'batch output']
Compute prediction disagreement metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ensemble_predictions
|
Float[Array, 'models batch output']
|
Predictions from all ensemble models |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch output']
|
Disagreement metric (pairwise prediction variance) |
EpistemicUncertainty
¶
Epistemic (model) uncertainty estimation.
compute_variance
staticmethod
¶
Compute epistemic uncertainty as variance across model samples.
compute_entropy
staticmethod
¶
compute_entropy(predictions: Float[Array, 'samples batch classes']) -> Float[Array, 'batch classes']
Compute predictive entropy for classification tasks.
compute_mutual_information
staticmethod
¶
compute_mutual_information(predictions: Float[Array, 'samples batch classes']) -> Float[Array, 'batch classes']
Compute mutual information between predictions and model parameters.
compute_variance_of_expected
staticmethod
¶
compute_variance_of_expected(predictions: Float[Array, 'samples batch output']) -> Float[Array, 'batch output']
Compute variance of expected predictions (pure epistemic uncertainty).
MultiSourceUncertaintyAggregator
¶
Aggregation of uncertainty from multiple sources.
aggregate_uncertainties
¶
aggregate_uncertainties(epistemic_sources: list[Float[Array, 'batch output']], aleatoric_sources: list[Float[Array, 'batch output']], method: str = 'variance_sum', epistemic_weights: Array | None = None, aleatoric_weights: Array | None = None) -> Float[Array, 'batch output']
Aggregate uncertainties from multiple sources.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epistemic_sources
|
list[Float[Array, 'batch output']]
|
List of epistemic uncertainty estimates |
required |
aleatoric_sources
|
list[Float[Array, 'batch output']]
|
List of aleatoric uncertainty estimates |
required |
method
|
str
|
Aggregation method ("variance_sum", "weighted_sum", "max") |
'variance_sum'
|
epistemic_weights
|
Array | None
|
Weights for epistemic sources |
None
|
aleatoric_weights
|
Array | None
|
Weights for aleatoric sources |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch output']
|
Total aggregated uncertainty |
compute_uncertainty_breakdown
¶
compute_uncertainty_breakdown(epistemic_sources: list[Float[Array, 'batch output']], aleatoric_sources: list[Float[Array, 'batch output']], source_names: list[str] | None = None) -> dict[str, Float[Array, 'batch output']]
Compute detailed uncertainty breakdown by source.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epistemic_sources
|
list[Float[Array, 'batch output']]
|
List of epistemic uncertainty estimates |
required |
aleatoric_sources
|
list[Float[Array, 'batch output']]
|
List of aleatoric uncertainty estimates |
required |
source_names
|
list[str] | None
|
Names for uncertainty sources |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Float[Array, 'batch output']]
|
Dictionary mapping source names to uncertainty values |
UncertaintyComponents
dataclass
¶
UncertaintyComponents(epistemic: Float[Array, ...], aleatoric: Float[Array, ...], total: Float[Array, ...])
Decomposed uncertainty components.
UncertaintyIntegrationResults
dataclass
¶
UncertaintyIntegrationResults(predictions: Float[Array, 'batch output'], uncertainty_components: UncertaintyComponents, calibration_metrics: CalibrationMetrics, confidence_intervals: tuple[Float[Array, 'batch output'], Float[Array, 'batch output']], prediction_intervals: tuple[Float[Array, 'batch output'], Float[Array, 'batch output']])
Results from uncertainty propagation through model pipeline.
UncertaintyQuantifier
¶
Enhanced uncertainty quantification interface with integration capabilities.
decompose_uncertainty
¶
decompose_uncertainty(predictions: Float[Array, 'samples batch output'], aleatoric_variance: Float[Array, 'samples batch output'] | None = None) -> UncertaintyComponents
Decompose total uncertainty into epistemic and aleatoric components.
enhanced_uncertainty_decomposition
¶
enhanced_uncertainty_decomposition(predictions: Float[Array, 'samples batch output'], true_values: Float[Array, 'batch output'] | None = None, inputs: Float[Array, 'batch input_dim'] | None = None) -> UncertaintyComponents
Enhanced uncertainty decomposition with additional context.
compute_confidence_intervals
¶
compute_confidence_intervals(predictions: Float[Array, 'samples batch output'], confidence_level: float | None = None) -> tuple[Float[Array, 'batch output'], Float[Array, 'batch output']]
Compute confidence intervals from prediction samples.
compute_prediction_intervals
¶
compute_prediction_intervals(mean_predictions: Float[Array, 'batch output'], total_variance: Float[Array, 'batch output'], confidence_level: float | None = None) -> tuple[Float[Array, 'batch output'], Float[Array, 'batch output']]
Compute prediction intervals using Gaussian assumption.
propagate_uncertainty
¶
propagate_uncertainty(predictions: Float[Array, 'samples batch output'], inputs: Float[Array, 'batch input_dim'], true_values: Float[Array, 'batch output'] | None = None) -> UncertaintyIntegrationResults
Propagate uncertainty through the entire prediction pipeline.
AmortizedVariationalFramework
¶
AmortizedVariationalFramework(base_model: Module, prior_config: PriorConfig, variational_config: VariationalConfig, *, rngs: Rngs)
Bases: Module
Variational framework with amortized uncertainty estimation.
This framework combines a base neural network model with variational Bayesian inference capabilities, enabling uncertainty quantification through amortized variational inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_model
|
Module
|
Base neural network model to augment with uncertainty. |
required |
prior_config
|
PriorConfig
|
Configuration for physics-informed priors. |
required |
variational_config
|
VariationalConfig
|
Configuration for variational inference. |
required |
rngs
|
Rngs
|
Random number generator state. |
required |
predict_with_uncertainty
¶
predict_with_uncertainty(x: Float[Array, 'batch input_dim'], num_samples: int | None = None, *, rngs: Rngs) -> tuple[Float[Array, 'batch output_dim'], Float[Array, 'batch output_dim']]
Forward pass with uncertainty quantification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'batch input_dim']
|
Input tensor of shape (batch_size, input_dim). |
required |
num_samples
|
int | None
|
Number of Monte Carlo samples for uncertainty estimation. |
None
|
rngs
|
Rngs
|
Random number generator state. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, 'batch output_dim'], Float[Array, 'batch output_dim']]
|
Tuple of (mean_prediction, uncertainty) both of shape (batch_size, output_dim). |
compute_elbo
¶
compute_elbo(x: Float[Array, 'batch input_dim'], y: Float[Array, 'batch output_dim'], num_samples: int | None = None, *, rngs: Rngs) -> Float[Array, '']
Compute Evidence Lower BOund (ELBO).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'batch input_dim']
|
Input tensor of shape (batch_size, input_dim). |
required |
y
|
Float[Array, 'batch output_dim']
|
Target tensor of shape (batch_size, output_dim). |
required |
num_samples
|
int | None
|
Number of Monte Carlo samples for ELBO estimation. |
None
|
rngs
|
Rngs
|
Random number generator state. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
ELBO scalar value (higher is better). |
sample_predictive_distribution
¶
sample_predictive_distribution(x: Float[Array, 'batch input_dim'], num_samples: int | None = None, *, rngs: Rngs) -> Float[Array, 'samples batch output_dim']
Sample from predictive distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'batch input_dim']
|
Input tensor of shape (batch_size, input_dim). |
required |
num_samples
|
int | None
|
Number of predictive samples to generate. |
None
|
rngs
|
Rngs
|
Random number generator state. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'samples batch output_dim']
|
Predictive samples of shape (num_samples, batch_size, output_dim). |
MeanFieldGaussian
¶
MeanFieldGaussian(num_params: int, *, rngs: Rngs)
Bases: Module
Mean-field Gaussian variational posterior.
This class implements a factorized Gaussian posterior distribution for variational inference in neural networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_params
|
int
|
Number of parameters in the posterior. |
required |
rngs
|
Rngs
|
Random number generator state. |
required |
sample
¶
sample(num_samples: int, *, rngs: Rngs) -> Float[Array, 'samples params']
Sample from variational posterior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_samples
|
int
|
Number of samples to draw. |
required |
rngs
|
Rngs
|
Random number generator state. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'samples params']
|
Array of shape (num_samples, num_params) containing parameter samples. |
log_prob
¶
Compute log probability of samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
samples
|
Float[Array, 'samples params']
|
Parameter samples of shape (num_samples, num_params). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, samples]
|
Log probabilities for each sample of shape (num_samples,). |
kl_divergence
¶
PriorConfig
dataclass
¶
PriorConfig(conservation_laws: Sequence[str] = (), boundary_conditions: Sequence[str] = (), physics_constraints: Sequence[str] = (), prior_scale: float = 1.0)
Configuration for physics-informed priors.
Attributes:
| Name | Type | Description |
|---|---|---|
conservation_laws |
Sequence[str]
|
List of conservation laws to enforce (e.g., ['energy', 'momentum']). |
boundary_conditions |
Sequence[str]
|
List of boundary conditions to incorporate. |
physics_constraints |
Sequence[str]
|
List of physics constraints to respect. |
prior_scale |
float
|
Scale parameter for the prior distribution. |
UncertaintyEncoder
¶
Bases: Module
Neural network for amortized uncertainty estimation.
This encoder network predicts the parameters of the variational posterior directly from input data, enabling amortized variational inference.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Dimensionality of input features. |
required |
hidden_dims
|
Sequence[int]
|
Sequence of hidden layer dimensions. |
required |
output_dim
|
int
|
Dimensionality of output (typically 2 * num_params for mean and log_std). |
required |
rngs
|
Rngs
|
Random number generator state. |
required |
VariationalConfig
dataclass
¶
VariationalConfig(input_dim: int, hidden_dims: Sequence[int] = (64, 32), num_samples: int = 10, kl_weight: float = 1.0, temperature: float = 1.0)
Configuration for variational inference.
Attributes:
| Name | Type | Description |
|---|---|---|
input_dim |
int
|
Dimensionality of input features. |
hidden_dims |
Sequence[int]
|
Tuple of hidden layer dimensions for the encoder. |
num_samples |
int
|
Number of samples to draw during inference. |
kl_weight |
float
|
Weight for the KL divergence term in ELBO. |
temperature |
float
|
Temperature parameter for variational distribution. |
Domain Decomposition PINNs¶
Domain decomposition methods for physics-informed neural networks, enabling efficient training on complex geometries.
Base Classes¶
opifex.neural.pinns.domain_decomposition.base
¶
Base classes for Domain Decomposition PINNs.
This module provides the foundational classes for domain decomposition approaches to physics-informed neural networks.
Key Classes
- Subdomain: Represents a subdomain region in the computational domain
- Interface: Represents the interface between adjacent subdomains
- DomainDecompositionPINN: Abstract base class for DD-PINN variants
Design Principles
- Each subdomain has its own neural network
- Interfaces enforce continuity and flux matching
- Window functions provide smooth blending (for FBPINN variants)
References
- Survey Section 8.3: Domain Decomposition Methods
Subdomain
dataclass
¶
Representation of a subdomain in the computational domain.
A subdomain is a rectangular region defined by its bounds in each spatial dimension.
Attributes:
| Name | Type | Description |
|---|---|---|
id |
int
|
Unique identifier for this subdomain |
bounds |
Float[Array, 'dim 2']
|
Array of shape (dim, 2) with [min, max] for each dimension |
overlap |
float
|
Optional overlap with neighboring subdomains (for Schwarz methods) |
volume
property
¶
Compute the volume (area in 2D, length in 1D) of the subdomain.
contains
¶
Check if a point is inside this subdomain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, ' dim']
|
Point coordinates of shape (dim,) |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Boolean array (scalar) indicating if point is inside subdomain |
Interface
dataclass
¶
Interface(subdomain_ids: tuple[int, int], points: Float[Array, 'num_points dim'], normal: Float[Array, ' dim'])
Representation of an interface between two subdomains.
The interface stores sample points for enforcing continuity conditions between adjacent subdomains.
Attributes:
| Name | Type | Description |
|---|---|---|
subdomain_ids |
tuple[int, int]
|
Tuple of (left_id, right_id) for adjacent subdomains |
points |
Float[Array, 'num_points dim']
|
Sample points on the interface, shape (num_points, dim) |
normal |
Float[Array, ' dim']
|
Outward normal vector from first subdomain, shape (dim,) |
DomainDecompositionPINN
¶
DomainDecompositionPINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence[Interface], hidden_dims: Sequence[int], *, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Bases: Module
Base class for Domain Decomposition PINNs.
This class provides the infrastructure for training separate networks on subdomains with interface coupling conditions.
Attributes:
| Name | Type | Description |
|---|---|---|
input_dim |
Input spatial dimension |
|
output_dim |
Output dimension (solution fields) |
|
subdomains |
List of subdomain definitions |
|
interfaces |
List of interface definitions |
|
networks |
List of subdomain networks |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input spatial dimension |
required |
output_dim
|
int
|
Output dimension |
required |
subdomains
|
Sequence[Subdomain]
|
List of subdomain definitions |
required |
interfaces
|
Sequence[Interface]
|
List of interface definitions |
required |
hidden_dims
|
Sequence[int]
|
Hidden layer dimensions (shared across subdomains) |
required |
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
get_subdomain_outputs
¶
get_subdomain_outputs(x: Float[Array, ...]) -> list[Float[Array, 'batch out']]
Get outputs from all subdomain networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, ...]
|
Input coordinates |
required |
Returns:
| Type | Description |
|---|---|
list[Float[Array, 'batch out']]
|
List of outputs from each subdomain network |
compute_interface_residual
¶
Compute interface continuity residual.
Enforces u_left = u_right at interface points.
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar interface residual (MSE of discontinuity) |
compute_flux_residual
¶
compute_flux_residual(derivative_fn: Callable[[Module, Float[Array, ...]], Float[Array, ...]]) -> Float[Array, '']
Compute interface flux continuity residual.
Enforces (du/dn)_left = (du/dn)_right at interface points.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
derivative_fn
|
Callable[[Module, Float[Array, ...]], Float[Array, ...]]
|
Function to compute gradient of network output |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar flux residual |
SubdomainNetwork
¶
SubdomainNetwork(input_dim: int, output_dim: int, hidden_dims: Sequence[int], *, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Bases: Module
Neural network for a single subdomain.
A simple MLP that processes inputs for a specific subdomain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension |
required |
output_dim
|
int
|
Output dimension |
required |
hidden_dims
|
Sequence[int]
|
List of hidden layer dimensions |
required |
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
uniform_partition
¶
uniform_partition(bounds: Float[Array, 'dim 2'], num_partitions: tuple[int, ...], interface_points: int = 10) -> tuple[list[Subdomain], list[Interface]]
Create uniform partition of a rectangular domain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bounds
|
Float[Array, 'dim 2']
|
Domain bounds, shape (dim, 2) with [min, max] for each dimension |
required |
num_partitions
|
tuple[int, ...]
|
Number of partitions in each dimension |
required |
interface_points
|
int
|
Number of sample points per interface |
10
|
Returns:
| Type | Description |
|---|---|
tuple[list[Subdomain], list[Interface]]
|
Tuple of (subdomains, interfaces) |
XPINN (Extended PINN)¶
opifex.neural.pinns.domain_decomposition.xpinn
¶
Extended Physics-Informed Neural Network (XPINN).
XPINN extends the PINN framework to handle domain decomposition with explicit interface conditions for continuity and flux matching.
Key Features
- Separate networks for each subdomain
- Interface continuity conditions (u_left = u_right)
- Flux continuity conditions (du/dn_left = du/dn_right)
- Weighted loss combination for interface enforcement
References
- Jagtap & Karniadakis (2020): Extended Physics-Informed Neural Networks
- Survey Section 8.3.1: XPINNs
- GitHub: https://github.com/AmeyaJagtap/XPINNs
XPINN
¶
XPINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence[Interface], hidden_dims: Sequence[int], *, config: XPINNConfig | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Bases: DomainDecompositionPINN
Extended Physics-Informed Neural Network.
XPINN decomposes the computational domain into non-overlapping subdomains, training a separate neural network for each subdomain. Interface conditions enforce solution continuity and flux matching between adjacent subdomains.
The total loss includes
- Data loss (if available)
- PDE residual loss (per subdomain)
- Interface continuity loss: ||u_left - u_right||²
- Interface flux loss: ||∂u/∂n_left - ∂u/∂n_right||²
Attributes:
| Name | Type | Description |
|---|---|---|
config |
XPINN configuration with loss weights |
|
input_dim |
Spatial dimension |
|
output_dim |
Solution dimension |
|
subdomains |
List of subdomain definitions |
|
interfaces |
List of interface definitions |
|
networks |
List of subdomain networks |
Example
subdomains = [ ... Subdomain(id=0, bounds=jnp.array([[0.0, 0.5]])), ... Subdomain(id=1, bounds=jnp.array([[0.5, 1.0]])), ... ] interfaces = [ ... Interface(subdomain_ids=(0, 1), points=jnp.array([[0.5]]), ... normal=jnp.array([1.0])) ... ] model = XPINN( ... input_dim=1, output_dim=1, ... subdomains=subdomains, interfaces=interfaces, ... hidden_dims=[32, 32], rngs=nnx.Rngs(0) ... )
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Spatial dimension |
required |
output_dim
|
int
|
Solution dimension |
required |
subdomains
|
Sequence[Subdomain]
|
List of subdomain definitions |
required |
interfaces
|
Sequence[Interface]
|
List of interface definitions |
required |
hidden_dims
|
Sequence[int]
|
Hidden layer dimensions for subdomain networks |
required |
config
|
XPINNConfig | None
|
XPINN configuration. Uses defaults if None. |
None
|
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
compute_continuity_loss
¶
Compute interface continuity loss.
Delegates to base class compute_interface_residual (DRY).
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar continuity loss (MSE of discontinuity) |
compute_flux_loss
¶
Compute interface flux continuity loss.
Enforces ∂u/∂n_left = ∂u/∂n_right at all interface points, where n is the interface normal direction.
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar flux loss (MSE of flux discontinuity) |
compute_interface_loss
¶
Compute total weighted interface loss.
Combines continuity and flux losses with configured weights.
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar total interface loss |
compute_subdomain_residual
¶
compute_subdomain_residual(subdomain_id: int, residual_fn: Callable[[Callable[[Float[Array, ...]], Float[Array, 'batch out']], Float[Array, ...]], Float[Array, ' batch']], collocation_points: Float[Array, ...]) -> Float[Array, '']
Compute PDE residual for a specific subdomain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subdomain_id
|
int
|
ID of the subdomain |
required |
residual_fn
|
Callable[[Callable[[Float[Array, ...]], Float[Array, 'batch out']], Float[Array, ...]], Float[Array, ' batch']]
|
Function that computes PDE residual given network and points |
required |
collocation_points
|
Float[Array, ...]
|
Points where to evaluate residual |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar residual loss for this subdomain |
compute_total_residual
¶
compute_total_residual(residual_fn: Callable[[Callable[[Float[Array, ...]], Float[Array, 'batch out']], Float[Array, ...]], Float[Array, ' batch']], collocation_points_per_subdomain: Sequence[Float[Array, ...]]) -> Float[Array, '']
Compute total PDE residual across all subdomains.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residual_fn
|
Callable[[Callable[[Float[Array, ...]], Float[Array, 'batch out']], Float[Array, ...]], Float[Array, ' batch']]
|
Function that computes PDE residual |
required |
collocation_points_per_subdomain
|
Sequence[Float[Array, ...]]
|
Collocation points for each subdomain |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar total residual loss |
XPINNConfig
dataclass
¶
XPINNConfig(continuity_weight: float = 1.0, flux_weight: float = 1.0, residual_weight: float = 1.0, average_residual_weight: float = 0.0)
Configuration for XPINN training.
Attributes:
| Name | Type | Description |
|---|---|---|
continuity_weight |
float
|
Weight for interface continuity loss (u_left = u_right) |
flux_weight |
float
|
Weight for interface flux continuity loss (du/dn matching) |
residual_weight |
float
|
Weight for PDE residual loss in each subdomain |
average_residual_weight |
float
|
Weight for residual averaging at interfaces |
FBPINN (Finite Basis PINN)¶
opifex.neural.pinns.domain_decomposition.fbpinn
¶
Finite Basis Physics-Informed Neural Network (FBPINN).
FBPINN uses smooth window functions to create a partition of unity, enabling smooth blending of subdomain solutions without explicit interface conditions.
Key Features
- Smooth window functions (cosine, Gaussian)
- Partition of unity through normalization
- No explicit interface conditions needed
- Naturally handles overlapping subdomains
References
- Moseley et al. (2023): Finite Basis Physics-Informed Neural Networks
- Survey Section 8.3.2: FBPINNs
- GitHub: https://github.com/benmoseley/FBPINNs
FBPINN
¶
FBPINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence, hidden_dims: Sequence[int], *, config: FBPINNConfig | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Bases: DomainDecompositionPINN
Finite Basis Physics-Informed Neural Network.
FBPINN decomposes the computational domain into overlapping subdomains, using smooth window functions to blend subdomain network outputs. This creates a partition of unity that ensures smooth global solutions.
The output is computed as
u(x) = Σᵢ wᵢ(x) * uᵢ(x) / Σⱼ wⱼ(x)
where wᵢ(x) is the window function for subdomain i and uᵢ(x) is the network output for subdomain i.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
FBPINN configuration |
|
windows |
List of window functions for each subdomain |
Example
subdomains = [ ... Subdomain(id=0, bounds=jnp.array([[0.0, 0.6]])), ... Subdomain(id=1, bounds=jnp.array([[0.4, 1.0]])), ... ] model = FBPINN( ... input_dim=1, output_dim=1, ... subdomains=subdomains, interfaces=[], ... hidden_dims=[32, 32], rngs=nnx.Rngs(0) ... )
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Spatial dimension |
required |
output_dim
|
int
|
Solution dimension |
required |
subdomains
|
Sequence[Subdomain]
|
List of subdomain definitions (should overlap) |
required |
interfaces
|
Sequence
|
List of interface definitions (optional for FBPINN) |
required |
hidden_dims
|
Sequence[int]
|
Hidden layer dimensions for subdomain networks |
required |
config
|
FBPINNConfig | None
|
FBPINN configuration. Uses defaults if None. |
None
|
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
compute_window_weights
¶
Compute window weights for all subdomains.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, ...]
|
Input coordinates |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch num_subdomains']
|
Window weights, shape (batch, num_subdomains) |
FBPINNConfig
dataclass
¶
FBPINNConfig(window_type: Literal['cosine', 'gaussian'] = 'cosine', normalize_windows: bool = True, overlap_factor: float = 0.2, gaussian_sigma: float = 0.25)
Configuration for FBPINN training.
Attributes:
| Name | Type | Description |
|---|---|---|
window_type |
Literal['cosine', 'gaussian']
|
Type of window function ("cosine" or "gaussian") |
normalize_windows |
bool
|
Whether to normalize window weights to sum to 1 |
overlap_factor |
float
|
Factor controlling subdomain overlap (for auto-partitioning) |
gaussian_sigma |
float
|
Sigma parameter for Gaussian windows |
WindowFunction
¶
WindowFunction(subdomain: Subdomain)
Bases: ABC
Abstract base class for window functions.
Window functions define the influence region of each subdomain network. They should be smooth, have compact support within the subdomain, and enable partition of unity when combined.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subdomain
|
Subdomain
|
The subdomain this window is associated with |
required |
CosineWindow
¶
CosineWindow(subdomain: Subdomain)
Bases: WindowFunction
Cosine-based window function.
w(x) = 0.5 * (1 + cos(π * r)) for r < 1, else 0
where r is the normalized distance from the subdomain center, scaled by the subdomain half-width.
This creates a smooth bump function that is 1 at the center and 0 at the boundary.
GaussianWindow
¶
Bases: WindowFunction
Gaussian-based window function.
w(x) = exp(-||x - center||² / (2 * σ²))
where σ controls the width of the Gaussian.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subdomain
|
Subdomain
|
The subdomain this window is associated with |
required |
sigma
|
float
|
Standard deviation of the Gaussian (relative to subdomain size) |
0.25
|
CPINN (Conservative PINN)¶
opifex.neural.pinns.domain_decomposition.cpinn
¶
Conservative Physics-Informed Neural Network (cPINN).
cPINN extends XPINN with explicit flux conservation at interfaces, enforcing strong conservation properties required for conservation laws.
Key Features
- Explicit flux computation at interfaces
- Strong conservation enforcement
- Weighted combination of continuity and flux losses
References
- Jagtap et al. (2020): Conservative physics-informed neural networks
- Survey Section 8.3.2: Conservative PINNs
CPINN
¶
CPINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence[Interface], hidden_dims: Sequence[int], *, config: CPINNConfig | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Bases: DomainDecompositionPINN
Conservative Physics-Informed Neural Network.
cPINN enforces strong conservation at subdomain interfaces by explicitly computing and matching fluxes across boundaries.
The total interface loss includes
- Continuity loss: ||u_left - u_right||²
- Flux conservation loss: ||F_left · n - F_right · n||²
where F = ∇u is the flux (gradient) of the solution.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
cPINN configuration with loss weights |
|
input_dim |
Spatial dimension |
|
output_dim |
Solution dimension |
|
subdomains |
List of subdomain definitions |
|
interfaces |
List of interface definitions |
|
networks |
List of subdomain networks |
Example
subdomains = [ ... Subdomain(id=0, bounds=jnp.array([[0.0, 0.5]])), ... Subdomain(id=1, bounds=jnp.array([[0.5, 1.0]])), ... ] interfaces = [ ... Interface(subdomain_ids=(0, 1), points=jnp.array([[0.5]]), ... normal=jnp.array([1.0])) ... ] model = CPINN( ... input_dim=1, output_dim=1, ... subdomains=subdomains, interfaces=interfaces, ... hidden_dims=[32, 32], rngs=nnx.Rngs(0) ... )
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Spatial dimension |
required |
output_dim
|
int
|
Solution dimension |
required |
subdomains
|
Sequence[Subdomain]
|
List of subdomain definitions |
required |
interfaces
|
Sequence[Interface]
|
List of interface definitions |
required |
hidden_dims
|
Sequence[int]
|
Hidden layer dimensions for subdomain networks |
required |
config
|
CPINNConfig | None
|
cPINN configuration. Uses defaults if None. |
None
|
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
compute_continuity_loss
¶
Compute interface continuity loss.
Delegates to base class compute_interface_residual (DRY).
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar continuity loss (MSE of discontinuity) |
compute_flux_conservation_loss
¶
Compute flux conservation loss at interfaces.
Enforces F_left · n = F_right · n at all interface points, where F = ∇u is the flux.
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar flux conservation loss |
compute_interface_loss
¶
Compute total weighted interface loss.
Combines continuity and flux conservation losses with configured weights.
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar total interface loss |
APINN (Augmented PINN)¶
opifex.neural.pinns.domain_decomposition.apinn
¶
Augmented Physics-Informed Neural Network (APINN).
APINN uses a learnable gating network to smoothly blend subdomain solutions, allowing the model to learn optimal subdomain selection.
Key Features
- Learnable gating network for subdomain weighting
- Temperature-controlled softmax for soft/hard selection
- Differentiable blending for end-to-end training
References
- Survey Section 8.3.3: Augmented PINNs
APINN
¶
APINN(input_dim: int, output_dim: int, subdomains: Sequence[Subdomain], interfaces: Sequence[Interface], hidden_dims: Sequence[int], *, config: APINNConfig | None = None, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Bases: DomainDecompositionPINN
Augmented Physics-Informed Neural Network.
APINN uses a learnable gating network to determine how to blend solutions from different subdomains. Unlike FBPINN which uses fixed window functions, APINN learns the optimal blending.
The output is computed as
u(x) = Σᵢ gᵢ(x) * uᵢ(x)
where gᵢ(x) are the learned gating weights (sum to 1) and uᵢ(x) are the subdomain network outputs.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
APINN configuration |
|
gating_network |
Network that produces blending weights |
|
input_dim |
Spatial dimension |
|
output_dim |
Solution dimension |
|
subdomains |
List of subdomain definitions |
|
interfaces |
List of interface definitions |
|
networks |
List of subdomain networks |
Example
subdomains = [ ... Subdomain(id=0, bounds=jnp.array([[0.0, 0.5]])), ... Subdomain(id=1, bounds=jnp.array([[0.5, 1.0]])), ... ] interfaces = [ ... Interface(subdomain_ids=(0, 1), points=jnp.array([[0.5]]), ... normal=jnp.array([1.0])) ... ] model = APINN( ... input_dim=1, output_dim=1, ... subdomains=subdomains, interfaces=interfaces, ... hidden_dims=[32, 32], rngs=nnx.Rngs(0) ... )
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Spatial dimension |
required |
output_dim
|
int
|
Solution dimension |
required |
subdomains
|
Sequence[Subdomain]
|
List of subdomain definitions |
required |
interfaces
|
Sequence[Interface]
|
List of interface definitions |
required |
hidden_dims
|
Sequence[int]
|
Hidden layer dimensions for subdomain networks |
required |
config
|
APINNConfig | None
|
APINN configuration. Uses defaults if None. |
None
|
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
get_gating_weights
¶
Get gating weights for given points.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, 'batch dim']
|
Input coordinates |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'batch num_subdomains']
|
Gating weights for each subdomain |
compute_interface_loss
¶
Compute weighted interface continuity loss.
Delegates continuity computation to base class compute_interface_residual
and applies the configured continuity weight (DRY).
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar interface loss |
APINNConfig
dataclass
¶
APINNConfig(temperature: float = 1.0, gating_hidden_dims: list[int] = (lambda: [16, 16])(), continuity_weight: float = 1.0)
Configuration for APINN training.
Attributes:
| Name | Type | Description |
|---|---|---|
temperature |
float
|
Softmax temperature for gating. Lower values give sharper (more discrete) weights, higher values give smoother (more uniform) weights. |
gating_hidden_dims |
list[int]
|
Hidden dimensions for the gating network |
continuity_weight |
float
|
Weight for interface continuity loss |
GatingNetwork
¶
GatingNetwork(input_dim: int, num_subdomains: int, hidden_dims: Sequence[int], *, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Bases: Module
Gating network for subdomain selection.
This network takes spatial coordinates and outputs weights for blending subdomain solutions.
Attributes:
| Name | Type | Description |
|---|---|---|
layers |
List of linear layers |
|
activation |
Activation function |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input spatial dimension |
required |
num_subdomains
|
int
|
Number of subdomains to gate |
required |
hidden_dims
|
Sequence[int]
|
Hidden layer dimensions |
required |
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
For usage examples and best practices, see the Domain Decomposition PINNs Guide.
Activations¶
opifex.neural.activations
¶
Activation functions optimized for scientific neural networks.
This module provides a full collection of activation functions specifically optimized for scientific machine learning applications. All functions are fully compatible with Flax NNX patterns and JAX transformations.
MODERNIZATION APPLIED: - Full Flax NNX compliance with proper type annotations - Enhanced activation function selection with error handling - Optimized implementations for scientific computing - Support for both standard and specialized activation patterns
get_activation
¶
Get activation function by name or return function if already callable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str | Callable
|
Name of the activation function (case-insensitive) or callable function |
required |
Returns:
| Type | Description |
|---|---|
Any
|
JAX activation function or callable |
Raises:
| Type | Description |
|---|---|
ValueError
|
If activation function is not found |
list_activations
¶
register_activation
¶
mish
¶
Mish activation function: x * tanh(softplus(x)).
Mish is a self-gated activation function that has shown excellent performance in deep networks. It's smooth and non-monotonic.
Mathematical definition: f(x) = x * tanh(ln(1 + exp(x)))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output array with Mish activation applied |
Note
This implementation uses softplus(x) = ln(1 + exp(x)) for numerical stability.
snake_activation
¶
Snake activation function: x + sin²(αx)/α.
Snake activation has been shown to work well for certain scientific applications, particularly those involving periodic patterns.
Mathematical definition: f(x) = x + (1/α) * sin²(αx)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array |
required |
a
|
float
|
Frequency parameter (default: 1.0) |
1.0
|
Returns:
| Type | Description |
|---|---|
Array
|
Output array with Snake activation applied |
Note
The frequency parameter α controls the oscillation frequency. Higher values create more frequent oscillations.
gaussian_activation
¶
Gaussian activation function: exp(-x²/(2σ²)).
Gaussian activation can be useful for radial basis function networks and certain scientific applications where localized responses are desired.
Mathematical definition: f(x) = exp(-x²/(2σ²))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array |
required |
sigma
|
float
|
Standard deviation parameter (default: 1.0) |
1.0
|
Returns:
| Type | Description |
|---|---|
Array
|
Output array with Gaussian activation applied |
Note
The σ parameter controls the width of the Gaussian. Smaller values create sharper peaks.
normalized_tanh
¶
Normalized tanh activation: 1.7159 * tanh(2x/3).
This is a normalized version of tanh that has unit variance for normalized inputs, which can help with training stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output array with normalized tanh applied |
soft_exponential
¶
Soft exponential activation function.
This is a parameterized activation that interpolates between different behaviors based on the alpha parameter.
Mathematical definition: - If α < 0: -ln(1 - α(x + α)) / α - If α = 0: x - If α > 0: (exp(αx) - 1) / α + α
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array |
required |
alpha
|
float
|
Shape parameter |
0.0
|
Returns:
| Type | Description |
|---|---|
Array
|
Output array with soft exponential applied |
get_derivative_activation
¶
Get the derivative of an activation function.
This is useful for implementations that need explicit derivatives rather than relying on automatic differentiation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Name of the activation function |
required |
Returns:
| Type | Description |
|---|---|
Any
|
Derivative function of the specified activation |
Raises:
| Type | Description |
|---|---|
ValueError
|
If activation name is not recognized or derivative not available |