"""
This subpackage provides the tools needed to build expressive observation
processes for ``tinygp`` Gaussian process models. The most commonly used noise
model is :class:`Diagonal`, which adds a constant diagonal matrix to the process
covariance to represent per-observation noise. This subpackage also includes a
:class:`Dense` model for adding a full rank observation model, and
:class:`Banded` to capture noise that can be represented by a banded matrix.
"""
from __future__ import annotations
__all__ = ["Diagonal", "Dense", "Banded"]
from abc import abstractmethod
from typing import TYPE_CHECKING
import equinox as eqx
import jax.numpy as jnp
import numpy as np
from tinygp.helpers import JAXArray
if TYPE_CHECKING:
from tinygp.solvers.quasisep.core import DiagQSM, SymmQSM
class Noise(eqx.Module):
"""An abstract base class defining the noise model protocol"""
__array_priority__ = 2001
@abstractmethod
def diagonal(self) -> JAXArray:
"""The diagonal elements of the noise model as an array"""
raise NotImplementedError
@abstractmethod
def __add__(self, other: JAXArray) -> JAXArray:
raise NotImplementedError
@abstractmethod
def __radd__(self, other: JAXArray) -> JAXArray:
raise NotImplementedError
@abstractmethod
def __matmul__(self, other: JAXArray) -> JAXArray:
raise NotImplementedError
@abstractmethod
def to_qsm(self) -> SymmQSM | DiagQSM:
"""This noise model represented as a quasiseparable matrix"""
raise NotImplementedError
[docs]
class Diagonal(Noise):
"""A diagonal observation noise model
This represents the observation model using per-observation measurement
variances.
Args:
diag: The diagonal elements of the noise model.
"""
diag: JAXArray
def __check_init__(self) -> None:
if jnp.ndim(self.diag) != 1:
raise ValueError(
"The diagonal for the noise model be the same shape as the data; "
"if passing a constant, it should be broadcasted first"
)
[docs]
def diagonal(self) -> JAXArray:
return self.diag
def _add(self, other: JAXArray) -> JAXArray:
return jnp.asarray(other).at[jnp.diag_indices(other.shape[0])].add(self.diag)
def __add__(self, other: JAXArray) -> JAXArray:
return self._add(other)
def __radd__(self, other: JAXArray) -> JAXArray:
return self._add(other)
def __matmul__(self, other: JAXArray) -> JAXArray:
if jnp.ndim(other) == 1:
return self.diag * other
else:
return self.diag[:, None] * other
[docs]
def to_qsm(self) -> DiagQSM:
from tinygp.solvers.quasisep.core import DiagQSM
return DiagQSM(d=self.diag)
[docs]
class Dense(Noise):
"""A full rank observation noise model
.. warning:: This model cannot be used in conjunction with the
:class:`tinygp.solvers.QuasisepSolver` for scalable computations.
Args:
value: The N-by-N full rank observation model.
"""
value: JAXArray
[docs]
def diagonal(self) -> JAXArray:
return jnp.diag(self.value)
def __add__(self, other: JAXArray) -> JAXArray:
return self.value + other
def __radd__(self, other: JAXArray) -> JAXArray:
return other + self.value
def __matmul__(self, other: JAXArray) -> JAXArray:
return self.value @ other
[docs]
def to_qsm(self) -> SymmQSM | DiagQSM:
"""This cannot be compactly represented as a quasiseparable matrix"""
raise NotImplementedError
[docs]
class Banded(Noise):
r"""A banded observation noise model
This model captures noise that can be represented by a small number of
off-diagonal elements in the observation matrix. One practical example of
such an observation model is discussed by `Delisle et al. (2020)
<https://arxiv.org/abs/2004.10678>`_. This matrix is defined by two arrays:
``diag`` and ``off_diags``, with shapes ``(N,)`` and ``(N, J)``
respectively, where ``N`` is the number of data points and ``J`` is the
number of non-zero off-diagonals required.
For example, the following matrix has ``N = 4`` and ``J = 2``:
.. math::
N = \left(\begin{array}{cccc}
n_{11} & n_{12} & n_{13} & 0 \\
n_{12} & n_{22} & n_{23} & n_{24} \\
n_{13} & n_{23} & n_{33} & n_{34} \\
0 & n_{24} & n_{34} & n_{44}
\end{array}\right)
and it would be represented by the following arrays:
.. code-block:: python
diag = [n11, n22, n33, n44]
and
.. code-block:: python
off_diags = [
[n12, n13],
[n23, n24],
[n34, * ],
[ *, * ],
]
Where ``*`` represents an element that can have any arbitrary value, since it
won't ever be accessed.
"""
diag: JAXArray
off_diags: JAXArray
[docs]
def diagonal(self) -> JAXArray:
return self.diag
def _indices(
self,
) -> tuple[tuple[JAXArray, JAXArray], tuple[JAXArray, JAXArray]]:
N, J = jnp.shape(self.off_diags)
sparse_idx_1 = []
sparse_idx_2 = []
dense_idx_1 = []
dense_idx_2 = []
for j in range(J):
sparse_idx_1.append(np.arange(N - j - 1))
sparse_idx_2.append(np.full(N - j - 1, j, dtype=int))
dense_idx_1.append(np.arange(0, N - j - 1))
dense_idx_2.append(np.arange(j + 1, N))
return (
(np.concatenate(sparse_idx_1), np.concatenate(sparse_idx_2)),
(np.concatenate(dense_idx_1), np.concatenate(dense_idx_2)),
)
def _add(self, other: JAXArray) -> JAXArray:
sparse_idx, dense_idx = self._indices()
# Start by adding the diagonal
result = jnp.asarray(other).at[jnp.diag_indices(other.shape[0])].add(self.diag)
# Then the off diagonals, assuming symmetric
return result.at[
(
np.append(dense_idx[0], dense_idx[1]),
np.append(dense_idx[1], dense_idx[0]),
)
].add(
self.off_diags[
(
np.append(sparse_idx[0], sparse_idx[0]),
np.append(sparse_idx[1], sparse_idx[1]),
)
]
)
def __add__(self, other: JAXArray) -> JAXArray:
return self._add(other)
def __radd__(self, other: JAXArray) -> JAXArray:
return self._add(other)
def __matmul__(self, other: JAXArray) -> JAXArray:
return self.to_qsm() @ other
[docs]
def to_qsm(self) -> SymmQSM:
from tinygp.solvers.quasisep import core
N, J = jnp.shape(self.off_diags)
p = jnp.repeat(jnp.eye(1, J), N, axis=0)
q = self.off_diags
a = jnp.repeat(jnp.eye(J, k=1)[None], N, axis=0)
return core.SymmQSM(
diag=core.DiagQSM(d=self.diag),
lower=core.StrictLowerTriQSM(p=p, q=q, a=a),
)