Problem Definition Guide¶
Overview¶
The Opifex framework provides a unified, extensible interface for defining scientific problems across multiple domains. This full system supports partial differential equations (PDEs), ordinary differential equations (ODEs), optimization problems, and quantum mechanical calculations, all built on JAX for high-performance computation and automatic differentiation.
The problem definition system is designed with modularity and extensibility in mind, allowing researchers to easily specify complex scientific problems while maintaining compatibility with the entire Opifex ecosystem of neural operators, physics-informed neural networks, and quantum neural networks.
Core Problem Types¶
1. Partial Differential Equations (PDEs)¶
PDEs form the backbone of many scientific simulations. The Opifex framework provides full support for defining and solving PDEs using both traditional numerical methods and neural approaches.
Basic PDE Problem Definition¶
from opifex.core.problems import PDEProblem
from opifex.core.conditions import DirichletBC, NeumannBC, InitialCondition
from opifex.geometry import Rectangle
import jax.numpy as jnp
class HeatEquationProblem(PDEProblem):
"""2D Heat equation with mixed boundary conditions."""
def __init__(self, diffusivity=0.01):
# Define geometry
geometry = Rectangle(center=jnp.array([0.5, 0.5]), width=1.0, height=1.0)
# Define boundary conditions
boundary_conditions = [
DirichletBC(boundary="left", value=0.0),
DirichletBC(boundary="right", value=1.0),
NeumannBC(boundary="top", value=0.0),
NeumannBC(boundary="bottom", value=0.0)
]
# Define initial condition
initial_conditions = [
InitialCondition(
value=lambda x: jnp.sin(jnp.pi * x[0]) * jnp.sin(jnp.pi * x[1]),
dimension=1,
name="u"
)
]
super().__init__(
geometry=geometry,
equation=self._heat_equation,
boundary_conditions=boundary_conditions,
initial_conditions=initial_conditions,
parameters={"diffusivity": diffusivity},
time_dependent=True
)
def residual(self, x, u, u_derivatives):
"""Compute PDE residual for physics-informed training."""
alpha = self.parameters["diffusivity"]
u_t = u_derivatives["t"]
u_xx = u_derivatives["xx"]
u_yy = u_derivatives["yy"]
return u_t - alpha * (u_xx + u_yy)
def _heat_equation(self, x, u, u_derivatives):
"""Heat equation: ∂u/∂t = α∇²u"""
return self.residual(x, u, u_derivatives)
# Create and use the problem
heat_problem = HeatEquationProblem(diffusivity=0.01)
print(f"Geometry: {heat_problem.get_geometry()}")
print(f"Parameters: {heat_problem.get_parameters()}")
Advanced PDE Examples¶
Navier-Stokes Equations¶
class NavierStokesProblem(PDEProblem):
"""2D incompressible Navier-Stokes equations."""
def __init__(self, reynolds_number=100):
geometry = Rectangle(center=jnp.array([1.0, 0.5]), width=2.0, height=1.0)
# No-slip boundary conditions on walls
boundary_conditions = [
DirichletBC(boundary="top", value=0.0),
DirichletBC(boundary="bottom", value=0.0),
DirichletBC(boundary="left", value=1.0), # inlet
NeumannBC(boundary="right", value=0.0) # outlet
]
super().__init__(
geometry=geometry,
equation=self._navier_stokes,
boundary_conditions=boundary_conditions,
parameters={"Re": reynolds_number},
time_dependent=True
)
def residual(self, x, u, u_derivatives):
"""Navier-Stokes residual: ∂u/∂t + u·∇u = -∇p + (1/Re)∇²u"""
Re = self.parameters["Re"]
u_vel, v_vel, pressure = u[..., 0], u[..., 1], u[..., 2]
# Velocity derivatives
u_t = u_derivatives["t"][..., 0]
v_t = u_derivatives["t"][..., 1]
u_x, u_y = u_derivatives["x"][..., 0], u_derivatives["y"][..., 0]
v_x, v_y = u_derivatives["x"][..., 1], u_derivatives["y"][..., 1]
u_xx, u_yy = u_derivatives["xx"][..., 0], u_derivatives["yy"][..., 0]
v_xx, v_yy = u_derivatives["xx"][..., 1], u_derivatives["yy"][..., 1]
# Pressure derivatives
p_x, p_y = u_derivatives["x"][..., 2], u_derivatives["y"][..., 2]
# Momentum equations
momentum_x = u_t + u_vel * u_x + v_vel * u_y + p_x - (1/Re) * (u_xx + u_yy)
momentum_y = v_t + u_vel * v_x + v_vel * v_y + p_y - (1/Re) * (v_xx + v_yy)
# Continuity equation
continuity = u_x + v_y
return jnp.stack([momentum_x, momentum_y, continuity], axis=-1)
Wave Equation with Source Terms¶
class WaveEquationProblem(PDEProblem):
"""2D wave equation with source terms."""
def __init__(self, wave_speed=1.0):
geometry = Rectangle(center=jnp.array([0.0, 0.0]), width=2.0, height=2.0)
# Absorbing boundary conditions
boundary_conditions = [
RobinBC(boundary="all", alpha=1.0, beta=wave_speed, gamma=0.0)
]
# Initial conditions: Gaussian pulse
initial_conditions = [
InitialCondition(
value=lambda x: jnp.exp(-(x[..., 0]**2 + x[..., 1]**2) / 0.1),
name="u",
derivative_order=0
),
InitialCondition(
value=0.0,
name="u_t",
derivative_order=1
)
]
super().__init__(
geometry=geometry,
equation=self._wave_equation,
boundary_conditions=boundary_conditions,
initial_conditions=initial_conditions,
parameters={"c": wave_speed}
)
def residual(self, x, u, u_derivatives):
"""Wave equation: ∂²u/∂t² = c²∇²u + f(x,y,t)"""
c = self.parameters["c"]
u_tt = u_derivatives["tt"]
u_xx = u_derivatives["xx"]
u_yy = u_derivatives["yy"]
# Source term (moving Gaussian)
x_pos, y_pos, t = x[..., 0], x[..., 1], x[..., 2]
source = jnp.exp(-((x_pos - 0.5*t)**2 + y_pos**2) / 0.05)
return u_tt - c**2 * (u_xx + u_yy) - source
2. Ordinary Differential Equations (ODEs)¶
The framework supports both initial value problems (IVPs) and boundary value problems (BVPs) with sophisticated parameter handling.
Basic ODE Systems¶
from opifex.core.problems import ODEProblem
import jax.numpy as jnp
class LorenzSystem(ODEProblem):
"""Chaotic Lorenz system."""
def __init__(self, sigma=10.0, rho=28.0, beta=8.0/3.0):
super().__init__(
time_span=(0.0, 20.0),
equation=self._lorenz_rhs,
initial_conditions={"u": jnp.array([1.0, 1.0, 1.0])},
parameters={"sigma": sigma, "rho": rho, "beta": beta}
)
def rhs(self, t, y):
"""Lorenz system: dx/dt = σ(y-x), dy/dt = x(ρ-z)-y, dz/dt = xy-βz"""
x, y_val, z = y
sigma, rho, beta = self.parameters["sigma"], self.parameters["rho"], self.parameters["beta"]
dxdt = sigma * (y_val - x)
dydt = x * (rho - z) - y_val
dzdt = x * y_val - beta * z
return jnp.array([dxdt, dydt, dzdt])
def _lorenz_rhs(self, t, y, params):
return self.rhs(t, y)
# Stiff ODE example
class VanDerPolOscillator(ODEProblem):
"""Van der Pol oscillator with adjustable stiffness."""
def __init__(self, mu=1.0):
super().__init__(
time_span=(0.0, 20.0),
equation=self._van_der_pol_rhs,
initial_conditions={"u": jnp.array([2.0, 0.0])},
parameters={"mu": mu}
)
def rhs(self, t, y):
"""Van der Pol: d²x/dt² - μ(1-x²)dx/dt + x = 0"""
x, v = y
mu = self.parameters["mu"]
dxdt = v
dvdt = mu * (1 - x**2) * v - x
return jnp.array([dxdt, dvdt])
Coupled ODE-PDE Systems¶
class ReactionDiffusionSystem(PDEProblem):
"""Coupled reaction-diffusion system with ODE kinetics."""
def __init__(self, D_u=1.0, D_v=0.5, reaction_params=None):
if reaction_params is None:
reaction_params = {"a": 1.0, "b": 3.0, "k": 1.0}
geometry = Rectangle(center=jnp.array([5.0, 5.0]), width=10.0, height=10.0)
# No-flux boundary conditions
boundary_conditions = [
NeumannBC(boundary="all", value=0.0)
]
super().__init__(
geometry=geometry,
equation=self._reaction_diffusion,
boundary_conditions=boundary_conditions,
parameters={"D_u": D_u, "D_v": D_v, **reaction_params}
)
def residual(self, x, u, u_derivatives):
"""Reaction-diffusion: ∂u/∂t = D∇²u + R(u,v)"""
D_u, D_v = self.parameters["D_u"], self.parameters["D_v"]
a, b, k = self.parameters["a"], self.parameters["b"], self.parameters["k"]
u_conc, v_conc = u[..., 0], u[..., 1]
u_t, v_t = u_derivatives["t"][..., 0], u_derivatives["t"][..., 1]
u_laplacian = u_derivatives["xx"][..., 0] + u_derivatives["yy"][..., 0]
v_laplacian = u_derivatives["xx"][..., 1] + u_derivatives["yy"][..., 1]
# Reaction terms (Schnakenberg kinetics)
reaction_u = a - u_conc + u_conc**2 * v_conc
reaction_v = b - u_conc**2 * v_conc
residual_u = u_t - D_u * u_laplacian - reaction_u
residual_v = v_t - D_v * v_laplacian - reaction_v
return jnp.stack([residual_u, residual_v], axis=-1)
3. Optimization Problems¶
The framework provides sophisticated optimization problem definitions with support for constraints, multi-objective optimization, and learn-to-optimize applications.
Constrained Optimization¶
from opifex.core.problems import OptimizationProblem
import jax
import jax.numpy as jnp
class ConstrainedQuadraticProblem(OptimizationProblem):
"""Quadratic programming with equality and inequality constraints."""
def __init__(self, Q, c, A_eq=None, b_eq=None, A_ineq=None, b_ineq=None):
dimension = Q.shape[0]
# Define constraint functions
constraints = []
if A_eq is not None:
constraints.extend([
lambda x, i=i: A_eq[i] @ x - b_eq[i]
for i in range(A_eq.shape[0])
])
if A_ineq is not None:
constraints.extend([
lambda x, i=i: A_ineq[i] @ x - b_ineq[i]
for i in range(A_ineq.shape[0])
])
super().__init__(
dimension=dimension,
bounds=[(-10.0, 10.0)] * dimension,
constraints=constraints,
parameters={
"Q": Q, "c": c,
"n_eq": A_eq.shape[0] if A_eq is not None else 0,
"n_ineq": A_ineq.shape[0] if A_ineq is not None else 0
}
)
self.Q = Q
self.c = c
def objective(self, x):
"""Quadratic objective: f(x) = 0.5 * x^T Q x + c^T x"""
return 0.5 * x.T @ self.Q @ x + self.c.T @ x
# Multi-objective optimization
class MultiObjectiveProblem(OptimizationProblem):
"""Multi-objective optimization problem."""
def __init__(self, objectives, weights=None):
self.objectives = objectives
self.weights = weights or jnp.ones(len(objectives))
super().__init__(
dimension=2, # Example: 2D problem
bounds=[(-5.0, 5.0), (-5.0, 5.0)],
parameters={"n_objectives": len(objectives)}
)
def objective(self, x):
"""Weighted sum of objectives."""
values = jnp.array([obj(x) for obj in self.objectives])
return jnp.sum(self.weights * values)
def pareto_objectives(self, x):
"""Return all objective values for Pareto analysis."""
return jnp.array([obj(x) for obj in self.objectives])
# Example usage
def rosenbrock(x):
return 100 * (x[1] - x[0]**2)**2 + (1 - x[0])**2
def sphere(x):
return jnp.sum(x**2)
multi_obj = MultiObjectiveProblem([rosenbrock, sphere], weights=jnp.array([0.7, 0.3]))
4. Quantum Mechanical Problems¶
The framework includes first-class support for quantum mechanical calculations, including electronic structure problems and molecular dynamics.
Electronic Structure Problems¶
from opifex.core.problems import QuantumProblem
from opifex.core.quantum.molecular_system import create_molecular_system
class DFTProblem(QuantumProblem):
"""Density Functional Theory problem for molecular systems."""
def __init__(self, atoms, positions, charge=0, multiplicity=1):
# Create molecular system
molecular_system = create_molecular_system(
atoms=atoms,
positions=positions,
charge=charge,
multiplicity=multiplicity
)
super().__init__(
molecular_system=molecular_system,
method="neural_dft",
convergence_threshold=1e-8,
parameters={
"exchange_functional": "PBE",
"correlation_functional": "PBE",
"basis_set": "def2-TZVP",
"grid_density": "fine"
}
)
def compute_energy(self, density=None):
"""Compute total electronic energy."""
if density is None:
# Use self-consistent field density
density = self._scf_density()
# Kinetic energy
T = self._kinetic_energy(density)
# External potential energy (electron-nuclear)
V_ext = self._external_potential_energy(density)
# Hartree energy (electron-electron repulsion)
V_H = self._hartree_energy(density)
# Exchange-correlation energy
E_xc = self._exchange_correlation_energy(density)
# Nuclear repulsion energy
V_nn = self._nuclear_repulsion_energy()
return T + V_ext + V_H + E_xc + V_nn
def compute_forces(self, density=None):
"""Compute forces on nuclei using automatic differentiation."""
energy_fn = lambda positions: self._energy_at_positions(positions, density)
forces = -jax.grad(energy_fn)(self.molecular_system.positions)
return forces
# Quantum dynamics problem
class QuantumDynamicsProblem(QuantumProblem):
"""Time-dependent Schrödinger equation."""
def __init__(self, hamiltonian, initial_wavefunction, time_span=(0.0, 1.0)):
# Create a minimal molecular system for the interface
molecular_system = create_molecular_system(
atoms=["H"],
positions=jnp.array([[0.0, 0.0, 0.0]]),
charge=0
)
super().__init__(
molecular_system=molecular_system,
method="time_dependent_dft",
parameters={
"hamiltonian": hamiltonian,
"initial_wavefunction": initial_wavefunction,
"time_span": time_span
}
)
def time_evolution(self, t, psi):
"""Time-dependent Schrödinger equation: iℏ ∂ψ/∂t = Ĥψ"""
H = self.parameters["hamiltonian"]
hbar = 1.0 # Atomic units
return -1j / hbar * H @ psi
Advanced Boundary Conditions¶
Classical Boundary Conditions¶
The Opifex framework provides full support for all standard boundary condition types with advanced features like time-dependence and spatial variation.
Dirichlet Conditions¶
Dirichlet boundary conditions specify function values at boundaries. They are essential for problems where the solution value is known or constrained at domain boundaries.
from opifex.core.conditions import DirichletBC
import jax.numpy as jnp
# Simple constant Dirichlet condition
constant_bc = DirichletBC(
boundary="left",
value=1.0
)
# Time-dependent Dirichlet condition
time_varying_bc = DirichletBC(
boundary="right",
value=lambda x, t: jnp.sin(2 * jnp.pi * t) * jnp.exp(-x[0]**2),
time_dependent=True
)
# Spatially-varying Dirichlet condition (callable value)
spatial_bc = DirichletBC(
boundary="top",
value=lambda x: x[0]**2 + x[1]**2
)
print("Dirichlet boundary conditions configured for various scenarios")
Neumann Conditions¶
Neumann boundary conditions specify derivative (flux) values at boundaries, commonly used for heat flux, mass flux, or stress conditions.
from opifex.core.conditions import NeumannBC
# Constant flux condition
constant_flux = NeumannBC(
boundary="top",
value=0.1 # Heat flux
)
# Zero flux (insulation) condition
no_flux = NeumannBC(
boundary="bottom",
value=0.0
)
# Spatially-varying flux (pass a callable as value)
def parabolic_flux(x):
"""Parabolic flux profile."""
return -0.1 * x[0] * (1 - x[0]) # Maximum at center, zero at edges
varying_flux = NeumannBC(
boundary="right",
value=parabolic_flux
)
print("Neumann boundary conditions configured for flux problems")
Robin Conditions¶
Robin (mixed) boundary conditions combine function values and derivatives, commonly used for convective heat transfer and radiation problems.
from opifex.core.conditions import RobinBC
# Convective heat transfer: h(T - T_ambient) + k(dT/dn) = 0
convective_bc = RobinBC(
boundary="surface",
alpha=1.0, # Coefficient of u (temperature)
beta=0.1, # Coefficient of ∂u/∂n (heat conduction)
gamma=20.0 # External condition (ambient temperature)
)
# Time-varying ambient condition
def ambient_temperature(x, y, t):
"""Daily temperature variation."""
return 20.0 + 10.0 * jnp.sin(2 * jnp.pi * t / 24.0) # 24-hour cycle
time_varying_robin = RobinBC(
boundary="exterior",
alpha=1.0,
beta=0.05,
gamma=ambient_temperature,
time_dependent=True
)
print("Robin boundary conditions configured for heat transfer problems")
Periodic Domains¶
For problems with inherent periodicity, use PeriodicCell from opifex.geometry to define the periodic domain geometry rather than periodic boundary conditions. The periodicity is handled at the geometry level:
from opifex.geometry import PeriodicCell
# Define a periodic unit cell via lattice vectors
lattice_vectors = jnp.array([
[5.0, 0.0, 0.0],
[0.0, 5.0, 0.0],
[0.0, 0.0, 5.0]
])
periodic_cell = PeriodicCell(lattice_vectors=lattice_vectors)
# Wrap coordinates into the unit cell
positions = jnp.array([[6.0, 2.0, 3.0], [-1.0, 7.0, 0.5]])
wrapped = periodic_cell.wrap_coordinates(positions)
print(f"Unit cell volume: {periodic_cell.volume:.4f}")
Domain Specification and Geometry¶
Geometric Domains¶
The Opifex framework provides sophisticated domain specification capabilities, from simple geometric shapes to complex multi-physics domains.
Basic Geometric Shapes¶
from opifex.geometry import Rectangle, Circle, Polygon
import jax.numpy as jnp
# 2D Rectangular domain
rectangle = Rectangle(
center=jnp.array([1.0, 0.5]),
width=2.0,
height=1.0
)
# Circular domain
circle = Circle(
center=jnp.array([0.0, 0.0]),
radius=1.0
)
# Polygonal domain (airfoil shape)
airfoil_vertices = jnp.array([
[1.0, 0.0], # Trailing edge
[0.8, 0.1], # Upper surface
[0.4, 0.15],
[0.0, 0.05], # Leading edge
[0.4, -0.1], # Lower surface
[0.8, -0.05]
])
airfoil = Polygon(vertices=airfoil_vertices)
print("Basic geometric domains configured")
Complex Geometric Operations¶
from opifex.geometry import union, intersection, difference
# Complex domain using CSG operations
outer_circle = Circle(center=jnp.array([0.0, 0.0]), radius=2.0)
inner_circle = Circle(center=jnp.array([0.0, 0.0]), radius=0.5)
rectangular_slot = Rectangle(center=jnp.array([0.0, 0.0]), width=0.4, height=6.0)
# Annular domain with rectangular slot
annular_region = difference(outer_circle, inner_circle)
slotted_annulus = difference(annular_region, rectangular_slot)
# Multi-hole geometry for heat transfer
base_plate = Rectangle(center=jnp.array([0.0, 0.0]), width=4.0, height=2.0)
holes = [
Circle(center=jnp.array([-1.0, 0.0]), radius=0.2),
Circle(center=jnp.array([0.0, 0.0]), radius=0.2),
Circle(center=jnp.array([1.0, 0.0]), radius=0.2)
]
perforated_plate = base_plate
for hole in holes:
perforated_plate = difference(perforated_plate, hole)
print("Complex CSG domains created")
Adaptive and Multi-Resolution Domains¶
class AdaptiveDomain:
"""Domain with adaptive mesh refinement capabilities."""
def __init__(self, base_geometry, initial_resolution=32):
self.base_geometry = base_geometry
self.resolution = initial_resolution
self.refinement_levels = []
def create_initial_mesh(self):
"""Create initial uniform mesh."""
bounds = self.base_geometry.bounding_box()
x_min, x_max = bounds[0]
y_min, y_max = bounds[1]
x = jnp.linspace(x_min, x_max, self.resolution)
y = jnp.linspace(y_min, y_max, self.resolution)
X, Y = jnp.meshgrid(x, y, indexing='ij')
points = jnp.stack([X.flatten(), Y.flatten()], axis=1)
# Keep only points inside geometry
inside_mask = self.base_geometry.contains(points)
return points[inside_mask]
def refine_mesh(self, solution, error_threshold=1e-3):
"""Adaptive mesh refinement based on solution gradients."""
gradients = jnp.gradient(solution)
error_indicator = jnp.linalg.norm(gradients, axis=0)
# Mark elements for refinement
refine_mask = error_indicator > error_threshold
if jnp.any(refine_mask):
refined_points = self._local_refinement(refine_mask)
self.refinement_levels.append(refined_points)
return True
return False
print("Adaptive domains implemented")
Graph Domains¶
For problems on irregular structures, networks, and discrete systems, the framework supports graph-based domains.
Network Structures¶
from opifex.geometry.topology import GraphTopology
import jax.numpy as jnp
# Create molecular graph domain
def create_molecular_graph_domain(positions, atomic_numbers, cutoff_radius=3.0):
"""Create graph domain for molecular systems."""
n_atoms = len(positions)
# Compute pairwise distances
distances = jnp.linalg.norm(
positions[:, None, :] - positions[None, :, :], axis=2
)
# Create edges for atoms within cutoff
edge_mask = (distances < cutoff_radius) & (distances > 0)
edge_indices = jnp.where(edge_mask)
# Node features (atomic properties)
node_features = jnp.column_stack([
atomic_numbers.astype(float), # Atomic number
jnp.linalg.norm(positions, axis=1), # Distance from origin
jnp.sum(edge_mask, axis=1).astype(float) # Coordination number
])
# Edge features (bond properties)
edge_distances = distances[edge_mask]
edge_vectors = positions[edge_indices[1]] - positions[edge_indices[0]]
edge_features = jnp.column_stack([
edge_distances[:, None],
edge_vectors,
jnp.exp(-edge_distances[:, None]) # Exponential decay
])
return GraphTopology(
nodes=node_features,
edges=jnp.stack(edge_indices, axis=1),
edge_features=edge_features,
domain_type="molecular"
)
print("Graph domains created for molecular systems")
Irregular Connectivity Patterns¶
# Irregular mesh connectivity
class IrregularMeshDomain:
"""Domain with irregular mesh connectivity."""
def __init__(self, vertices, elements, boundary_markers=None):
self.vertices = vertices
self.elements = elements # Connectivity matrix
self.boundary_markers = boundary_markers or {}
# Compute mesh properties
self.adjacency_matrix = self._compute_adjacency()
self.element_areas = self._compute_element_areas()
def _compute_adjacency(self):
"""Compute vertex adjacency matrix."""
n_vertices = len(self.vertices)
adjacency = jnp.zeros((n_vertices, n_vertices))
for element in self.elements:
# Connect all vertices in each element
for i in range(len(element)):
for j in range(i+1, len(element)):
v1, v2 = element[i], element[j]
adjacency = adjacency.at[v1, v2].set(1)
adjacency = adjacency.at[v2, v1].set(1)
return adjacency
def get_boundary_vertices(self, marker=None):
"""Get vertices on specified boundary."""
if marker is None:
# Return all boundary vertices
boundary_vertices = set()
for marker_vertices in self.boundary_markers.values():
boundary_vertices.update(marker_vertices)
return list(boundary_vertices)
else:
return self.boundary_markers.get(marker, [])
print("Irregular connectivity patterns implemented")
Dynamic Graphs¶
# Time-evolving graph domain
class DynamicGraphDomain:
"""Graph domain that evolves over time."""
def __init__(self, initial_graph, evolution_rules):
self.current_graph = initial_graph
self.evolution_rules = evolution_rules
self.time_history = [initial_graph]
def evolve(self, dt, current_time):
"""Evolve graph structure based on rules."""
new_graph = self.current_graph.copy()
# Apply evolution rules
for rule in self.evolution_rules:
new_graph = rule.apply(new_graph, dt, current_time)
self.current_graph = new_graph
self.time_history.append(new_graph)
return new_graph
def get_graph_at_time(self, time_index):
"""Get graph state at specific time."""
return self.time_history[time_index]
print("Dynamic graph domains implemented")
Best Practices and Guidelines¶
Problem Definition Checklist¶
-
Domain Specification
- Ensure domain bounds are physically meaningful
- Check for proper boundary condition coverage
- Validate initial conditions for time-dependent problems
-
Parameter Validation
- Implement parameter bounds checking
- Use dimensionally consistent units
- Document parameter physical meanings
-
Numerical Stability
- Consider CFL conditions for time-dependent problems
- Implement adaptive time stepping when needed
- Use appropriate boundary condition types
-
Testing and Validation
- Implement analytical solution comparisons when available
- Use method of manufactured solutions for verification
- Perform convergence studies
Performance Optimization¶
# Use JAX transformations for performance
@jax.jit
def optimized_residual_computation(problem, x, u, u_derivatives):
"""JIT-compiled residual computation."""
return problem.residual(x, u, u_derivatives)
# Vectorized parameter studies
@jax.vmap
def solve_parameter_sweep(problem_params):
"""Vectorized solution over parameter space."""
problem = create_problem_with_params(problem_params)
return solve_problem(problem)
# Memory-efficient large-scale problems
def chunked_problem_solve(problem, chunk_size=1000):
"""Solve large problems in chunks to manage memory."""
domain_points = problem.generate_domain_points()
n_chunks = len(domain_points) // chunk_size
solutions = []
for i in range(n_chunks):
chunk = domain_points[i*chunk_size:(i+1)*chunk_size]
chunk_solution = solve_chunk(problem, chunk)
solutions.append(chunk_solution)
return jnp.concatenate(solutions)
This complete guide provides the foundation for defining and working with scientific problems in the Opifex framework. The unified interface allows seamless integration with neural networks, traditional solvers, and advanced optimization techniques while maintaining the flexibility needed for modern scientific machine learning research.