Skip to content

Multilevel Training

Multilevel training leverages multigrid insights to accelerate neural network convergence by training from coarse to fine representations. This approach captures low-frequency features quickly on coarse networks, then refines high-frequency details on finer networks.

Overview

Multilevel training offers significant benefits:

  • Faster convergence through hierarchical initialization
  • Better optimization landscape via coarse-to-fine progression
  • Reduced overfitting risk from progressive capacity
  • Natural curriculum from simple to complex representations

Survey Reference

This implementation follows the methodology described in Section 8.2 of the PINN survey (arXiv:2601.10222v1).

Width-Based Hierarchy (MLPs)

For standard MLPs, the hierarchy is based on network width (number of neurons per layer).

CascadeTrainer

The CascadeTrainer provides a generic framework for multilevel training, supporting any model hierarchy and optimizer.

from flax import nnx
import optax
from opifex.training.multilevel import (
    CascadeTrainer,
    create_network_hierarchy,
    prolongate,
    MultilevelAdam,
)

# 1. Create hierarchy (List of models from coarse to fine)
hierarchy = create_network_hierarchy(
    input_dim=2,
    output_dim=1,
    base_hidden_dims=[64, 64],
    num_levels=3,
    coarsening_factor=0.5,
    rngs=nnx.Rngs(0),
)

# 2. Create Multilevel Optimizer
# MultilevelAdam automatically handles state resizing during level transitions
optimizer = MultilevelAdam(learning_rate=1e-3)

# 3. Create Trainer
trainer = CascadeTrainer(
    hierarchy=hierarchy,
    optimizer=optimizer,
    prolongate_fn=prolongate,
)

# Current model (Level 0 - Coarsest)
model = trainer.get_current_model()

Multilevel Optimization

Standard optimizers (like Adam) maintain state (momentum, variance) that corresponds to specific parameters. When moving from a coarse model to a fine model, this state must be prolongated to match the new parameter shapes.

Opifex provides MultilevelAdam, a specialized optimizer that wraps optax.adam and handles this transition automatically.

# During training
optimizer.update(model, grads)

# When advancing level:
# trainer.advance_level() automatically calls:
# optimizer.resize_state(new_model, transition_fn)

Training Loop

# Iterate until finest level is completed
while True:
    model = trainer.get_current_model()
    level = trainer.current_level

    print(f"Training Level {level}")

    # Train for some epochs
    for epoch in range(100):
        grads = nnx.grad(loss_fn)(model, batch)
        # Update model and optimizer state
        optimizer.update(model, grads)

    # Advance to next level (automatically prolongates model and optimizer state)
    # Returns False if already at finest level
    if not trainer.advance_level():
        break

Transfer Operators

Transfer operators move parameters between hierarchy levels.

from opifex.training.multilevel import prolongate, restrict

# Prolongate: coarse -> fine (copy and pad)
fine_model = prolongate(coarse_model, fine_model)

# Restrict: fine -> coarse (truncate)
coarse_model = restrict(fine_model, coarse_model)

Prolongation: Copies coarse parameters to corresponding fine parameters, leaving additional fine parameters at initialization.

Restriction: Extracts a subset of fine parameters for the coarse model.

Creating Custom Hierarchies

You can use create_network_hierarchy or manually create a list of models.

from opifex.training.multilevel import create_network_hierarchy

hierarchy = create_network_hierarchy(
    input_dim=2,
    output_dim=1,
    base_hidden_dims=[128, 128],
    num_levels=4,
    coarsening_factor=0.5,
    activation=nnx.gelu,  # Custom activation
    rngs=nnx.Rngs(0),
)

# hierarchy[0]: smallest network (coarsest)
# hierarchy[-1]: largest network (finest)

Mode-Based Hierarchy (FNOs)

For Fourier Neural Operators, the hierarchy is based on the number of Fourier modes retained.

FNO Training Example

from opifex.training.multilevel import (
    create_fno_hierarchy,
    prolongate_fno_modes,
)

# 1. Create FNO hierarchy
fno_hierarchy = create_fno_hierarchy(
    base_modes=16,
    width=64,
    num_levels=3,
    reduction_factor=2,
    rngs=nnx.Rngs(0),
    # ... other args
)

# 2. Use generic CascadeTrainer with FNO-specific transfer
trainer = CascadeTrainer(
    hierarchy=fno_hierarchy,
    optimizer=MultilevelAdam(1e-3),
    prolongate_fn=prolongate_fno_modes,
)

Best Practices

Choosing Number of Levels

Problem Complexity Recommended Levels
Simple (smooth solutions) 2-3
Moderate 3-4
Complex (multi-scale) 4-5
# Simple problem: few levels, aggressive coarsening
config = MultilevelConfig(
    num_levels=2,
    coarsening_factor=0.5,
)

# Complex problem: more levels, gradual refinement
config = MultilevelConfig(
    num_levels=5,
    coarsening_factor=0.7,  # Less aggressive
)

Epoch Distribution

More epochs at finer levels capture more detail:

# Standard: increasing epochs
config = MultilevelConfig(
    level_epochs=[50, 100, 200],
)

# Fast warmup: emphasize fine level
config = MultilevelConfig(
    level_epochs=[20, 50, 100],
    warmup_epochs=100,  # Extra at finest
)

Combining with Other Techniques

With Adaptive Sampling:

from opifex.training.adaptive_sampling import RADSampler

sampler = RADSampler()

while True:
    model = trainer.get_current_model()

    for epoch in range(100):
        # Compute residuals
        residuals = compute_residual(model, all_points)

        # Adaptive sampling
        batch = sampler.sample(all_points, residuals, batch_size, key)

        # Training step
        loss, grads = nnx.value_and_grad(loss_fn)(model, batch)
        # ...

    if not trainer.advance_level():
        break

With GradNorm:

from opifex.core.physics.gradnorm import GradNormBalancer

balancer = GradNormBalancer(num_losses=3, rngs=nnx.Rngs(0))

while True:
    model = trainer.get_current_model()

    # Reset balancer for each level
    balancer._initial_losses = None

    for epoch in range(100):
        losses = compute_losses(model)

        if epoch == 0:
            balancer.set_initial_losses(losses)

        weighted_loss = balancer.compute_weighted_loss(losses)
        # ...

    if not trainer.advance_level():
        break

With Second-Order Optimization:

from opifex.optimization.second_order import (
    HybridOptimizer,
    HybridOptimizerConfig,
)

while True:
    model = trainer.get_current_model()
    level = trainer.current_level
    is_finest = level == len(trainer.hierarchy) - 1

    # Use Adam at coarse levels, hybrid at finest
    if is_finest:
        opt = HybridOptimizer(HybridOptimizerConfig())
    else:
        opt = optax.adam(1e-3)

    # ... training ...
    if not trainer.advance_level():
        break

Monitoring Progress

# Track loss at each level
level_losses = []

while True:
    model = trainer.get_current_model()
    level = trainer.current_level

    # Training
    losses = []
    for epoch in range(100):
        loss, grads = nnx.value_and_grad(loss_fn)(model)
        losses.append(float(loss))
        # ... update ...

    level_losses.append({
        'level': level,
        'final_loss': losses[-1],
        'improvement': losses[0] / losses[-1],
    })

    if not trainer.advance_level():
        break

# Analyze progression
for info in level_losses:
    print(f"Level {info['level']}: loss={info['final_loss']:.4e}, "
          f"improvement={info['improvement']:.1f}x")

Complete Training Example

import jax
import jax.numpy as jnp
import optax
from flax import nnx
from opifex.training.multilevel import (
    CascadeTrainer,
    MultilevelAdam,
    create_network_hierarchy,
    prolongate,
)

# Problem setup
def pde_residual(model, x):
    """Compute PDE residual for Poisson equation."""
    def u_scalar(xi):
        return model(xi.reshape(1, -1)).squeeze()

    laplacian = jax.vmap(lambda xi: jnp.trace(jax.hessian(u_scalar)(xi)))(x)
    f = jnp.sin(jnp.pi * x[:, 0]) * jnp.sin(jnp.pi * x[:, 1])
    return laplacian + f

def loss_fn(model, x_interior, x_boundary):
    residual = pde_residual(model, x_interior)
    pde_loss = jnp.mean(residual ** 2)

    boundary_pred = model(x_boundary)
    bc_loss = jnp.mean(boundary_pred ** 2)

    return pde_loss + 10.0 * bc_loss

# Create hierarchy and trainer
hierarchy = create_network_hierarchy(
    input_dim=2,
    output_dim=1,
    base_hidden_dims=[64, 64],
    num_levels=3,
    coarsening_factor=0.5,
    rngs=nnx.Rngs(42),
)

optimizer = MultilevelAdam(learning_rate=1e-3)

trainer = CascadeTrainer(
    hierarchy=hierarchy,
    optimizer=optimizer,
    prolongate_fn=prolongate,
)

# Training data
x_interior = jax.random.uniform(jax.random.key(0), (1000, 2))
x_boundary = generate_boundary_points(100)

# Multilevel training loop
level_epochs = [100, 200, 500]
while True:
    model = trainer.get_current_model()
    level = trainer.current_level
    epochs = level_epochs[level]

    print(f"\n--- Level {level} ---")
    for epoch in range(epochs):
        loss, grads = nnx.value_and_grad(
            lambda m: loss_fn(m, x_interior, x_boundary)
        )(model)

        optimizer.update(model, grads)

        if epoch % 50 == 0:
            print(f"Epoch {epoch}: loss = {loss:.4e}")

    if not trainer.advance_level():
        break

# Final model
final_model = trainer.get_current_model()

See Also