Spectral Normalization for Neural Operators¶
| Metadata | Value |
|---|---|
| Level | Intermediate |
| Runtime | ~5 min (CPU) |
| Prerequisites | JAX, Flax NNX, Linear Algebra basics |
| Format | Python + Jupyter |
Overview¶
Spectral normalization controls the Lipschitz constant of neural network layers by normalizing weight matrices by their spectral norm (largest singular value). This is critical for PDE-solving neural operators where stability and convergence guarantees depend on bounded operator norms.
This example demonstrates spectral normalized linear layers, convolutions, and attention mechanisms. It includes stability analysis comparing regular vs spectral normalized networks, adaptive bounds, power iteration algorithm details, and performance benchmarks.
What You'll Learn¶
- Apply
SpectralLinearandSpectralNormalizedConvfor stable neural operator layers - Use
SpectralMultiHeadAttentionfor normalized attention mechanisms - Configure
AdaptiveSpectralNormwith fixed and learnable bounds - Analyze Lipschitz constants to verify stability improvements
- Build complete spectral normalized neural operators with
create_spectral_neural_operator
Files¶
- Python Script:
examples/layers/spectral_normalization_example.py - Jupyter Notebook:
examples/layers/spectral_normalization_example.ipynb
Quick Start¶
Core Concepts¶
Why Spectral Normalization?¶
Neural operators learn mappings between function spaces. Without normalization, weight matrices can have arbitrarily large spectral norms, leading to:
- Exploding gradients during training
- Unstable predictions that amplify input perturbations
- Poor generalization due to lack of Lipschitz control
Spectral normalization bounds the spectral norm to 1 (or a configurable bound), ensuring controlled sensitivity to input changes.
graph TB
subgraph Without["Without Spectral Norm"]
A1["Input x"] --> B1["W (unbounded)"]
B1 --> C1["Output Wx<br/>(can explode)"]
end
subgraph With["With Spectral Norm"]
A2["Input x"] --> B2["W/sigma(W)"]
B2 --> C2["Output<br/>(bounded response)"]
end
style Without fill:#fce4ec
style With fill:#e3f2fd
Power Iteration Algorithm¶
Computing the exact spectral norm requires SVD (O(n^3)). Power iteration provides an efficient O(n) approximation by iteratively refining an estimate of the largest singular value:
| Iterations | Accuracy | Cost |
|---|---|---|
| 1 | Approximate | Minimal overhead |
| 3-5 | Good for training | Standard choice |
| 10+ | Near-exact | Higher overhead |
Implementation¶
Step 1: Basic Spectral Normalization Layers¶
Replace standard layers with spectral normalized variants:
from opifex.neural.operators.specialized.spectral_normalization import (
SpectralLinear,
SpectralNormalizedConv,
)
# Spectral normalized linear layer
spectral_linear = SpectralLinear(10, 5, power_iterations=5, rngs=rngs)
y = spectral_linear(x, training=True)
# Spectral normalized convolution
spectral_conv = SpectralNormalizedConv(3, 16, kernel_size=3, power_iterations=3, rngs=rngs)
y_conv = spectral_conv(x_img, training=True)
Terminal Output:
BASIC SPECTRAL NORMALIZATION LAYERS
==================================================
Linear Layer Comparison:
Regular Linear: (8, 10) -> (8, 5)
Spectral Linear: (8, 10) -> (8, 5)
Regular kernel spectral norm: 2.018
Spectral normalized estimate: 1.434
Convolution Layer Comparison:
Regular Conv: (4, 32, 32, 3) -> (4, 32, 32, 16)
Spectral Conv: (4, 32, 32, 3) -> (4, 32, 32, 16)
Step 2: Spectral Normalized Attention¶
Multi-head attention with spectral normalization for stable sequence processing:
from opifex.neural.operators.specialized.spectral_normalization import SpectralMultiHeadAttention
spectral_attention = SpectralMultiHeadAttention(
num_heads=8,
in_features=64,
power_iterations=3,
rngs=rngs,
)
output = spectral_attention(x, training=True)
Terminal Output:
SPECTRAL NORMALIZED ATTENTION
==================================================
Attention configuration:
Number of heads: 8
Feature dimension: 64
Head dimension: 8
Processing sequence: (2, 32, 64)
Output shape: (2, 32, 64)
Step 3: Adaptive Spectral Normalization¶
AdaptiveSpectralNorm allows per-layer spectral bound tuning, with optional
learnable bounds:
from opifex.neural.operators.specialized.spectral_normalization import AdaptiveSpectralNorm
adaptive_layer = AdaptiveSpectralNorm(
base_linear,
power_iterations=5,
initial_bound=1.0,
learnable_bound=True, # Bound adjusts during training
rngs=rngs,
)
Terminal Output:
ADAPTIVE SPECTRAL NORMALIZATION
==================================================
Fixed bound (1.0):
Initial bound: 1.0
Learnable: False
Learnable bound:
Initial bound: 1.0
Learnable: True
Step 4: Power Iteration Algorithm¶
The core algorithm for efficient spectral norm estimation, tested against exact SVD computation:
Terminal Output:
POWER ITERATION ALGORITHM
==================================================
Matrix: Identity (shape: (4, 4))
True spectral norm (SVD): 1.000000
1 iterations: 1.000000 (error: 0.000000, time: 344.39 ms)
3 iterations: 1.000000 (error: 0.000000, time: 3.84 ms)
5 iterations: 1.000000 (error: 0.000000, time: 3.66 ms)
10 iterations: 1.000000 (error: 0.000000, time: 5.34 ms)
Matrix: Diagonal (shape: (4, 4))
True spectral norm (SVD): 3.000000
1 iterations: 2.682737 (error: 0.317263, time: 2.15 ms)
3 iterations: 2.988353 (error: 0.011647, time: 4.91 ms)
5 iterations: 2.999541 (error: 0.000459, time: 2.64 ms)
10 iterations: 3.000000 (error: 0.000000, time: 3.78 ms)
Matrix: Random (shape: (6, 4))
True spectral norm (SVD): 3.189911
1 iterations: 2.980349 (error: 0.209562, time: 363.63 ms)
3 iterations: 3.117129 (error: 0.072781, time: 3.76 ms)
5 iterations: 3.170034 (error: 0.019876, time: 3.06 ms)
10 iterations: 3.189305 (error: 0.000606, time: 4.89 ms)
Matrix: Large Random (shape: (128, 64))
True spectral norm (SVD): 18.911154
1 iterations: 14.052646 (error: 4.858508, time: 356.26 ms)
3 iterations: 17.750385 (error: 1.160769, time: 2.42 ms)
5 iterations: 18.417088 (error: 0.494066, time: 2.52 ms)
10 iterations: 18.828447 (error: 0.082706, time: 3.36 ms)
Step 5: Stability Analysis¶
Comparing Lipschitz constants between regular and spectral normalized networks:
Terminal Output:
STABILITY ANALYSIS & LIPSCHITZ CONTROL
==================================================
Network configurations:
Regular: Linear layers with standard weights
Spectral: SpectralLinear layers with spectral normalization
Lipschitz constant estimation:
Regular network:
Mean Lipschitz: 0.396 +/- 0.158
Max Lipschitz: 0.927
Spectral normalized network:
Mean Lipschitz: 0.066 +/- 0.030
Max Lipschitz: 0.168
Step 6: Complete Spectral Neural Operators¶
Building full architectures with varying sizes:
from opifex.neural.operators.fno.spectral import create_spectral_neural_operator
model = create_spectral_neural_operator(
input_dim=64,
output_dim=64,
hidden_dims=(128, 128, 64),
num_heads=8,
power_iterations=3,
rngs=rngs,
)
Terminal Output:
COMPLETE SPECTRAL NEURAL OPERATORS
==================================================
Creating Small FNO-style:
Input/Output dims: 32 -> 32
Hidden layers: (64, 64)
Attention heads: 4
Creation time: 387.87 ms
Creating Medium PDE solver:
Input/Output dims: 64 -> 64
Hidden layers: (128, 128, 64)
Attention heads: 8
Creation time: 633.09 ms
Creating Large Multi-scale:
Input/Output dims: 128 -> 64
Hidden layers: (256, 192, 128, 96)
Attention heads: 16
Creation time: 959.89 ms
Testing forward passes:
Small FNO-style: (4, 32) -> (4, 32) (698.53 ms)
Medium PDE solver: (4, 64) -> (4, 64) (414.97 ms)
Large Multi-scale: (4, 128) -> (4, 64) (2268.00 ms)
Results Summary¶
| Component | Benefit | Overhead |
|---|---|---|
| SpectralLinear | Bounded Lipschitz constant | ~10-30% |
| SpectralNormalizedConv | Stable spatial processing | ~15-25% |
| SpectralMultiHeadAttention | Stable attention weights | ~10-20% |
| AdaptiveSpectralNorm | Flexible per-layer control | ~10-20% |
| PowerIteration | Efficient O(n) norm estimation | Configurable |
Key Takeaways¶
- Spectral normalization controls Lipschitz constants for stable training
- Power iteration provides efficient O(n) spectral norm estimation
- Adaptive bounds allow layer-specific flexibility
- JAX transformations (JIT, grad, vmap, Hessian) work seamlessly
- Modest overhead (~10-30%) for significant stability improvements
- Particularly beneficial for PDE-solving neural operators
Next Steps¶
Experiments to Try¶
- FNO with spectral norm: Apply
SpectralLinearto FNO spectral layers for stable Darcy flow training - PINN stability: Compare PINN training convergence with and without spectral normalization
- Adaptive bounds: Use learnable bounds with
AdaptiveSpectralNormfor multi-scale architectures
Related Examples¶
| Example | Level | What You'll Learn |
|---|---|---|
| Grid Embeddings | Beginner | Spatial coordinate injection |
| FNO Darcy Full | Intermediate | Apply spectral layers in training |
| Neural Operator Benchmark | Advanced | Cross-architecture comparison |
API Reference¶
SpectralLinear- Spectral normalized linear layerSpectralNormalizedConv- Spectral normalized convolutionSpectralMultiHeadAttention- Spectral normalized attentionAdaptiveSpectralNorm- Adaptive spectral boundsPowerIteration- Spectral norm estimation algorithmcreate_spectral_neural_operator- Complete architecture factoryspectral_norm_summary- Model spectral norm analysis
Troubleshooting¶
Spectral norm not converging¶
Symptom: Power iteration estimates vary significantly between forward passes.
Cause: Too few power iterations for the matrix size.
Solution: Increase power_iterations (3-5 is typical, use 10+ for large matrices):
Training loss oscillating¶
Symptom: Loss oscillates when using spectral normalization.
Cause: Spectral bound may be too restrictive, limiting model capacity.
Solution: Use AdaptiveSpectralNorm with a larger initial bound:
adaptive = AdaptiveSpectralNorm(
base_layer,
initial_bound=2.0, # Allow larger spectral norm
learnable_bound=True, # Let the model find the right bound
rngs=rngs,
)
training parameter confusion¶
Symptom: Different results with training=True vs training=False.
Cause: Power iteration updates internal state during training=True.
During inference, use training=False to use cached spectral norm.
Solution: