Second-Order Optimization¶
Second-order optimization methods leverage curvature information (Hessian or approximations) for faster convergence, particularly beneficial in the later stages of physics-informed neural network training.
Overview¶
Second-order methods offer significant advantages for PINN training:
- Faster convergence near optima due to curvature information
- Better handling of ill-conditioned loss landscapes
- Reduced sensitivity to learning rate selection
- More effective in smooth regions of the loss landscape
Survey Reference
This implementation follows recommendations from Section 7 of the PINN survey (arXiv:2601.10222v1).
Methods¶
L-BFGS¶
L-BFGS (Limited-memory BFGS) approximates the inverse Hessian using a limited history of gradient differences, making it suitable for large-scale optimization.
from opifex.optimization.second_order import (
create_lbfgs_optimizer,
LBFGSConfig,
LinesearchType,
)
# Configure L-BFGS
config = LBFGSConfig(
memory_size=10, # Gradient pairs to store (typically 3-20)
scale_init_precond=True, # Scale initial preconditioner
linesearch=LinesearchType.ZOOM, # Line search algorithm
max_linesearch_steps=20, # Max line search iterations
max_iterations=100, # Max L-BFGS iterations
tolerance=1e-6, # Convergence tolerance
)
# Create optimizer
optimizer = create_lbfgs_optimizer(config)
Line Search Options:
| Algorithm | Description | Best For |
|---|---|---|
ZOOM |
Strong Wolfe conditions with zoom | General use, guaranteed descent |
BACKTRACKING |
Simple Armijo backtracking | Faster per-step, less robust |
Using L-BFGS with optax:
import jax
import jax.numpy as jnp
import optax
from flax import nnx
# Define loss function
def loss_fn(model, x, y_true):
y_pred = model(x)
return jnp.mean((y_pred - y_true) ** 2)
# L-BFGS requires value_and_grad_fn for line search
def value_and_grad_fn(params, model_template, x, y_true):
def loss(params):
model = model_template.replace(params=params)
return loss_fn(model, x, y_true)
return jax.value_and_grad(loss)(params)
# Training with L-BFGS
optimizer = create_lbfgs_optimizer()
opt_state = optimizer.init(params)
for step in range(num_steps):
loss, grads = value_and_grad_fn(params, model, x, y_true)
updates, opt_state = optimizer.update(
grads, opt_state, params,
value=loss,
grad=grads,
value_fn=lambda p: loss_fn(model.replace(params=p), x, y_true),
)
params = optax.apply_updates(params, updates)
Gauss-Newton¶
Gauss-Newton is effective for nonlinear least-squares problems, approximating the Hessian using only first derivatives.
from opifex.optimization.second_order import (
create_gauss_newton_solver,
GaussNewtonConfig,
)
import optimistix as optx
# Configure solver
config = GaussNewtonConfig(
max_iterations=100,
rtol=1e-6,
atol=1e-6,
)
# Create solver
solver = create_gauss_newton_solver(config)
# Use with optimistix for least-squares
def residual_fn(params, args):
"""Residual function for least-squares."""
model = args['model']
x, y_true = args['x'], args['y_true']
y_pred = model(x)
return y_pred - y_true
# Solve
result = optx.least_squares(
residual_fn,
solver=solver,
y0=initial_params,
args={'model': model, 'x': x, 'y_true': y_true},
)
optimal_params = result.value
Levenberg-Marquardt¶
Levenberg-Marquardt adds damping to Gauss-Newton for improved robustness, especially when far from the optimum.
from opifex.optimization.second_order import (
create_levenberg_marquardt_solver,
GaussNewtonConfig,
)
# Configure with damping parameters
config = GaussNewtonConfig(
damping_factor=1e-3, # Initial damping (λ)
damping_increase_factor=10.0, # Factor to increase on failure
damping_decrease_factor=0.1, # Factor to decrease on success
min_damping=1e-10, # Minimum damping
max_damping=1e10, # Maximum damping
max_iterations=100,
rtol=1e-6,
atol=1e-6,
)
solver = create_levenberg_marquardt_solver(config)
Damping Behavior:
- Large damping: Behaves like gradient descent (robust, slow)
- Small damping: Behaves like Gauss-Newton (fast, less robust)
- Adaptive: Increases damping on failed steps, decreases on success
BFGS¶
Full-memory BFGS for smaller-scale problems where storing the complete inverse Hessian approximation is feasible.
from opifex.optimization.second_order import (
create_bfgs_solver,
GaussNewtonConfig,
)
import optimistix as optx
solver = create_bfgs_solver(
GaussNewtonConfig(rtol=1e-6, atol=1e-6)
)
# Use with optimistix minimise
result = optx.minimise(
loss_fn,
solver=solver,
y0=initial_params,
)
Hybrid Adam to L-BFGS Optimizer¶
The hybrid optimizer combines Adam's robustness in early training with L-BFGS's efficiency for final convergence.
Insight
L-BFGS is more effective in later stages when loss varies smoothly.
from opifex.optimization.second_order import (
HybridOptimizer,
HybridOptimizerConfig,
SwitchCriterion,
LBFGSConfig,
)
import optax
# Configure hybrid optimizer
config = HybridOptimizerConfig(
# Adam phase
first_order_steps=1000, # Steps before considering switch
adam_learning_rate=1e-3,
adam_b1=0.9,
adam_b2=0.999,
# Switching criterion
switch_criterion=SwitchCriterion.LOSS_VARIANCE,
loss_variance_threshold=1e-4,
loss_history_window=50,
# L-BFGS phase
lbfgs_config=LBFGSConfig(
memory_size=10,
max_linesearch_steps=20,
),
)
# Create and use optimizer
optimizer = HybridOptimizer(config)
state = optimizer.init(params)
# Training loop
for step in range(num_steps):
loss, grads = jax.value_and_grad(loss_fn)(params)
updates, state = optimizer.update(
grads, state, params,
loss=loss,
value_fn=lambda p: loss_fn(p),
)
params = optax.apply_updates(params, updates)
# Check current optimizer mode
if state.switched:
print(f"Step {step}: Using L-BFGS")
Switch Criteria:
| Criterion | Description | When to Use |
|---|---|---|
EPOCH |
Switch after fixed steps | Simple, predictable |
LOSS_VARIANCE |
Switch when loss variance drops | Detects smooth regions |
GRADIENT_NORM |
Switch when gradients are small | Near convergence |
RELATIVE_IMPROVEMENT |
Switch when improvement slows | Adaptive to progress |
# Example: Gradient norm-based switching
config = HybridOptimizerConfig(
first_order_steps=500,
switch_criterion=SwitchCriterion.GRADIENT_NORM,
gradient_norm_threshold=1e-3,
)
# Example: Relative improvement-based switching
config = HybridOptimizerConfig(
first_order_steps=500,
switch_criterion=SwitchCriterion.RELATIVE_IMPROVEMENT,
relative_improvement_threshold=1e-4,
)
NNX Integration¶
The second-order optimizers integrate seamlessly with FLAX NNX models.
from flax import nnx
from opifex.optimization.second_order import HybridOptimizer, HybridOptimizerConfig
# Create NNX model
class MyPINN(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.layers = nnx.List([
nnx.Linear(2, 64, rngs=rngs),
nnx.Linear(64, 64, rngs=rngs),
nnx.Linear(64, 1, rngs=rngs),
])
def __call__(self, x):
for layer in list(self.layers)[:-1]:
x = nnx.tanh(layer(x))
return list(self.layers)[-1](x)
model = MyPINN(rngs=nnx.Rngs(0))
# Define loss with NNX
def loss_fn(model):
predictions = model(x_collocation)
residual = compute_pde_residual(model, x_collocation)
return jnp.mean(residual ** 2)
# Training with hybrid optimizer
optimizer = HybridOptimizer(HybridOptimizerConfig())
graphdef, state = nnx.split(model)
opt_state = optimizer.init(state)
for step in range(num_steps):
loss, grads = nnx.value_and_grad(loss_fn)(model)
state = nnx.state(model)
updates, opt_state = optimizer.update(
grads, opt_state, state,
loss=loss,
)
# Apply updates
new_state = optax.apply_updates(state, updates)
nnx.update(model, new_state)
Configuration Reference¶
LBFGSConfig¶
@dataclass(frozen=True)
class LBFGSConfig:
memory_size: int = 10 # Gradient pairs to store
scale_init_precond: bool = True # Scale initial preconditioner
linesearch: LinesearchType = LinesearchType.ZOOM
max_linesearch_steps: int = 20
max_iterations: int = 100
tolerance: float = 1e-6
GaussNewtonConfig¶
@dataclass(frozen=True)
class GaussNewtonConfig:
damping_factor: float = 1e-3
damping_increase_factor: float = 10.0
damping_decrease_factor: float = 0.1
min_damping: float = 1e-10
max_damping: float = 1e10
max_iterations: int = 100
rtol: float = 1e-6
atol: float = 1e-6
HybridOptimizerConfig¶
@dataclass(frozen=True)
class HybridOptimizerConfig:
first_order_steps: int = 1000
switch_criterion: SwitchCriterion = SwitchCriterion.LOSS_VARIANCE
loss_variance_threshold: float = 1e-4
loss_history_window: int = 50
gradient_norm_threshold: float = 1e-3
relative_improvement_threshold: float = 1e-4
adam_learning_rate: float = 1e-3
adam_b1: float = 0.9
adam_b2: float = 0.999
lbfgs_config: LBFGSConfig = field(default_factory=LBFGSConfig)
Best Practices¶
When to Use Second-Order Methods¶
| Scenario | Recommended Method |
|---|---|
| General PINN training | Hybrid Adam→L-BFGS |
| Well-conditioned problems | Pure L-BFGS |
| Least-squares formulation | Gauss-Newton or LM |
| Ill-conditioned problems | Levenberg-Marquardt |
| Small models (< 10K params) | Full BFGS |
Tuning L-BFGS Memory Size¶
# Small memory (3-5): Less memory, faster iterations
LBFGSConfig(memory_size=5)
# Medium memory (10-15): Good balance (default)
LBFGSConfig(memory_size=10)
# Large memory (20+): Better approximation, more memory
LBFGSConfig(memory_size=20)
Handling Convergence Issues¶
# If L-BFGS oscillates, increase line search steps
config = LBFGSConfig(
max_linesearch_steps=40, # Increased from 20
)
# If hybrid switch is too early, increase first_order_steps
config = HybridOptimizerConfig(
first_order_steps=2000, # More Adam steps
)
# If convergence stalls, try different criterion
config = HybridOptimizerConfig(
switch_criterion=SwitchCriterion.GRADIENT_NORM,
gradient_norm_threshold=1e-4,
)
Batch Size Considerations¶
Second-order methods typically work best with:
- Full batch: Most accurate gradient/curvature estimates
- Large mini-batch: Good balance of noise and efficiency
- Small mini-batch: May cause instability in L-BFGS
# For stochastic settings, use more conservative L-BFGS
config = LBFGSConfig(
memory_size=5, # Smaller memory
max_linesearch_steps=10, # Fewer line search steps
)
See Also¶
- Optimization Guide - General optimization strategies
- Training Guide - Training procedures
- GradNorm - Multi-task loss balancing
- API Reference - Complete API documentation