# astra/spatial_index.py
"""ASTRA Core Persistent Spatial Index — KDTree.
Implements a 3D KDTree spatial partitioning structure for O(N log N)
conjunction candidate pair generation, replacing the naive O(N²)
all-pairs search.
Wraps the ultra-fast C++ scipy.spatial.cKDTree.
"""
from __future__ import annotations
import threading
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from astra.models import TrajectoryMap
import numpy as np
from scipy.spatial import cKDTree
[docs]
class SpatialIndex:
"""High-level persistent spatial index for conjunction screening.
Wraps scipy.spatial.cKDTree for robust and extremely fast spatial queries.
Usage:
idx = SpatialIndex()
idx.insert("25544", np.array([6771.0, 0.0, 0.0]))
pairs = idx.query_pairs(threshold_km=50.0)
Args:
half_size_km: Nominal maximum separation (km) for screening workflows.
Pass the same value to :meth:`query_pairs` when you want a consistent
radius; it is not applied automatically.
max_objects_per_node: ``leafsize`` for SciPy's ``cKDTree`` (bucket size).
"""
def __init__(self, half_size_km: float = 50000.0, max_objects_per_node: int = 16):
self.half_size_km = float(half_size_km)
self.max_objects_per_node = int(max_objects_per_node)
self._tree: cKDTree | None = None
self._excursions: dict[str, float] = {} # Max distance from center in window
self._max_excursion: float = 0.0
self._lock = threading.RLock()
self._positions: dict[str, np.ndarray] = {}
self._ids: list[str] = []
[docs]
def insert(self, obj_id: str, position: np.ndarray) -> None:
"""Insert or update an object's position. Thread-safe."""
with self._lock:
self._positions[obj_id] = position.copy()
self._tree = None
[docs]
def query_radius(
self, point: np.ndarray, radius_km: float
) -> list[tuple[str, np.ndarray]]:
"""Find all objects within radius of a point. Thread-safe."""
with self._lock:
self._ensure_tree()
if self._tree is None:
return []
indices = self._tree.query_ball_point(point, r=radius_km)
return [(self._ids[i], self._positions[self._ids[i]]) for i in indices]
[docs]
def query_pairs(self, threshold_km: float = 50.0) -> list[tuple[str, str]]:
"""Find all pairs of objects within threshold distance. Thread-safe.
PERF-01 Fix: For trajectory-mode indices (where excursions are tracked),
uses per-object excursion radii instead of the global max_excursion to
bound the KDTree query. The global max approach was dominated by any
single HEO object (excursion ~20,000 km), producing a coarse threshold
of 50 + 2*20,000 = 40,050 km that pairs the HEO with virtually every
object in the catalog. Per-object refinement (step 2 below) already
existed and correctly filters these false positives, but the initial
query was still O(N²) for HEO-heavy catalogs. The fix tightens the
tree query radius to threshold + per-object excursion (not global max),
then refines with the tighter condition, giving near-optimal selectivity.
"""
with self._lock:
self._ensure_tree()
if self._tree is None or len(self._ids) < 2:
return []
if self._max_excursion > 0:
# PERF-01: Use each object's own excursion to query the tree
# independently, then collect unique pairs. For LEO objects
# this typically < 50 km; for HEO it is large but only for
# that one object's own query, not globally inflating all pairs.
seen: set[tuple[str, str]] = set()
for i, nid in enumerate(self._ids):
# Per-object query radius for complete symmetric discovery:
# Query i with threshold + 2*exc_i. If j has larger excursion,
# the pair is guaranteed to be discovered when j is queried
# because threshold + exc_i + exc_j <= threshold + 2*max(exc_i, exc_j).
per_obj_radius = threshold_km + 2.0 * self._excursions[nid]
neighbours = self._tree.query_ball_point(
self._positions[nid], r=per_obj_radius
)
for j in neighbours:
if j == i:
continue
id_j = self._ids[j]
key = (min(nid, id_j), max(nid, id_j))
if key in seen:
continue
# Tight per-pair refinement: centers within exc_i + exc_j + threshold
dist_centers = float(
np.linalg.norm(self._positions[nid] - self._positions[id_j])
)
if (
dist_centers
<= threshold_km
+ self._excursions[nid]
+ self._excursions[id_j]
):
seen.add(key)
return list(seen)
else:
index_pairs = self._tree.query_pairs(r=threshold_km, output_type="set")
results = []
for i, j in index_pairs:
id_a, id_b = self._ids[i], self._ids[j]
results.append((min(id_a, id_b), max(id_a, id_b)))
return results
[docs]
def rebuild_for_trajectories(self, trajectories: TrajectoryMap) -> None:
"""Build a unified spatial index for entire trajectories (high-performance).
Uses the mean position of each trajectory as the center and tracks the
maximum excursion from that center. Enables one-shot conjunction
screening for the entire propagation window.
"""
with self._lock:
new_positions = {}
new_excursions = {}
for nid, traj in trajectories.items():
if np.any(~np.isfinite(traj)):
continue
# Center = average of start and end might be enough, but mean is safer
center = np.mean(traj, axis=0)
# Max excursion = max distance from center
excursions = np.linalg.norm(traj - center, axis=1)
max_exc = float(np.max(excursions))
new_positions[nid] = center
new_excursions[nid] = max_exc
self._positions = new_positions
self._excursions = new_excursions
self._max_excursion = (
max(new_excursions.values()) if new_excursions else 0.0
)
self._ensure_tree(force=True)
@property
def size(self) -> int:
"""Number of objects indexed."""
return len(self._positions)
[docs]
def rebuild(self, positions: dict[str, np.ndarray]) -> None:
"""Rebuild the entire index from a fresh position dictionary. Thread-safe.
Silently drops any objects whose position contains NaN or Inf.
"""
with self._lock:
self._positions = {
k: v.copy() for k, v in positions.items() if np.all(np.isfinite(v))
}
self._ensure_tree(force=True)
def _ensure_tree(self, force: bool = False) -> None:
"""Internal tree reconstruction logic. MUST BE CALLED WITHIN _lock."""
if (self._tree is None or force) and self._positions:
self._ids = list(self._positions.keys())
points = np.array([self._positions[nid] for nid in self._ids])
if len(points) > 0:
self._tree = cKDTree(points, leafsize=max(1, self.max_objects_per_node))
else:
self._tree = None