GaussianProcess#
- class tinygp.GaussianProcess(kernel: kernels.Kernel, X: JAXArray, *, diag: JAXArray | None = None, noise: Noise | None = None, mean: means.MeanBase | Callable[[JAXArray], JAXArray] | JAXArray | None = None, solver: Any | None = None, mean_value: JAXArray | None = None, covariance_value: Any | None = None, **solver_kwargs: Any)[source]#
Bases:
Module
An interface for designing a Gaussian Process regression model
- Parameters:
kernel (Kernel) – The kernel function
X (JAXArray) – The input coordinates. This can be any PyTree that is compatible with
kernel
where the zeroth dimension isN_data
, the size of the data set.diag (JAXArray, optional) – The value to add to the diagonal of the covariance matrix, often used to capture measurement uncertainty. This should be a scalar or have the shape
(N_data,)
. If not provided, this will default to the square root of machine epsilon for the data type being used. This can sometimes be sufficient to avoid numerical issues, but if you’re getting NaNs, try increasing this value.noise (Noise, optional) – Used to implement more expressive observation noise models than those supported by just
diag
. This can be any object that implements thetinygp.noise.Noise
protocol. If this is provided, thediag
parameter will be ignored.mean (Callable, optional) – A callable or constant mean function that will be evaluated with the
X
as input:mean(X)
solver – The solver type to be used to execute the required linear algebra.
- condition(y: JAXArray, X_test: JAXArray | None = None, *, diag: JAXArray | None = None, noise: Noise | None = None, include_mean: bool = True, kernel: kernels.Kernel | None = None) ConditionResult [source]#
Condition the model on observed data and
- Parameters:
y (JAXArray) – The observed data. This should have the shape
(N_data,)
, whereN_data
was the zeroth axis of theX
data provided when instantiating this object.X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the
X
data provided when instantiating this object. If it is not provided,X
will be used by default, so the predictions will be made.diag (JAXArray, optional) – Will be passed as the diagonal to the conditioned
GaussianProcess
object, so this can be used to introduce, for example, observational noise to predicted data.include_mean (bool, optional) – If
True
(default), the predicted values will include the mean function evaluated atX_test
.kernel (Kernel, optional) – A kernel to optionally specify the covariance between the observed data and predicted data. See Mixture of Kernels for an example.
- Returns:
A named tuple where the first element
log_probability
is the log marginal probability of the model, and the second elementgp
is theGaussianProcess
object describing the conditional distribution evaluated atX_test
.
- log_probability(y: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray [source]#
Compute the log probability of this multivariate normal
- Parameters:
y (JAXArray) – The observed data. This should have the shape
(N_data,)
, whereN_data
was the zeroth axis of theX
data provided when instantiating this object.- Returns:
The marginal log probability of this multivariate normal model, evaluated at
y
.
- numpyro_dist(**kwargs: Any) TinyDistribution [source]#
Get the numpyro MultivariateNormal distribution for this process
- predict(y: JAXArray, X_test: JAXArray | None = None, *, kernel: kernels.Kernel | None = None, include_mean: bool = True, return_var: bool = False, return_cov: bool = False) JAXArray | tuple[JAXArray, JAXArray] [source]#
Predict the GP model at new test points conditioned on observed data
- Parameters:
y (JAXArray) – The observed data. This should have the shape
(N_data,)
, whereN_data
was the zeroth axis of theX
data provided when instantiating this object.X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the
X
data provided when instantiating this object. If it is not provided,X
will be used by default, so the predictions will be made.include_mean (bool, optional) – If
True
(default), the predicted values will include the mean function evaluated atX_test
.return_var (bool, optional) – If
True
, the variance of the predicted values atX_test
will be returned.return_cov (bool, optional) – If
True
, the covariance of the predicted values atX_test
will be returned. Ifreturn_var
isTrue
, this flag will be ignored.
- Returns:
The mean of the predictive model evaluated at
X_test
, with shape(N_test,)
whereN_test
is the zeroth dimension ofX_test
. If eitherreturn_var
orreturn_cov
isTrue
, the variance or covariance of the predicted process will also be returned with shape(N_test,)
or(N_test, N_test)
respectively.
- sample(key: jax.random.KeyArray, shape: Sequence[int] | None = None) JAXArray [source]#
Generate samples from the prior process
- Parameters:
key – A
jax
random number key array. shape (tuple, optional): Theto (number and shape of samples) – generate.
- Returns:
The sampled realizations from the process with shape
(N_data,) + shape
whereN_data
is the zeroth dimension of theX
coordinates provided when instantiating this process.