Source code for tinygp.solvers.direct
from __future__ import annotations
__all__ = ["DirectSolver"]
from typing import Any
import jax.numpy as jnp
import numpy as np
from jax.scipy import linalg
from tinygp import kernels
from tinygp.helpers import JAXArray
from tinygp.noise import Noise
from tinygp.solvers.solver import Solver
[docs]
class DirectSolver(Solver):
"""A direct solver that uses ``jax``'s built in Cholesky factorization
You generally won't instantiate this object directly but, if you do, you'll
probably want to use the :func:`DirectSolver.init` method instead of the
usual constructor.
"""
X: JAXArray
variance_value: JAXArray
covariance_value: JAXArray
scale_tril: JAXArray
def __init__(
self,
kernel: kernels.Kernel,
X: JAXArray,
noise: Noise,
*,
covariance: Any | None = None,
):
"""Build a :class:`DirectSolver` for a given kernel and coordinates
Args:
kernel: The kernel function.
X: The input coordinates.
noise: The noise model for the process.
covariance: Optionally, a pre-computed array with the covariance
matrix. This should be equal to the result of calling ``kernel``
and adding ``diag``, but that is not checked.
"""
self.X = X
self.variance_value = kernel(X) + noise.diagonal()
if covariance is None:
covariance = kernel(X, X) + noise
self.covariance_value = covariance
self.scale_tril = linalg.cholesky(covariance, lower=True)
[docs]
def variance(self) -> JAXArray:
return self.variance_value
[docs]
def covariance(self) -> JAXArray:
return self.covariance_value
[docs]
def normalization(self) -> JAXArray:
return jnp.sum(
jnp.log(jnp.diag(self.scale_tril))
) + 0.5 * self.scale_tril.shape[0] * np.log(2 * np.pi)
[docs]
def solve_triangular(self, y: JAXArray, *, transpose: bool = False) -> JAXArray:
if transpose:
return linalg.solve_triangular(self.scale_tril, y, lower=True, trans=1)
else:
return linalg.solve_triangular(self.scale_tril, y, lower=True)
[docs]
def dot_triangular(self, y: JAXArray) -> JAXArray:
return jnp.einsum("ij,j...->i...", self.scale_tril, y)
[docs]
def condition(
self, kernel: kernels.Kernel, X_test: JAXArray | None, noise: Noise
) -> Any:
"""Compute the covariance matrix for a conditional GP
Args:
kernel: The kernel for the covariance between the observed and
predicted data.
X_test: The coordinates of the predicted points. Defaults to the
input coordinates.
noise: The noise model for the predicted process.
"""
if X_test is None:
Ks = kernel(self.X, self.X)
Kss = Ks + noise
else:
Ks = kernel(self.X, X_test)
Kss = kernel(X_test, X_test) + noise
A = self.solve_triangular(Ks)
return Kss - A.transpose() @ A