Training Infrastructure Guide¶
Overview¶
The Opifex training framework provides full, production-ready training infrastructure for scientific machine learning models. Built on JAX and FLAX NNX, it supports physics-informed neural networks (PINNs), neural operators, quantum neural networks, and traditional supervised learning with advanced optimization algorithms and physics-aware loss functions.
The training system is designed with modularity and extensibility in mind, featuring component-based architecture, advanced error recovery, and sophisticated metrics collection for scientific computing applications.
Core Training Components¶
Unified Trainer Architecture ⭐ RECOMMENDED¶
from opifex.core.training.trainer import Trainer
from opifex.core.training.config import TrainingConfig
from opifex.core.training.config import QuantumTrainingConfig
from opifex.core.training.physics_configs import ConservationConfig
from opifex.neural.base import StandardMLP
import jax.numpy as jnp
import jax
# Create model (requires rngs for NNX initialization)
rngs = nnx.Rngs(jax.random.PRNGKey(0))
model = StandardMLP(
layer_sizes=[2, 50, 50, 1],
activation="tanh",
rngs=rngs
)
# Configure physics-aware training with composable configs
conservation_config = ConservationConfig(
laws=["energy", "momentum"],
energy_tolerance=1e-6,
momentum_tolerance=1e-6,
)
quantum_config = QuantumTrainingConfig(
chemical_accuracy_target=1e-3,
scf_max_iterations=100,
enable_symmetry_enforcement=True,
)
config = TrainingConfig(
num_epochs=1000,
batch_size=256,
learning_rate=1e-3,
validation_frequency=100,
checkpoint_frequency=100,
conservation_config=conservation_config,
quantum_config=quantum_config,
)
# Initialize trainer
trainer = Trainer(model, config)
# Prepare training data
key = jax.random.PRNGKey(42)
x_train = jax.random.uniform(key, (1000, 2), minval=-1.0, maxval=1.0)
y_train = jnp.sin(jnp.pi * x_train[:, 0]) * jnp.cos(jnp.pi * x_train[:, 1])
# Train the model
model, history = trainer.fit(
train_data=(x_train, y_train),
val_data=(x_train[:200], y_train[:200])
)
print(f"Training completed in {len(history['train_losses'])} epochs")
print(f"Final training loss: {history['train_losses'][-1]:.6f}")
Key Advantages:
- Composable: Mix and match physics configurations without modifying trainer code
- Type-Safe: Full IDE support with full type hints
- Zero Runtime Overhead: All configuration at initialization
- Extensible: Add new configs without changing existing code
- Well-Tested: 88 full tests covering all functionality
Available Physics Configurations:
ConservationConfig: Energy, momentum, mass, and symmetry conservationMultiScaleConfig: Multi-scale physics with adaptive couplingQuantumTrainingConfig: Quantum chemistry and electronic structureBoundaryConfig: Boundary condition enforcementDFTConfig: Density functional theory workflowsSCFConfig: Self-consistent field convergenceMetricsTrackingConfig: Custom metrics trackingLoggingConfig: Advanced logging and alertingPerformanceConfig: Performance optimization settings
Basic Training Infrastructure¶
The BasicTrainer class provides a complete training framework with physics-informed capabilities:
from opifex.training.basic_trainer import BasicTrainer
from opifex.core.training.config import TrainingConfig
from opifex.neural.base import StandardMLP
import jax.numpy as jnp
import jax
# Create a neural network model
rngs = nnx.Rngs(jax.random.PRNGKey(0))
model = StandardMLP(
layer_sizes=[2, 50, 50, 1],
activation="tanh",
rngs=rngs
)
# Configure training parameters
# Note: optimizer, early_stopping_patience, and weight_decay live in sub-configs
# (optimization_config, validation_config), not at the top level of TrainingConfig
config = TrainingConfig(
learning_rate=1e-3,
num_epochs=1000,
batch_size=256,
validation_frequency=100,
checkpoint_frequency=100
)
# Initialize trainer
trainer = BasicTrainer(
model=model,
training_config=config
)
# Prepare training data
key = jax.random.PRNGKey(42)
x_train = jax.random.uniform(key, (1000, 2), minval=-1.0, maxval=1.0)
y_train = jnp.sin(jnp.pi * x_train[:, 0]) * jnp.cos(jnp.pi * x_train[:, 1])
# Train the model
history = trainer.train(
train_data=(x_train, y_train),
validation_data=(x_train[:200], y_train[:200])
)
print(f"Training completed in {len(history.train_losses)} epochs")
print(f"Final training loss: {history.train_losses[-1]:.6f}")
print(f"Final validation loss: {history.val_losses[-1]:.6f}")
Advanced Modular Training Architecture¶
For complex scientific applications, the ModularTrainer provides a component-based architecture:
from opifex.training.basic_trainer import ModularTrainer
from opifex.training.recovery import ErrorRecoveryManager
from opifex.training.components import FlexibleOptimizerFactory
from opifex.training.metrics import AdvancedMetricsCollector
# Configure advanced training components
error_recovery = ErrorRecoveryManager(
config={
"max_retries": 3,
"checkpoint_on_error": True,
"gradient_clip_threshold": 10.0,
"loss_explosion_threshold": 1e6,
"learning_rate": 1e-3
}
)
optimizer_factory = FlexibleOptimizerFactory(
config={
"optimizer_type": "adamw",
"learning_rate": 1e-3,
"weight_decay": 1e-4,
"use_schedule": True,
"schedule_type": "cosine",
"total_steps": 10000
}
)
metrics_collector = AdvancedMetricsCollector()
# Create modular trainer with custom components
modular_trainer = ModularTrainer(
model=model,
config=config,
components={
"error_recovery": error_recovery,
"optimizer_factory": optimizer_factory,
"metrics_collector": metrics_collector
}
)
# Train with advanced error handling and metrics
advanced_history = modular_trainer.train(
train_data=(x_train, y_train),
validation_data=(x_train[:200], y_train[:200])
)
print("Advanced modular training completed with enhanced error recovery")
Physics-Informed Neural Networks (PINNs)¶
Basic PINN Training¶
Physics-informed training incorporates physical laws directly into the loss function:
from opifex.core.physics.losses import PhysicsInformedLoss, PhysicsLossConfig
from opifex.core.problems import PDEProblem
from opifex.core.conditions import DirichletBC
from opifex.geometry import Rectangle
# Define a PDE problem (2D Poisson equation)
class PoissonProblem(PDEProblem):
def __init__(self):
geometry = Rectangle(
center=jnp.array([0.5, 0.5]), width=1.0, height=1.0
)
boundary_conditions = [
DirichletBC(boundary="left", value=0.0),
DirichletBC(boundary="right", value=0.0),
DirichletBC(boundary="top", value=0.0),
DirichletBC(boundary="bottom", value=0.0)
]
super().__init__(
geometry=geometry,
equation=self._poisson_equation,
boundary_conditions=boundary_conditions
)
def residual(self, x, u, u_derivatives):
"""Poisson equation: ∇²u = f(x,y)"""
u_xx = u_derivatives["xx"]
u_yy = u_derivatives["yy"]
# Source term
x_coord, y_coord = x[..., 0], x[..., 1]
source = -2 * jnp.pi**2 * jnp.sin(jnp.pi * x_coord) * jnp.sin(jnp.pi * y_coord)
return u_xx + u_yy - source
# Configure physics-informed loss
physics_config = PhysicsLossConfig(
data_loss_weight=1.0,
physics_loss_weight=1.0,
boundary_loss_weight=10.0,
adaptive_weighting=True
)
physics_loss = PhysicsInformedLoss(
config=physics_config,
equation_type="poisson",
domain_type="rectangular",
)
# Create PINN trainer
pinn_trainer = BasicTrainer(
model=model,
training_config=config,
physics_loss=physics_loss
)
# Generate collocation points for physics loss
poisson_problem = PoissonProblem()
key = jax.random.PRNGKey(123)
# Interior collocation points
x_physics = jax.random.uniform(key, (2000, 2), minval=0.0, maxval=1.0)
# Boundary points
x_boundary = jnp.concatenate([
jnp.column_stack([jnp.zeros(100), jnp.linspace(0, 1, 100)]), # Left
jnp.column_stack([jnp.ones(100), jnp.linspace(0, 1, 100)]), # Right
jnp.column_stack([jnp.linspace(0, 1, 100), jnp.zeros(100)]), # Bottom
jnp.column_stack([jnp.linspace(0, 1, 100), jnp.ones(100)]) # Top
])
u_boundary = jnp.zeros(len(x_boundary))
# Train PINN
pinn_history = pinn_trainer.train(
collocation_points=x_physics,
boundary_data=(x_boundary, u_boundary),
problem=poisson_problem
)
print(f"PINN training completed")
print(f"Final physics loss: {pinn_history.physics_losses[-1]:.6f}")
print(f"Final boundary loss: {pinn_history.boundary_losses[-1]:.6f}")
Neural Operator Training¶
Fourier Neural Operator (FNO) Training¶
from opifex.neural.operators.fno import FourierNeuralOperator
from opifex.training.basic_trainer import BasicTrainer
from opifex.core.training.config import TrainingConfig
# Create FNO model for operator learning
rngs = nnx.Rngs(jax.random.PRNGKey(0))
fno_model = FourierNeuralOperator(
in_channels=2, # Input function channels
out_channels=1, # Output function channels
hidden_channels=64, # Hidden dimension
modes=16, # Fourier modes to keep
num_layers=4, # Number of Fourier layers
rngs=rngs
)
# Generate operator training data (input-output function pairs)
def generate_operator_data(n_samples=1000, resolution=64):
"""Generate training data for operator learning."""
key = jax.random.PRNGKey(456)
# Input functions (random Gaussian random fields)
x = jnp.linspace(0, 1, resolution)
y = jnp.linspace(0, 1, resolution)
X, Y = jnp.meshgrid(x, y, indexing='ij')
input_functions = []
output_functions = []
for i in range(n_samples):
# Random input function
key, subkey = jax.random.split(key)
coeffs = jax.random.normal(subkey, (8, 8))
input_func = jnp.zeros((resolution, resolution))
for kx in range(8):
for ky in range(8):
input_func += coeffs[kx, ky] * jnp.sin(
2 * jnp.pi * kx * X
) * jnp.sin(2 * jnp.pi * ky * Y)
# Corresponding output function (solve PDE)
output_func = solve_pde_with_input(input_func, X, Y)
input_functions.append(input_func)
output_functions.append(output_func)
return jnp.stack(input_functions), jnp.stack(output_functions)
def solve_pde_with_input(input_func, X, Y):
"""Solve PDE with given input function (simplified)."""
# This would typically involve a numerical PDE solver
# For demonstration, we use a simple transformation
return jnp.fft.fft2(input_func).real
# Generate training data
input_funcs, output_funcs = generate_operator_data(n_samples=500)
# Configure FNO training
# Note: optimizer selection is handled via optimization_config, not a top-level string
fno_config = TrainingConfig(
learning_rate=1e-3,
num_epochs=200,
batch_size=16, # Smaller batch size for function data
validation_frequency=20
)
# Train FNO
fno_trainer = Trainer(model=fno_model, config=fno_config)
fno_model, fno_history = fno_trainer.fit(
train_data=(input_funcs[:400], output_funcs[:400]),
val_data=(input_funcs[400:], output_funcs[400:])
)
print(f"FNO training completed")
print(f"Final training loss: {fno_history.train_losses[-1]:.6f}")
DeepONet Training¶
from opifex.neural.operators.deeponet import DeepONet
# Create DeepONet model
rngs = nnx.Rngs(jax.random.PRNGKey(0))
deeponet_model = DeepONet(
branch_sizes=[100, 100, 100], # Branch network architecture
trunk_sizes=[2, 100, 100, 100], # Trunk network architecture (2D input)
rngs=rngs
)
# Generate DeepONet training data
def generate_deeponet_data(n_samples=1000, n_sensors=100):
"""Generate training data for DeepONet."""
key = jax.random.PRNGKey(789)
# Sensor locations (fixed)
sensor_locations = jnp.linspace(0, 1, n_sensors)
# Query locations (variable)
query_locations = jax.random.uniform(key, (n_samples, 2))
branch_inputs = [] # Function values at sensors
trunk_inputs = [] # Query coordinates
outputs = [] # Function values at query points
for i in range(n_samples):
# Random function (polynomial)
key, subkey = jax.random.split(key)
coeffs = jax.random.normal(subkey, (5,))
# Function values at sensor locations
sensor_values = jnp.sum(
coeffs[:, None] * sensor_locations[None, :]**jnp.arange(5)[:, None],
axis=0
)
# Function value at query location
query_x, query_y = query_locations[i]
query_value = jnp.sum(coeffs * query_x**jnp.arange(5)) * jnp.sin(jnp.pi * query_y)
branch_inputs.append(sensor_values)
trunk_inputs.append(query_locations[i])
outputs.append(query_value)
return (
jnp.stack(branch_inputs),
jnp.stack(trunk_inputs),
jnp.array(outputs)
)
# Generate DeepONet training data
branch_data, trunk_data, target_data = generate_deeponet_data(n_samples=2000)
# Train DeepONet
deeponet_trainer = BasicTrainer(model=deeponet_model, training_config=fno_config)
deeponet_history = deeponet_trainer.train(
train_data=((branch_data[:1600], trunk_data[:1600]), target_data[:1600]),
validation_data=((branch_data[1600:], trunk_data[1600:]), target_data[1600:])
)
print(f"DeepONet training completed")
print(f"Final training loss: {deeponet_history.train_losses[-1]:.6f}")
Advanced Optimization Strategies¶
Learning Rate Scheduling¶
import optax
def create_advanced_scheduler(base_lr=1e-3, total_steps=10000):
"""Create sophisticated learning rate schedule."""
# Warmup phase
warmup_steps = int(0.1 * total_steps)
warmup_schedule = optax.linear_schedule(
init_value=1e-6,
end_value=base_lr,
transition_steps=warmup_steps
)
# Cosine annealing with restarts
cosine_steps = total_steps - warmup_steps
cosine_schedule = optax.cosine_decay_schedule(
init_value=base_lr,
decay_steps=cosine_steps,
alpha=0.1 # Minimum learning rate factor
)
# Combine schedules
combined_schedule = optax.join_schedules(
schedules=[warmup_schedule, cosine_schedule],
boundaries=[warmup_steps]
)
return combined_schedule
# Use advanced scheduling in training
# Note: optimizer type and weight_decay are configured via optimization_config sub-config
advanced_config = TrainingConfig(
learning_rate=1e-3,
num_epochs=100,
batch_size=64
)
Gradient Clipping and Regularization¶
class RegularizedTrainer(BasicTrainer):
"""Trainer with advanced regularization techniques."""
def __init__(self, model, config, regularization_config=None, **kwargs):
super().__init__(model, config, **kwargs)
self.reg_config = regularization_config or {}
# Configure gradient clipping
self.gradient_clip_value = self.reg_config.get("gradient_clip", 1.0)
# Regularization weights
self.l1_weight = self.reg_config.get("l1_weight", 0.0)
self.l2_weight = self.reg_config.get("l2_weight", 1e-4)
self.spectral_norm_weight = self.reg_config.get("spectral_norm", 0.0)
def compute_regularization_loss(self, params):
"""Compute various regularization terms."""
reg_loss = 0.0
# L1 regularization
if self.l1_weight > 0:
l1_loss = sum(jnp.sum(jnp.abs(p)) for p in jax.tree_leaves(params))
reg_loss += self.l1_weight * l1_loss
# L2 regularization
if self.l2_weight > 0:
l2_loss = sum(jnp.sum(p**2) for p in jax.tree_leaves(params))
reg_loss += self.l2_weight * l2_loss
return reg_loss
# Use regularized training
reg_config = {
"gradient_clip": 1.0,
"l1_weight": 1e-5,
"l2_weight": 1e-4,
"spectral_norm": 1e-3
}
regularized_trainer = RegularizedTrainer(
model=model,
config=config,
regularization_config=reg_config
)
Monitoring, Visualization, and Checkpointing¶
Advanced Metrics Collection¶
from opifex.training.metrics import AdvancedMetricsCollector
import matplotlib.pyplot as plt
class ComprehensiveMetricsCollector(AdvancedMetricsCollector):
"""Enhanced metrics collection with physics-aware diagnostics."""
def __init__(self):
super().__init__()
self.physics_metrics = {}
self.convergence_metrics = {}
self.gradient_metrics = {}
def collect_physics_metrics(self, params, batch, model, problem=None):
"""Collect physics-specific metrics."""
if problem is None:
return
# Physics residual statistics
x_physics = batch[0] if len(batch) > 0 else None
if x_physics is not None:
def network_fn(x):
return model.apply(params, x)
u_pred = network_fn(x_physics)
u_derivatives = self._compute_derivatives(network_fn, x_physics)
residuals = problem.residual(x_physics, u_pred, u_derivatives)
self.physics_metrics.update({
"residual_mean": jnp.mean(jnp.abs(residuals)),
"residual_max": jnp.max(jnp.abs(residuals)),
"residual_std": jnp.std(residuals)
})
def collect_gradient_metrics(self, gradients):
"""Collect gradient-based metrics."""
grad_norms = [jnp.linalg.norm(g) for g in jax.tree_leaves(gradients)]
self.gradient_metrics.update({
"grad_norm_mean": jnp.mean(jnp.array(grad_norms)),
"grad_norm_max": jnp.max(jnp.array(grad_norms)),
"grad_norm_total": jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(gradients)))
})
# Use full metrics
comprehensive_metrics = ComprehensiveMetricsCollector()
Real-Time Visualization¶
class TrainingVisualizer:
"""Real-time training visualization."""
def __init__(self, update_frequency=10):
self.update_frequency = update_frequency
self.fig, self.axes = plt.subplots(2, 2, figsize=(12, 8))
self.loss_history = {"train": [], "val": [], "physics": [], "boundary": []}
self.metrics_history = {}
plt.ion() # Interactive mode
def update_plots(self, epoch, current_losses, current_metrics):
"""Update all visualization plots."""
# Update loss history
for key, value in current_losses.items():
if key in self.loss_history:
self.loss_history[key].append(value)
# Update metrics history
for key, value in current_metrics.items():
if key not in self.metrics_history:
self.metrics_history[key] = []
self.metrics_history[key].append(value)
if epoch % self.update_frequency == 0:
self._redraw_plots(epoch)
# Use visualization during training
visualizer = TrainingVisualizer(update_frequency=5)
Robust Checkpointing System¶
import orbax.checkpoint as ocp
from pathlib import Path
import time
class AdvancedCheckpointManager:
"""Advanced checkpointing with metadata and recovery."""
def __init__(self, checkpoint_dir, max_to_keep=5):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_to_keep = max_to_keep
# Initialize Orbax checkpoint manager
self.manager = ocp.CheckpointManager(
self.checkpoint_dir,
max_to_keep=max_to_keep,
item_names=("model_state", "optimizer_state", "metadata")
)
def save_checkpoint(self, epoch, model_state, optimizer_state,
training_metrics, physics_metrics=None):
"""Save full checkpoint with metadata."""
# Prepare metadata
metadata = {
"epoch": epoch,
"training_metrics": training_metrics,
"physics_metrics": physics_metrics or {},
"timestamp": time.time(),
"model_info": {
"architecture": type(model_state).__name__,
"parameter_count": sum(
p.size for p in jax.tree_leaves(model_state) if hasattr(p, 'size')
)
}
}
# Save checkpoint
checkpoint_data = {
"model_state": model_state,
"optimizer_state": optimizer_state,
"metadata": metadata
}
self.manager.save(epoch, checkpoint_data)
print(f"Checkpoint saved at epoch {epoch}")
def load_checkpoint(self, epoch=None):
"""Load checkpoint with automatic recovery."""
try:
if epoch is None:
# Load latest checkpoint
latest_step = self.manager.latest_step()
if latest_step is None:
return None
epoch = latest_step
checkpoint_data = self.manager.restore(epoch)
print(f"Checkpoint loaded from epoch {epoch}")
return checkpoint_data
except Exception as e:
print(f"Failed to load checkpoint: {e}")
return None
# Use advanced checkpointing
checkpoint_manager = AdvancedCheckpointManager(
checkpoint_dir="./checkpoints/advanced_training",
max_to_keep=10
)
print("Full training infrastructure guide completed")
This thorough training guide provides the complete infrastructure for advanced scientific machine learning training. The modular, component-based architecture enables researchers to build sophisticated training workflows while maintaining the flexibility needed for modern scientific applications.
Advanced Training Techniques¶
Multilevel Training¶
Multilevel training accelerates convergence by training from coarse to fine representations, leveraging multigrid insights for neural network optimization.
from opifex.training.multilevel import CascadeTrainer, MultilevelConfig
# Configure coarse-to-fine training
config = MultilevelConfig(
num_levels=3,
coarsening_factor=0.5,
level_epochs=[100, 200, 500],
)
trainer = CascadeTrainer(
input_dim=2,
output_dim=1,
base_hidden_dims=[64, 64],
config=config,
rngs=nnx.Rngs(0),
)
# Train through hierarchy
while not trainer.is_at_finest():
model = trainer.get_current_model()
# ... train current level ...
trainer.advance_level()
Key Benefits:
- Faster convergence through hierarchical initialization
- Better optimization landscape via progressive capacity
- Natural curriculum from simple to complex representations
For full details on MLP and FNO hierarchies, see the Multilevel Training Guide.
Adaptive Sampling¶
Adaptive sampling focuses computational resources on high-residual regions, improving training efficiency for PINNs:
from opifex.training.adaptive_sampling import RADSampler, RADConfig
# Configure residual-based sampling
config = RADConfig(
beta=1.0, # Residual exponent
resample_frequency=100, # Steps between resampling
)
sampler = RADSampler(config)
# During training
residuals = compute_pde_residual(model, all_points)
batch = sampler.sample(all_points, residuals, batch_size=256, key=key)
Strategies Available:
- RAD: Samples with probability proportional to residual magnitude
- RAR-D: Progressively adds points near high-residual regions
For detailed algorithms and best practices, see the Adaptive Sampling Guide.
GradNorm Loss Balancing¶
For multi-task learning with multiple loss terms, GradNorm automatically balances gradient magnitudes:
from opifex.core.physics.gradnorm import GradNormBalancer, GradNormConfig
config = GradNormConfig(
alpha=1.5, # Asymmetry parameter
learning_rate=0.01, # Weight update rate
)
balancer = GradNormBalancer(num_losses=3, config=config, rngs=nnx.Rngs(0))
# Compute weighted loss
losses = jnp.array([pde_loss, bc_loss, data_loss])
weighted_loss = balancer.compute_weighted_loss(losses)
Benefits:
- Prevents any single loss from dominating training
- Encourages uniform convergence across all objectives
- Adapts weights dynamically based on training progress
For the complete algorithm and configuration options, see the GradNorm Guide.
See Also¶
- Multilevel Training - Coarse-to-fine training hierarchies
- Adaptive Sampling - RAD and RAR-D strategies
- GradNorm - Multi-task loss balancing
- NTK Analysis - Training diagnostics via spectral analysis
- Second-Order Optimization - L-BFGS and hybrid optimizers