Skip to content

Architecture Guide

import jax
import jax.numpy as jnp

Overview

Opifex follows a modular, extensible architecture built on JAX for high-performance scientific computing.

Core Modules

Problems (opifex.core.problems)

  • Abstract problem definitions
  • PDE, ODE, optimization problems
  • Boundary conditions and domains

Neural Networks (opifex.neural)

  • Specialized architectures (PINNs, Neural Operators)
  • Bayesian networks for uncertainty
  • Quantum chemistry networks

Training (opifex.training)

  • Physics-informed training
  • Multi-objective optimization
  • Adaptive learning strategies

Geometry (opifex.geometry)

  • Domain representations
  • Mesh generation
  • Coordinate transformations

Design Principles

Functional Programming

# Pure functions with JAX transformations
@jax.jit
def physics_loss(model, batch):
    predictions = model(batch.inputs)
    residuals = compute_pde_residuals(predictions, batch)
    return jnp.mean(residuals**2)

Immutable Data Structures

# Using dataclasses for configuration
@dataclass(frozen=True)
class TrainingConfig:
    learning_rate: float = 1e-3
    batch_size: int = 32
    epochs: int = 1000

Composability

# Composable components
problem = PDEProblem(...)
model = PINN(...)
trainer = Trainer(model, config)
model, history = trainer.fit(train_data)

Module Dependencies

graph TD
    A[opifex.core] --> B[opifex.neural]
    A --> C[opifex.training]
    A --> D[opifex.geometry]
    B --> C
    D --> B
    C --> E[opifex.optimization]
    F[opifex.benchmarking] --> A
    F --> B
    F --> C

Extension Points

Custom Neural Networks

class CustomPINN(nnx.Module):
    """Custom physics-informed architecture."""

    def __call__(self, x):
        # Implementation
        return output

Custom Physics Losses

def custom_physics_loss(predictions, inputs):
    """Problem-specific physics constraints."""
    # Implementation
    return loss

Custom Optimizers

class CustomOptimizer:
    """Domain-specific optimization algorithm."""

    def step(self, params, grads):
        # Implementation
        return updated_params

Performance Considerations

Memory Layout

  • Column-major arrays for linear algebra
  • Contiguous memory access patterns
  • Efficient tensor operations

Compilation Strategy

  • JIT compilation for hot paths
  • Static shapes for XLA optimization
  • Minimal Python overhead

Scaling

  • Single-device optimization
  • Multi-device parallelism (pmap)
  • Distributed training (experimental)