from __future__ import annotations
__all__ = ["GaussianProcess"]
from collections.abc import Sequence
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
NamedTuple,
)
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from tinygp import kernels, means
from tinygp.helpers import JAXArray
from tinygp.kernels.quasisep import Quasisep
from tinygp.noise import Diagonal, Noise
from tinygp.solvers import DirectSolver, QuasisepSolver
from tinygp.solvers.quasisep.core import SymmQSM
from tinygp.solvers.solver import Solver
if TYPE_CHECKING:
from tinygp.numpyro_support import TinyDistribution
[docs]
class GaussianProcess(eqx.Module):
"""An interface for designing a Gaussian Process regression model
Args:
kernel (Kernel): The kernel function
X (JAXArray): The input coordinates. This can be any PyTree that is
compatible with ``kernel`` where the zeroth dimension is ``N_data``,
the size of the data set.
diag (JAXArray, optional): The value to add to the diagonal of the
covariance matrix, often used to capture measurement uncertainty.
This should be a scalar or have the shape ``(N_data,)``. If not
provided, this will default to the square root of machine epsilon
for the data type being used. This can sometimes be sufficient to
avoid numerical issues, but if you're getting NaNs, try increasing
this value.
noise (Noise, optional): Used to implement more expressive observation
noise models than those supported by just ``diag``. This can be any
object that implements the :class:`tinygp.noise.Noise` protocol. If
this is provided, the ``diag`` parameter will be ignored.
mean (Callable, optional): A callable or constant mean function that
will be evaluated with the ``X`` as input: ``mean(X)``
solver: The solver type to be used to execute the required linear
algebra.
"""
num_data: int = eqx.field(static=True)
dtype: np.dtype = eqx.field(static=True)
kernel: kernels.Kernel
X: JAXArray
mean_function: means.MeanBase
mean: JAXArray
noise: Noise
solver: Solver
def __init__(
self,
kernel: kernels.Kernel,
X: JAXArray,
*,
diag: JAXArray | None = None,
noise: Noise | None = None,
mean: means.MeanBase | Callable[[JAXArray], JAXArray] | JAXArray | None = None,
solver: Any | None = None,
mean_value: JAXArray | None = None,
covariance_value: Any | None = None,
**solver_kwargs: Any,
):
self.kernel = kernel
self.X = X
if isinstance(mean, means.MeanBase):
self.mean_function = mean
elif mean is None:
self.mean_function = means.Mean(jnp.zeros(()))
else:
self.mean_function = means.Mean(mean)
if mean_value is None:
mean_value = jax.vmap(self.mean_function)(self.X)
self.num_data = mean_value.shape[0]
self.dtype = mean_value.dtype
self.mean = mean_value
if self.mean.ndim != 1:
raise ValueError(
"Invalid mean shape: " f"expected ndim = 1, got ndim={self.mean.ndim}"
)
if noise is None:
diag = _default_diag(self.mean) if diag is None else diag
noise = Diagonal(diag=jnp.broadcast_to(diag, self.mean.shape))
self.noise = noise
if solver is None:
if isinstance(covariance_value, SymmQSM) or isinstance(kernel, Quasisep):
solver = QuasisepSolver
else:
solver = DirectSolver
self.solver = solver(
kernel,
self.X,
self.noise,
covariance=covariance_value,
**solver_kwargs,
)
@property
def loc(self) -> JAXArray:
return self.mean
@property
def variance(self) -> JAXArray:
return self.solver.variance()
@property
def covariance(self) -> JAXArray:
return self.solver.covariance()
[docs]
def log_probability(self, y: JAXArray) -> JAXArray:
"""Compute the log probability of this multivariate normal
Args:
y (JAXArray): The observed data. This should have the shape
``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
data provided when instantiating this object.
Returns:
The marginal log probability of this multivariate normal model,
evaluated at ``y``.
"""
return self._compute_log_prob(self._get_alpha(y))
[docs]
def condition(
self,
y: JAXArray,
X_test: JAXArray | None = None,
*,
diag: JAXArray | None = None,
noise: Noise | None = None,
include_mean: bool = True,
kernel: kernels.Kernel | None = None,
) -> ConditionResult:
"""Condition the model on observed data and
Args:
y (JAXArray): The observed data. This should have the shape
``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
data provided when instantiating this object.
X_test (JAXArray, optional): The coordinates where the prediction
should be evaluated. This should have a data type compatible
with the ``X`` data provided when instantiating this object. If
it is not provided, ``X`` will be used by default, so the
predictions will be made.
diag (JAXArray, optional): Will be passed as the diagonal to the
conditioned ``GaussianProcess`` object, so this can be used to
introduce, for example, observational noise to predicted data.
include_mean (bool, optional): If ``True`` (default), the predicted
values will include the mean function evaluated at ``X_test``.
kernel (Kernel, optional): A kernel to optionally specify the
covariance between the observed data and predicted data. See
:ref:`mixture` for an example.
Returns:
A named tuple where the first element ``log_probability`` is the log
marginal probability of the model, and the second element ``gp`` is
the :class:`GaussianProcess` object describing the conditional
distribution evaluated at ``X_test``.
"""
# If X_test is provided, we need to check that the tree structure
# matches that of the input data, and that the shapes are all compatible
# (i.e. the dimension of the inputs must match). This is slightly
# convoluted since we need to support arbitrary pytrees.
if X_test is not None:
matches = jax.tree_util.tree_map(
lambda a, b: jnp.ndim(a) == jnp.ndim(b)
and jnp.shape(a)[1:] == jnp.shape(b)[1:],
self.X,
X_test,
)
if not jax.tree_util.tree_reduce(lambda a, b: a and b, matches):
raise ValueError(
"`X_test` must have the same tree structure as the input `X`, "
"and all but the leading dimension must have matching sizes"
)
alpha, log_prob, mean_value = self._condition(y, X_test, include_mean, kernel)
if kernel is None:
kernel = self.kernel
if noise is None:
diag = _default_diag(mean_value) if diag is None else diag
noise = Diagonal(diag=jnp.broadcast_to(diag, mean_value.shape))
covariance_value = self.solver.condition(kernel, X_test, noise)
if X_test is None:
X_test = self.X
# The conditional GP will also be a GP with the mean an covariance
# specified by a :class:`tinygp.means.Conditioned` and
# :class:`tinygp.kernels.Conditioned` respectively.
gp = GaussianProcess(
kernels.Conditioned(self.X, self.solver, kernel),
X_test,
noise=noise,
mean=means.Conditioned(
self.X,
alpha,
kernel,
include_mean=include_mean,
mean_function=self.mean_function,
),
mean_value=mean_value,
covariance_value=covariance_value,
)
return ConditionResult(log_prob, gp)
[docs]
@partial(
jax.jit,
static_argnames=("include_mean", "return_var", "return_cov"),
)
def predict(
self,
y: JAXArray,
X_test: JAXArray | None = None,
*,
kernel: kernels.Kernel | None = None,
include_mean: bool = True,
return_var: bool = False,
return_cov: bool = False,
) -> JAXArray | tuple[JAXArray, JAXArray]:
"""Predict the GP model at new test points conditioned on observed data
Args:
y (JAXArray): The observed data. This should have the shape
``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
data provided when instantiating this object.
X_test (JAXArray, optional): The coordinates where the prediction
should be evaluated. This should have a data type compatible
with the ``X`` data provided when instantiating this object. If
it is not provided, ``X`` will be used by default, so the
predictions will be made.
include_mean (bool, optional): If ``True`` (default), the predicted
values will include the mean function evaluated at ``X_test``.
return_var (bool, optional): If ``True``, the variance of the
predicted values at ``X_test`` will be returned.
return_cov (bool, optional): If ``True``, the covariance of the
predicted values at ``X_test`` will be returned. If
``return_var`` is ``True``, this flag will be ignored.
Returns:
The mean of the predictive model evaluated at ``X_test``, with shape
``(N_test,)`` where ``N_test`` is the zeroth dimension of
``X_test``. If either ``return_var`` or ``return_cov`` is ``True``,
the variance or covariance of the predicted process will also be
returned with shape ``(N_test,)`` or ``(N_test, N_test)``
respectively.
"""
_, cond = self.condition(y, X_test, kernel=kernel, include_mean=include_mean)
if return_var:
return cond.loc, cond.variance
if return_cov:
return cond.loc, cond.covariance
return cond.loc
[docs]
def sample(
self,
key: jax.random.KeyArray,
shape: Sequence[int] | None = None,
) -> JAXArray:
"""Generate samples from the prior process
Args:
key: A ``jax`` random number key array. shape (tuple, optional): The
number and shape of samples to
generate.
Returns:
The sampled realizations from the process with shape ``(N_data,) +
shape`` where ``N_data`` is the zeroth dimension of the ``X``
coordinates provided when instantiating this process.
"""
return self._sample(key, shape)
[docs]
def numpyro_dist(self, **kwargs: Any) -> TinyDistribution:
"""Get the numpyro MultivariateNormal distribution for this process"""
from tinygp.numpyro_support import TinyDistribution
return TinyDistribution(self, **kwargs)
@partial(jax.jit, static_argnums=(2,))
def _sample(
self,
key: jax.random.KeyArray,
shape: Sequence[int] | None,
) -> JAXArray:
if shape is None:
shape = (self.num_data,)
else:
shape = (self.num_data,) + tuple(shape)
normal_samples = jax.random.normal(key, shape=shape, dtype=self.dtype)
return self.mean + jnp.moveaxis(
self.solver.dot_triangular(normal_samples), 0, -1
)
@jax.jit
def _compute_log_prob(self, alpha: JAXArray) -> JAXArray:
loglike = -0.5 * jnp.sum(jnp.square(alpha)) - self.solver.normalization()
return jnp.where(jnp.isfinite(loglike), loglike, -jnp.inf)
@jax.jit
def _get_alpha(self, y: JAXArray) -> JAXArray:
return self.solver.solve_triangular(y - self.loc)
@partial(jax.jit, static_argnums=(3,))
def _condition(
self,
y: JAXArray,
X_test: JAXArray | None,
include_mean: bool,
kernel: kernels.Kernel | None = None,
) -> tuple[JAXArray, JAXArray, JAXArray]:
alpha = self._get_alpha(y)
log_prob = self._compute_log_prob(alpha)
# Below, we actually want alpha = K^-1 y instead of alpha = L^-1 y
alpha = self.solver.solve_triangular(alpha, transpose=True)
if X_test is None:
X_test = self.X
# In this common case (where we're predicting the GP at the data
# points, using the original kernel), the mean is especially fast to
# compute; so let's use that calculation here.
if kernel is None:
delta = self.noise @ alpha
mean_value = y - delta
if not include_mean:
mean_value -= self.loc
else:
mean_value = kernel.matmul(self.X, y=alpha)
if include_mean:
mean_value += self.loc
else:
if kernel is None:
kernel = self.kernel
mean_value = kernel.matmul(X_test, self.X, alpha)
if include_mean:
mean_value += jax.vmap(self.mean_function)(X_test)
return alpha, log_prob, mean_value
[docs]
class ConditionResult(NamedTuple):
"""The result of conditioning a :class:`GaussianProcess` on data
This has two entries, ``log_probability`` and ``gp``, that are described
below.
"""
log_probability: JAXArray
"""The log probability of the conditioned model
In other words, this is the marginal likelihood for the kernel parameters,
given the observed data, or the multivariate normal log probability
evaluated at the given data.
"""
gp: GaussianProcess
"""A :class:`GaussianProcess` describing the conditional distribution
This will have a mean and covariance conditioned on the observed data, but
it is otherwise a fully functional GP that can sample from or condition
further (although that's probably not going to be very efficient).
"""
def _default_diag(reference: JAXArray) -> JAXArray:
"""Default to adding some amount of jitter to the diagonal, just in case,
we use sqrt(eps) for the dtype of the mean function because that seems to
give sensible results in general.
"""
return jnp.sqrt(jnp.finfo(reference).eps)