Source code for tinygp.solvers.quasisep.solver

from __future__ import annotations

__all__ = ["QuasisepSolver"]

from typing import TYPE_CHECKING, Any

import jax
import jax.numpy as jnp
import numpy as np

from tinygp.helpers import JAXArray
from tinygp.kernels.base import Kernel
from tinygp.noise import Noise
from tinygp.solvers.quasisep.core import LowerTriQSM, SymmQSM
from tinygp.solvers.solver import Solver

[docs] class QuasisepSolver(Solver): """A scalable solver that uses quasiseparable matrices Take a look at the documentation for the :ref:`api-solvers-quasisep`, for more technical details. You generally won't instantiate this object directly but, if you do, you'll probably want to use the :func:`QuasisepSolver.init` method instead of the usual constructor. """ X: JAXArray matrix: SymmQSM factor: LowerTriQSM def __init__( self, kernel: Kernel, X: JAXArray, noise: Noise, *, covariance: Any | None = None, assume_sorted: bool = False, ): """Build a :class:`QuasisepSolver` for a given kernel and coordinates Args: kernel: The kernel function. This must be an instance of a subclass of :class:`tinygp.kernels.quasisep.Quasisep`. X: The input coordinates. noise: The noise model for the process. covariance: Optionally, a pre-computed :class:`tinygp.solvers.quasisep.core.QSM` with the covariance matrix. assume_sorted: If ``True``, assume that the input coordinates are sorted. If ``False``, check that they are sorted and throw an error if they are not. This can introduce a runtime overhead, and you can pass ``assume_sorted=True`` to get the best performance. """ from tinygp.kernels.quasisep import Quasisep if covariance is None: if TYPE_CHECKING: assert isinstance(kernel, Quasisep) if not assume_sorted: jax.debug.callback(_check_sorted, kernel.coord_to_sortable(X)) matrix = kernel.to_symm_qsm(X) matrix += noise.to_qsm() else: if TYPE_CHECKING: assert isinstance(covariance, SymmQSM) matrix = covariance self.X = X self.matrix = matrix self.factor = matrix.cholesky()
[docs] def variance(self) -> JAXArray: return self.matrix.diag.d
[docs] def covariance(self) -> JAXArray: return self.matrix.to_dense()
[docs] def normalization(self) -> JAXArray: return jnp.sum(jnp.log(self.factor.diag.d)) + 0.5 * self.factor.shape[ 0 ] * np.log(2 * np.pi)
[docs] def solve_triangular(self, y: JAXArray, *, transpose: bool = False) -> JAXArray: if transpose: return self.factor.transpose().solve(y) else: return self.factor.solve(y)
[docs] def dot_triangular(self, y: JAXArray) -> JAXArray: return self.factor @ y
[docs] def condition(self, kernel: Kernel, X_test: JAXArray | None, noise: Noise) -> Any: """Compute the covariance matrix for a conditional GP In the case where the prediction is made at the input coordinates with a :class:`tinygp.kernels.quasisep.Quasisep` kernel, this will return the quasiseparable representation of the conditional matrix. Otherwise, it will use scalable methods where possible, but return a dense representation of the covariance, so be careful when predicting at a large number of test points! Args: kernel: The kernel for the covariance between the observed and predicted data. X_test: The coordinates of the predicted points. Defaults to the input coordinates. noise: The noise model for the predicted process. """ from tinygp.kernels.quasisep import Quasisep # We can easily compute the conditional as a QSM in the special case # where we are predicting at the input coordinates and a Quasisep kernel if X_test is None and isinstance(kernel, Quasisep): M = kernel.to_symm_qsm(self.X) delta = (self.factor.inv() @ M).gram() M += noise.to_qsm() return M - delta # Otherwise fall back on the slow method for now :( if X_test is None: Kss = Ks = kernel(self.X, self.X) else: Kss = kernel(X_test, X_test) Ks = kernel(self.X, X_test) A = self.solve_triangular(Ks) return Kss - A.transpose() @ A
def _check_sorted(X: JAXArray) -> None: if np.any(np.diff(X) < 0.0): raise ValueError( "Input coordinates must be sorted in order to use the QuasisepSolver" )