Visualization API Reference¶
The opifex.visualization package provides full visualization tools for scientific computing applications, including field plotting, animation, and performance analysis.
Overview¶
The visualization module offers:
- Field Plotting: 2D/3D field visualizations with multiple plotting modes
- Animation: Create physics-based animations of time-dependent solutions
- Performance Visualization: Plot FLOPS, memory usage, and model complexity
- Spectral Analysis: Visualize frequency-domain representations
- Vector Fields: Streamline and quiver plots for vector data
All visualization functions are designed to work seamlessly with JAX arrays and support both interactive and publication-quality output.
Field Plotting¶
plot_2d_field¶
Plot 2D scalar fields with various visualization modes.
from opifex.visualization import plot_2d_field
def plot_2d_field(
field: Array,
coordinates: Optional[Array] = None,
title: str = "2D Field",
cmap: str = "viridis",
show_colorbar: bool = True,
levels: Optional[int] = None,
mode: str = "contourf",
ax: Optional[plt.Axes] = None,
**kwargs
) -> plt.Figure:
"""
Plot 2D scalar field with multiple visualization modes.
Args:
field: 2D array of field values, shape (nx, ny)
coordinates: Optional coordinate grid, shape (nx, ny, 2)
If None, uses uniform grid [0, nx] × [0, ny]
title: Plot title
cmap: Matplotlib colormap name
show_colorbar: Whether to display colorbar
levels: Number of contour levels (for contour/contourf modes)
mode: Visualization mode:
- 'contourf': Filled contours (default)
- 'contour': Line contours
- 'pcolormesh': Pseudocolor plot
- 'imshow': Image plot
ax: Matplotlib axes (creates new if None)
**kwargs: Additional arguments passed to plotting function
Returns:
matplotlib Figure object
Example:
>>> import jax.numpy as jnp
>>> x = jnp.linspace(-1, 1, 100)
>>> y = jnp.linspace(-1, 1, 100)
>>> X, Y = jnp.meshgrid(x, y)
>>> field = jnp.sin(jnp.pi * X) * jnp.cos(jnp.pi * Y)
>>> fig = plot_2d_field(field, title="Standing Wave")
"""
plot_field_evolution¶
Visualize the temporal evolution of a field as a sequence of subplots.
from opifex.visualization import plot_field_evolution
def plot_field_evolution(
trajectory: Array,
times: Optional[Array] = None,
num_snapshots: int = 6,
title: str = "Field Evolution",
cmap: str = "RdBu_r",
vmin: Optional[float] = None,
vmax: Optional[float] = None,
figsize: Tuple[int, int] = (15, 10)
) -> plt.Figure:
"""
Plot time evolution of field as subplot grid.
Args:
trajectory: Time-dependent field, shape (nt, nx, ny) or (nt, nx)
times: Time values for each snapshot, shape (nt,)
If None, uses indices
num_snapshots: Number of snapshots to display
title: Overall figure title
cmap: Colormap name
vmin, vmax: Color scale limits (auto if None)
figsize: Figure size in inches
Returns:
matplotlib Figure object
Example:
>>> # Visualize PDE solution evolution
>>> trajectory = burgers_solution # Shape: (100, 256, 256)
>>> times = jnp.linspace(0, 1, 100)
>>> fig = plot_field_evolution(
... trajectory,
... times=times,
... num_snapshots=6,
... title="Burgers Equation Evolution"
... )
"""
plot_field_comparison¶
Compare multiple fields side-by-side (e.g., ground truth vs. prediction).
from opifex.visualization import plot_field_comparison
def plot_field_comparison(
fields: List[Array],
titles: List[str],
suptitle: str = "Field Comparison",
cmap: str = "viridis",
show_difference: bool = True,
figsize: Optional[Tuple[int, int]] = None
) -> plt.Figure:
"""
Compare multiple 2D fields side-by-side.
Args:
fields: List of 2D arrays to compare
titles: Title for each field
suptitle: Overall figure title
cmap: Colormap name
show_difference: If True and 2 fields, show difference plot
figsize: Figure size (auto-computed if None)
Returns:
matplotlib Figure object
Example:
>>> # Compare model prediction with ground truth
>>> fields = [ground_truth, prediction]
>>> titles = ["Ground Truth", "Neural Operator Prediction"]
>>> fig = plot_field_comparison(
... fields, titles,
... suptitle="FNO Performance",
... show_difference=True
... )
"""
plot_vector_field¶
Visualize 2D vector fields using streamlines or quiver plots.
from opifex.visualization import plot_vector_field
def plot_vector_field(
u: Array,
v: Array,
coordinates: Optional[Tuple[Array, Array]] = None,
mode: str = "streamplot",
density: float = 1.0,
color: Optional[Array] = None,
title: str = "Vector Field",
ax: Optional[plt.Axes] = None
) -> plt.Figure:
"""
Plot 2D vector field.
Args:
u: x-component of vector field, shape (nx, ny)
v: y-component of vector field, shape (nx, ny)
coordinates: Optional (X, Y) mesh grids
mode: Visualization mode:
- 'streamplot': Streamlines (default)
- 'quiver': Arrow plot
density: Streamline/arrow density
color: Optional scalar field for coloring, shape (nx, ny)
title: Plot title
ax: Matplotlib axes
Returns:
matplotlib Figure object
Example:
>>> # Visualize fluid velocity field
>>> u = jnp.cos(X) * jnp.sin(Y) # x-velocity
>>> v = -jnp.sin(X) * jnp.cos(Y) # y-velocity
>>> magnitude = jnp.sqrt(u**2 + v**2)
>>> fig = plot_vector_field(
... u, v,
... mode="streamplot",
... color=magnitude,
... title="Velocity Field"
... )
"""
plot_spectral_analysis¶
Visualize frequency-domain representation of fields.
from opifex.visualization import plot_spectral_analysis
def plot_spectral_analysis(
field: Array,
axis: int = -1,
title: str = "Spectral Analysis",
show_phase: bool = False,
log_scale: bool = True
) -> plt.Figure:
"""
Plot spectral (Fourier) analysis of field.
Args:
field: Input field array
axis: Axis along which to compute FFT
title: Plot title
show_phase: Whether to show phase plot
log_scale: Use logarithmic scale for magnitude
Returns:
matplotlib Figure object
Example:
>>> # Analyze frequency content of solution
>>> fig = plot_spectral_analysis(
... solution,
... axis=-1,
... title="Frequency Spectrum",
... log_scale=True
... )
"""
Animation¶
create_physics_animation¶
Create animated visualizations of time-dependent physics simulations.
from opifex.visualization import create_physics_animation
def create_physics_animation(
trajectory: Array,
times: Optional[Array] = None,
interval: int = 50,
cmap: str = "viridis",
title: str = "Physics Animation",
save_path: Optional[str] = None,
fps: int = 30,
writer: str = "pillow"
) -> animation.FuncAnimation:
"""
Create animation of time-dependent field evolution.
Args:
trajectory: Time-dependent field, shape (nt, nx, ny) or (nt, nx)
times: Time values, shape (nt,)
interval: Delay between frames in milliseconds
cmap: Colormap name
title: Animation title
save_path: If provided, save animation to this path
Supports .gif, .mp4, .avi formats
fps: Frames per second for saved video
writer: Animation writer ('pillow', 'ffmpeg', 'imagemagick')
Returns:
matplotlib FuncAnimation object
Example:
>>> # Create and save animation
>>> trajectory = burgers_evolution # Shape: (200, 256, 256)
>>> times = jnp.linspace(0, 2, 200)
>>> anim = create_physics_animation(
... trajectory,
... times=times,
... save_path="burgers_evolution.gif",
... fps=30,
... title="Burgers Equation"
... )
>>> # Display in Jupyter
>>> from IPython.display import HTML
>>> HTML(anim.to_html5_video())
"""
Performance Visualization¶
plot_flops_analysis¶
Visualize computational complexity (FLOPs) analysis.
from opifex.visualization import plot_flops_analysis
def plot_flops_analysis(
flops_data: Dict[str, int],
title: str = "FLOPS Analysis",
log_scale: bool = True,
show_breakdown: bool = True
) -> plt.Figure:
"""
Plot FLOPS analysis for model or computation.
Args:
flops_data: Dictionary mapping operation names to FLOP counts
Example: {'forward': 1e9, 'backward': 2e9, 'total': 3e9}
title: Plot title
log_scale: Use logarithmic scale
show_breakdown: Show breakdown by operation type
Returns:
matplotlib Figure object
Example:
>>> from opifex.training import FlopsCounter
>>> counter = FlopsCounter(model)
>>> flops = counter.count(sample_input)
>>> fig = plot_flops_analysis(
... flops,
... title="FNO Computational Cost"
... )
"""
plot_memory_usage¶
Visualize memory consumption over time or by component.
from opifex.visualization import plot_memory_usage
def plot_memory_usage(
memory_data: Array,
timestamps: Optional[Array] = None,
title: str = "Memory Usage",
show_peak: bool = True,
unit: str = "GB"
) -> plt.Figure:
"""
Plot memory usage over time.
Args:
memory_data: Memory usage values
timestamps: Time points (or iteration numbers)
title: Plot title
show_peak: Highlight peak memory usage
unit: Memory unit ('GB', 'MB', 'KB')
Returns:
matplotlib Figure object
Example:
>>> # Monitor memory during training
>>> from opifex.training import MemoryMonitor
>>> monitor = MemoryMonitor()
>>> # ... training loop ...
>>> fig = plot_memory_usage(
... monitor.memory_history,
... timestamps=monitor.timestamps,
... title="Training Memory Profile"
... )
"""
plot_model_complexity_comparison¶
Compare complexity metrics across multiple models.
from opifex.visualization import plot_model_complexity_comparison
def plot_model_complexity_comparison(
models: Dict[str, Any],
metrics: List[str] = ['params', 'flops', 'memory'],
normalize: bool = True
) -> plt.Figure:
"""
Compare computational complexity of multiple models.
Args:
models: Dictionary mapping model names to model objects
metrics: List of metrics to compare:
- 'params': Number of parameters
- 'flops': Floating point operations
- 'memory': Memory footprint
- 'inference_time': Inference latency
normalize: Normalize to smallest model
Returns:
matplotlib Figure object
Example:
>>> models = {
... 'FNO-Small': fno_small,
... 'FNO-Large': fno_large,
... 'DeepONet': deeponet,
... 'U-Net': unet
... }
>>> fig = plot_model_complexity_comparison(
... models,
... metrics=['params', 'flops', 'memory']
... )
"""
Integration Examples¶
Complete Workflow Example¶
import jax
import jax.numpy as jnp
from opifex.data.loaders import create_burgers_loader
from opifex.neural.operators.fno import FourierNeuralOperator
from opifex.training import BasicTrainer, TrainingConfig
from opifex.visualization import (
plot_field_comparison,
create_physics_animation,
)
# Setup data loader
train_loader = create_burgers_loader(
n_samples=1000,
batch_size=32,
resolution=256,
seed=42,
)
# Train model
model = FourierNeuralOperator(
in_channels=1,
out_channels=1,
hidden_channels=32,
modes=12,
num_layers=4,
rngs=nnx.Rngs(42),
)
config = TrainingConfig(num_epochs=100, learning_rate=1e-3)
trainer = BasicTrainer(model, config)
trained_model, history = trainer.train(train_loader)
# Compare predictions
test_sample = dataset[0]
prediction = model(test_sample['input'])
ground_truth = test_sample['output'][-1] # Final time
fig2 = plot_field_comparison(
[ground_truth, prediction],
titles=['Ground Truth', 'FNO Prediction'],
show_difference=True
)
fig2.savefig('comparison.pdf')
# Create animation
trajectory_pred = model.predict_trajectory(test_sample['input'], steps=100)
anim = create_physics_animation(
trajectory_pred,
save_path='prediction.gif',
title='FNO Prediction'
)
See Also¶
- Data API: Dataset classes and preprocessing
- Training API: Training infrastructure
- Neural API: Neural network architectures
- Examples: Complete usage examples