Mixture of Kernels

Hide code cell content
try:
    import tinygp
except ImportError:
    %pip install -q tinygp

try:
    import jaxopt
except ImportError:
    %pip install -q jaxopt

Mixture of Kernels#

It can be useful to model a dataset using a mixture of GPs. For example, the data might have both systematic effects and a physical signal that can be modeled using a GP. I know of a few examples where this method has been used in the context of time series analysis for the discovery of transiting exoplanets (for example, Aigrain et al. 2016 and Luger et al. 2016), but I’m sure that these aren’t the earliest references. The idea is pretty simple: if your model is a mixture of two GPs (with covariance matrices \(K_1\) and \(K_2\) respectively), this is equivalent to a single GP where the kernel is the sum of two kernels, one for each component (\(K = K_1 + K_2\)). In this case, the equation for the predictive mean conditioned on a dataset \(\boldsymbol{y}\) is

\[ \boldsymbol{\mu} = (K_1 + K_2)\,(K_1 + K_2 + N)^{-1} \, \boldsymbol{y} \]

where \(N\) is the (possibly diagonal) matrix describing the measurement uncertainties. It turns out that the equation for computing the predictive mean for component 1 is simply

\[ \boldsymbol{\mu}_1 = K_1\,(K_1 + K_2 + N)^{-1} \, \boldsymbol{y} \]

and the equivalent expression can be written for component 2.

This can be implemented in tinygp using the new kernel keyword argument in the predict method. To demonstrate this, let’s start by generating a synthetic dataset. Component 1 is a systematic signal that depends on two input parameters (\(t\) and \(\theta\) following Aigrain) and component 2 is a quasiperiodic oscillation that is the target of our analysis.

import jax
import jax.numpy as jnp

import numpy as np
import matplotlib.pyplot as plt

from tinygp import GaussianProcess, kernels, transforms

jax.config.update("jax_enable_x64", True)

random = np.random.default_rng(123)
N = 256
t = np.sort(random.uniform(0, 10, N))
theta = random.uniform(-np.pi, np.pi, N)
X = np.vstack((t, theta)).T


def build_gp(params):
    kernel1 = jnp.exp(params["log_amp1"]) * transforms.Linear(
        jnp.exp(params["log_scale1"]), kernels.Matern32()
    )
    kernel2 = jnp.exp(params["log_amp2"]) * transforms.Subspace(
        0,
        kernels.ExpSquared(jnp.exp(params["log_scale2"]))
        * kernels.ExpSineSquared(
            scale=jnp.exp(params["log_period"]),
            gamma=jnp.exp(params["log_gamma"]),
        ),
    )
    kernel = kernel1 + kernel2
    return GaussianProcess(kernel, X, diag=jnp.exp(params["log_diag"]))


true_params = {
    "log_amp1": np.log(2.0),
    "log_scale1": np.log([2.0, 0.8]),
    "log_amp2": np.log(2.0),
    "log_scale2": np.log(3.5),
    "log_period": np.log(2.0),
    "log_gamma": np.log(10.0),
    "log_diag": np.log(0.5),
}
gp = build_gp(true_params)
y = gp.sample(jax.random.PRNGKey(5678))

plt.plot(t, y, ".k")
plt.ylim(-6.5, 6.5)
plt.xlim(0, 10)
plt.xlabel("t")
plt.ylabel("y");
../_images/c1d52790a83be51eee2d06b9830db0f5ed405b7ab3cf3e0d563711f5101c5b95.png

The physical (oscillatory) component is not obvious in this dataset because it is swamped by the systematics. Now, we’ll find the maximum likelihood hyperparameters by numerically minimizing the negative log-likelihood function.

import jaxopt


@jax.jit
def loss(params):
    return -build_gp(params).log_probability(y)


solver = jaxopt.ScipyMinimize(fun=loss)
soln = solver.run(true_params)
print("Maximum likelihood parameters:")
print(soln.params)
Maximum likelihood parameters:
{'log_amp1': Array(0.60148188, dtype=float64), 'log_amp2': Array(0.93318559, dtype=float64), 'log_diag': Array(-0.88629962, dtype=float64), 'log_gamma': Array(2.22236561, dtype=float64), 'log_period': Array(0.70370267, dtype=float64), 'log_scale1': Array([ 0.97831013, -0.31883729], dtype=float64), 'log_scale2': Array(1.42164303, dtype=float64)}

Now let’s use the trick from above to compute the prediction of component 1 and remove it to see the periodic signal.

# Compute the predictive means - note the "kernel" argument
gp = build_gp(soln.params)
mu1 = gp.condition(y, kernel=gp.kernel.kernel1).gp.loc
mu2 = gp.condition(y, kernel=gp.kernel.kernel2).gp.loc

plt.plot(t, y, ".k", mec="none", alpha=0.3)
plt.plot(t, y - mu1, ".k")
plt.plot(t, mu2)

plt.ylim(-6.5, 6.5)
plt.xlim(0, 10)
plt.xlabel("t")
plt.ylabel("y");
../_images/ec44bfc4daaeab7a3a1fa8f795cac6f1c391287f8f5b617076b8b36fc466469a.png

In this plot, the original dataset is plotted in light gray points and the “de-trended” data with component 1 removed is plotted as black points. The prediction of the GP model for component 2 is shown as a blue line.