Skip to content

Geometry API Reference

The opifex.geometry package provides tools for defining computational domains using Constructive Solid Geometry (CSG).

Shapes

Base Protocol

opifex.geometry.csg.Shape2D

Bases: Geometry, Protocol

Protocol for 2D geometric shapes.

contains abstractmethod

contains(point: Point2D) -> bool

Check if a point is contained within the shape.

Source code in opifex/geometry/csg.py
@abstractmethod
def contains(self, point: Point2D) -> bool:
    """Check if a point is contained within the shape."""
    ...

compute_normal abstractmethod

compute_normal(point: Point2D) -> Point2D

Compute outward normal at a boundary point.

Source code in opifex/geometry/csg.py
@abstractmethod
def compute_normal(self, point: Point2D) -> Point2D:
    """Compute outward normal at a boundary point."""
    ...

distance abstractmethod

distance(point: Point2D) -> Float[Array, '']

Compute signed distance to shape boundary.

Source code in opifex/geometry/csg.py
@abstractmethod
def distance(self, point: Point2D) -> Float[jax.Array, ""]:
    """Compute signed distance to shape boundary."""
    ...

sample_boundary abstractmethod

sample_boundary(n: int, key: Array) -> Points2D

Sample points on the shape boundary.

Source code in opifex/geometry/csg.py
@abstractmethod
def sample_boundary(self, n: int, key: jax.Array) -> Points2D:
    """Sample points on the shape boundary."""
    ...

Basic Shapes

opifex.geometry.csg.Rectangle

Rectangle(center: Point2D, width: float, height: float)

Bases: _EnhancedShapeBase

2D rectangle shape for computational domains.

Parameters:

Name Type Description Default
center Point2D

Center point of the rectangle

required
width float

Width of the rectangle (must be positive)

required
height float

Height of the rectangle (must be positive)

required
Source code in opifex/geometry/csg.py
def __init__(self, center: Point2D, width: float, height: float):
    """
    Initialize rectangle.

    Args:
        center: Center point of the rectangle
        width: Width of the rectangle (must be positive)
        height: Height of the rectangle (must be positive)
    """
    self.center = jnp.asarray(center)
    # Keep width/height as scalars to avoid tracer leaks, but handle both types
    if hasattr(width, "shape") or hasattr(height, "shape"):  # JAX arrays
        self.width = width
        self.height = height
    else:  # Python scalars
        if width <= 0 or height <= 0:
            raise ValueError("Width and height must be positive")
        self.width = float(width)
        self.height = float(height)

    # Precompute bounds for efficiency
    self.x_min = self.center[0] - self.width / 2
    self.x_max = self.center[0] + self.width / 2
    self.y_min = self.center[1] - self.height / 2
    self.y_max = self.center[1] + self.height / 2

contains

contains(point: Point2D) -> bool

Check if point is inside rectangle (inclusive of boundary).

Source code in opifex/geometry/csg.py
def contains(self, point: Point2D) -> bool:
    """Check if point is inside rectangle (inclusive of boundary)."""
    point = jnp.asarray(point)
    return bool(
        (self.x_min <= point[0] <= self.x_max) and (self.y_min <= point[1] <= self.y_max)
    )

distance

distance(point: Point2D) -> Float[Array, '']

Compute signed distance to rectangle boundary (smooth and differentiable).

Source code in opifex/geometry/csg.py
def distance(self, point: Point2D) -> Float[jax.Array, ""]:
    """Compute signed distance to rectangle boundary (smooth and differentiable)."""
    point = jnp.asarray(point)

    # Use smooth absolute value: |x| ≈ sqrt(x^2 + ε^2) - ε
    eps = 1e-8

    def smooth_abs(x):
        return jnp.sqrt(x * x + eps * eps) - eps

    # Distance to each edge using smooth operations
    d_x = smooth_abs(point[0] - self.center[0]) - self.width / 2
    d_y = smooth_abs(point[1] - self.center[1]) - self.height / 2

    # Smooth maximum using logsumexp for better numerical stability
    def smooth_max(a, b, k=10.0):
        return jnp.logaddexp(k * a, k * b) / k

    # Combine distances for SDF using smooth operations
    zero = jnp.array(0.0)
    outside_dist = jnp.sqrt(smooth_max(d_x, zero) ** 2 + smooth_max(d_y, zero) ** 2)
    inside_dist = smooth_max(d_x, d_y)

    # Use smooth minimum to blend inside and outside distances
    # When both d_x <= 0 and d_y <= 0, we want inside_dist
    # Otherwise, we want outside_dist
    condition_value = smooth_max(-d_x, -d_y)  # positive when inside
    blend_factor = jnp.tanh(10.0 * condition_value)
    result = blend_factor * inside_dist + (1 - blend_factor) * outside_dist

    return jnp.asarray(result)

sample_boundary

sample_boundary(n: int, key: Array) -> Points2D

Sample points uniformly on rectangle boundary.

Source code in opifex/geometry/csg.py
def sample_boundary(self, n: int, key: jax.Array) -> Points2D:
    """Sample points uniformly on rectangle boundary."""
    # Total perimeter
    perimeter = 2 * (self.width + self.height)

    # Generate random parameters along perimeter
    t = jax.random.uniform(key, (n,)) * perimeter

    def point_on_boundary(param):
        """Map parameter to boundary point."""
        # Bottom edge
        cond1 = param < self.width
        p1 = jnp.array([self.x_min + param, self.y_min])

        # Right edge
        param2 = param - self.width
        cond2 = (param >= self.width) & (param < self.width + self.height)
        p2 = jnp.array([self.x_max, self.y_min + param2])

        # Top edge
        param3 = param - self.width - self.height
        cond3 = (param >= self.width + self.height) & (param < 2 * self.width + self.height)
        p3 = jnp.array([self.x_max - param3, self.y_max])

        # Left edge
        param4 = param - 2 * self.width - self.height
        p4 = jnp.array([self.x_min, self.y_max - param4])

        # Use scalar conditions with jnp.where for JAX compatibility
        result = jnp.where(
            cond1,
            p1,
            jnp.where(cond2, p2, jnp.where(cond3, p3, p4)),
        )
        return jnp.asarray(result)

    points = jax.vmap(point_on_boundary)(t)
    return jnp.asarray(points).reshape(n, 2)

sample_interior

sample_interior(n: int, key: Array) -> Points2D

Sample points uniformly from rectangle interior.

Source code in opifex/geometry/csg.py
def sample_interior(self, n: int, key: jax.Array) -> Points2D:
    """Sample points uniformly from rectangle interior."""
    key1, key2 = jax.random.split(key)
    x = jax.random.uniform(key1, (n,), minval=self.x_min, maxval=self.x_max)
    y = jax.random.uniform(key2, (n,), minval=self.y_min, maxval=self.y_max)
    return jnp.stack([x, y], axis=1)

compute_normal

compute_normal(point: Point2D) -> Point2D

Compute outward normal at boundary point.

Source code in opifex/geometry/csg.py
def compute_normal(self, point: Point2D) -> Point2D:
    """Compute outward normal at boundary point."""
    point = jnp.asarray(point)

    # Determine which edge the point is on
    on_left = jnp.isclose(point[0], self.x_min, atol=1e-6)
    on_right = jnp.isclose(point[0], self.x_max, atol=1e-6)
    on_bottom = jnp.isclose(point[1], self.y_min, atol=1e-6)
    on_top = jnp.isclose(point[1], self.y_max, atol=1e-6)

    result = jnp.where(
        on_left,
        jnp.array([-1.0, 0.0]),
        jnp.where(
            on_right,
            jnp.array([1.0, 0.0]),
            jnp.where(
                on_bottom,
                jnp.array([0.0, -1.0]),
                jnp.where(
                    on_top,
                    jnp.array([0.0, 1.0]),
                    jnp.array([0.0, 0.0]),  # Default for points not on boundary
                ),
            ),
        ),
    )
    return jnp.asarray(result).reshape(2)

opifex.geometry.csg.Circle

Circle(center: Point2D, radius: float)

Bases: _EnhancedShapeBase

2D circle shape for computational domains.

Parameters:

Name Type Description Default
center Point2D

Center point of the circle

required
radius float

Radius of the circle (must be positive)

required
Source code in opifex/geometry/csg.py
def __init__(self, center: Point2D, radius: float):
    """
    Initialize circle.

    Args:
        center: Center point of the circle
        radius: Radius of the circle (must be positive)
    """
    self.center = jnp.asarray(center)
    # Keep radius as scalar to avoid tracer leaks, but handle both types
    if hasattr(radius, "shape"):  # JAX array
        self.radius = radius
    else:  # Python scalar
        if radius <= 0:
            raise ValueError("Radius must be positive")
        self.radius = float(radius)

contains

contains(point: Point2D) -> bool

Check if point is inside circle (inclusive of boundary).

Source code in opifex/geometry/csg.py
def contains(self, point: Point2D) -> bool:
    """Check if point is inside circle (inclusive of boundary)."""
    point = jnp.asarray(point)
    distance_squared = jnp.sum((point - self.center) ** 2)
    return bool(distance_squared <= self.radius**2)

distance

distance(point: Point2D) -> Float[Array, '']

Compute signed distance to circle boundary (smooth and differentiable).

Source code in opifex/geometry/csg.py
def distance(self, point: Point2D) -> Float[jax.Array, ""]:
    """Compute signed distance to circle boundary (smooth and differentiable)."""
    point = jnp.asarray(point)
    # Use smooth norm: ||x|| ≈ sqrt(x^2 + ε^2) - ε for differentiability at origin
    eps = 1e-8
    diff = point - self.center
    dist_to_center = jnp.sqrt(jnp.sum(diff * diff) + eps * eps) - eps
    return dist_to_center - self.radius

sample_boundary

sample_boundary(n: int, key: Array) -> Points2D

Sample points uniformly on circle boundary.

Source code in opifex/geometry/csg.py
def sample_boundary(self, n: int, key: jax.Array) -> Points2D:
    """Sample points uniformly on circle boundary."""
    # Generate random angles
    angles = jax.random.uniform(key, (n,)) * 2 * jnp.pi

    # Convert to Cartesian coordinates
    x = self.center[0] + self.radius * jnp.cos(angles)
    y = self.center[1] + self.radius * jnp.sin(angles)

    return jnp.stack([x, y], axis=1)

sample_interior

sample_interior(n: int, key: Array) -> Points2D

Sample points uniformly from circle interior.

Source code in opifex/geometry/csg.py
def sample_interior(self, n: int, key: jax.Array) -> Points2D:
    """Sample points uniformly from circle interior."""
    key1, key2 = jax.random.split(key)
    # Rejection sampling or polar coordinates with sqrt(r)
    theta = jax.random.uniform(key1, (n,), maxval=2 * jnp.pi)
    r = jnp.sqrt(jax.random.uniform(key2, (n,))) * self.radius

    x = self.center[0] + r * jnp.cos(theta)
    y = self.center[1] + r * jnp.sin(theta)
    return jnp.stack([x, y], axis=1)

compute_normal

compute_normal(point: Point2D) -> Point2D

Compute outward normal at boundary point.

Source code in opifex/geometry/csg.py
def compute_normal(self, point: Point2D) -> Point2D:
    """Compute outward normal at boundary point."""
    point = jnp.asarray(point)

    # Normal is the direction from center to point
    direction = point - self.center
    # Normalize to unit vector
    norm = jnp.linalg.norm(direction)

    # Handle the case where point is at center
    normal = jnp.where(
        norm > 1e-12,
        direction / norm,
        jnp.array([1.0, 0.0]),  # Default direction if at center
    )

    return jnp.asarray(normal)

opifex.geometry.csg.Polygon

Polygon(vertices: Points2D)

Bases: _EnhancedShapeBase

2D polygon shape defined by vertices.

Parameters:

Name Type Description Default
vertices Points2D

Array of vertex coordinates, shape (N, 2) where N >= 3

required

Raises:

Type Description
ValueError

If fewer than 3 vertices provided

Source code in opifex/geometry/csg.py
def __init__(self, vertices: Points2D):
    """
    Initialize polygon from vertices.

    Args:
        vertices: Array of vertex coordinates, shape (N, 2) where N >= 3

    Raises:
        ValueError: If fewer than 3 vertices provided
    """
    vertices = jnp.asarray(vertices)
    if vertices.shape[0] < 3:
        raise ValueError("Polygon must have at least 3 vertices")

    self.vertices = jnp.asarray(vertices)
    self.n_vertices = vertices.shape[0]

contains

contains(point: Point2D) -> bool

Check if point is inside polygon using ray casting algorithm.

Source code in opifex/geometry/csg.py
def contains(self, point: Point2D) -> bool:
    """Check if point is inside polygon using ray casting algorithm."""
    point = jnp.asarray(point)

    def ray_intersects_edge(i):
        """Check if horizontal ray from point intersects edge i."""
        v1 = self.vertices[i]
        v2 = self.vertices[(i + 1) % self.n_vertices]

        # Check if ray can intersect (y-coordinate conditions)
        y_check = (v1[1] > point[1]) != (v2[1] > point[1])

        # Compute x-intersection point
        x_intersect = v1[0] + (point[1] - v1[1]) / (v2[1] - v1[1]) * (v2[0] - v1[0])

        # Ray intersects if intersection is to the right of the point
        return y_check & (point[0] < x_intersect)

    # Count intersections
    intersections = jnp.sum(jax.vmap(ray_intersects_edge)(jnp.arange(self.n_vertices)))

    # Point is inside if odd number of intersections
    return bool(intersections % 2 == 1)

distance

distance(point: Point2D) -> Float[Array, '']

Compute signed distance to polygon boundary (enhanced).

Source code in opifex/geometry/csg.py
def distance(self, point: Point2D) -> Float[jax.Array, ""]:
    """Compute signed distance to polygon boundary (enhanced)."""
    point = jnp.asarray(point)

    # Find minimum distance to all edges
    def distance_to_edge(i):
        v1 = self.vertices[i]
        v2 = self.vertices[(i + 1) % self.n_vertices]

        # Vector from v1 to v2
        edge_vec = v2 - v1
        # Vector from v1 to point
        point_vec = point - v1

        # Project point onto edge line
        edge_length_sq = jnp.sum(edge_vec**2)
        t = jnp.clip(jnp.dot(point_vec, edge_vec) / edge_length_sq, 0.0, 1.0)

        # Closest point on edge
        closest = v1 + t * edge_vec
        return jnp.linalg.norm(point - closest)

    distances = jax.vmap(distance_to_edge)(jnp.arange(self.n_vertices))
    min_dist = jnp.min(distances)

    # Determine sign based on containment
    inside = self.contains(point)
    return jnp.where(inside, -min_dist, min_dist)

sample_boundary

sample_boundary(n: int, key: Array) -> Points2D

Sample points uniformly on polygon boundary.

Source code in opifex/geometry/csg.py
def sample_boundary(self, n: int, key: jax.Array) -> Points2D:
    """Sample points uniformly on polygon boundary."""
    # Compute edge lengths
    edges = jnp.roll(self.vertices, -1, axis=0) - self.vertices
    edge_lengths = jnp.linalg.norm(edges, axis=1)
    total_perimeter = jnp.sum(edge_lengths)

    # Generate random parameters along perimeter
    t = jax.random.uniform(key, (n,)) * total_perimeter

    def point_on_boundary(param):
        """Map parameter to boundary point."""
        cumulative_lengths = jnp.cumsum(jnp.concatenate([jnp.array([0.0]), edge_lengths]))

        # Find which edge the parameter corresponds to
        edge_idx = jnp.searchsorted(cumulative_lengths[1:], param, side="right")
        edge_idx = jnp.clip(edge_idx, 0, self.n_vertices - 1)

        # Parameter along the specific edge
        edge_param = (param - cumulative_lengths[edge_idx]) / edge_lengths[edge_idx]
        edge_param = jnp.clip(edge_param, 0.0, 1.0)

        # Interpolate along edge
        v1 = self.vertices[edge_idx]
        v2 = self.vertices[(edge_idx + 1) % self.n_vertices]

        return v1 + edge_param * (v2 - v1)

    result = jax.vmap(point_on_boundary)(t)
    return jnp.asarray(result)

sample_interior

sample_interior(n: int, key: Array) -> Points2D

Sample points from polygon interior using rejection sampling.

Source code in opifex/geometry/csg.py
def sample_interior(self, n: int, key: jax.Array) -> Points2D:
    """Sample points from polygon interior using rejection sampling."""
    # Find bounding box
    min_vals = jnp.min(self.vertices, axis=0)
    max_vals = jnp.max(self.vertices, axis=0)

    # Simple rejection sampling
    # Note: For complex polygons, ear clipping triangulation is better
    # but more complex
    def rejection_sample(current_key, num_needed):
        # Generate proposals
        key1, key2 = jax.random.split(current_key)
        proposals_x = jax.random.uniform(
            key1, (num_needed * 2,), minval=min_vals[0], maxval=max_vals[0]
        )
        proposals_y = jax.random.uniform(
            key2, (num_needed * 2,), minval=min_vals[1], maxval=max_vals[1]
        )
        proposals = jnp.stack([proposals_x, proposals_y], axis=1)

        # Check containment
        mask = jax.vmap(self.contains)(proposals)
        return proposals[mask]

    # Initial batch
    valid_points = rejection_sample(key, n)

    # Pad or slice to get exactly n
    # This is a naive implementation; production code might iterate or use dynamic
    # shapes if allowed. For fixed shape JAX, we typically oversample and then
    # mask/pad.
    if valid_points.shape[0] >= n:
        return valid_points[:n]

    # If not enough, pad with last point (not ideal but safe for array shapes)
    padding = jnp.repeat(valid_points[-1:], n - valid_points.shape[0], axis=0)
    return jnp.concatenate([valid_points, padding], axis=0)

compute_normal

compute_normal(point: Point2D) -> Point2D

Compute outward normal at boundary point.

Source code in opifex/geometry/csg.py
def compute_normal(self, point: Point2D) -> Point2D:
    """Compute outward normal at boundary point."""
    point = jnp.asarray(point)

    # Find closest edge
    def distance_to_edge(i):
        v1 = self.vertices[i]
        v2 = self.vertices[(i + 1) % self.n_vertices]

        # Project point onto edge
        edge_vec = v2 - v1
        edge_length_sq = jnp.sum(edge_vec**2)

        t = jnp.clip(jnp.dot(point - v1, edge_vec) / edge_length_sq, 0.0, 1.0)
        closest_point = v1 + t * edge_vec

        return jnp.linalg.norm(point - closest_point)

    distances = jax.vmap(distance_to_edge)(jnp.arange(self.n_vertices))
    closest_edge = jnp.argmin(distances)

    # Compute normal for closest edge
    v1 = self.vertices[closest_edge]
    v2 = self.vertices[(closest_edge + 1) % self.n_vertices]
    edge_vec = v2 - v1

    # Perpendicular vector (rotated 90 degrees)
    normal = jnp.array([-edge_vec[1], edge_vec[0]])
    return normal / jnp.linalg.norm(normal)

CSG Operations

Classes

opifex.geometry.csg.CSGUnion

CSGUnion(shape_a: Shape2D, shape_b: Shape2D)

Bases: _EnhancedShapeBase

Union of two shapes (A ∪ B) with enhanced algorithms.

Source code in opifex/geometry/csg.py
def __init__(self, shape_a: Shape2D, shape_b: Shape2D):
    self.shape_a = shape_a
    self.shape_b = shape_b
    # Check if shapes support distance fields for enhanced operations
    self._has_sdf = hasattr(shape_a, "distance") and hasattr(shape_b, "distance")

contains

contains(point: Point2D) -> bool

Point is in union if it's in either shape.

Source code in opifex/geometry/csg.py
def contains(self, point: Point2D) -> bool:
    """Point is in union if it's in either shape."""
    if self._has_sdf:
        # Use SDF-based robust evaluation
        dist_a = self.shape_a.distance(point)
        dist_b = self.shape_b.distance(point)
        union_dist = _SDFOperations.union_sdf(dist_a, dist_b)
        return bool(union_dist <= 0)
    # Fallback to original method
    return self.shape_a.contains(point) or self.shape_b.contains(point)

distance

distance(point: Point2D) -> Float[Array, '']

Compute signed distance to union boundary.

Source code in opifex/geometry/csg.py
def distance(self, point: Point2D) -> Float[jax.Array, ""]:
    """Compute signed distance to union boundary."""
    dist_a = self.shape_a.distance(point)
    dist_b = self.shape_b.distance(point)
    # Union SDF: minimum of distances
    result = _SDFOperations.union_sdf(dist_a, dist_b)
    return jnp.array(result)

sample_boundary

sample_boundary(n: int, key: Array) -> Points2D

Sample boundary points using enhanced filtering.

Source code in opifex/geometry/csg.py
def sample_boundary(self, n: int, key: jax.Array) -> Points2D:
    """Sample boundary points using enhanced filtering."""
    if self._has_sdf:
        # Enhanced sampling using distance-based filtering
        key1, key2, key3 = jax.random.split(key, 3)

        # Oversample from both shapes
        oversample_factor = 2
        points_a = self.shape_a.sample_boundary(n * oversample_factor, key1)
        points_b = self.shape_b.sample_boundary(n * oversample_factor, key2)

        # Filter points near the true boundary
        def is_boundary_point(point):
            dist_a = self.shape_a.distance(point)
            dist_b = self.shape_b.distance(point)
            union_dist = _SDFOperations.union_sdf(dist_a, dist_b)
            return jnp.abs(union_dist) < 1e-3

        all_points = jnp.concatenate([points_a, points_b], axis=0)
        boundary_mask = jax.vmap(is_boundary_point)(all_points)
        boundary_points = all_points[boundary_mask]

        # Sample n if we have more than needed
        if len(boundary_points) >= n:
            indices = jax.random.choice(key3, len(boundary_points), (n,), replace=False)
            return boundary_points[indices]
        # If not enough boundary points, fill with regular sampling
        remaining = n - len(boundary_points)
        if remaining > 0:
            extra_a = self.shape_a.sample_boundary(remaining // 2, key1)
            extra_b = self.shape_b.sample_boundary(remaining - remaining // 2, key2)
            return jnp.concatenate([boundary_points, extra_a, extra_b], axis=0)
        return boundary_points
    # Fallback to original method
    key1, key2 = jax.random.split(key)
    points_a = self.shape_a.sample_boundary(n // 2, key1)
    points_b = self.shape_b.sample_boundary(n - n // 2, key2)
    return jnp.concatenate([points_a, points_b], axis=0)

sample_interior

sample_interior(n: int, key: Array) -> Points2D

Sample points from union interior.

Source code in opifex/geometry/csg.py
def sample_interior(self, n: int, key: jax.Array) -> Points2D:
    """Sample points from union interior."""
    # Simple approach: sample from A and B proportionally
    key1, key2 = jax.random.split(key)
    points_a = self.shape_a.sample_interior(n // 2, key1)
    points_b = self.shape_b.sample_interior(n - n // 2, key2)
    return jnp.concatenate([points_a, points_b], axis=0)

compute_normal

compute_normal(point: Point2D) -> Point2D

Compute normal (enhanced approach).

Source code in opifex/geometry/csg.py
def compute_normal(self, point: Point2D) -> Point2D:
    """Compute normal (enhanced approach)."""
    if self._has_sdf:

        def union_distance(p):
            dist_a = self.shape_a.distance(p)
            dist_b = self.shape_b.distance(p)
            return _SDFOperations.union_sdf(dist_a, dist_b)

        gradient_fn = jax.grad(union_distance)
        normal = gradient_fn(point)
        norm = jnp.linalg.norm(normal)
        result = jnp.where(norm > 1e-10, normal / norm, jnp.array([1.0, 0.0]))
        return jnp.asarray(result).reshape(2)
    # Fallback to original method
    if self.shape_a.contains(point):
        return self.shape_a.compute_normal(point)
    return self.shape_b.compute_normal(point)

opifex.geometry.csg.CSGIntersection

CSGIntersection(shape_a: Shape2D, shape_b: Shape2D)

Bases: _EnhancedShapeBase

Intersection of two shapes (A ∩ B) with enhanced algorithms.

Source code in opifex/geometry/csg.py
def __init__(self, shape_a: Shape2D, shape_b: Shape2D):
    self.shape_a = shape_a
    self.shape_b = shape_b
    self._has_sdf = hasattr(shape_a, "distance") and hasattr(shape_b, "distance")

contains

contains(point: Point2D) -> bool

Point is in intersection if it's in both shapes.

Source code in opifex/geometry/csg.py
def contains(self, point: Point2D) -> bool:
    """Point is in intersection if it's in both shapes."""
    if self._has_sdf:
        dist_a = self.shape_a.distance(point)
        dist_b = self.shape_b.distance(point)
        intersection_dist = _SDFOperations.intersection_sdf(dist_a, dist_b)
        return bool(intersection_dist <= 0)
    return self.shape_a.contains(point) and self.shape_b.contains(point)

distance

distance(point: Point2D) -> Float[Array, '']

Compute signed distance to intersection boundary.

Source code in opifex/geometry/csg.py
def distance(self, point: Point2D) -> Float[jax.Array, ""]:
    """Compute signed distance to intersection boundary."""
    dist_a = self.shape_a.distance(point)
    dist_b = self.shape_b.distance(point)
    # Intersection SDF: maximum of distances
    result = _SDFOperations.intersection_sdf(dist_a, dist_b)
    return jnp.array(result)

sample_boundary

sample_boundary(n: int, key: Array) -> Points2D

Sample boundary points (enhanced approach).

Source code in opifex/geometry/csg.py
def sample_boundary(self, n: int, key: jax.Array) -> Points2D:
    """Sample boundary points (enhanced approach)."""
    if self._has_sdf:
        # Similar enhanced sampling as union
        key1, key2, key3 = jax.random.split(key, 3)

        oversample_factor = 3
        points_a = self.shape_a.sample_boundary(n * oversample_factor, key1)
        points_b = self.shape_b.sample_boundary(n * oversample_factor, key2)

        def is_intersection_boundary(point):
            dist_a = self.shape_a.distance(point)
            dist_b = self.shape_b.distance(point)
            intersection_dist = _SDFOperations.intersection_sdf(dist_a, dist_b)
            return jnp.abs(intersection_dist) < 1e-3

        all_points = jnp.concatenate([points_a, points_b], axis=0)
        boundary_mask = jax.vmap(is_intersection_boundary)(all_points)
        boundary_points = all_points[boundary_mask]

        if len(boundary_points) >= n:
            indices = jax.random.choice(key3, len(boundary_points), (n,), replace=False)
            return boundary_points[indices]
        if len(boundary_points) > 0:
            return boundary_points
        # Fallback if no intersection boundary found
        return jnp.zeros((1, 2))
    # Simplified implementation for non-SDF shapes
    key1, _ = jax.random.split(key)
    points_a = self.shape_a.sample_boundary(n, key1)
    mask = jax.vmap(self.shape_b.contains)(points_a)
    valid_points = points_a[mask]
    return valid_points[:n] if len(valid_points) >= n else points_a[:1]

sample_interior

sample_interior(n: int, key: Array) -> Points2D

Sample points from intersection interior.

Source code in opifex/geometry/csg.py
def sample_interior(self, n: int, key: jax.Array) -> Points2D:
    """Sample points from intersection interior."""
    # Rejection sampling from Shape A
    # Since intersection is subset of A, this is efficient if overlap is high
    candidates = self.shape_a.sample_interior(n * 2, key)
    mask = jax.vmap(self.shape_b.contains)(candidates)
    valid = candidates[mask]

    if valid.shape[0] >= n:
        return valid[:n]
    # Pad with last valid or zeros
    if valid.shape[0] > 0:
        padding = jnp.repeat(valid[-1:], n - valid.shape[0], axis=0)
        return jnp.concatenate([valid, padding], axis=0)
    return jnp.zeros((n, 2))

compute_normal

compute_normal(point: Point2D) -> Point2D

Compute normal (enhanced approach).

Source code in opifex/geometry/csg.py
def compute_normal(self, point: Point2D) -> Point2D:
    """Compute normal (enhanced approach)."""
    if self._has_sdf:

        def intersection_distance(p):
            dist_a = self.shape_a.distance(p)
            dist_b = self.shape_b.distance(p)
            return _SDFOperations.intersection_sdf(dist_a, dist_b)

        gradient_fn = jax.grad(intersection_distance)
        normal = gradient_fn(point)
        norm = jnp.linalg.norm(normal)
        result = jnp.where(norm > 1e-10, normal / norm, jnp.array([1.0, 0.0]))
        return jnp.asarray(result).reshape(2)
    # Use normal from first shape as approximation
    return self.shape_a.compute_normal(point)

opifex.geometry.csg.CSGDifference

CSGDifference(shape_a: Shape2D, shape_b: Shape2D)

Bases: _EnhancedShapeBase

Difference of two shapes (A - B) with enhanced algorithms.

Source code in opifex/geometry/csg.py
def __init__(self, shape_a: Shape2D, shape_b: Shape2D):
    self.shape_a = shape_a
    self.shape_b = shape_b
    self._has_sdf = hasattr(shape_a, "distance") and hasattr(shape_b, "distance")

contains

contains(point: Point2D) -> bool

Point is in difference if it's in A but not in B.

Source code in opifex/geometry/csg.py
def contains(self, point: Point2D) -> bool:
    """Point is in difference if it's in A but not in B."""
    if self._has_sdf:
        dist_a = self.shape_a.distance(point)
        dist_b = self.shape_b.distance(point)
        difference_dist = _SDFOperations.difference_sdf(dist_a, dist_b)
        return bool(difference_dist <= 0)
    return self.shape_a.contains(point) and not self.shape_b.contains(point)

distance

distance(point: Point2D) -> Float[Array, '']

Compute signed distance to difference boundary.

Source code in opifex/geometry/csg.py
def distance(self, point: Point2D) -> Float[jax.Array, ""]:
    """Compute signed distance to difference boundary."""
    dist_a = self.shape_a.distance(point)
    dist_b = self.shape_b.distance(point)
    # Difference SDF: maximum of first shape and negative of second
    result = _SDFOperations.difference_sdf(dist_a, dist_b)
    return jnp.array(result)

sample_boundary

sample_boundary(n: int, key: Array) -> Points2D

Sample points on difference boundary.

Source code in opifex/geometry/csg.py
def sample_boundary(self, n: int, key: jax.Array) -> Points2D:
    """Sample points on difference boundary."""
    # Sample candidates from shape_a boundary and filter
    candidates = self.shape_a.sample_boundary(n * 2, key)

    def is_difference_boundary(point):
        """Check if point is on difference boundary."""
        # Point is on boundary if it's on shape_a and outside shape_b
        on_a = jnp.isclose(self.shape_a.distance(point), 0.0, atol=1e-6)
        outside_b = self.shape_b.distance(point) > 1e-6
        return on_a & outside_b

    boundary_mask = jax.vmap(is_difference_boundary)(candidates)
    boundary_points = candidates[boundary_mask]

    if len(boundary_points) >= n:
        indices = jax.random.choice(key, len(boundary_points), (n,), replace=False)
        return boundary_points[indices]
    if len(boundary_points) > 0:
        return boundary_points
    # Fallback
    return self.shape_a.sample_boundary(n, key)

sample_interior

sample_interior(n: int, key: Array) -> Points2D

Sample points from difference interior (A - B).

Source code in opifex/geometry/csg.py
def sample_interior(self, n: int, key: jax.Array) -> Points2D:
    """Sample points from difference interior (A - B)."""
    # Rejection sampling from Shape A: accept if NOT in B
    candidates = self.shape_a.sample_interior(n * 2, key)
    mask = jax.vmap(lambda p: not self.shape_b.contains(p))(candidates)
    valid = candidates[mask]

    if valid.shape[0] >= n:
        return valid[:n]
    if valid.shape[0] > 0:
        padding = jnp.repeat(valid[-1:], n - valid.shape[0], axis=0)
        return jnp.concatenate([valid, padding], axis=0)
    return jnp.zeros((n, 2))

compute_normal

compute_normal(point: Point2D) -> Point2D

Compute normal from shape A.

Source code in opifex/geometry/csg.py
def compute_normal(self, point: Point2D) -> Point2D:
    """Compute normal from shape A."""
    if self._has_sdf:

        def difference_distance(p):
            dist_a = self.shape_a.distance(p)
            dist_b = self.shape_b.distance(p)
            return _SDFOperations.difference_sdf(dist_a, dist_b)

        gradient_fn = jax.grad(difference_distance)
        normal = gradient_fn(point)
        norm = jnp.linalg.norm(normal)
        result = jnp.where(norm > 1e-10, normal / norm, jnp.array([1.0, 0.0]))
        return jnp.asarray(result).reshape(2)
    return self.shape_a.compute_normal(point)

Functional API

Create union of two shapes.

Create intersection of two shapes.

Create difference of two shapes.

Boundary Analysis

Compute boundary normal at a point.

Sample points on shape boundary.

Molecular Geometry

opifex.geometry.csg.MolecularGeometry

MolecularGeometry(atomic_symbols: list[str], positions: Array)

3D molecular geometry with atomic coordinates.

Parameters:

Name Type Description Default
atomic_symbols list[str]

List of atomic symbols (e.g., ['H', 'H', 'O'])

required
positions Array

Atomic positions in Bohr, shape (N, 3)

required

Raises:

Type Description
ValueError

If number of symbols doesn't match number of positions

Source code in opifex/geometry/csg.py
def __init__(self, atomic_symbols: list[str], positions: jax.Array):
    """Initialize molecular geometry.

    Args:
        atomic_symbols: List of atomic symbols (e.g., ['H', 'H', 'O'])
        positions: Atomic positions in Bohr, shape (N, 3)

    Raises:
        ValueError: If number of symbols doesn't match number of positions
    """
    positions = jnp.asarray(positions)

    if len(atomic_symbols) != positions.shape[0]:
        raise ValueError("Number of atomic symbols must match number of positions")

    self.atomic_symbols = atomic_symbols
    self.positions = positions
    self.n_atoms = len(atomic_symbols)

compute_distances

compute_distances() -> Array

Compute all pairwise interatomic distances.

Source code in opifex/geometry/csg.py
def compute_distances(self) -> jax.Array:
    """Compute all pairwise interatomic distances."""
    # Compute pairwise distance matrix
    diff = self.positions[:, None, :] - self.positions[None, :, :]
    return jnp.linalg.norm(diff, axis=2)

project_to_2d

project_to_2d(plane: str = 'xy') -> Array

Project 3D coordinates to 2D plane.

Source code in opifex/geometry/csg.py
def project_to_2d(self, plane: str = "xy") -> jax.Array:
    """Project 3D coordinates to 2D plane."""
    if plane == "xy":
        return self.positions[:, :2]
    if plane == "xz":
        return self.positions[:, [0, 2]]
    if plane == "yz":
        return self.positions[:, [1, 2]]
    raise ValueError("Plane must be 'xy', 'xz', or 'yz'")

from_molecular_system classmethod

from_molecular_system(molecular_system) -> MolecularGeometry

Create molecular geometry from MolecularSystem.

Source code in opifex/geometry/csg.py
@classmethod
def from_molecular_system(cls, molecular_system) -> MolecularGeometry:
    """Create molecular geometry from MolecularSystem."""
    # Extract atomic symbols
    atomic_symbols = cls._extract_atomic_symbols(molecular_system)

    # Extract positions
    positions = cls._extract_positions(molecular_system)

    if atomic_symbols is None or positions is None:
        # Fallback: inspect the molecular system object for debugging
        available_attrs = [attr for attr in dir(molecular_system) if not attr.startswith("_")]
        raise ValueError(
            f"Molecular system must have atomic symbols and positions. "
            f"Available attributes: {available_attrs}. "
            f"Found atomic_symbols: {atomic_symbols is not None}, "
            f"Found positions: {positions is not None}"
        )

    return cls(atomic_symbols, positions)