"""
The kernels implemented in this subpackage are used with the
:class:`tinygp.solvers.QuasisepSolver` to allow scalable GP computations by
exploting quasiseparable structure in the relevant matrices (see
:ref:`api-solvers-quasisep` for more technical details). For now, these methods
are experimental, so you may find the documentation patchy in places. You are
encouraged to `open issues or pull requests
<https://github.com/dfm/tinygp/issues>`_ as you find gaps.
"""
from __future__ import annotations
__all__ = [
"Quasisep",
"Wrapper",
"Sum",
"Product",
"Scale",
"Celerite",
"SHO",
"Exp",
"Matern32",
"Matern52",
"Cosine",
"CARMA",
]
from abc import abstractmethod
from typing import Any
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
from tinygp.helpers import JAXArray
from tinygp.kernels.base import Kernel
from tinygp.solvers.quasisep.core import DiagQSM, StrictLowerTriQSM, SymmQSM
from tinygp.solvers.quasisep.general import GeneralQSM
[docs]
class Quasisep(Kernel):
"""The base class for all quasiseparable kernels
Instead of directly implementing the ``p``, ``q``, and ``a`` elements of the
:class:`tinygp.solvers.quasisep.core.StrictLowerQSM`, this class implements
``h``, ``Pinf``, and ``A``, where:
- ``q = h``,
- ``p = h.T @ Pinf @ A``, and
- ``a = A``.
This notation follows the notation from state space models for stochastic
differential equations, and so far it seems like a good way to specify these
models, but these details are subject to change in future versions of
``tinygp``.
"""
[docs]
@abstractmethod
def design_matrix(self) -> JAXArray:
"""The design matrix for the process"""
raise NotImplementedError
[docs]
@abstractmethod
def stationary_covariance(self) -> JAXArray:
"""The stationary covariance of the process"""
raise NotImplementedError
[docs]
@abstractmethod
def observation_model(self, X: JAXArray) -> JAXArray:
"""The observation model for the process"""
raise NotImplementedError
[docs]
@abstractmethod
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
"""The transition matrix between two coordinates"""
raise NotImplementedError
[docs]
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""A helper function used to convert coordinates to sortable 1-D values
By default, this is the identity, but in cases where ``X`` is structured
(e.g. multivariate inputs), this can be used to appropriately unwrap
that structure.
"""
return X
[docs]
def to_symm_qsm(self, X: JAXArray) -> SymmQSM:
"""The symmetric quasiseparable representation of this kernel"""
Pinf = self.stationary_covariance()
a = jax.vmap(self.transition_matrix)(
jax.tree_util.tree_map(lambda y: jnp.append(y[0], y[:-1]), X), X
)
h = jax.vmap(self.observation_model)(X)
q = h
p = h @ Pinf
d = jnp.sum(p * q, axis=1)
p = jax.vmap(jnp.dot)(p, a)
return SymmQSM(diag=DiagQSM(d=d), lower=StrictLowerTriQSM(p=p, q=q, a=a))
[docs]
def to_general_qsm(self, X1: JAXArray, X2: JAXArray) -> GeneralQSM:
"""The generalized quasiseparable representation of this kernel"""
sortable = jax.vmap(self.coord_to_sortable)
idx = jnp.searchsorted(sortable(X2), sortable(X1), side="right") - 1
Xs = jax.tree_util.tree_map(lambda x: jnp.append(x[0], x[:-1]), X2)
Pinf = self.stationary_covariance()
a = jax.vmap(self.transition_matrix)(Xs, X2)
h1 = jax.vmap(self.observation_model)(X1)
h2 = jax.vmap(self.observation_model)(X2)
ql = h2
pl = h1 @ Pinf
qu = h1
pu = h2 @ Pinf
i = jnp.clip(idx, 0, ql.shape[0] - 1)
Xi = jax.tree_util.tree_map(lambda x: jnp.asarray(x)[i], X2)
pl = jax.vmap(jnp.dot)(pl, jax.vmap(self.transition_matrix)(Xi, X1))
i = jnp.clip(idx + 1, 0, pu.shape[0] - 1)
Xi = jax.tree_util.tree_map(lambda x: jnp.asarray(x)[i], X2)
qu = jax.vmap(jnp.dot)(jax.vmap(self.transition_matrix)(X1, Xi), qu)
return GeneralQSM(pl=pl, ql=ql, pu=pu, qu=qu, a=a, idx=idx)
def matmul(
self,
X1: JAXArray,
X2: JAXArray | None = None,
y: JAXArray | None = None,
) -> JAXArray:
if y is None:
assert X2 is not None
y = X2
X2 = None
if X2 is None:
return self.to_symm_qsm(X1) @ y
else:
return self.to_general_qsm(X1, X2) @ y
def __add__(self, other: Kernel | JAXArray) -> Kernel:
if not isinstance(other, Quasisep):
raise ValueError(
"Quasisep kernels can only be added to other Quasisep kernels"
)
return Sum(self, other)
def __radd__(self, other: Any) -> Kernel:
# We'll hit this first branch when using the `sum` function
if other == 0:
return self
if not isinstance(other, Quasisep):
raise ValueError(
"Quasisep kernels can only be added to other Quasisep kernels"
)
return Sum(other, self)
def __mul__(self, other: Kernel | JAXArray) -> Kernel:
if isinstance(other, Quasisep):
return Product(self, other)
if isinstance(other, Kernel) or jnp.ndim(other) != 0:
raise ValueError(
"Quasisep kernels can only be multiplied by scalars and other "
"Quasisep kernels"
)
return Scale(kernel=self, scale=other)
def __rmul__(self, other: Any) -> Kernel:
if isinstance(other, Quasisep):
return Product(other, self)
if isinstance(other, Kernel) or jnp.ndim(other) != 0:
raise ValueError(
"Quasisep kernels can only be multiplied by scalars and other "
"Quasisep kernels"
)
return Scale(kernel=self, scale=other)
[docs]
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
"""The kernel evaluated via the quasiseparable representation"""
Pinf = self.stationary_covariance()
h1 = self.observation_model(X1)
h2 = self.observation_model(X2)
return jnp.where(
self.coord_to_sortable(X1) < self.coord_to_sortable(X2),
h2 @ Pinf @ self.transition_matrix(X1, X2) @ h1,
h1 @ Pinf @ self.transition_matrix(X2, X1) @ h2,
)
[docs]
def evaluate_diag(self, X: JAXArray) -> JAXArray:
"""For quasiseparable kernels, the variance is simple to compute"""
h = self.observation_model(X)
return h @ self.stationary_covariance() @ h
[docs]
class Wrapper(Quasisep):
"""A base class for wrapping kernels with some custom implementations"""
kernel: Quasisep
[docs]
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
return self.kernel.coord_to_sortable(X)
[docs]
def design_matrix(self) -> JAXArray:
return self.kernel.design_matrix()
[docs]
def stationary_covariance(self) -> JAXArray:
return self.kernel.stationary_covariance()
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
return self.kernel.observation_model(self.coord_to_sortable(X))
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return self.kernel.transition_matrix(
self.coord_to_sortable(X1), self.coord_to_sortable(X2)
)
[docs]
class Sum(Quasisep):
"""A helper to represent the sum of two quasiseparable kernels"""
kernel1: Quasisep
kernel2: Quasisep
[docs]
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""We assume that both kernels use the same coordinates"""
return self.kernel1.coord_to_sortable(X)
[docs]
def design_matrix(self) -> JAXArray:
return jsp.linalg.block_diag(
self.kernel1.design_matrix(), self.kernel2.design_matrix()
)
[docs]
def stationary_covariance(self) -> JAXArray:
return jsp.linalg.block_diag(
self.kernel1.stationary_covariance(),
self.kernel2.stationary_covariance(),
)
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
return jnp.concatenate(
(
self.kernel1.observation_model(X),
self.kernel2.observation_model(X),
)
)
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return jsp.linalg.block_diag(
self.kernel1.transition_matrix(X1, X2),
self.kernel2.transition_matrix(X1, X2),
)
[docs]
class Product(Quasisep):
"""A helper to represent the product of two quasiseparable kernels"""
kernel1: Quasisep
kernel2: Quasisep
[docs]
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""We assume that both kernels use the same coordinates"""
return self.kernel1.coord_to_sortable(X)
[docs]
def design_matrix(self) -> JAXArray:
F1 = self.kernel1.design_matrix()
F2 = self.kernel2.design_matrix()
return _prod_helper(F1, jnp.eye(F2.shape[0])) + _prod_helper(
jnp.eye(F1.shape[0]), F2
)
[docs]
def stationary_covariance(self) -> JAXArray:
return _prod_helper(
self.kernel1.stationary_covariance(),
self.kernel2.stationary_covariance(),
)
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
return _prod_helper(
self.kernel1.observation_model(X),
self.kernel2.observation_model(X),
)
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return _prod_helper(
self.kernel1.transition_matrix(X1, X2),
self.kernel2.transition_matrix(X1, X2),
)
[docs]
class Scale(Wrapper):
"""The product of a scalar and a quasiseparable kernel"""
scale: JAXArray | float
[docs]
def stationary_covariance(self) -> JAXArray:
return self.scale * self.kernel.stationary_covariance()
[docs]
class Celerite(Quasisep):
r"""The baseline kernel from the ``celerite`` package
This form of the kernel was introduced by `Foreman-Mackey et al. (2017)
<https://arxiv.org/abs/1703.09710>`_, and implemented in the `celerite
<https://celerite.readthedocs.io>`_ package. It shouldn't generally be used
on its own, and other kernels described in this subpackage should generally
be preferred.
This kernel takes the form:
.. math::
k(\tau)=\exp(-c\,\tau)\,\left[a\,\cos(d\,\tau)+b\,\sin(d\,\tau)\right]
for :math:`\tau = |x_i - x_j|`.
In order to be positive definite, the parameters of this kernel must satisfy
:math:`a\,c - b\,d > 0`, and you will see NaNs if you use parameters that
don't satisfy this relationship.
"""
a: JAXArray | float
b: JAXArray | float
c: JAXArray | float
d: JAXArray | float
[docs]
def design_matrix(self) -> JAXArray:
return jnp.array([[-self.c, -self.d], [self.d, -self.c]])
[docs]
def stationary_covariance(self) -> JAXArray:
c = self.c
d = self.d
return jnp.array(
[
[1, -c / d],
[-c / d, 1 + 2 * jnp.square(c) / jnp.square(d)],
]
)
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
del X
a = self.a
b = self.b
c = self.c
d = self.d
c2 = jnp.square(c)
d2 = jnp.square(d)
s2 = c2 + d2
h2_2 = d2 * (a * c - b * d) / (2 * c * s2)
h2 = jnp.sqrt(h2_2)
h1 = (c * h2 - jnp.sqrt(a * d2 - s2 * h2_2)) / d
return jnp.array([h1, h2])
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
dt = X2 - X1
cos = jnp.cos(self.d * dt)
sin = jnp.sin(self.d * dt)
return jnp.exp(-self.c * dt) * jnp.array([[cos, -sin], [sin, cos]]).T
[docs]
class SHO(Quasisep):
r"""The damped, driven simple harmonic oscillator kernel
This form of the kernel was introduced by `Foreman-Mackey et al. (2017)
<https://arxiv.org/abs/1703.09710>`_, and it takes the form:
.. math::
k(\tau) = \sigma^2\,\exp\left(-\frac{\omega\,\tau}{2\,Q}\right)
\left\{\begin{array}{ll}
1 + \omega\,\tau & \mbox{for } Q = 1/2 \\
\cosh(f\,\omega\,\tau/2\,Q) + \sinh(f\,\omega\,\tau/2\,Q)/f
& \mbox{for } Q < 1/2 \\
\cos(g\,\omega\,\tau/2\,Q) + \sin(g\,\omega\,\tau/2\,Q)/g
& \mbox{for } Q > 1/2
\end{array}\right.
for :math:`\tau = |x_i - x_j|`, :math:`f = \sqrt{1 - 4\,Q^2}`, and
:math:`g = \sqrt{4\,Q^2 - 1}`.
Args:
omega: The parameter :math:`\omega`.
quality: The parameter :math:`Q`.
sigma (optional): The parameter :math:`\sigma`. Defaults to a value of
1. Specifying the explicit value here provides a slight performance
boost compared to independently multiplying the kernel with a
prefactor.
"""
omega: JAXArray | float
quality: JAXArray | float
sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))
[docs]
def design_matrix(self) -> JAXArray:
return jnp.array(
[[0, 1], [-jnp.square(self.omega), -self.omega / self.quality]]
)
[docs]
def stationary_covariance(self) -> JAXArray:
return jnp.diag(jnp.array([1, jnp.square(self.omega)]))
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
del X
return jnp.array([self.sigma, 0])
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
dt = X2 - X1
w = self.omega
q = self.quality
def critical(dt: JAXArray) -> JAXArray:
return jnp.exp(-w * dt) * jnp.array(
[[1 + w * dt, -jnp.square(w) * dt], [dt, 1 - w * dt]]
)
def underdamped(dt: JAXArray) -> JAXArray:
f = jnp.sqrt(jnp.maximum(4 * jnp.square(q) - 1, 0))
arg = 0.5 * f * w * dt / q
sin = jnp.sin(arg)
cos = jnp.cos(arg)
return jnp.exp(-0.5 * w * dt / q) * jnp.array(
[
[cos + sin / f, -2 * q * w * sin / f],
[2 * q * sin / (w * f), cos - sin / f],
]
)
def overdamped(dt: JAXArray) -> JAXArray:
f = jnp.sqrt(jnp.maximum(1 - 4 * jnp.square(q), 0))
arg = 0.5 * f * w * dt / q
sinh = jnp.sinh(arg)
cosh = jnp.cosh(arg)
return jnp.exp(-0.5 * w * dt / q) * jnp.array(
[
[cosh + sinh / f, -2 * q * w * sinh / f],
[2 * q * sinh / (w * f), cosh - sinh / f],
]
)
return jax.lax.cond(
jnp.allclose(q, 0.5),
critical,
lambda dt: jax.lax.cond(q > 0.5, underdamped, overdamped, dt),
dt,
)
[docs]
class Exp(Quasisep):
r"""A scalable implementation of :class:`tinygp.kernels.stationary.Exp`
This kernel takes the form:
.. math::
k(\tau)=\sigma^2\,\exp\left(-\frac{\tau}{\ell}\right)
for :math:`\tau = |x_i - x_j|`.
Args:
scale: The parameter :math:`\ell`.
sigma (optional): The parameter :math:`\sigma`. Defaults to a value of
1. Specifying the explicit value here provides a slight performance
boost compared to independently multiplying the kernel with a
prefactor.
"""
scale: JAXArray | float
sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))
[docs]
def design_matrix(self) -> JAXArray:
return jnp.array([[-1 / self.scale]])
[docs]
def stationary_covariance(self) -> JAXArray:
return jnp.ones((1, 1))
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
del X
return jnp.array([self.sigma])
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
dt = X2 - X1
return jnp.exp(-dt[None, None] / self.scale)
[docs]
class Matern32(Quasisep):
r"""A scalable implementation of :class:`tinygp.kernels.stationary.Matern32`
This kernel takes the form:
.. math::
k(\tau)=\sigma^2\,\left(1+f\,\tau\right)\,\exp(-f\,\tau)
for :math:`\tau = |x_i - x_j|` and :math:`f = \sqrt{3} / \ell`.
Args:
scale: The parameter :math:`\ell`.
sigma (optional): The parameter :math:`\sigma`. Defaults to a value of
1. Specifying the explicit value here provides a slight performance
boost compared to independently multiplying the kernel with a
prefactor.
"""
scale: JAXArray | float
sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))
def noise(self) -> JAXArray:
f = np.sqrt(3) / self.scale
return 4 * f**3
[docs]
def design_matrix(self) -> JAXArray:
f = np.sqrt(3) / self.scale
return jnp.array([[0, 1], [-jnp.square(f), -2 * f]])
[docs]
def stationary_covariance(self) -> JAXArray:
return jnp.diag(jnp.array([1, 3 / jnp.square(self.scale)]))
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
return jnp.array([self.sigma, 0])
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
dt = X2 - X1
f = np.sqrt(3) / self.scale
return jnp.exp(-f * dt) * jnp.array(
[[1 + f * dt, -jnp.square(f) * dt], [dt, 1 - f * dt]]
)
[docs]
class Matern52(Quasisep):
r"""A scalable implementation of :class:`tinygp.kernels.stationary.Matern52`
This kernel takes the form:
.. math::
k(\tau)=\sigma^2\,\left(1+f\,\tau + \frac{f^2\,\tau^2}{3}\right)
\,\exp(-f\,\tau)
for :math:`\tau = |x_i - x_j|` and :math:`f = \sqrt{5} / \ell`.
Args:
scale: The parameter :math:`\ell`.
sigma (optional): The parameter :math:`\sigma`. Defaults to a value of
1. Specifying the explicit value here provides a slight performance
boost compared to independently multiplying the kernel with a
prefactor.
"""
scale: JAXArray | float
sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))
[docs]
def design_matrix(self) -> JAXArray:
f = np.sqrt(5) / self.scale
f2 = jnp.square(f)
return jnp.array([[0, 1, 0], [0, 0, 1], [-f2 * f, -3 * f2, -3 * f]])
[docs]
def stationary_covariance(self) -> JAXArray:
f = np.sqrt(5) / self.scale
f2 = jnp.square(f)
f2o3 = f2 / 3
return jnp.array([[1, 0, -f2o3], [0, f2o3, 0], [-f2o3, 0, jnp.square(f2)]])
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
del X
return jnp.array([self.sigma, 0, 0])
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
dt = X2 - X1
f = np.sqrt(5) / self.scale
f2 = jnp.square(f)
d2 = jnp.square(dt)
return jnp.exp(-f * dt) * jnp.array(
[
[
0.5 * f2 * d2 + f * dt + 1,
-0.5 * f * f2 * d2,
0.5 * f2 * f * dt * (f * dt - 2),
],
[
dt * (f * dt + 1),
-f2 * d2 + f * dt + 1,
f2 * dt * (f * dt - 3),
],
[
0.5 * d2,
0.5 * dt * (2 - f * dt),
0.5 * f2 * d2 - 2 * f * dt + 1,
],
]
)
[docs]
class Cosine(Quasisep):
r"""A scalable implementation of :class:`tinygp.kernels.stationary.Cosine`
This kernel takes the form:
.. math::
k(\tau)=\sigma^2\,\cos(-2\,\pi\,\tau/\ell)
for :math:`\tau = |x_i - x_j|`.
Args:
scale: The parameter :math:`\ell`.
sigma (optional): The parameter :math:`\sigma`. Defaults to a value of
1. Specifying the explicit value here provides a slight performance
boost compared to independently multiplying the kernel with a
prefactor.
"""
scale: JAXArray | float
sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))
[docs]
def design_matrix(self) -> JAXArray:
f = 2 * np.pi / self.scale
return jnp.array([[0, -f], [f, 0]])
[docs]
def stationary_covariance(self) -> JAXArray:
return jnp.eye(2)
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
return jnp.array([self.sigma, 0])
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
dt = X2 - X1
f = 2 * np.pi / self.scale
cos = jnp.cos(f * dt)
sin = jnp.sin(f * dt)
return jnp.array([[cos, sin], [-sin, cos]])
def _prod_helper(a1: JAXArray, a2: JAXArray) -> JAXArray:
i, j = np.meshgrid(np.arange(a1.shape[0]), np.arange(a2.shape[0]))
i = i.flatten()
j = j.flatten()
if a1.ndim == 1:
return a1[i] * a2[j]
elif a1.ndim == 2:
return a1[i[:, None], i[None, :]] * a2[j[:, None], j[None, :]]
else:
raise NotImplementedError
[docs]
class CARMA(Quasisep):
r"""A continuous-time autoregressive moving average (CARMA) process kernel
This process has the power spectrum density (PSD)
.. math::
P(\omega) = \sigma^2\,\frac{|\sum_{q} \beta_q\,(i\,\omega)^q|^2}{|\sum_{p}
\alpha_p\,(i\,\omega)^p|^2}
defined following Equation 1 in `Kelly et al. (2014)
<https://arxiv.org/abs/1402.5978>`_, where :math:`\alpha_p` and :math:`\beta_0`
are set to 1. In this implementation, we absorb :math:`\sigma` into the
definition of :math:`\beta` parameters. That is :math:`\beta_{new}` =
:math:`\beta * \sigma`.
.. note::
To construct a stationary CARMA kernel/process, the roots of the
characteristic polynomials for Equation 1 in `Kelly et al. (2014)` must
have negative real parts. This condition can be met automatically by
requiring positive input parameters when instantiating the kernel using
the :func:`init` method for CARMA(1,0), CARMA(2,0), and CARMA(2,1)
models or by requiring positive input parameters when instantiating the
kernel using the :func:`from_quads` method.
.. note:: Implementation details
The logic behind this implementation is simple---finding the correct
combination of real/complex exponential kernels that resembles the
autocovariance function of the CARMA model. Note that the order also
matters. This task is achieved using the `acvf` method. Then the rest
is copied from the `Exp` and `Celerite` kernel.
Given the requirement of negative roots for stationarity, the
`from_quads` method is implemented to facilitate consturcting
stationary higher-order CARMA models beyond CARMA(2,1). The inputs for
`from_quads` are the coefficients of the quadratic equations factorized
out of the full characteristic polynomial. `poly2quads` is used to
factorize a polynomial into a product of said quadractic equations, and
`quads2poly` is used for the reverse process.
One last trick is the use of `_real_mask`, `_complex_mask`, and
`complex_select`, which are arrays of 0s and 1s. They are implemented
to avoid control flows. More specifically, some intermediate quantities
are computed regardless, but are only used if there is a matching real
or complex exponential kernel for the specific CARMA kernel.
Args:
alpha: The parameter :math:`\alpha` in the definition above, exlcuding
:math:`\alpha_p`. This should be an array of length `p`.
beta: The product of parameters :math:`\beta` and parameter :math:`\sigma`
in the definition above. This should be an array of length `q+1`,
where `q+1 <= p`.
"""
alpha: JAXArray
beta: JAXArray
sigma: JAXArray
arroots: JAXArray
acf: JAXArray
_real_mask: JAXArray
_complex_mask: JAXArray
_complex_select: JAXArray
obsmodel: JAXArray
def __init__(self, alpha: Any, beta: Any):
sigma = jnp.ones(())
alpha = jnp.atleast_1d(jnp.asarray(alpha))
beta = jnp.atleast_1d(jnp.asarray(beta))
assert alpha.ndim == 1
assert beta.ndim == 1
p = alpha.shape[0]
assert beta.shape[0] <= p
# Find acvf using Eqn. 4 in Kelly+14, giving the correct combination of
# real/complex exponential kernels
arroots = carma_roots(jnp.append(alpha, 1.0))
acf = carma_acvf(arroots, alpha, beta * sigma)
# Mask for real/complex exponential kernels
_real_mask = jnp.abs(arroots.imag) < 10 * jnp.finfo(arroots.imag.dtype).eps
_complex_mask = ~_real_mask
complex_idx = jnp.cumsum(_complex_mask) * _complex_mask
_complex_select = _complex_mask * complex_idx % 2
# Construct the obsservation model => real + complex
om_real = jnp.sqrt(jnp.abs(acf.real))
a, b, c, d = (
2 * acf.real,
2 * acf.imag,
-arroots.real,
-arroots.imag,
)
c2 = jnp.square(c)
d2 = jnp.square(d)
s2 = c2 + d2
denom = jnp.where(_real_mask, 1.0, 2 * c * s2)
h2_2 = d2 * (a * c - b * d) / denom
h2 = jnp.sqrt(h2_2)
denom = jnp.where(_real_mask, 1.0, d)
h1 = (c * h2 - jnp.sqrt(a * d2 - s2 * h2_2)) / denom
om_complex = jnp.array([h1, h2])
# for complex roots, every conjugate pair match one full celerite term,
# so, every other entry from om_complex is used.
# same logic as for _complex_select
self.obsmodel = jnp.where(_real_mask, om_real, jnp.ravel(om_complex)[::2])
self.alpha = alpha
self.beta = beta
self.sigma = sigma
self.arroots = arroots
self.acf = acf
self._real_mask = _real_mask
self._complex_mask = _complex_mask
self._complex_select = _complex_select
@classmethod
def init(cls, alpha: JAXArray, beta: JAXArray) -> CARMA:
return cls(alpha, beta)
[docs]
@classmethod
def from_quads(
cls, alpha_quads: JAXArray, beta_quads: JAXArray, beta_mult: JAXArray
) -> CARMA:
r"""Construct a CARMA kernel using the roots of its characteristic polynomials
The roots can be parameterized as the 0th and 1st order coefficients of a set
of quadratic equations (2nd order coefficient equals 1). The product of
those quadratic equations gives the characteristic polynomials of CARMA.
The input of this method are said coefficients of the quadratic equations.
See Equation 30 in `Kelly et al. (2014) <https://arxiv.org/abs/1402.5978>`_.
for more detail.
Args:
alpha_quads: Coefficients of the auto-regressive (AR) quadratic
equations corresponding to the :math:`\alpha` parameters. This should
be an array of length `p`.
beta_quads: Coefficients of the moving-average (MA) quadratic
equations corresponding to the :math:`\beta` parameters. This should
be an array of length `q`.
beta_mult: A multiplier of the MA coefficients, equivalent to
:math:`\beta_q`---the last entry of the :math:`\beta` parameters input
to the :func:`init` method.
"""
alpha_quads = jnp.atleast_1d(alpha_quads)
beta_quads = jnp.atleast_1d(beta_quads)
beta_mult = jnp.atleast_1d(beta_mult)
alpha = carma_quads2poly(jnp.append(alpha_quads, jnp.array([1.0])))[:-1]
beta = carma_quads2poly(jnp.append(beta_quads, beta_mult))
return cls(alpha, beta)
[docs]
def design_matrix(self) -> JAXArray:
# for real exponential components
dm_real = jnp.diag(self.arroots.real * self._real_mask)
# for complex exponential components
dm_complex_diag = jnp.diag(self.arroots.real * self._complex_mask)
# upper triangle entries
dm_complex_u = jnp.diag((self.arroots.imag * self._complex_select)[:-1], k=1)
return dm_real + dm_complex_diag + -dm_complex_u.T + dm_complex_u
[docs]
def stationary_covariance(self) -> JAXArray:
p = self.acf.shape[0]
# for real exponential components
diag = jnp.diag(jnp.where(self.acf.real > 0, jnp.ones(p), -jnp.ones(p)))
# for complex exponential components
denom = jnp.where(self._real_mask, 1.0, self.arroots.imag)
diag_complex = jnp.diag(
2
* jnp.square(
self.arroots.real
/ denom
* jnp.roll(self._complex_select, 1)
* self._complex_mask
)
)
c_over_d = self.arroots.real / denom
# upper triangular entries
sc_complex_u = jnp.diag((-c_over_d * self._complex_select)[:-1], k=1)
return diag + diag_complex + sc_complex_u + sc_complex_u.T
[docs]
def observation_model(self, X: JAXArray) -> JAXArray:
del X
return self.obsmodel
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
dt = X2 - X1
c = -self.arroots.real
d = -self.arroots.imag
decay = jnp.exp(-c * dt)
sin = jnp.sin(d * dt)
tm_real = jnp.diag(decay * self._real_mask)
tm_complex_diag = jnp.diag(decay * jnp.cos(d * dt) * self._complex_mask)
tm_complex_u = jnp.diag(
(decay * sin * self._complex_select)[:-1],
k=1,
)
return tm_real + tm_complex_diag + -tm_complex_u.T + tm_complex_u
@jax.jit
def carma_roots(poly_coeffs: JAXArray) -> JAXArray:
roots = jnp.roots(poly_coeffs[::-1], strip_zeros=False)
return roots[jnp.argsort(roots.real)]
@jax.jit
def carma_quads2poly(quads_coeffs: JAXArray) -> JAXArray:
"""Expand a product of quadractic equations into a polynomial
Args:
quads_coeffs: The 0th and 1st order coefficients of the quadractic
equations. The last entry is a multiplier, which corresponds
to the coefficient of the highest order term in the output full
polynomial.
Returns:
Coefficients of the full polynomial. The first entry corresponds to
the lowest order term.
"""
size = quads_coeffs.shape[0] - 1
remain = size % 2
nPair = size // 2
mult_f = quads_coeffs[-1:] # The coeff of highest order term in the output
poly = jax.lax.cond(
remain == 1,
lambda x: jnp.array([1.0, x]),
lambda _: jnp.array([0.0, 1.0]),
quads_coeffs[-2],
)
poly = poly[-remain + 1 :]
for p in jnp.arange(nPair):
poly = jnp.convolve(
poly,
jnp.append(
jnp.array([quads_coeffs[p * 2], quads_coeffs[p * 2 + 1]]),
jnp.ones((1,)),
)[::-1],
)
# the returned is low->high following Kelly+14
return poly[::-1] * mult_f
def carma_poly2quads(poly_coeffs: JAXArray) -> JAXArray:
"""Factorize a polynomial into a product of quadratic equations
Args:
poly_coeffs: Coefficients of the input characteristic polynomial. The
first entry corresponds to the lowest order term.
Returns:
The 0th and 1st order coefficients of the quadractic equations. The last
entry is a multiplier, which corresponds to the coefficient of the highest
order term in the full polynomial.
"""
quads = jnp.empty(0)
mult_f = poly_coeffs[-1]
roots = carma_roots(poly_coeffs / mult_f)
odd = bool(len(roots) & 0x1)
rootsComp = roots[roots.imag != 0]
rootsReal = roots[roots.imag == 0]
nCompPair = len(rootsComp) // 2
nRealPair = len(rootsReal) // 2
for i in range(nCompPair):
root1 = rootsComp[i]
root2 = rootsComp[i + 1]
quads = jnp.append(quads, (root1 * root2).real)
quads = jnp.append(quads, -(root1.real + root2.real))
for i in range(nRealPair):
root1 = rootsReal[i]
root2 = rootsReal[i + 1]
quads = jnp.append(quads, (root1 * root2).real)
quads = jnp.append(quads, -(root1.real + root2.real))
if odd:
quads = jnp.append(quads, -rootsReal[-1].real)
return jnp.append(quads, jnp.array(mult_f))
def carma_acvf(arroots: JAXArray, arparam: JAXArray, maparam: JAXArray) -> JAXArray:
r"""Compute the coefficients of the autocovariance function (ACVF)
Args:
arroots: The roots of the autoregressive characteristic polynomial.
arparam: :math:`\alpha` parameters
maparam: :math:`\beta` parameters
Returns:
ACVF coefficients, each entry corresponds to one root.
"""
from jax._src import dtypes # type: ignore
arparam = jnp.atleast_1d(arparam)
maparam = jnp.atleast_1d(maparam)
complex_dtype = dtypes.to_complex_dtype(arparam.dtype)
p = arparam.shape[0]
q = maparam.shape[0] - 1
sigma = maparam[0]
# normalize beta_0 to 1
maparam = maparam / sigma
# init acf product terms
num_left = jnp.zeros(p, dtype=complex_dtype)
num_right = jnp.zeros(p, dtype=complex_dtype)
denom = -2 * arroots.real + jnp.zeros_like(arroots) * 1j
for k in range(q + 1):
num_left += maparam[k] * jnp.power(arroots, k)
num_right += maparam[k] * jnp.power(jnp.negative(arroots), k)
root_idx = jnp.arange(p)
for j in range(1, p):
root_k = arroots[jnp.roll(root_idx, j)]
denom *= (root_k - arroots) * (jnp.conj(root_k) + arroots)
return sigma**2 * num_left * num_right / denom