Skip to content

Physics API Reference

The opifex.physics package provides JAX-native physics solvers and numerical methods for scientific computing applications.

Overview

The physics module offers:

  • PDE Solvers: Numerical solvers for common PDEs (Burgers, diffusion-advection, shallow water)
  • Spectral Methods: Fourier-based PDE solvers and analysis tools
  • Numerical Schemes: Finite difference, finite element, spectral methods
  • Conservation Laws: Tools for enforcing physical constraints
  • Quantum Spectral: Quantum chemistry spectral solvers

PDE Solvers

Burgers 2D Solver

opifex.physics.solvers.burgers.Burgers2DSolver

Burgers2DSolver(resolution: int = 64, domain_size: tuple[float, float] = (2.0 * pi, 2.0 * pi), viscosity: float = 0.01, dt_max: float = 0.001)

2D viscous Burgers equation solver using JAX.

Solves the nonlinear 2D Burgers equation using finite difference schemes with adaptive time stepping for numerical stability.

Parameters:

Name Type Description Default
resolution int

Grid resolution (number of points per dimension)

64
domain_size tuple[float, float]

Physical domain size (default: 2π x 2π)

(2.0 * pi, 2.0 * pi)
viscosity float

Kinematic viscosity coefficient

0.01
dt_max float

Maximum time step (adaptive stepping will use smaller values if needed)

0.001

solve

solve(initial_condition: tuple[Array, Array], time_final: float, save_every: int | None = None) -> tuple[Array, Array, Array]

Solve the 2D Burgers equation from initial condition to final time.

Parameters:

Name Type Description Default
initial_condition tuple[Array, Array]

Tuple of (u0, v0) initial velocity fields

required
time_final float

Final time to integrate to

required
save_every int | None

Save solution every N time steps (None = only save final)

None

Returns:

Type Description
tuple[Array, Array, Array]

Tuple of (time_array, u_trajectory, v_trajectory)

create_vortex_initial_condition

create_vortex_initial_condition(strength: float = 1.0, center: tuple[float, float] | None = None) -> tuple[Array, Array]

Create a vortex initial condition.

create_shear_layer_initial_condition

create_shear_layer_initial_condition(shear_strength: float = 1.0) -> tuple[Array, Array]

Create a shear layer initial condition.

Diffusion-Advection Solver

opifex.physics.solvers.diffusion_advection.solve_diffusion_advection_2d

solve_diffusion_advection_2d(initial_condition: Array, diffusion_coeff: float, advection_vel: tuple[float, float], dt: float = 0.001, n_steps: int = 1000, grid_spacing: float = 1.0) -> Array

Solve 2D diffusion-advection equation using finite differences.

JIT-compiled for optimal performance on the core PDE solver.

Parameters:

Name Type Description Default
initial_condition Array

Initial field values

required
diffusion_coeff float

Diffusion coefficient

required
advection_vel tuple[float, float]

(vx, vy) advection velocities

required
dt float

Time step

0.001
n_steps int

Number of time steps

1000
grid_spacing float

Spatial grid spacing

1.0

Returns:

Type Description
Array

Solution at final time

opifex.physics.solvers.navier_stokes.solve_navier_stokes_2d

solve_navier_stokes_2d(u0: Array, v0: Array, nu: float, time_range: tuple[float, float] = (0.0, 1.0), time_steps: int = 5, resolution: int = 64) -> tuple[Array, Array]

Solve 2D incompressible Navier-Stokes equations.

Uses a projection method with finite differences on a periodic domain. The domain is [0, 2π] × [0, 2π] with periodic boundary conditions.

Parameters:

Name Type Description Default
u0 Array

Initial x-velocity field, shape (resolution, resolution)

required
v0 Array

Initial y-velocity field, shape (resolution, resolution)

required
nu float

Kinematic viscosity (ν = μ/ρ)

required
time_range tuple[float, float]

(start_time, end_time)

(0.0, 1.0)
time_steps int

Number of time steps to save

5
resolution int

Grid resolution (should match u0, v0)

64

Returns:

Type Description
Array

Tuple of (u_trajectory, v_trajectory) each of shape

Array

(time_steps+1, resolution, resolution) including initial condition

Initial Condition Factories

opifex.physics.solvers.navier_stokes.create_taylor_green_vortex

create_taylor_green_vortex(resolution: int, amplitude: float = 1.0) -> tuple[Array, Array]

Create Taylor-Green vortex initial condition.

The Taylor-Green vortex is an exact solution of the NS equations at t=0 and decays exponentially due to viscosity. It satisfies incompressibility.

u = A * sin(x) * cos(y) v = -A * cos(x) * sin(y)

Parameters:

Name Type Description Default
resolution int

Grid resolution

required
amplitude float

Velocity amplitude

1.0

Returns:

Type Description
tuple[Array, Array]

Tuple of (u0, v0) initial velocity fields

opifex.physics.solvers.navier_stokes.create_lid_driven_cavity_ic

create_lid_driven_cavity_ic(resolution: int, lid_velocity: float = 1.0) -> tuple[Array, Array]

Create lid-driven cavity initial condition.

For lid-driven cavity, the top boundary has a specified velocity while all other boundaries are no-slip. This is an approximation using a smooth profile since we use periodic boundaries.

Parameters:

Name Type Description Default
resolution int

Grid resolution

required
lid_velocity float

Velocity of the lid (top boundary)

1.0

Returns:

Type Description
tuple[Array, Array]

Tuple of (u0, v0) initial velocity fields

opifex.physics.solvers.navier_stokes.create_double_shear_layer

create_double_shear_layer(resolution: int, shear_thickness: float = 0.05, perturbation: float = 0.05) -> tuple[Array, Array]

Create double shear layer initial condition.

A classic test case for 2D turbulence that develops Kelvin-Helmholtz instabilities.

Parameters:

Name Type Description Default
resolution int

Grid resolution

required
shear_thickness float

Thickness of the shear layers

0.05
perturbation float

Amplitude of initial perturbation

0.05

Returns:

Type Description
tuple[Array, Array]

Tuple of (u0, v0) initial velocity fields

Shallow Water Equations Solver

opifex.physics.solvers.shallow_water.solve_shallow_water_2d

solve_shallow_water_2d(h_initial: Array, u_initial: Array, v_initial: Array, g: float = 9.81, dt: float = 0.001, n_steps: int = 1000, grid_spacing: float = 1.0) -> tuple[Array, Array, Array]

Solve 2D shallow water equations using finite differences.

JIT-compiled for optimal performance on the shallow water PDE system.

Parameters:

Name Type Description Default
h_initial Array

Initial height field

required
u_initial Array

Initial u-velocity field

required
v_initial Array

Initial v-velocity field

required
g float

Gravitational acceleration

9.81
dt float

Time step

0.001
n_steps int

Number of time steps

1000
grid_spacing float

Spatial grid spacing

1.0

Returns:

Type Description
tuple[Array, Array, Array]

Tuple of (height, u_velocity, v_velocity) at final time

Conservation Laws

ConservationLaw is an Enum in opifex.core.physics.conservation that defines all supported conservation law types.

opifex.core.physics.conservation.ConservationLaw

Bases: Enum

Enum defining all supported conservation laws.

This is the single source of truth for conservation law names across the entire Opifex framework.

Neural Tangent Kernel (NTK) Analysis

Tools for spectral analysis and training diagnostics via the Neural Tangent Kernel.

NTK Wrapper

opifex.core.physics.ntk.wrapper

Neural Tangent Kernel (NTK) wrapper for FLAX NNX models.

This module provides utilities for computing empirical Neural Tangent Kernels for FLAX NNX models using JAX-native Jacobian contraction.

Key Features
  • Wrap NNX models for NTK computation
  • Compute empirical NTK matrices
  • Jacobian computation utilities
  • Spectral analysis of NTK
References
  • Jacot et al. (2018): Neural Tangent Kernel
  • Survey Section 3: Neural Tangent Kernel Analysis

NTKWrapper

NTKWrapper(model: Module, config: NTKConfig | None = None)

Wrapper for computing NTK with NNX models.

This class provides a convenient interface for NTK computations, caching the model structure and configuration.

Attributes:

Name Type Description
model

The NNX model

config

NTK configuration

Example

model = MyModel(rngs=nnx.Rngs(0)) wrapper = NTKWrapper(model) ntk = wrapper.compute_ntk(x)

Parameters:

Name Type Description Default
model Module

FLAX NNX model

required
config NTKConfig | None

NTK configuration

None

compute_ntk

compute_ntk(x1: Float[Array, 'batch1 dim'], x2: Float[Array, 'batch2 dim'] | None = None) -> Float[Array, 'batch1 batch2']

Compute empirical NTK between input points.

Parameters:

Name Type Description Default
x1 Float[Array, 'batch1 dim']

First set of input points

required
x2 Float[Array, 'batch2 dim'] | None

Second set of input points (uses x1 if None)

None

Returns:

Type Description
Float[Array, 'batch1 batch2']

NTK matrix

compute_eigenvalues

compute_eigenvalues(x: Float[Array, ...]) -> Float[Array, ' batch']

Compute eigenvalues of NTK at given points.

Parameters:

Name Type Description Default
x Float[Array, ...]

Input points

required

Returns:

Type Description
Float[Array, ' batch']

Eigenvalues sorted in descending order

compute_condition_number

compute_condition_number(x: Float[Array, ...]) -> Float[Array, '']

Compute condition number of NTK.

The condition number is the ratio of largest to smallest eigenvalue. Large condition numbers indicate ill-conditioning.

Parameters:

Name Type Description Default
x Float[Array, ...]

Input points

required

Returns:

Type Description
Float[Array, '']

Condition number

NTKConfig dataclass

NTKConfig(implementation: int = 1, trace_axes: tuple = (), diagonal_axes: tuple = (), vmap_axes: tuple | None = None)

Configuration for NTK computation.

Attributes:

Name Type Description
implementation int

NTK implementation method (1=Jacobian contraction, 2=NTK-vector products, 3=structured derivatives)

trace_axes tuple

Axes to trace over for NTK computation

diagonal_axes tuple

Axes to compute diagonal for

vmap_axes tuple | None

Axes to vmap over

Spectral Analysis

opifex.core.physics.ntk.spectral_analysis

NTK Spectral Analysis for training diagnostics.

This module provides tools for analyzing the spectral properties of the Neural Tangent Kernel, which are fundamental for understanding training dynamics and convergence properties.

Key Features
  • Eigenvalue decomposition of NTK
  • Condition number computation
  • Effective rank estimation
  • Mode-wise convergence analysis
  • Spectral bias detection
References
  • Survey Section 3.2: Mode-wise Error Decay
  • Jacot et al. (2018): Neural Tangent Kernel

NTKSpectralAnalyzer

NTKSpectralAnalyzer(model: Module)

Analyzer for NTK spectral properties.

This class provides a convenient interface for analyzing NTK eigenvalue distributions and tracking them during training.

Attributes:

Name Type Description
model

The NNX model to analyze

ntk_wrapper

NTK computation wrapper

history list[NTKDiagnostics]

History of diagnostics during training

Example

model = MyModel(rngs=nnx.Rngs(0)) analyzer = NTKSpectralAnalyzer(model) diagnostics = analyzer.analyze(x_train) print(f"Condition number: {diagnostics.condition_number}")

Parameters:

Name Type Description Default
model Module

FLAX NNX model to analyze

required

analyze

analyze(x: Float[Array, ...], learning_rate: float = 0.01, track: bool = False) -> NTKDiagnostics

Analyze NTK spectral properties at given points.

Parameters:

Name Type Description Default
x Float[Array, ...]

Input points to analyze NTK at

required
learning_rate float

Learning rate for convergence rate computation

0.01
track bool

Whether to store result in history

False

Returns:

Type Description
NTKDiagnostics

NTKDiagnostics with spectral analysis results

get_condition_number_history

get_condition_number_history() -> Float[Array, ...]

Get history of condition numbers during training.

Returns:

Type Description
Float[Array, ...]

Array of condition numbers from tracked analyses

get_effective_rank_history

get_effective_rank_history() -> Float[Array, ...]

Get history of effective ranks during training.

Returns:

Type Description
Float[Array, ...]

Array of effective ranks from tracked analyses

clear_history

clear_history()

Clear the tracking history.

compute_effective_rank

compute_effective_rank(eigenvalues: Float[Array, ...]) -> Float[Array, '']

Compute effective rank from eigenvalue distribution.

The effective rank is computed using the entropy-based definition

effective_rank = exp(entropy(p))

where p is the normalized eigenvalue distribution.

This gives a smooth measure of how many "significant" eigenvalues exist.

Parameters:

Name Type Description Default
eigenvalues Float[Array, ...]

Eigenvalues (should be non-negative)

required

Returns:

Type Description
Float[Array, '']

Effective rank (scalar between 1 and len(eigenvalues))

compute_mode_convergence_rates

compute_mode_convergence_rates(eigenvalues: Float[Array, ...], learning_rate: float) -> Float[Array, ...]

Compute convergence rates for each eigenmode.

For gradient descent with learning rate α, the error in eigenmode k decays as (1 - α * λ_k)^t, where λ_k is the k-th eigenvalue.

The convergence rate is 1 - α * λ_k, with smaller values indicating faster convergence.

From Survey Section 3.2: e_k = Σᵢ cᵢ(1 - αλᵢ)^k qᵢ

Parameters:

Name Type Description Default
eigenvalues Float[Array, ...]

Eigenvalues

required
learning_rate float

Learning rate α

required

Returns:

Type Description
Float[Array, ...]

Per-mode convergence rates (values in [0, 1] for stable training)

Training Diagnostics

opifex.core.physics.ntk.diagnostics

NTK-based Training Diagnostics and Callbacks.

This module provides tools for diagnosing training dynamics using the Neural Tangent Kernel, including mode-wise error decay prediction and training callbacks for monitoring NTK evolution.

Key Features
  • Mode-wise error decay computation
  • Convergence prediction from NTK eigenvalues
  • Spectral bias detection and monitoring
  • Training callbacks for NTK diagnostics
References
  • Survey Section 3.2: Mode-wise Error Decay
  • e_k = Σᵢ cᵢ(1 - αλᵢ)^k qᵢ

NTKDiagnosticsCallback

NTKDiagnosticsCallback(compute_frequency: int = 100)

Callback for NTK diagnostics during training.

Computes and tracks NTK properties at specified intervals.

Attributes:

Name Type Description
frequency

How often to compute NTK (every N steps)

history

List of diagnostic dictionaries

Parameters:

Name Type Description Default
compute_frequency int

Compute NTK every N steps

100

on_step_end

on_step_end(model: Module, x: Float[Array, ...], step: int) -> None

Called at end of each training step.

Parameters:

Name Type Description Default
model Module

Current model state

required
x Float[Array, ...]

Sample inputs for NTK computation

required
step int

Current training step

required

get_history

get_history() -> list[dict]

Get history of diagnostics.

Returns:

Type Description
list[dict]

List of diagnostic dictionaries

get_condition_numbers

get_condition_numbers() -> Float[Array, ...]

Get array of condition numbers from history.

Returns:

Type Description
Float[Array, ...]

Condition numbers at each tracked step

For detailed usage and theoretical background, see the NTK Analysis Guide.

GradNorm Loss Balancing

Multi-task loss balancing through gradient magnitude normalization.

opifex.core.physics.gradnorm

GradNorm: Gradient Normalization for Multi-Task Learning.

This module implements GradNorm, which automatically balances the contribution of different loss terms based on gradient magnitudes.

Key Features
  • Automatic loss weight adaptation
  • Balances training rates across tasks
  • Prevents gradient domination by any single loss

The key insight is that losses with larger gradients tend to dominate training. GradNorm adjusts weights to equalize the gradient contributions.

From Survey Section 2.2.2: L_grad = Σᵢ |‖γᵢ∇_θR̂ᵢ‖ - Ḡ × rᵢ^ζ|

References
  • Chen et al. (2018): GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks
  • Survey Section 2.2.2: Loss Weighting Strategies

GradNormBalancer

GradNormBalancer(num_losses: int, config: GradNormConfig | None = None, *, rngs: Rngs)

Bases: Module

GradNorm balancer for multi-task learning.

This module maintains learnable weights for each loss component and updates them to balance gradient contributions.

Attributes:

Name Type Description
config

GradNorm configuration

log_weights

Learnable log-weights (exp to get actual weights)

Example

balancer = GradNormBalancer(num_losses=3, rngs=nnx.Rngs(0)) losses = jnp.array([data_loss, pde_loss, boundary_loss]) weighted_loss = balancer.compute_weighted_loss(losses)

Parameters:

Name Type Description Default
num_losses int

Number of loss components to balance

required
config GradNormConfig | None

GradNorm configuration

None
rngs Rngs

Random number generators

required

weights property

weights: Float[Array, ...]

Get current weights (exponentiated and clipped).

compute_weighted_loss

compute_weighted_loss(losses: Float[Array, ...]) -> Float[Array, '']

Compute weighted sum of losses.

Parameters:

Name Type Description Default
losses Float[Array, ...]

Individual loss values

required

Returns:

Type Description
Float[Array, '']

Weighted sum of losses

compute_gradnorm_loss

compute_gradnorm_loss(grad_norms: Float[Array, ...], losses: Float[Array, ...], initial_losses: Float[Array, ...]) -> Float[Array, '']

Compute GradNorm balancing loss.

The GradNorm loss encourages

‖w_i ∇L_i‖ ≈ Ḡ × r_i^α

where Ḡ is the average gradient norm and r_i is the relative inverse training rate.

Parameters:

Name Type Description Default
grad_norms Float[Array, ...]

Gradient norms for each loss

required
losses Float[Array, ...]

Current loss values

required
initial_losses Float[Array, ...]

Initial loss values

required

Returns:

Type Description
Float[Array, '']

GradNorm loss for weight updates

update_weights

update_weights(grad_norms: Float[Array, ...], losses: Float[Array, ...], initial_losses: Float[Array, ...]) -> None

Update weights based on gradient norms.

This updates the log_weights to minimize the GradNorm loss.

Parameters:

Name Type Description Default
grad_norms Float[Array, ...]

Gradient norms for each loss

required
losses Float[Array, ...]

Current loss values

required
initial_losses Float[Array, ...]

Initial loss values

required

set_initial_losses

set_initial_losses(losses: Float[Array, ...]) -> None

Set initial losses for training rate computation.

Parameters:

Name Type Description Default
losses Float[Array, ...]

Initial loss values

required

get_initial_losses

get_initial_losses() -> Float[Array, ...] | None

Get stored initial losses.

Returns:

Type Description
Float[Array, ...] | None

Initial losses if set, None otherwise

GradNormConfig dataclass

GradNormConfig(alpha: float = 1.5, learning_rate: float = 0.01, update_frequency: int = 1, min_weight: float = 0.01, max_weight: float = 100.0)

Configuration for GradNorm balancing.

Attributes:

Name Type Description
alpha float

Asymmetry parameter (ζ in the paper). Controls how much to penalize tasks with different training rates. alpha=0: Equal weighting for all tasks alpha>0: Stronger gradient for tasks training slower

learning_rate float

Learning rate for weight updates

update_frequency int

How often to update weights (in training steps)

min_weight float

Minimum allowed weight

max_weight float

Maximum allowed weight

compute_gradient_norms

compute_gradient_norms(model: Module, loss_fns: Sequence[Callable[[Module], Float[Array, '']]]) -> Float[Array, ...]

Compute gradient norms for each loss function.

Parameters:

Name Type Description Default
model Module

The neural network model

required
loss_fns Sequence[Callable[[Module], Float[Array, '']]]

List of loss functions, each taking model and returning scalar

required

Returns:

Type Description
Float[Array, ...]

Array of gradient norms for each loss

compute_inverse_training_rates

compute_inverse_training_rates(current_losses: Float[Array, ...], initial_losses: Float[Array, ...]) -> Float[Array, ...]

Compute relative inverse training rates.

The inverse training rate for task i is

r_i = L_i(t) / L_i(0)

We normalize so the mean is 1

r̃_i = r_i / mean®

Parameters:

Name Type Description Default
current_losses Float[Array, ...]

Current loss values for each task

required
initial_losses Float[Array, ...]

Initial loss values for each task

required

Returns:

Type Description
Float[Array, ...]

Relative inverse training rates (mean normalized to 1)

For algorithm details and best practices, see the GradNorm Guide.

See Also