Source code for orc_bound.core.upper_bound

"""
Residual-shell upper bound on the Wasserstein-1 distance.

Implements the bucket-based residual matching algorithm from the
ORC-bound paper for computing an upper bound on W_1(mu, nu).
"""

from __future__ import annotations

import numpy as np
from typing import Dict, Tuple, NamedTuple


class UpperBoundResult(NamedTuple):
    """Result of the residual-shell W1 upper bound computation."""

    ub: float
    """Upper bound on W_1(mu, nu)."""

    m_r: np.ndarray
    """Mass assigned to each distance bucket r."""

    residual_mass: float
    """Remaining unmatched mass R_l."""

    rbar: float
    """Effective distance for the residual mass."""


[docs] def residual_shell_upper_bound( mu_x: Dict[int, float], mu_y: Dict[int, float], D: np.ndarray, idx: Dict[int, int], l: int = 2, tol: float = 1e-12, rbar_mode: str = "local-max", ) -> Tuple[float, np.ndarray, float, float]: """ Compute the residual-shell upper bound on W_1(mu_x, mu_y). The algorithm performs bucket-based matching between the supports of the two measures, grouping pairs by their graph distance. Parameters ---------- mu_x : Dict[int, float] Measure at node x, as a dict of {node: mass}. mu_y : Dict[int, float] Measure at node y, as a dict of {node: mass}. D : np.ndarray Precomputed distance matrix from :func:`all_pairs_shortest_path_matrix_cutoff`. idx : Dict[int, int] Node-to-index mapping. l : int, default=2 Maximum shell distance to consider. Pairs at distance > l are assigned to the residual shell. tol : float, default=1e-12 Pruning threshold. Masses below this are ignored. rbar_mode : str, default="local-max" How to compute the residual distance rbar: - ``"local-max"``: max distance between residual supports. - ``"global"``: max finite distance in the matrix. Returns ------- ub : float Upper bound on the 1-Wasserstein distance. m_r : np.ndarray Mass assigned to each bucket r=0..l. residual_mass : float Remaining unmatched mass. rbar : float Effective residual distance. Raises ------ RuntimeError If no finite residual distances are found (cutoff too small). Examples -------- >>> import numpy as np >>> D = np.array([[0., 1., 2.], ... [1., 0., 1.], ... [2., 1., 0.]], dtype=float) >>> idx = {0: 0, 1: 1, 2: 2} >>> mu_x = {0: 0.5, 1: 0.5} >>> mu_y = {1: 0.5, 2: 0.5} >>> ub, m_r, Rl, rbar = residual_shell_upper_bound(mu_x, mu_y, D, idx, l=2) >>> ub >= 0 True """ # Filter by tolerance U = [u for u, m in mu_x.items() if m > tol] W = [v for v, m in mu_y.items() if m > tol] a = {u: mu_x[u] for u in U} b = {v: mu_y[v] for v in W} # Build distance buckets buckets = [[] for _ in range(l + 1)] for u in U: iu = idx[u] for v in W: d = D[iu, idx[v]] if np.isfinite(d) and d <= l: buckets[int(d)].append((u, v)) m_r = np.zeros(l + 1, dtype=np.float64) # Residual matching within each bucket for r in range(l + 1): for (u, v) in buckets[r]: delta = min(a[u], b[v]) if delta > tol: a[u] -= delta b[v] -= delta m_r[r] += delta # Remaining residual mass Rl = sum(a.values()) if Rl <= tol: rbar = 0.0 else: if rbar_mode == "local-max": RU = [u for u in U if a[u] > tol] RV = [v for v in W if b[v] > tol] finite_residual_distances = [ D[idx[u], idx[v]] for u in RU for v in RV if np.isfinite(D[idx[u], idx[v]]) ] if len(finite_residual_distances) == 0: raise RuntimeError( "No finite residual distances found. " "The shortest-path cutoff may be too small." ) rbar = max(finite_residual_distances) else: rbar = float(np.max(D[np.isfinite(D)])) ub = float(np.dot(np.arange(l + 1), m_r) + rbar * Rl) return ub, m_r, Rl, rbar