Skip to content

Training API Documentation

Overview

The opifex.training module provides full training infrastructure for scientific machine learning, including physics-informed neural networks, optimization algorithms, and quantum-aware training workflows.

Module Structure:

  • opifex.core.training.trainer - Unified Trainer (Recommended)
  • opifex.core.training.config - Training configuration classes
  • opifex.core.training.physics_configs - Physics-specific configurations
  • opifex.training.basic_trainer - Core trainer implementations
  • opifex.training.metrics - Metrics tracking and state management
  • opifex.training.recovery - Error recovery and stability handling
  • opifex.training.components - Modular training components
  • opifex.training.utils - Utility functions for safe model operations

Core Classes

The unified, composable trainer architecture for all training workflows.

from opifex.core.training.trainer import Trainer
from opifex.core.training.config import TrainingConfig
from opifex.core.training.physics_configs import ConservationConfig, MultiScaleConfig

# Configure physics-aware training
conservation_config = ConservationConfig(
    laws=["energy", "momentum"],
    energy_tolerance=1e-6,
)

multiscale_config = MultiScaleConfig(
    scales=["molecular", "atomic"],
    weights={"molecular": 0.5, "atomic": 0.5},
)

config = TrainingConfig(
    num_epochs=100,
    learning_rate=1e-3,
    conservation_config=conservation_config,
    multiscale_config=multiscale_config,
)

# Create and use trainer
# Create a dummy model for demonstration
class SimpleModel(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.linear = nnx.Linear(10, 1, rngs=rngs)
    def __call__(self, x):
        return self.linear(x)

model = SimpleModel(rngs=nnx.Rngs(0))
trainer = Trainer(model, config)
trained_model, history = trainer.train(train_data, val_data)

Key Features:

  • Composable Architecture: Mix and match physics configurations
  • Type-Safe: Full type hints and IDE support
  • Zero Runtime Overhead: Configuration at initialization only
  • Extensible: Add custom configs without modifying trainer
  • Production-Ready: Full testing and error handling

Supported Physics Configurations:

  • ConservationConfig: Energy, momentum, mass, and symmetry conservation
  • MultiScaleConfig: Multi-scale physics with coupling
  • QuantumTrainingConfig: Quantum chemistry and electronic structure
  • BoundaryConfig: Boundary condition enforcement
  • DFTConfig: Density functional theory workflows
  • SCFConfig: Self-consistent field convergence
  • MetricsTrackingConfig: Custom metrics tracking
  • LoggingConfig: Advanced logging and alerting

BasicTrainer

Standard training workflow with physics-informed capabilities.

from opifex.training.basic_trainer import BasicTrainer
from opifex.core.training.config import TrainingConfig

trainer = BasicTrainer(model, config)
trained_model, history = trainer.train(train_data, val_data)

Key Features:

  • Physics-informed neural network (PINN) training
  • Orbax-compatible checkpointing
  • JAX Array and automatic differentiation support
  • Type-safe with jaxtyping annotations

ModularTrainer ✅ NEW

Component-based training architecture with production-grade capabilities.

from opifex.training.basic_trainer import ModularTrainer
from opifex.core.training.config import TrainingConfig
from opifex.core.training.components.recovery import ErrorRecoveryManager
from opifex.core.training.components import FlexibleOptimizerFactory

trainer = ModularTrainer(
    model=model,
    config=config,
    rngs=rngs,
    components={
        "error_recovery": ErrorRecoveryManager(),
        "optimizer_factory": FlexibleOptimizerFactory()
    }
)

Key Features:

  • Component composition architecture
  • Pluggable training components
  • Production-grade error handling
  • Optimization strategies
  • Physics-aware metrics collection

Components

ErrorRecoveryManager ✅ NEW

Production-grade error handling with gradient stability and automatic recovery.

from opifex.training.recovery import ErrorRecoveryManager

error_manager = ErrorRecoveryManager(
    config={
        "max_retries": 5,
        "gradient_clip_threshold": 1.0,
        "loss_explosion_threshold": 100.0,
        "checkpoint_on_error": True
    }
)

Features:

  • Gradient clipping with automatic threshold adaptation
  • NaN detection and recovery mechanisms
  • Loss explosion detection and mitigation
  • Multiple recovery strategies (gradient clipping, learning rate reduction, parameter reinitialization)
  • Full error logging and analytics

FlexibleOptimizerFactory ✅ NEW

Optimizer creation with scheduling support.

from opifex.core.training.components import FlexibleOptimizerFactory

optimizer_factory = FlexibleOptimizerFactory(
    config={
        "optimizer_type": "adamw",  # "adam", "adamw", "sgd"
        "learning_rate": 1e-3,
        "schedule_type": "cosine",  # "cosine", "exponential", "linear"
        "total_steps": 1000,
        "cosine_alpha": 0.0
    }
)

Supported Optimizers:

  • Adam: Adaptive moment estimation
  • AdamW: Adam with weight decay
  • SGD: Stochastic gradient descent with momentum

Supported Schedules:

  • Cosine: Cosine annealing learning rate decay
  • Exponential: Exponential decay
  • Linear: Linear decay

MetricsCollector ✅ NEW

Physics-aware metrics collection with convergence tracking.

from opifex.training.metrics import AdvancedMetricsCollector

collector = AdvancedMetricsCollector()
collector.start_training()
metrics = collector.collect_physics_metrics(model, x, y_true)

Collected Metrics:

  • Training loss and validation metrics
  • Gradient norms and stability indicators
  • Physics-specific metrics (energy conservation, mass conservation)
  • Convergence rates and training diagnostics
  • Real-time performance analytics

TrainingComponent ✅ NEW

Base class for creating custom training components.

from opifex.core.training.components import TrainingComponent

class CustomComponent(TrainingComponent):
    def setup(self, model, training_state):
        # Initialize component
        pass

    def step(self, model, training_state):
        # Update component state
        pass

Purpose:

  • Enables modular component development
  • Provides common interface for training components
  • Supports pluggable architecture patterns

Configuration Classes

TrainingConfig

Training configuration with full parameter control.

from opifex.core.training.config import TrainingConfig

config = TrainingConfig(
    num_epochs=1000,
    batch_size=64,
    learning_rate=1e-3,
    validation_frequency=100,
    checkpoint_frequency=500,
    early_stopping=True,
    patience=50
)

Parameters:

  • num_epochs: Number of training epochs
  • batch_size: Training batch size
  • learning_rate: Learning rate (can be overridden by optimizer factory)
  • validation_frequency: Validation evaluation frequency
  • checkpoint_frequency: Model checkpointing frequency
  • early_stopping: Enable early stopping
  • patience: Early stopping patience

TrainingState

Enhanced training state with full tracking.

from opifex.training.metrics import TrainingState

# Automatically managed by trainers
state = trainer.training_state
print(f"Current epoch: {state.epoch}")
print(f"Best validation loss: {state.best_val_loss}")

Tracked Information:

  • Current epoch and step counters
  • Best validation metrics
  • Model and optimizer states
  • Recovery attempt history
  • Training diagnostics

Physics-Informed Training

PhysicsInformedLoss

Hierarchical multi-physics loss composition with adaptive weighting.

from opifex.core.physics.losses import PhysicsInformedLoss, PhysicsLossConfig

physics_loss = PhysicsInformedLoss(
    config=PhysicsLossConfig(
        physics_weight=1.0,
        boundary_weight=1.0,
        data_weight=1.0,
        adaptive_weighting=True
    )
)

# Use with BasicTrainer
trainer.set_physics_loss(physics_loss)

Supported Physics:

  • Partial differential equations (PDEs)
  • Conservation laws (mass, momentum, energy)
  • Quantum mechanical constraints
  • Boundary condition enforcement

Usage Examples

Basic Training Workflow

import jax
import jax.numpy as jnp
import flax.nnx as nnx
from opifex.neural.base import StandardMLP
from opifex.training.basic_trainer import BasicTrainer
from opifex.core.training.config import TrainingConfig

# Create model
model = StandardMLP([1, 32, 32, 1], activation="tanh", rngs=nnx.Rngs(42))

# Configure training
config = TrainingConfig(num_epochs=1000, batch_size=64, learning_rate=1e-3)

# Create trainer and train
trainer = BasicTrainer(model, config)
trained_model, history = trainer.train(train_data, val_data)

Modular Training

from opifex.training.basic_trainer import ModularTrainer
from opifex.core.training.config import TrainingConfig
from opifex.core.training.components.recovery import ErrorRecoveryManager
from opifex.core.training.components import FlexibleOptimizerFactory

# Configure components
error_recovery = ErrorRecoveryManager(config={"max_retries": 5, "gradient_clip_threshold": 1.0})
optimizer_factory = FlexibleOptimizerFactory(config={"optimizer_type": "adamw", "schedule_type": "cosine"})

# Create modular trainer
# Note: AdvancedMetricsCollector is automatically created by ModularTrainer
trainer = ModularTrainer(
    model=model,
    config=config,
    rngs=rngs,
    components={
        "error_recovery": error_recovery,
        "optimizer_factory": optimizer_factory
    }
)

# Train with capabilities
trained_model, history = trainer.train(train_data, val_data)

Physics-Informed Training

from opifex.training.basic_trainer import BasicTrainer
from opifex.core.physics.losses import PhysicsInformedLoss

# Define PDE residual
def pde_residual(model_fn, x, t):
    u = model_fn(jnp.array([x, t]).reshape(1, -1))
    # Compute PDE residual (example: heat equation)
    u_t = jax.grad(lambda t: model_fn(jnp.array([x, t]).reshape(1, -1)))(t)
    u_xx = jax.grad(jax.grad(lambda x: model_fn(jnp.array([x, t]).reshape(1, -1))))(x)
    return u_t - 0.1 * u_xx

# Configure physics loss
physics_loss = PhysicsInformedLoss(pde_residual=pde_residual)

# Set up PINN training
trainer = BasicTrainer(model, config)
trainer.set_physics_loss(physics_loss)

# Train with physics constraints
trained_model, history = trainer.train(
    train_data=(domain_points, None),  # No target data for domain points
    boundary_data=(boundary_points, boundary_values)
)

Integration

With Neural Networks

from opifex.neural.base import StandardMLP
from opifex.neural.quantum import QuantumMLP

# Standard networks
standard_model = StandardMLP([3, 64, 64, 1], activation="swish", rngs=rngs)

# Quantum networks
quantum_model = QuantumMLP(features=[128, 128, 1], n_atoms=3, rngs=rngs)

With Optimization

from opifex.optimization import MetaOptimizer

# Use with learn-to-optimize
meta_optimizer = MetaOptimizer()
trainer = BasicTrainer(model, config, meta_optimizer=meta_optimizer)

With Geometry

from opifex.geometry import ComplexDomain
from opifex.core.conditions import DirichletBC

# Complex domain training
domain = ComplexDomain(boundaries=["left", "right", "top", "bottom"])
boundary_conditions = [DirichletBC(boundary="left", value=0.0)]

Best Practices

Production Training

  1. Use ModularTrainer for production workflows with error recovery
  2. Configure appropriate error recovery strategies for your problem
  3. Monitor training metrics with MetricsCollector
  4. Use learning rate scheduling for better convergence
  5. Enable checkpointing for long training runs

Physics-Informed Training

  1. Balance loss weights between physics, boundary, and data terms
  2. Use adaptive weighting for complex multi-physics problems
  3. Monitor conservation laws during training
  4. Validate physics constraints on test data

Performance Optimization

  1. Choose appropriate batch sizes for your hardware
  2. Use JAX transformations (vmap, jit) for efficiency
  3. Profile training with JAX profiling tools
  4. Monitor gradient health and stability

Troubleshooting

Common Issues

  • NaN losses: Enable NaN detection in ErrorRecoveryManager
  • Gradient explosions: Use gradient clipping with appropriate thresholds
  • Slow convergence: Try different optimizers and learning rate schedules
  • Physics constraint violations: Increase physics loss weights or improve residual computation

Debug Features

  • Full logging of training metrics and errors
  • Recovery attempt tracking for debugging stability issues
  • Gradient norm monitoring for optimization health
  • Physics constraint validation for PINN problems

Multilevel Training

Coarse-to-fine training hierarchies for accelerated convergence.

Width-Based Hierarchy (MLPs)

opifex.training.multilevel.coarse_to_fine

CascadeTrainer

CascadeTrainer(input_dim: int, output_dim: int, base_hidden_dims: Sequence[int], config: MultilevelConfig | None = None, *, activation: Callable[[Array], Array] = tanh, rngs: Rngs)

Cascade trainer for multilevel training.

Trains models from coarse to fine levels, transferring learned parameters between levels.

Attributes:

Name Type Description
config

Multilevel configuration

hierarchy

List of models from coarse to fine

current_level

Current training level

Example

trainer = CascadeTrainer( ... input_dim=1, output_dim=1, ... base_hidden_dims=[64, 64], ... config=MultilevelConfig(num_levels=3), ... rngs=nnx.Rngs(0) ... ) model = trainer.get_current_model()

Train model...

trainer.advance_level() finer_model = trainer.get_current_model()

Parameters:

Name Type Description Default
input_dim int

Input dimension

required
output_dim int

Output dimension

required
base_hidden_dims Sequence[int]

Hidden dimensions for finest level

required
config MultilevelConfig | None

Multilevel configuration

None
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

get_current_model

get_current_model() -> MultilevelMLP

Get model at current level.

Returns:

Type Description
MultilevelMLP

Current level model

advance_level

advance_level() -> bool

Advance to next finer level.

Transfers learned parameters from current level to next level.

Returns:

Type Description
bool

True if advanced successfully, False if already at finest level

is_at_finest

is_at_finest() -> bool

Check if at finest level.

Returns:

Type Description
bool

True if at finest level

get_epochs_for_current_level

get_epochs_for_current_level() -> int

Get number of epochs for current level.

Returns:

Type Description
int

Number of epochs to train at current level

MultilevelConfig dataclass

MultilevelConfig(num_levels: int = 3, coarsening_factor: float = 0.5, level_epochs: list[int] = (lambda: [100, 200, 300])(), warmup_epochs: int = 0)

Configuration for multilevel training.

Attributes:

Name Type Description
num_levels int

Number of levels in the hierarchy

coarsening_factor float

Factor to reduce width at each coarser level

level_epochs list[int]

Number of epochs to train at each level

warmup_epochs int

Extra epochs at the finest level

create_network_hierarchy

create_network_hierarchy(input_dim: int, output_dim: int, base_hidden_dims: Sequence[int], num_levels: int, coarsening_factor: float = 0.5, *, activation: Callable[[Array], Array] = tanh, rngs: Rngs) -> list[MultilevelMLP]

Create hierarchy of networks from coarse to fine.

The finest level (highest index) uses the base_hidden_dims. Coarser levels use progressively smaller networks.

Parameters:

Name Type Description Default
input_dim int

Input dimension

required
output_dim int

Output dimension

required
base_hidden_dims Sequence[int]

Hidden dimensions for the finest level

required
num_levels int

Number of levels in hierarchy

required
coarsening_factor float

Factor to reduce width at each coarser level

0.5
activation Callable[[Array], Array]

Activation function

tanh
rngs Rngs

Random number generators

required

Returns:

Type Description
list[MultilevelMLP]

List of networks from coarsest to finest

prolongate

prolongate(coarse_model: MultilevelMLP, fine_model: MultilevelMLP) -> MultilevelMLP

Transfer (prolongate) parameters from coarse to fine model.

This copies the coarse model parameters to the corresponding subset of the fine model parameters. Additional fine model parameters are left at their initialized values.

Parameters:

Name Type Description Default
coarse_model MultilevelMLP

Coarse level model

required
fine_model MultilevelMLP

Fine level model (will be modified in place)

required

Returns:

Type Description
MultilevelMLP

Fine model with prolongated parameters

restrict

restrict(fine_model: MultilevelMLP, coarse_model: MultilevelMLP) -> MultilevelMLP

Transfer (restrict) parameters from fine to coarse model.

This copies a subset of the fine model parameters to the coarse model.

Parameters:

Name Type Description Default
fine_model MultilevelMLP

Fine level model

required
coarse_model MultilevelMLP

Coarse level model (will be modified in place)

required

Returns:

Type Description
MultilevelMLP

Coarse model with restricted parameters

Mode-Based Hierarchy (FNOs)

opifex.training.multilevel.multilevel_fno

MultilevelFNOTrainer

MultilevelFNOTrainer(width: int, input_dim: int, output_dim: int, config: MultilevelFNOConfig | None = None, *, num_layers: int = 4, activation: Callable[[Array], Array] = gelu, rngs: Rngs)

Trainer for multilevel FNO.

Trains FNOs from coarse to fine modes, transferring learned spectral weights between levels.

Attributes:

Name Type Description
config

Multilevel configuration

hierarchy

List of FNOs from coarse to fine

current_level

Current training level

Parameters:

Name Type Description Default
width int

Hidden channel width

required
input_dim int

Input channels

required
output_dim int

Output channels

required
config MultilevelFNOConfig | None

Multilevel configuration

None
num_layers int

FNO layers per network

4
activation Callable[[Array], Array]

Activation function

gelu
rngs Rngs

Random number generators

required

get_current_model

get_current_model() -> SimpleFNO

Get FNO at current level.

Returns:

Type Description
SimpleFNO

Current level FNO

advance_level

advance_level() -> bool

Advance to next finer level.

Transfers learned weights from current level to next level.

Returns:

Type Description
bool

True if advanced successfully, False if at finest level

is_at_finest

is_at_finest() -> bool

Check if at finest level.

Returns:

Type Description
bool

True if at finest level

get_epochs_for_current_level

get_epochs_for_current_level() -> int

Get epochs for current level from config.

Returns:

Type Description
int

Number of epochs to train at current level

MultilevelFNOConfig dataclass

MultilevelFNOConfig(num_levels: int = 3, base_modes: int = 12, mode_reduction_factor: int = 2, level_epochs: list[int] = (lambda: [50, 100, 150])())

Configuration for multilevel FNO training.

Attributes:

Name Type Description
num_levels int

Number of levels in the hierarchy

base_modes int

Number of Fourier modes at finest level

mode_reduction_factor int

Factor to reduce modes at each coarser level

level_epochs list[int]

Epochs to train at each level

create_fno_hierarchy

create_fno_hierarchy(base_modes: int, width: int, input_dim: int, output_dim: int, num_levels: int, reduction_factor: int = 2, *, num_layers: int = 4, activation: Callable[[Array], Array] = gelu, rngs: Rngs) -> list[SimpleFNO]

Create hierarchy of FNOs from coarse to fine.

Parameters:

Name Type Description Default
base_modes int

Modes at finest level

required
width int

Hidden channel width (same for all levels)

required
input_dim int

Input channels

required
output_dim int

Output channels

required
num_levels int

Number of levels

required
reduction_factor int

Mode reduction per level

2
num_layers int

FNO layers per network

4
activation Callable[[Array], Array]

Activation function

gelu
rngs Rngs

Random number generators

required

Returns:

Type Description
list[SimpleFNO]

List of FNOs from coarsest to finest

create_mode_hierarchy

create_mode_hierarchy(base_modes: int, num_levels: int, reduction_factor: int = 2) -> list[int]

Create hierarchy of mode counts from coarse to fine.

The finest level (highest index) uses base_modes. Coarser levels use progressively fewer modes.

Parameters:

Name Type Description Default
base_modes int

Number of modes at finest level

required
num_levels int

Number of levels in hierarchy

required
reduction_factor int

Factor to reduce modes at each coarser level

2

Returns:

Type Description
list[int]

List of mode counts from coarsest to finest

prolongate_fno_modes

prolongate_fno_modes(coarse_fno: SimpleFNO, fine_fno: SimpleFNO) -> SimpleFNO

Transfer spectral weights from coarse to fine FNO.

Copies the lower-frequency modes from the coarse network to the corresponding modes in the fine network.

Parameters:

Name Type Description Default
coarse_fno SimpleFNO

Coarse FNO (fewer modes)

required
fine_fno SimpleFNO

Fine FNO (more modes, modified in place)

required

Returns:

Type Description
SimpleFNO

Fine FNO with prolongated weights

restrict_fno_modes

restrict_fno_modes(fine_fno: SimpleFNO, coarse_fno: SimpleFNO) -> SimpleFNO

Transfer spectral weights from fine to coarse FNO.

Copies the lower-frequency modes from the fine network to the coarse network (truncation).

Parameters:

Name Type Description Default
fine_fno SimpleFNO

Fine FNO (more modes)

required
coarse_fno SimpleFNO

Coarse FNO (fewer modes, modified in place)

required

Returns:

Type Description
SimpleFNO

Coarse FNO with restricted weights

For usage examples and best practices, see the Multilevel Training Guide.

Adaptive Sampling

Residual-based sampling strategies for efficient PINN training.

opifex.training.adaptive_sampling

RADSampler

RADSampler(config: RADConfig | None = None)

Residual-based Adaptive Distribution sampler.

This sampler draws collocation points from the domain with probability proportional to the PDE residual magnitude.

Attributes:

Name Type Description
config

RAD configuration

Example

sampler = RADSampler() domain_points = jnp.linspace(0, 1, 100).reshape(-1, 1) residuals = compute_pde_residual(model, domain_points) key = jax.random.key(0) sampled = sampler.sample(domain_points, residuals, batch_size=32, key=key)

Parameters:

Name Type Description Default
config RADConfig | None

RAD configuration. Uses defaults if None.

None

sample

sample(domain_points: Float[Array, ...], residuals: Float[Array, ...], batch_size: int, key: PRNGKeyArray) -> Float[Array, ...]

Sample collocation points based on residual distribution.

Parameters:

Name Type Description Default
domain_points Float[Array, ...]

Candidate points in the domain

required
residuals Float[Array, ...]

PDE residual magnitudes at each point

required
batch_size int

Number of points to sample

required
key PRNGKeyArray

JAX random key

required

Returns:

Type Description
Float[Array, ...]

Sampled collocation points

compute_weights

compute_weights(residuals: Float[Array, ...]) -> Float[Array, ...]

Compute importance weights for residual-weighted loss.

These weights can be used to weight the loss function instead of resampling the collocation points.

Parameters:

Name Type Description Default
residuals Float[Array, ...]

PDE residual magnitudes

required

Returns:

Type Description
Float[Array, ...]

Importance weights

RADConfig dataclass

RADConfig(beta: float = 1.0, resample_frequency: int = 100, min_probability: float = 1e-06, temperature: float = 1.0)

Configuration for Residual-based Adaptive Distribution sampling.

Attributes:

Name Type Description
beta float

Exponent for residual weighting. Higher values concentrate sampling more strongly on high-residual regions. ξ_j = |r_j|^β / Σ_k |r_k|^β

resample_frequency int

Number of training steps between resampling

min_probability float

Minimum sampling probability to ensure coverage

temperature float

Temperature for probability smoothing

RARDRefiner

RARDRefiner(config: RARDConfig | None = None, num_new_points: int | None = None, noise_scale: float | None = None)

Residual-based Adaptive Refinement with Distribution.

This refiner adds new collocation points near regions with high PDE residual, adaptively increasing resolution where needed.

Attributes:

Name Type Description
config

RAR-D configuration

Example

refiner = RARDRefiner(num_new_points=20) refined_points = refiner.refine(points, residuals, bounds, key)

Parameters:

Name Type Description Default
config RARDConfig | None

RAR-D configuration. Uses defaults if None.

None
num_new_points int | None

Override for number of new points

None
noise_scale float | None

Override for noise scale

None

refine

refine(current_points: Float[Array, ...], residuals: Float[Array, ...], bounds: Float[Array, 'dim 2'], key: PRNGKeyArray) -> Float[Array, ...]

Add new points near high-residual regions.

Parameters:

Name Type Description Default
current_points Float[Array, ...]

Current collocation points

required
residuals Float[Array, ...]

PDE residual magnitudes at each point

required
bounds Float[Array, 'dim 2']

Domain bounds, shape (dim, 2) with [min, max]

required
key PRNGKeyArray

JAX random key

required

Returns:

Type Description
Float[Array, ...]

Refined point set including new points

identify_refinement_regions

identify_refinement_regions(residuals: Float[Array, ...]) -> Float[Array, ...]

Identify which points are in refinement regions.

Parameters:

Name Type Description Default
residuals Float[Array, ...]

PDE residual magnitudes

required

Returns:

Type Description
Float[Array, ...]

Boolean mask indicating refinement regions

RARDConfig dataclass

RARDConfig(num_new_points: int = 10, percentile_threshold: float = 90.0, noise_scale: float = 0.1)

Configuration for RAR-D refinement.

Attributes:

Name Type Description
num_new_points int

Number of new points to add per refinement

percentile_threshold float

Only refine near points above this percentile

noise_scale float

Scale of random perturbation for new points

compute_sampling_distribution

compute_sampling_distribution(residuals: Float[Array, ...], beta: float = 1.0, min_probability: float = 1e-06) -> Float[Array, ...]

Compute sampling distribution from residual magnitudes.

The sampling probability for each point is proportional to the residual magnitude raised to the power beta: p_j = |r_j|^β / Σ_k |r_k|^β

Parameters:

Name Type Description Default
residuals Float[Array, ...]

PDE residual magnitudes at each collocation point

required
beta float

Exponent for residual weighting

1.0
min_probability float

Minimum probability to ensure all points have some chance of being sampled

1e-06

Returns:

Type Description
Float[Array, ...]

Sampling probabilities that sum to 1

For detailed algorithms and best practices, see the Adaptive Sampling Guide.