Training API Documentation¶
Overview¶
The opifex.training module provides full training infrastructure for scientific machine learning, including physics-informed neural networks, optimization algorithms, and quantum-aware training workflows.
Module Structure:
opifex.core.training.trainer- Unified Trainer (Recommended)opifex.core.training.config- Training configuration classesopifex.core.training.physics_configs- Physics-specific configurationsopifex.training.basic_trainer- Core trainer implementationsopifex.training.metrics- Metrics tracking and state managementopifex.training.recovery- Error recovery and stability handlingopifex.training.components- Modular training componentsopifex.training.utils- Utility functions for safe model operations
Core Classes¶
Trainer ⭐ RECOMMENDED¶
The unified, composable trainer architecture for all training workflows.
from opifex.core.training.trainer import Trainer
from opifex.core.training.config import TrainingConfig
from opifex.core.training.physics_configs import ConservationConfig, MultiScaleConfig
# Configure physics-aware training
conservation_config = ConservationConfig(
laws=["energy", "momentum"],
energy_tolerance=1e-6,
)
multiscale_config = MultiScaleConfig(
scales=["molecular", "atomic"],
weights={"molecular": 0.5, "atomic": 0.5},
)
config = TrainingConfig(
num_epochs=100,
learning_rate=1e-3,
conservation_config=conservation_config,
multiscale_config=multiscale_config,
)
# Create and use trainer
# Create a dummy model for demonstration
class SimpleModel(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(10, 1, rngs=rngs)
def __call__(self, x):
return self.linear(x)
model = SimpleModel(rngs=nnx.Rngs(0))
trainer = Trainer(model, config)
trained_model, history = trainer.train(train_data, val_data)
Key Features:
- Composable Architecture: Mix and match physics configurations
- Type-Safe: Full type hints and IDE support
- Zero Runtime Overhead: Configuration at initialization only
- Extensible: Add custom configs without modifying trainer
- Production-Ready: Full testing and error handling
Supported Physics Configurations:
ConservationConfig: Energy, momentum, mass, and symmetry conservationMultiScaleConfig: Multi-scale physics with couplingQuantumTrainingConfig: Quantum chemistry and electronic structureBoundaryConfig: Boundary condition enforcementDFTConfig: Density functional theory workflowsSCFConfig: Self-consistent field convergenceMetricsTrackingConfig: Custom metrics trackingLoggingConfig: Advanced logging and alerting
BasicTrainer¶
Standard training workflow with physics-informed capabilities.
from opifex.training.basic_trainer import BasicTrainer
from opifex.core.training.config import TrainingConfig
trainer = BasicTrainer(model, config)
trained_model, history = trainer.train(train_data, val_data)
Key Features:
- Physics-informed neural network (PINN) training
- Orbax-compatible checkpointing
- JAX Array and automatic differentiation support
- Type-safe with jaxtyping annotations
ModularTrainer ✅ NEW¶
Component-based training architecture with production-grade capabilities.
from opifex.training.basic_trainer import ModularTrainer
from opifex.core.training.config import TrainingConfig
from opifex.core.training.components.recovery import ErrorRecoveryManager
from opifex.core.training.components import FlexibleOptimizerFactory
trainer = ModularTrainer(
model=model,
config=config,
rngs=rngs,
components={
"error_recovery": ErrorRecoveryManager(),
"optimizer_factory": FlexibleOptimizerFactory()
}
)
Key Features:
- Component composition architecture
- Pluggable training components
- Production-grade error handling
- Optimization strategies
- Physics-aware metrics collection
Components¶
ErrorRecoveryManager ✅ NEW¶
Production-grade error handling with gradient stability and automatic recovery.
from opifex.training.recovery import ErrorRecoveryManager
error_manager = ErrorRecoveryManager(
config={
"max_retries": 5,
"gradient_clip_threshold": 1.0,
"loss_explosion_threshold": 100.0,
"checkpoint_on_error": True
}
)
Features:
- Gradient clipping with automatic threshold adaptation
- NaN detection and recovery mechanisms
- Loss explosion detection and mitigation
- Multiple recovery strategies (gradient clipping, learning rate reduction, parameter reinitialization)
- Full error logging and analytics
FlexibleOptimizerFactory ✅ NEW¶
Optimizer creation with scheduling support.
from opifex.core.training.components import FlexibleOptimizerFactory
optimizer_factory = FlexibleOptimizerFactory(
config={
"optimizer_type": "adamw", # "adam", "adamw", "sgd"
"learning_rate": 1e-3,
"schedule_type": "cosine", # "cosine", "exponential", "linear"
"total_steps": 1000,
"cosine_alpha": 0.0
}
)
Supported Optimizers:
- Adam: Adaptive moment estimation
- AdamW: Adam with weight decay
- SGD: Stochastic gradient descent with momentum
Supported Schedules:
- Cosine: Cosine annealing learning rate decay
- Exponential: Exponential decay
- Linear: Linear decay
MetricsCollector ✅ NEW¶
Physics-aware metrics collection with convergence tracking.
from opifex.training.metrics import AdvancedMetricsCollector
collector = AdvancedMetricsCollector()
collector.start_training()
metrics = collector.collect_physics_metrics(model, x, y_true)
Collected Metrics:
- Training loss and validation metrics
- Gradient norms and stability indicators
- Physics-specific metrics (energy conservation, mass conservation)
- Convergence rates and training diagnostics
- Real-time performance analytics
TrainingComponent ✅ NEW¶
Base class for creating custom training components.
from opifex.core.training.components import TrainingComponent
class CustomComponent(TrainingComponent):
def setup(self, model, training_state):
# Initialize component
pass
def step(self, model, training_state):
# Update component state
pass
Purpose:
- Enables modular component development
- Provides common interface for training components
- Supports pluggable architecture patterns
Configuration Classes¶
TrainingConfig¶
Training configuration with full parameter control.
from opifex.core.training.config import TrainingConfig
config = TrainingConfig(
num_epochs=1000,
batch_size=64,
learning_rate=1e-3,
validation_frequency=100,
checkpoint_frequency=500,
early_stopping=True,
patience=50
)
Parameters:
num_epochs: Number of training epochsbatch_size: Training batch sizelearning_rate: Learning rate (can be overridden by optimizer factory)validation_frequency: Validation evaluation frequencycheckpoint_frequency: Model checkpointing frequencyearly_stopping: Enable early stoppingpatience: Early stopping patience
TrainingState¶
Enhanced training state with full tracking.
from opifex.training.metrics import TrainingState
# Automatically managed by trainers
state = trainer.training_state
print(f"Current epoch: {state.epoch}")
print(f"Best validation loss: {state.best_val_loss}")
Tracked Information:
- Current epoch and step counters
- Best validation metrics
- Model and optimizer states
- Recovery attempt history
- Training diagnostics
Physics-Informed Training¶
PhysicsInformedLoss¶
Hierarchical multi-physics loss composition with adaptive weighting.
from opifex.core.physics.losses import PhysicsInformedLoss, PhysicsLossConfig
physics_loss = PhysicsInformedLoss(
config=PhysicsLossConfig(
physics_weight=1.0,
boundary_weight=1.0,
data_weight=1.0,
adaptive_weighting=True
)
)
# Use with BasicTrainer
trainer.set_physics_loss(physics_loss)
Supported Physics:
- Partial differential equations (PDEs)
- Conservation laws (mass, momentum, energy)
- Quantum mechanical constraints
- Boundary condition enforcement
Usage Examples¶
Basic Training Workflow¶
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from opifex.neural.base import StandardMLP
from opifex.training.basic_trainer import BasicTrainer
from opifex.core.training.config import TrainingConfig
# Create model
model = StandardMLP([1, 32, 32, 1], activation="tanh", rngs=nnx.Rngs(42))
# Configure training
config = TrainingConfig(num_epochs=1000, batch_size=64, learning_rate=1e-3)
# Create trainer and train
trainer = BasicTrainer(model, config)
trained_model, history = trainer.train(train_data, val_data)
Modular Training¶
from opifex.training.basic_trainer import ModularTrainer
from opifex.core.training.config import TrainingConfig
from opifex.core.training.components.recovery import ErrorRecoveryManager
from opifex.core.training.components import FlexibleOptimizerFactory
# Configure components
error_recovery = ErrorRecoveryManager(config={"max_retries": 5, "gradient_clip_threshold": 1.0})
optimizer_factory = FlexibleOptimizerFactory(config={"optimizer_type": "adamw", "schedule_type": "cosine"})
# Create modular trainer
# Note: AdvancedMetricsCollector is automatically created by ModularTrainer
trainer = ModularTrainer(
model=model,
config=config,
rngs=rngs,
components={
"error_recovery": error_recovery,
"optimizer_factory": optimizer_factory
}
)
# Train with capabilities
trained_model, history = trainer.train(train_data, val_data)
Physics-Informed Training¶
from opifex.training.basic_trainer import BasicTrainer
from opifex.core.physics.losses import PhysicsInformedLoss
# Define PDE residual
def pde_residual(model_fn, x, t):
u = model_fn(jnp.array([x, t]).reshape(1, -1))
# Compute PDE residual (example: heat equation)
u_t = jax.grad(lambda t: model_fn(jnp.array([x, t]).reshape(1, -1)))(t)
u_xx = jax.grad(jax.grad(lambda x: model_fn(jnp.array([x, t]).reshape(1, -1))))(x)
return u_t - 0.1 * u_xx
# Configure physics loss
physics_loss = PhysicsInformedLoss(pde_residual=pde_residual)
# Set up PINN training
trainer = BasicTrainer(model, config)
trainer.set_physics_loss(physics_loss)
# Train with physics constraints
trained_model, history = trainer.train(
train_data=(domain_points, None), # No target data for domain points
boundary_data=(boundary_points, boundary_values)
)
Integration¶
With Neural Networks¶
from opifex.neural.base import StandardMLP
from opifex.neural.quantum import QuantumMLP
# Standard networks
standard_model = StandardMLP([3, 64, 64, 1], activation="swish", rngs=rngs)
# Quantum networks
quantum_model = QuantumMLP(features=[128, 128, 1], n_atoms=3, rngs=rngs)
With Optimization¶
from opifex.optimization import MetaOptimizer
# Use with learn-to-optimize
meta_optimizer = MetaOptimizer()
trainer = BasicTrainer(model, config, meta_optimizer=meta_optimizer)
With Geometry¶
from opifex.geometry import ComplexDomain
from opifex.core.conditions import DirichletBC
# Complex domain training
domain = ComplexDomain(boundaries=["left", "right", "top", "bottom"])
boundary_conditions = [DirichletBC(boundary="left", value=0.0)]
Best Practices¶
Production Training¶
- Use ModularTrainer for production workflows with error recovery
- Configure appropriate error recovery strategies for your problem
- Monitor training metrics with MetricsCollector
- Use learning rate scheduling for better convergence
- Enable checkpointing for long training runs
Physics-Informed Training¶
- Balance loss weights between physics, boundary, and data terms
- Use adaptive weighting for complex multi-physics problems
- Monitor conservation laws during training
- Validate physics constraints on test data
Performance Optimization¶
- Choose appropriate batch sizes for your hardware
- Use JAX transformations (vmap, jit) for efficiency
- Profile training with JAX profiling tools
- Monitor gradient health and stability
Troubleshooting¶
Common Issues¶
- NaN losses: Enable NaN detection in ErrorRecoveryManager
- Gradient explosions: Use gradient clipping with appropriate thresholds
- Slow convergence: Try different optimizers and learning rate schedules
- Physics constraint violations: Increase physics loss weights or improve residual computation
Debug Features¶
- Full logging of training metrics and errors
- Recovery attempt tracking for debugging stability issues
- Gradient norm monitoring for optimization health
- Physics constraint validation for PINN problems
Multilevel Training¶
Coarse-to-fine training hierarchies for accelerated convergence.
Width-Based Hierarchy (MLPs)¶
opifex.training.multilevel.coarse_to_fine
¶
CascadeTrainer
¶
CascadeTrainer(input_dim: int, output_dim: int, base_hidden_dims: Sequence[int], config: MultilevelConfig | None = None, *, activation: Callable[[Array], Array] = tanh, rngs: Rngs)
Cascade trainer for multilevel training.
Trains models from coarse to fine levels, transferring learned parameters between levels.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
Multilevel configuration |
|
hierarchy |
List of models from coarse to fine |
|
current_level |
Current training level |
Example
trainer = CascadeTrainer( ... input_dim=1, output_dim=1, ... base_hidden_dims=[64, 64], ... config=MultilevelConfig(num_levels=3), ... rngs=nnx.Rngs(0) ... ) model = trainer.get_current_model()
Train model...¶
trainer.advance_level() finer_model = trainer.get_current_model()
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension |
required |
output_dim
|
int
|
Output dimension |
required |
base_hidden_dims
|
Sequence[int]
|
Hidden dimensions for finest level |
required |
config
|
MultilevelConfig | None
|
Multilevel configuration |
None
|
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
get_current_model
¶
Get model at current level.
Returns:
| Type | Description |
|---|---|
MultilevelMLP
|
Current level model |
advance_level
¶
advance_level() -> bool
Advance to next finer level.
Transfers learned parameters from current level to next level.
Returns:
| Type | Description |
|---|---|
bool
|
True if advanced successfully, False if already at finest level |
MultilevelConfig
dataclass
¶
MultilevelConfig(num_levels: int = 3, coarsening_factor: float = 0.5, level_epochs: list[int] = (lambda: [100, 200, 300])(), warmup_epochs: int = 0)
Configuration for multilevel training.
Attributes:
| Name | Type | Description |
|---|---|---|
num_levels |
int
|
Number of levels in the hierarchy |
coarsening_factor |
float
|
Factor to reduce width at each coarser level |
level_epochs |
list[int]
|
Number of epochs to train at each level |
warmup_epochs |
int
|
Extra epochs at the finest level |
create_network_hierarchy
¶
create_network_hierarchy(input_dim: int, output_dim: int, base_hidden_dims: Sequence[int], num_levels: int, coarsening_factor: float = 0.5, *, activation: Callable[[Array], Array] = tanh, rngs: Rngs) -> list[MultilevelMLP]
Create hierarchy of networks from coarse to fine.
The finest level (highest index) uses the base_hidden_dims. Coarser levels use progressively smaller networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension |
required |
output_dim
|
int
|
Output dimension |
required |
base_hidden_dims
|
Sequence[int]
|
Hidden dimensions for the finest level |
required |
num_levels
|
int
|
Number of levels in hierarchy |
required |
coarsening_factor
|
float
|
Factor to reduce width at each coarser level |
0.5
|
activation
|
Callable[[Array], Array]
|
Activation function |
tanh
|
rngs
|
Rngs
|
Random number generators |
required |
Returns:
| Type | Description |
|---|---|
list[MultilevelMLP]
|
List of networks from coarsest to finest |
prolongate
¶
Transfer (prolongate) parameters from coarse to fine model.
This copies the coarse model parameters to the corresponding subset of the fine model parameters. Additional fine model parameters are left at their initialized values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coarse_model
|
MultilevelMLP
|
Coarse level model |
required |
fine_model
|
MultilevelMLP
|
Fine level model (will be modified in place) |
required |
Returns:
| Type | Description |
|---|---|
MultilevelMLP
|
Fine model with prolongated parameters |
restrict
¶
Transfer (restrict) parameters from fine to coarse model.
This copies a subset of the fine model parameters to the coarse model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fine_model
|
MultilevelMLP
|
Fine level model |
required |
coarse_model
|
MultilevelMLP
|
Coarse level model (will be modified in place) |
required |
Returns:
| Type | Description |
|---|---|
MultilevelMLP
|
Coarse model with restricted parameters |
Mode-Based Hierarchy (FNOs)¶
opifex.training.multilevel.multilevel_fno
¶
MultilevelFNOTrainer
¶
MultilevelFNOTrainer(width: int, input_dim: int, output_dim: int, config: MultilevelFNOConfig | None = None, *, num_layers: int = 4, activation: Callable[[Array], Array] = gelu, rngs: Rngs)
Trainer for multilevel FNO.
Trains FNOs from coarse to fine modes, transferring learned spectral weights between levels.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
Multilevel configuration |
|
hierarchy |
List of FNOs from coarse to fine |
|
current_level |
Current training level |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
width
|
int
|
Hidden channel width |
required |
input_dim
|
int
|
Input channels |
required |
output_dim
|
int
|
Output channels |
required |
config
|
MultilevelFNOConfig | None
|
Multilevel configuration |
None
|
num_layers
|
int
|
FNO layers per network |
4
|
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generators |
required |
get_current_model
¶
Get FNO at current level.
Returns:
| Type | Description |
|---|---|
SimpleFNO
|
Current level FNO |
advance_level
¶
advance_level() -> bool
Advance to next finer level.
Transfers learned weights from current level to next level.
Returns:
| Type | Description |
|---|---|
bool
|
True if advanced successfully, False if at finest level |
MultilevelFNOConfig
dataclass
¶
MultilevelFNOConfig(num_levels: int = 3, base_modes: int = 12, mode_reduction_factor: int = 2, level_epochs: list[int] = (lambda: [50, 100, 150])())
Configuration for multilevel FNO training.
Attributes:
| Name | Type | Description |
|---|---|---|
num_levels |
int
|
Number of levels in the hierarchy |
base_modes |
int
|
Number of Fourier modes at finest level |
mode_reduction_factor |
int
|
Factor to reduce modes at each coarser level |
level_epochs |
list[int]
|
Epochs to train at each level |
create_fno_hierarchy
¶
create_fno_hierarchy(base_modes: int, width: int, input_dim: int, output_dim: int, num_levels: int, reduction_factor: int = 2, *, num_layers: int = 4, activation: Callable[[Array], Array] = gelu, rngs: Rngs) -> list[SimpleFNO]
Create hierarchy of FNOs from coarse to fine.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_modes
|
int
|
Modes at finest level |
required |
width
|
int
|
Hidden channel width (same for all levels) |
required |
input_dim
|
int
|
Input channels |
required |
output_dim
|
int
|
Output channels |
required |
num_levels
|
int
|
Number of levels |
required |
reduction_factor
|
int
|
Mode reduction per level |
2
|
num_layers
|
int
|
FNO layers per network |
4
|
activation
|
Callable[[Array], Array]
|
Activation function |
gelu
|
rngs
|
Rngs
|
Random number generators |
required |
Returns:
| Type | Description |
|---|---|
list[SimpleFNO]
|
List of FNOs from coarsest to finest |
create_mode_hierarchy
¶
Create hierarchy of mode counts from coarse to fine.
The finest level (highest index) uses base_modes. Coarser levels use progressively fewer modes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_modes
|
int
|
Number of modes at finest level |
required |
num_levels
|
int
|
Number of levels in hierarchy |
required |
reduction_factor
|
int
|
Factor to reduce modes at each coarser level |
2
|
Returns:
| Type | Description |
|---|---|
list[int]
|
List of mode counts from coarsest to finest |
prolongate_fno_modes
¶
Transfer spectral weights from coarse to fine FNO.
Copies the lower-frequency modes from the coarse network to the corresponding modes in the fine network.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coarse_fno
|
SimpleFNO
|
Coarse FNO (fewer modes) |
required |
fine_fno
|
SimpleFNO
|
Fine FNO (more modes, modified in place) |
required |
Returns:
| Type | Description |
|---|---|
SimpleFNO
|
Fine FNO with prolongated weights |
restrict_fno_modes
¶
Transfer spectral weights from fine to coarse FNO.
Copies the lower-frequency modes from the fine network to the coarse network (truncation).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fine_fno
|
SimpleFNO
|
Fine FNO (more modes) |
required |
coarse_fno
|
SimpleFNO
|
Coarse FNO (fewer modes, modified in place) |
required |
Returns:
| Type | Description |
|---|---|
SimpleFNO
|
Coarse FNO with restricted weights |
For usage examples and best practices, see the Multilevel Training Guide.
Adaptive Sampling¶
Residual-based sampling strategies for efficient PINN training.
opifex.training.adaptive_sampling
¶
RADSampler
¶
RADSampler(config: RADConfig | None = None)
Residual-based Adaptive Distribution sampler.
This sampler draws collocation points from the domain with probability proportional to the PDE residual magnitude.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
RAD configuration |
Example
sampler = RADSampler() domain_points = jnp.linspace(0, 1, 100).reshape(-1, 1) residuals = compute_pde_residual(model, domain_points) key = jax.random.key(0) sampled = sampler.sample(domain_points, residuals, batch_size=32, key=key)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
RADConfig | None
|
RAD configuration. Uses defaults if None. |
None
|
sample
¶
sample(domain_points: Float[Array, ...], residuals: Float[Array, ...], batch_size: int, key: PRNGKeyArray) -> Float[Array, ...]
Sample collocation points based on residual distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
domain_points
|
Float[Array, ...]
|
Candidate points in the domain |
required |
residuals
|
Float[Array, ...]
|
PDE residual magnitudes at each point |
required |
batch_size
|
int
|
Number of points to sample |
required |
key
|
PRNGKeyArray
|
JAX random key |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Sampled collocation points |
compute_weights
¶
Compute importance weights for residual-weighted loss.
These weights can be used to weight the loss function instead of resampling the collocation points.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residuals
|
Float[Array, ...]
|
PDE residual magnitudes |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Importance weights |
RADConfig
dataclass
¶
RADConfig(beta: float = 1.0, resample_frequency: int = 100, min_probability: float = 1e-06, temperature: float = 1.0)
Configuration for Residual-based Adaptive Distribution sampling.
Attributes:
| Name | Type | Description |
|---|---|---|
beta |
float
|
Exponent for residual weighting. Higher values concentrate sampling more strongly on high-residual regions. ξ_j = |r_j|^β / Σ_k |r_k|^β |
resample_frequency |
int
|
Number of training steps between resampling |
min_probability |
float
|
Minimum sampling probability to ensure coverage |
temperature |
float
|
Temperature for probability smoothing |
RARDRefiner
¶
RARDRefiner(config: RARDConfig | None = None, num_new_points: int | None = None, noise_scale: float | None = None)
Residual-based Adaptive Refinement with Distribution.
This refiner adds new collocation points near regions with high PDE residual, adaptively increasing resolution where needed.
Attributes:
| Name | Type | Description |
|---|---|---|
config |
RAR-D configuration |
Example
refiner = RARDRefiner(num_new_points=20) refined_points = refiner.refine(points, residuals, bounds, key)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
RARDConfig | None
|
RAR-D configuration. Uses defaults if None. |
None
|
num_new_points
|
int | None
|
Override for number of new points |
None
|
noise_scale
|
float | None
|
Override for noise scale |
None
|
refine
¶
refine(current_points: Float[Array, ...], residuals: Float[Array, ...], bounds: Float[Array, 'dim 2'], key: PRNGKeyArray) -> Float[Array, ...]
Add new points near high-residual regions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
current_points
|
Float[Array, ...]
|
Current collocation points |
required |
residuals
|
Float[Array, ...]
|
PDE residual magnitudes at each point |
required |
bounds
|
Float[Array, 'dim 2']
|
Domain bounds, shape (dim, 2) with [min, max] |
required |
key
|
PRNGKeyArray
|
JAX random key |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Refined point set including new points |
identify_refinement_regions
¶
Identify which points are in refinement regions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residuals
|
Float[Array, ...]
|
PDE residual magnitudes |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Boolean mask indicating refinement regions |
RARDConfig
dataclass
¶
compute_sampling_distribution
¶
compute_sampling_distribution(residuals: Float[Array, ...], beta: float = 1.0, min_probability: float = 1e-06) -> Float[Array, ...]
Compute sampling distribution from residual magnitudes.
The sampling probability for each point is proportional to the residual magnitude raised to the power beta: p_j = |r_j|^β / Σ_k |r_k|^β
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residuals
|
Float[Array, ...]
|
PDE residual magnitudes at each collocation point |
required |
beta
|
float
|
Exponent for residual weighting |
1.0
|
min_probability
|
float
|
Minimum probability to ensure all points have some chance of being sampled |
1e-06
|
Returns:
| Type | Description |
|---|---|
Float[Array, ...]
|
Sampling probabilities that sum to 1 |
For detailed algorithms and best practices, see the Adaptive Sampling Guide.