Geometry & Computational Domains Guide¶
Overview¶
The Opifex geometry framework provides full geometric modeling capabilities for scientific machine learning applications. Built on JAX for high-performance computation, it supports 2D/3D domain handling, constructive solid geometry (CSG) operations, Lie groups, Riemannian manifolds, graph neural networks, and molecular geometry modeling.
This system is designed to handle complex geometric problems in scientific computing, from simple rectangular domains to advanced manifold-based neural operators and molecular systems with quantum mechanical constraints.
Core Geometric Primitives¶
2D Basic Shapes¶
The framework provides fundamental 2D shapes with full geometric operations:
import jax
import jax.numpy as jnp
from opifex.geometry import Rectangle, Circle, Polygon
# Rectangle with center and dimensions
rect = Rectangle(
center=jnp.array([0.0, 0.0]),
width=2.0,
height=1.5
)
# Circle with center and radius
circle = Circle(
center=jnp.array([1.0, 0.5]),
radius=0.8
)
# Polygon from vertices (counterclockwise ordering)
vertices = jnp.array([
[-1.0, -1.0],
[1.0, -1.0],
[0.5, 1.0],
[-0.5, 1.0]
])
polygon = Polygon(vertices=vertices)
# Basic geometric properties
rect_area = rect.width * rect.height
circle_area = jnp.pi * circle.radius**2
print(f"Rectangle area: {rect_area:.4f}")
print(f"Circle area: {circle_area:.4f}")
Point Containment and Distance Functions¶
All shapes support efficient point containment testing and signed distance functions:
# Test points for containment
test_points = jnp.array([
[0.0, 0.0], # Center of rectangle
[1.0, 0.5], # Center of circle
[2.0, 2.0], # Outside both
[0.5, 0.25] # Potential intersection
])
# Point containment (vectorized)
rect_contains = rect.contains(test_points)
circle_contains = circle.contains(test_points)
# Signed distance functions
rect_distances = jnp.array([rect.distance(pt) for pt in test_points])
circle_distances = jnp.array([circle.distance(pt) for pt in test_points])
print("Point containment and distances:")
for i, point in enumerate(test_points):
print(f"Point {point}:")
print(f" Rectangle: contains={rect_contains[i]}, distance={rect_distances[i]:.3f}")
print(f" Circle: contains={circle_contains[i]}, distance={circle_distances[i]:.3f}")
Boundary Sampling and Normal Computation¶
# Sample points on shape boundaries
key = jax.random.PRNGKey(42)
rect_boundary = rect.sample_boundary(n_points=50, key=key)
circle_boundary = circle.sample_boundary(n_points=50, key=key)
# Compute outward normals at boundary points
rect_normals = jnp.array([rect.compute_normal(pt) for pt in rect_boundary])
circle_normals = jnp.array([circle.compute_normal(pt) for pt in circle_boundary])
print(f"Sampled {len(rect_boundary)} rectangle boundary points")
print(f"Sampled {len(circle_boundary)} circle boundary points")
print(f"Normal vectors computed for boundary analysis")
Constructive Solid Geometry (CSG)¶
CSG operations enable complex shape creation through boolean operations:
Basic CSG Operations¶
from opifex.geometry import union, intersection, difference
# Create base shapes
base_rect = Rectangle(center=jnp.array([0.0, 0.0]), width=2.0, height=2.0)
cutout_circle = Circle(center=jnp.array([0.5, 0.5]), radius=0.6)
# Boolean operations
union_shape = union(base_rect, cutout_circle) # A ∪ B
intersection_shape = intersection(base_rect, cutout_circle) # A ∩ B
difference_shape = difference(base_rect, cutout_circle) # A - B
# Test complex shape properties
test_point = jnp.array([0.3, 0.3])
print(f"Point {test_point} containment:")
print(f" Union: {union_shape.contains(test_point)}")
print(f" Intersection: {intersection_shape.contains(test_point)}")
print(f" Difference: {difference_shape.contains(test_point)}")
Advanced CSG Compositions¶
# Create complex geometries through composition
outer_boundary = Circle(center=jnp.array([0.0, 0.0]), radius=2.0)
inner_hole = Circle(center=jnp.array([0.0, 0.0]), radius=0.8)
rectangular_slot = Rectangle(center=jnp.array([0.0, 0.0]), width=0.4, height=3.0)
# Annular region with rectangular slot
annular_region = difference(outer_boundary, inner_hole)
slotted_annulus = difference(annular_region, rectangular_slot)
# Multi-hole geometry
holes = [
Circle(center=jnp.array([0.8, 0.8]), radius=0.2),
Circle(center=jnp.array([-0.8, 0.8]), radius=0.2),
Circle(center=jnp.array([0.8, -0.8]), radius=0.2),
Circle(center=jnp.array([-0.8, -0.8]), radius=0.2)
]
multi_hole_plate = base_rect
for hole in holes:
multi_hole_plate = difference(multi_hole_plate, hole)
print("Complex CSG geometries created successfully")
Smooth CSG with SDF Operations¶
The framework uses signed distance functions (SDFs) for smooth, differentiable CSG operations:
from opifex.geometry import union, intersection, difference
# Use the public CSG functions for smooth SDF-based operations
def smooth_union_distance(point, shape1, shape2):
"""Union via the public CSG API (SDF-based internally)."""
combined = union(shape1, shape2)
return combined.distance(point)
def smooth_intersection_distance(point, shape1, shape2):
"""Intersection via the public CSG API (SDF-based internally)."""
combined = intersection(shape1, shape2)
return combined.distance(point)
# Example: Blending between shapes
blend_point = jnp.array([0.5, 0.0])
smooth_dist = smooth_union_distance(blend_point, base_rect, cutout_circle)
print(f"Union distance at {blend_point}: {smooth_dist:.4f}")
Molecular Geometry and 3D Systems¶
Molecular System Definition¶
from opifex.geometry import MolecularGeometry
# Define a water molecule (H2O) in atomic units
water_positions = jnp.array([
[0.0000, 0.0000, 0.1173], # Oxygen
[0.0000, 0.7572, -0.4692], # Hydrogen 1
[0.0000, -0.7572, -0.4692] # Hydrogen 2
])
water_molecule = MolecularGeometry(
atomic_symbols=["O", "H", "H"],
positions=water_positions
)
print(f"Water molecule properties:")
print(f" Number of atoms: {water_molecule.n_atoms}")
print(f" Pairwise distances: {water_molecule.compute_distances()}")
Periodic Systems and Crystal Structures¶
from opifex.geometry import PeriodicCell
# Define a cubic unit cell
lattice_vectors = jnp.array([
[5.0, 0.0, 0.0], # a vector
[0.0, 5.0, 0.0], # b vector
[0.0, 0.0, 5.0] # c vector
])
# Create periodic cell (takes only lattice vectors)
unit_cell = PeriodicCell(lattice_vectors=lattice_vectors)
# Wrap coordinates into the unit cell
positions = jnp.array([
[0.0, 0.0, 0.0], # Atom at origin
[2.5, 2.5, 2.5] # Atom at body center
])
wrapped = unit_cell.wrap_coordinates(positions)
print(f"Unit cell volume: {unit_cell.volume:.4f}")
print(f"Wrapped positions: {wrapped}")
Molecular Exclusion Domains¶
from opifex.geometry import create_computational_domain_with_molecular_exclusion
# Use the MolecularGeometry directly for exclusion domain
box_size = 10.0 # Atomic units
computational_domain = create_computational_domain_with_molecular_exclusion(
molecular_geometry=water_molecule,
box_dimensions=jnp.array([box_size, box_size, box_size]),
exclusion_radius=2.0, # Exclude within 2 a.u. of atoms
buffer_zone=1.0 # Additional buffer for numerical stability
)
print("Molecular exclusion domain created for quantum calculations")
Advanced Manifolds and Differential Geometry¶
Riemannian Manifolds¶
from opifex.geometry.manifolds import SphericalManifold, TangentSpace
# Create spherical manifold for geometric deep learning
sphere_manifold = SphericalManifold(dim=2) # 2-sphere (surface of 3D ball)
# Sample points on the manifold
key = jax.random.PRNGKey(123)
manifold_points = sphere_manifold.sample_uniform(n_points=100, key=key)
# Compute tangent spaces at sampled points
tangent_spaces = [
TangentSpace(manifold=sphere_manifold, base_point=point)
for point in manifold_points[:5] # First 5 points
]
# Manifold operations
def parallel_transport_vector(manifold, vector, start_point, end_point):
"""Parallel transport a vector along the manifold."""
# Simplified parallel transport for sphere
# In practice, this would use proper Riemannian geometry
return vector - jnp.dot(vector, end_point) * end_point
# Example: Transport vectors between points
start_point = manifold_points[0]
end_point = manifold_points[1]
tangent_vector = jnp.array([0.1, 0.2, 0.0]) # Tangent at start_point
transported_vector = parallel_transport_vector(
sphere_manifold, tangent_vector, start_point, end_point
)
print(f"Manifold points sampled: {len(manifold_points)}")
print(f"Tangent spaces computed: {len(tangent_spaces)}")
print(f"Vector transport completed")
Lie Groups and Algebraic Structures¶
from opifex.geometry.algebra import SO3Group, SE3Group
# Special Orthogonal Group SO(3) - 3D rotations
so3_group = SO3Group()
# Generate random rotation matrices
rotation_key = jax.random.PRNGKey(456)
random_rotations = so3_group.sample_uniform(n_samples=10, key=rotation_key)
# Compose rotations (group operation)
R1 = random_rotations[0]
R2 = random_rotations[1]
composed_rotation = so3_group.compose(R1, R2)
# Compute group inverse
R1_inverse = so3_group.inverse(R1)
# Verify group properties
identity_check = so3_group.compose(R1, R1_inverse)
print(f"Group identity verification (should be close to I):")
print(f"Max deviation from identity: {jnp.max(jnp.abs(identity_check - jnp.eye(3))):.6f}")
# Special Euclidean Group SE(3) - 3D rigid transformations
se3_group = SE3Group()
# Create transformation matrices (rotation + translation)
translation = jnp.array([1.0, 2.0, 3.0])
transformation = se3_group.from_rotation_translation(R1, translation)
# Apply transformation to points
points_3d = jnp.array([
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]
])
transformed_points = se3_group.apply_transformation(transformation, points_3d)
print(f"Applied SE(3) transformation to {len(points_3d)} points")
Graph Neural Networks and Topology¶
Graph Structures for Scientific Computing¶
from opifex.geometry.topology import GraphTopology
# Create graph from molecular structure
def create_molecular_graph(positions, atomic_numbers, cutoff_radius=3.0):
"""Create molecular graph with distance-based edges."""
n_atoms = len(positions)
# Compute pairwise distances
distances = jnp.linalg.norm(
positions[:, None, :] - positions[None, :, :], axis=2
)
# Create edges for atoms within cutoff
edge_mask = (distances < cutoff_radius) & (distances > 0)
edge_indices = jnp.where(edge_mask)
# Edge features (distances and relative positions)
edge_distances = distances[edge_mask]
edge_vectors = (
positions[edge_indices[1]] - positions[edge_indices[0]]
)
return GraphTopology(
nodes=atomic_numbers.astype(float), # Node features: atomic numbers
edges=jnp.stack(edge_indices, axis=1), # Edge connectivity
edge_features=jnp.column_stack([
edge_distances[:, None],
edge_vectors
])
)
# Create molecular graph for water
atomic_numbers = jnp.array([8, 1, 1]) # O, H, H
molecular_graph = create_molecular_graph(
water_positions, atomic_numbers, cutoff_radius=2.0
)
print(f"Molecular graph created:")
print(f" Nodes: {molecular_graph.nodes.shape}")
print(f" Edges: {molecular_graph.edges.shape}")
print(f" Edge features: {molecular_graph.edge_features.shape}")
Graph Neural Operators¶
from opifex.neural.operators.graph import GraphNeuralOperator
from flax import nnx
rngs = nnx.Rngs(jax.random.PRNGKey(0))
# Create graph neural operator for molecular property prediction
graph_operator = GraphNeuralOperator(
node_dim=molecular_graph.nodes.shape[-1],
hidden_dim=64,
num_layers=3,
edge_dim=molecular_graph.edge_features.shape[-1],
rngs=rngs
)
print("Graph neural operator initialized for molecular ML")
Topological Spaces and Simplicial Complexes¶
from opifex.geometry.topology import SimplicialComplex, TopologicalSpace
# Create simplicial complex for topological data analysis
vertices = jnp.array([
[0.0, 0.0], [1.0, 0.0], [0.5, 1.0], # Triangle vertices
[1.5, 0.5], [2.0, 1.0] # Additional vertices
])
# Define simplices (0-simplices: vertices, 1-simplices: edges, 2-simplices: faces)
simplices = {
0: jnp.arange(len(vertices)), # All vertices
1: jnp.array([[0, 1], [1, 2], [2, 0], [1, 3], [3, 4]]), # Edges
2: jnp.array([[0, 1, 2]]) # Triangle face
}
simplicial_complex = SimplicialComplex(
vertices=vertices,
simplices=simplices
)
# Compute topological properties
betti_numbers = simplicial_complex.compute_betti_numbers()
euler_characteristic = simplicial_complex.euler_characteristic()
print(f"Topological analysis:")
print(f" Betti numbers: {betti_numbers}")
print(f" Euler characteristic: {euler_characteristic}")
Domain Discretization and Mesh Generation¶
Structured Grid Generation¶
def create_structured_grid(domain_bounds, resolution):
"""Create structured Cartesian grid."""
x_min, x_max = domain_bounds[0]
y_min, y_max = domain_bounds[1]
x = jnp.linspace(x_min, x_max, resolution[0])
y = jnp.linspace(y_min, y_max, resolution[1])
X, Y = jnp.meshgrid(x, y, indexing='ij')
grid_points = jnp.stack([X.flatten(), Y.flatten()], axis=1)
return grid_points, (X, Y)
# Create grid for rectangular domain
domain_bounds = [(-1.0, 1.0), (-1.0, 1.0)]
resolution = [64, 64]
grid_points, (X, Y) = create_structured_grid(domain_bounds, resolution)
print(f"Structured grid created: {grid_points.shape[0]} points")
Adaptive Mesh Refinement¶
def adaptive_refinement(geometry, initial_resolution=32, max_levels=3):
"""Adaptive mesh refinement based on geometry complexity."""
def refinement_criterion(points):
"""Refine near boundaries and complex regions."""
distances = jnp.array([geometry.distance(pt) for pt in points])
return jnp.abs(distances) < 0.1 # Refine near boundaries
# Start with coarse grid
current_points, _ = create_structured_grid(
[(-2.0, 2.0), (-2.0, 2.0)], [initial_resolution, initial_resolution]
)
refined_points = []
for level in range(max_levels):
# Identify points needing refinement
refine_mask = refinement_criterion(current_points)
# Keep non-refined points
refined_points.extend(current_points[~refine_mask])
# Refine marked regions
if jnp.any(refine_mask):
refine_centers = current_points[refine_mask]
# Add finer points around each center
for center in refine_centers:
local_spacing = 2.0 / (initial_resolution * (2 ** (level + 1)))
local_points = center + local_spacing * jnp.array([
[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]
])
refined_points.extend(local_points)
return jnp.array(refined_points)
# Apply adaptive refinement to complex geometry
refined_mesh = adaptive_refinement(slotted_annulus, initial_resolution=16, max_levels=2)
print(f"Adaptive mesh created: {len(refined_mesh)} points")
Unstructured Mesh Generation¶
def delaunay_triangulation_2d(points):
"""Simple Delaunay triangulation for 2D points."""
# This is a simplified version - in practice, use scipy.spatial.Delaunay
# or specialized mesh generation libraries
from scipy.spatial import Delaunay
import numpy as np
# Convert JAX arrays to numpy for scipy
points_np = np.array(points)
tri = Delaunay(points_np)
# Convert back to JAX arrays
triangles = jnp.array(tri.simplices)
return triangles
def generate_boundary_conforming_mesh(geometry, target_edge_length=0.1):
"""Generate mesh that conforms to geometry boundaries."""
# Sample boundary points
key = jax.random.PRNGKey(789)
boundary_points = geometry.sample_boundary(
n_points=int(2 * jnp.pi / target_edge_length), key=key
)
# Add interior points
bbox_min = jnp.min(boundary_points, axis=0) - 0.5
bbox_max = jnp.max(boundary_points, axis=0) + 0.5
# Generate candidate interior points
n_interior = 1000
interior_candidates = jax.random.uniform(
key, (n_interior, 2), minval=bbox_min, maxval=bbox_max
)
# Keep only points inside geometry
inside_mask = geometry.contains(interior_candidates)
interior_points = interior_candidates[inside_mask]
# Combine boundary and interior points
all_points = jnp.vstack([boundary_points, interior_points])
# Generate triangulation
triangles = delaunay_triangulation_2d(all_points)
return all_points, triangles
# Generate mesh for complex geometry
mesh_points, mesh_triangles = generate_boundary_conforming_mesh(
slotted_annulus, target_edge_length=0.05
)
print(f"Unstructured mesh generated:")
print(f" Vertices: {len(mesh_points)}")
print(f" Triangles: {len(mesh_triangles)}")
Coordinate Systems and Transformations¶
Coordinate System Transformations¶
def cartesian_to_polar(x, y):
"""Convert Cartesian to polar coordinates."""
r = jnp.sqrt(x**2 + y**2)
theta = jnp.arctan2(y, x)
return r, theta
def polar_to_cartesian(r, theta):
"""Convert polar to Cartesian coordinates."""
x = r * jnp.cos(theta)
y = r * jnp.sin(theta)
return x, y
def cartesian_to_spherical(x, y, z):
"""Convert Cartesian to spherical coordinates."""
r = jnp.sqrt(x**2 + y**2 + z**2)
theta = jnp.arccos(z / (r + 1e-10)) # Polar angle
phi = jnp.arctan2(y, x) # Azimuthal angle
return r, theta, phi
# Example coordinate transformations
cartesian_points = jnp.array([
[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]
])
polar_coords = jnp.array([
cartesian_to_polar(pt[0], pt[1]) for pt in cartesian_points
])
print("Coordinate transformations:")
for i, (cart, polar) in enumerate(zip(cartesian_points, polar_coords)):
print(f" Point {i}: ({cart[0]:.1f}, {cart[1]:.1f}) → (r={polar[0]:.3f}, θ={polar[1]:.3f})")
Geometric Transformations¶
def create_transformation_matrix_2d(translation, rotation_angle, scale):
"""Create 2D transformation matrix."""
cos_theta = jnp.cos(rotation_angle)
sin_theta = jnp.sin(rotation_angle)
# Homogeneous transformation matrix
T = jnp.array([
[scale[0] * cos_theta, -scale[0] * sin_theta, translation[0]],
[scale[1] * sin_theta, scale[1] * cos_theta, translation[1]],
[0.0, 0.0, 1.0]
])
return T
def apply_transformation_2d(points, transformation_matrix):
"""Apply 2D transformation to points."""
# Convert to homogeneous coordinates
homogeneous_points = jnp.column_stack([points, jnp.ones(len(points))])
# Apply transformation
transformed_homogeneous = homogeneous_points @ transformation_matrix.T
# Convert back to Cartesian coordinates
return transformed_homogeneous[:, :2]
# Example: Transform a square
square_vertices = jnp.array([
[-0.5, -0.5], [0.5, -0.5], [0.5, 0.5], [-0.5, 0.5]
])
# Create transformation: translate, rotate 45°, scale by 2
transform_matrix = create_transformation_matrix_2d(
translation=jnp.array([1.0, 1.0]),
rotation_angle=jnp.pi / 4,
scale=jnp.array([2.0, 2.0])
)
transformed_square = apply_transformation_2d(square_vertices, transform_matrix)
print("Geometric transformation applied:")
print(f"Original square vertices: {square_vertices.shape}")
print(f"Transformed square vertices: {transformed_square.shape}")
Integration with Physics Problems¶
Domain Definition for PDE Problems¶
from opifex.core.problems import PDEProblem
from opifex.core.conditions import DirichletBC, NeumannBC
class ComplexDomainPDEProblem(PDEProblem):
"""PDE problem on complex geometric domain."""
def __init__(self, geometry, physics_parameters):
# Define boundary conditions based on geometry
boundary_conditions = [
DirichletBC(boundary="wall", value=1.0),
NeumannBC(boundary="symmetry", value=0.0)
]
super().__init__(
geometry=geometry,
equation=self._heat_equation_with_geometry,
boundary_conditions=boundary_conditions,
parameters=physics_parameters
)
def _heat_equation_with_geometry(self, x, u, u_derivatives):
"""Heat equation with geometry-dependent source term."""
alpha = self.parameters["diffusivity"]
u_t = u_derivatives["t"]
u_xx = u_derivatives["xx"]
u_yy = u_derivatives["yy"]
# Geometry-dependent source term
distance_to_boundary = self.geometry.distance(x[..., :2])
source_term = jnp.exp(-distance_to_boundary**2)
return u_t - alpha * (u_xx + u_yy) - source_term
def generate_collocation_points(self, n_points, key):
"""Generate physics-informed collocation points."""
# Sample points inside the geometry
bbox_min = jnp.array([-2.0, -2.0])
bbox_max = jnp.array([2.0, 2.0])
candidates = jax.random.uniform(
key, (n_points * 3, 2), minval=bbox_min, maxval=bbox_max
)
# Keep only points inside geometry
inside_mask = self.geometry.contains(candidates)
interior_points = candidates[inside_mask][:n_points]
return interior_points
# Create PDE problem with complex geometry
complex_pde = ComplexDomainPDEProblem(
geometry=slotted_annulus,
physics_parameters={"diffusivity": 0.01}
)
# Generate collocation points for PINN training
key = jax.random.PRNGKey(999)
collocation_points = complex_pde.generate_collocation_points(1000, key)
print(f"Complex domain PDE problem created")
print(f"Generated {len(collocation_points)} collocation points")
Molecular Geometry for Quantum Problems¶
from opifex.core.problems import QuantumProblem
class MolecularQuantumProblem(QuantumProblem):
"""Quantum problem with molecular geometry constraints."""
def __init__(self, molecular_system, computational_domain):
self.molecular_system = molecular_system
self.computational_domain = computational_domain
super().__init__(
molecular_system=molecular_system,
method="neural_dft",
parameters={
"computational_domain": computational_domain,
"basis_cutoff": 10.0, # Atomic units
"grid_spacing": 0.1
}
)
def generate_grid_points(self, spacing=0.1):
"""Generate computational grid excluding molecular regions."""
# Create regular grid in computational domain
bounds = self.computational_domain.bounds
x = jnp.arange(bounds[0][0], bounds[0][1], spacing)
y = jnp.arange(bounds[1][0], bounds[1][1], spacing)
z = jnp.arange(bounds[2][0], bounds[2][1], spacing)
X, Y, Z = jnp.meshgrid(x, y, z, indexing='ij')
grid_points = jnp.stack([X.flatten(), Y.flatten(), Z.flatten()], axis=1)
# Exclude points too close to nuclei
valid_points = []
for point in grid_points:
min_distance = jnp.min(jnp.linalg.norm(
point - self.molecular_system.positions, axis=1
))
if min_distance > 0.5: # Minimum distance in atomic units
valid_points.append(point)
return jnp.array(valid_points)
# Create quantum problem with molecular geometry
quantum_problem = MolecularQuantumProblem(
molecular_system=water_molecule,
computational_domain=computational_domain
)
grid_points = quantum_problem.generate_grid_points(spacing=0.2)
print(f"Quantum grid generated: {len(grid_points)} points")
Performance Optimization and Best Practices¶
JAX Optimization Techniques¶
# JIT compilation for geometric operations
@jax.jit
def batch_distance_computation(geometry, points):
"""JIT-compiled batch distance computation."""
return jnp.array([geometry.distance(pt) for pt in points])
@jax.jit
def batch_containment_test(geometry, points):
"""JIT-compiled batch containment testing."""
return geometry.contains(points)
# Vectorized operations for performance
@jax.vmap
def vectorized_normal_computation(geometry, points):
"""Vectorized normal computation."""
return geometry.compute_normal(points)
# Example usage with performance timing
import time
large_point_set = jax.random.uniform(
jax.random.PRNGKey(1000), (10000, 2), minval=-2.0, maxval=2.0
)
# Time JIT-compiled operations
start_time = time.time()
distances = batch_distance_computation(circle, large_point_set)
jit_time = time.time() - start_time
print(f"JIT-compiled distance computation: {jit_time:.4f}s for {len(large_point_set)} points")
Memory-Efficient Geometry Operations¶
def chunked_geometry_operations(geometry, points, chunk_size=1000):
"""Process large point sets in chunks to manage memory."""
n_points = len(points)
n_chunks = (n_points + chunk_size - 1) // chunk_size
results = []
for i in range(n_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, n_points)
chunk = points[start_idx:end_idx]
# Process chunk
chunk_distances = batch_distance_computation(geometry, chunk)
results.append(chunk_distances)
return jnp.concatenate(results)
# Process very large point set efficiently
very_large_points = jax.random.uniform(
jax.random.PRNGKey(1001), (50000, 2), minval=-3.0, maxval=3.0
)
chunked_distances = chunked_geometry_operations(
slotted_annulus, very_large_points, chunk_size=5000
)
print(f"Processed {len(very_large_points)} points in chunks")
print(f"Memory-efficient computation completed")
Geometry Caching and Precomputation¶
class CachedGeometry:
"""Geometry wrapper with caching for expensive operations."""
def __init__(self, base_geometry):
self.base_geometry = base_geometry
self._distance_cache = {}
self._normal_cache = {}
def distance(self, point):
"""Cached distance computation."""
point_key = tuple(point.tolist())
if point_key not in self._distance_cache:
self._distance_cache[point_key] = self.base_geometry.distance(point)
return self._distance_cache[point_key]
def compute_normal(self, point):
"""Cached normal computation."""
point_key = tuple(point.tolist())
if point_key not in self._normal_cache:
self._normal_cache[point_key] = self.base_geometry.compute_normal(point)
return self._normal_cache[point_key]
def clear_cache(self):
"""Clear all cached results."""
self._distance_cache.clear()
self._normal_cache.clear()
# Use cached geometry for repeated operations
cached_geometry = CachedGeometry(slotted_annulus)
# Repeated queries will be faster
test_point = jnp.array([0.5, 0.5])
for _ in range(100):
distance = cached_geometry.distance(test_point) # Cached after first call
print("Geometry caching implemented for performance optimization")
This full geometry guide provides the foundation for working with complex geometric problems in scientific machine learning. The unified framework supports everything from simple 2D domains to advanced manifold-based neural operators and quantum molecular systems, all optimized for high-performance computation with JAX.