Source code for tinygp.means
"""
In ``tinygp``, the Gaussian process mean function can be defined using any
callable object, but this submodule includes two helper classes for defining
means. When defining your own mean function, it's important to remember that
your callable should accept as input a single input coordinate (i.e. not a
*vector* of coordinates), and return the scalar value of the mean at that
coordinate. ``tinygp`` will handle all the relevant ``vmap``-ing and
broadcasting.
"""
from __future__ import annotations
__all__ = ["Mean", "Conditioned"]
from abc import abstractmethod
from typing import Callable
import equinox as eqx
import jax
from tinygp.helpers import JAXArray
from tinygp.kernels.base import Kernel
class MeanBase(eqx.Module):
@abstractmethod
def __call__(self, X: JAXArray) -> JAXArray:
raise NotImplementedError
[docs]
class Mean(MeanBase):
"""A wrapper for the GP mean which supports a constant value or a callable
In ``tinygp``, a mean function can be any callable which takes as input a
single coordinate and returns the scalar mean at that location.
Args:
value: Either a *scalar* constant, or a callable with the correct
signature.
"""
value: JAXArray | None = None
func: Callable[[JAXArray], JAXArray] | None = eqx.field(default=None, static=True)
def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]):
if callable(value):
self.func = value
else:
self.value = value
def __call__(self, X: JAXArray) -> JAXArray:
if self.value is None:
assert self.func is not None
return self.func(X)
return self.value
[docs]
class Conditioned(MeanBase):
r"""The mean of a process conditioned on observed data
Args:
X: The coordinates of the data. alpha: The value :math:`L^-1\,y` where L
is ``scale_tril`` and y is the
observed data.
scale_tril: The lower Cholesky factor of the base process' kernel
matrix.
kernel: The predictive kerenl; this will generally be the kernel from
the kernel used by the original process.
include_mean: If ``True``, the predicted values will include the mean
function evaluated at ``X_test``.
mean_function: The mean function of the base process. Used only if
``include_mean`` is ``True``.
"""
X: JAXArray
alpha: JAXArray
kernel: Kernel
include_mean: bool
mean_function: MeanBase | None = None
def __call__(self, X: JAXArray) -> JAXArray:
Ks = jax.vmap(self.kernel.evaluate, in_axes=(None, 0), out_axes=0)(X, self.X)
mu = Ks @ self.alpha
if self.include_mean and self.mean_function is not None:
mu += self.mean_function(X)
return mu