try:
    import tinygp
except ImportError:
    !pip install -q tinygp
    
from jax.config import config

config.update("jax_enable_x64", True)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Tutorial: Custom Kernels & Pytree Data

One of the goals of the tinygp interface design was to make the kernel building framework flexible and easily extensible. In this tutorial, we demontrate this interface with two examples; one simple, and one more complicated. Besides describing this interface, we also show how tinygp can support arbitrary JAX pytrees as input.

Example: Spectral mixture kernel

In this section, we will implement the “spectral mixture kernel” proposed by Gordon Wilson & Adams (2013). It would be possible to implement this using sums of built in kernels, but the interface seems better if we implement a custom kernel and I expect that we’d get somewhat better performance for mixtures with many components.

Now, let’s implement this kernel in a way that tinygp understands. When doing this, you will subclass tinygp.kernels.Kernel and implement the tinygp.kernels.Kernel.evaluate() method. One very important thing to note here is that evaluate will always be called via vmap, so you should write your evaluate method to operate on a single pair of inputs and let vmap handle the broadcasting sematics for you.

import tinygp
import jax
import jax.numpy as jnp


class SpectralMixture(tinygp.kernels.Kernel):
    def __init__(self, weight, scale, freq):
        self.weight = jnp.atleast_1d(weight)
        self.scale = jnp.atleast_1d(scale)
        self.freq = jnp.atleast_1d(freq)

    def evaluate(self, X1, X2):
        tau = jnp.atleast_1d(jnp.abs(X1 - X2))[..., None]
        return jnp.sum(
            self.weight
            * jnp.prod(
                jnp.exp(-2 * jnp.pi ** 2 * tau ** 2 / self.scale ** 2)
                * jnp.cos(2 * jnp.pi * self.freq * tau),
                axis=-1,
            )
        )

Now let’s implement the simulate some data from this model:

import numpy as np
import matplotlib.pyplot as plt


def build_gp(theta):
    kernel = SpectralMixture(
        jnp.exp(theta["log_weight"]),
        jnp.exp(theta["log_scale"]),
        jnp.exp(theta["log_freq"]),
    )
    return tinygp.GaussianProcess(
        kernel, t, diag=jnp.exp(theta["log_diag"]), mean=theta["mean"]
    )


params = {
    "log_weight": np.log([1.0, 1.0]),
    "log_scale": np.log([10.0, 20.0]),
    "log_freq": np.log([1.0, 1.0 / 3.0]),
    "log_diag": np.log(0.1),
    "mean": 0.0,
}

random = np.random.default_rng(546)
t = np.sort(random.uniform(0, 10, 50))
true_gp = build_gp(params)
y = true_gp.sample(jax.random.PRNGKey(123))

plt.plot(t, y, ".k")
plt.ylim(-4.5, 4.5)
plt.title("simulated data")
plt.xlabel("x")
_ = plt.ylabel("y")
../_images/kernels_4_0.png

One thing to note here is that we’ve used named parameters in a dictionary, instead of an array of parameters as in some of the other examples. This would be awkward (but not impossible) to fit using scipy, so instead we’ll use optax for optimization:

import optax


@jax.jit
@jax.value_and_grad
def loss(theta):
    return -build_gp(theta).condition(y)


opt = optax.sgd(learning_rate=3e-4)
opt_state = opt.init(params)
for i in range(1000):
    loss_val, grads = loss(params)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

opt_gp = build_gp(params)
tau = np.linspace(0, 5, 500)
plt.plot(tau, true_gp.kernel(tau[:1], tau)[0], "--k", label="true kernel")
plt.plot(tau, opt_gp.kernel(tau[:1], tau)[0], label="inferred kernel")
plt.legend()
plt.xlabel(r"$\tau$")
plt.ylabel(r"$k(\tau)$")
_ = plt.xlim(tau.min(), tau.max())
../_images/kernels_6_0.png

Using our optimized model, over-plot the conditional predictions:

x = np.linspace(-2, 12, 500)
plt.plot(t, y, ".k", label="data")
mu, var = opt_gp.predict(y, x, return_var=True)
plt.fill_between(
    x,
    mu + np.sqrt(var),
    mu - np.sqrt(var),
    color="C0",
    alpha=0.5,
    label="conditional",
)
plt.plot(x, mu, color="C0", lw=2)
plt.xlim(x.min(), x.max())
plt.ylim(-4.5, 4.5)
plt.legend(loc=2)
plt.xlabel("x")
_ = plt.ylabel("y")
../_images/kernels_8_0.png

Example: Latent GPs & their derivatives

We will next build a custom kernel that is commonly used when studying radial velocity observations of exoplanet-hosting stars, as described by Rajpaul et al. (2015). Take a look at that paper for more details about the math, but the tl;dr is that we want to model a set parallel, but qualitatively different, time series, using a latent Gaussian process. The interesting part of this model is that Rajpaul et al. model the observations as arbitrary linear combinations of the process and its first time derivative, and they work through the derivation of the resulting kernel function.

In this tutorial, we will implement this kernel using tinygp—something that is significantly more annoying to do with other Gaussian process frameworks (trust me, I’ve tried!)—and demonstrate a few key features along the way.

The kernel

The kernel matrix described by Rajpaul et al. (2015) is a block matrix where each element is a linear combination of the latent kernel and its first and second derivatives, where the relevant coefficients depend on the “class” of each pair of observations. This means that our input data needs to include, at each observation, the time t (our input coordinate in this case) and an integer class label label. As discussed below, we will structure our data in such a way that we can treat each input as being a tuple (t, label). In this case, we will unpack our inputs t, label = X and treat t and label as scalars. Here’s our implementation:

class DerivativeKernel(tinygp.kernels.Kernel):
    """A custom kernel based on Rajpaul et al. (2015)

    Args:
        kernel: The kernel function describing the latent process. This can be any other
            ``tinygp`` kernel.
        coeff_prim: The primal coefficients for each class. This can be thought of as how
            much the latent process itself projects into the observations for that class.
            This should be an array with an entry for each class of observation.
        coeff_deriv: The derivative coefficients for each class. This should have the same
            shape as ``coeff_prim``.
    """

    def __init__(self, kernel, coeff_prim, coeff_deriv):
        self.kernel = kernel
        self.coeff_prim, self.coeff_deriv = jnp.broadcast_arrays(
            jnp.asarray(coeff_prim), jnp.asarray(coeff_deriv)
        )

    def evaluate(self, X1, X2):
        t1, label1 = X1
        t2, label2 = X2

        # Differentiate the kernel function: the first derivative wrt x1
        Kp = jax.grad(self.kernel.evaluate, argnums=0)

        # ... and the second derivative
        Kpp = jax.grad(Kp, argnums=1)

        # Evaluate the kernel matrix and all of its relevant derivatives
        K = self.kernel.evaluate(t1, t2)
        d2K_dx1dx2 = Kpp(t1, t2)

        # For stationary kernels, these are related just by a minus sign, but we'll
        # evaluate them both separately for generality's sake
        dK_dx2 = jax.grad(self.kernel.evaluate, argnums=1)(t1, t2)
        dK_dx1 = Kp(t1, t2)

        # Extract the coefficients
        a1 = self.coeff_prim[label1]
        a2 = self.coeff_prim[label2]
        b1 = self.coeff_deriv[label1]
        b2 = self.coeff_deriv[label2]

        # Construct the matrix element
        return (
            a1 * a2 * K
            + a1 * b2 * dK_dx2
            + b1 * a2 * dK_dx1
            + b1 * b2 * d2K_dx1dx2
        )

Now that we have this definition, we can plot what the kernel functions look like for different latent processes. Don’t worry too much about the syntax here, but we’re plotting two classes of observations where the first class is just a direct observation of the latent process and the second observes the time derivative.

import numpy as np
import matplotlib.pyplot as plt


def plot_kernel(latent_kernel):
    kernel = DerivativeKernel(latent_kernel, [1.0, 0.0], [0.0, 1.0])

    N = 500
    dt = np.linspace(-7.5, 7.5, N)

    k00 = kernel(
        (jnp.zeros((1)), jnp.zeros((1), dtype=int)),
        (dt, np.zeros(N, dtype=int)),
    )[0]
    k11 = kernel(
        (jnp.zeros((1)), jnp.ones((1), dtype=int)), (dt, np.ones(N, dtype=int))
    )[0]
    k01 = kernel(
        (jnp.zeros((1)), jnp.zeros((1), dtype=int)),
        (dt, np.ones(N, dtype=int)),
    )[0]
    k10 = kernel(
        (jnp.zeros((1)), jnp.ones((1), dtype=int)),
        (dt, np.zeros(N, dtype=int)),
    )[0]

    plt.figure()
    plt.plot(dt, k00, label="$k[\Delta t^{(0,0)}]$", lw=1)
    plt.plot(dt, k01, label="$k[\Delta t^{(0,1)}]$", lw=1)
    plt.plot(dt, k10, label="$k[\Delta t^{(1,0)}]$", lw=1)
    plt.plot(dt, k11, label="$k[\Delta t^{(1,1)}]$", lw=1)
    plt.legend()
    plt.xlabel(r"$\Delta t$")
    plt.xlim(dt.min(), dt.max())


plot_kernel(tinygp.kernels.Matern52(scale=1.5))
plt.title("Matern-5/2")

plot_kernel(
    tinygp.kernels.ExpSquared(scale=2.5)
    * tinygp.kernels.ExpSineSquared(period=2.5, gamma=0.5)
)
_ = plt.title("Quasiperiodic")
../_images/kernels_12_0.png ../_images/kernels_12_1.png

Inference

Given this custom kernel definition, let’s simulate some data and fit for the kernel parameters. In this case, we’ll use the product of an kernels.ExpSquared and a kernels.ExpSineSquared kernel for our latent process. Note that we’re not including an amplitude parameter in our latent process, since that will be captured by the coefficients in our DerivativeKernel defined above. Using this latent process, we simulate an unbalanced dataset with two classes and plot the simulated data below, with class offsets for clarity.

latent_kernel = tinygp.kernels.ExpSquared(
    scale=1.5
) * tinygp.kernels.ExpSineSquared(period=2.5, gamma=0.5)
kernel = DerivativeKernel(latent_kernel, [1.0, 0.5], [-0.1, 0.3])

random = np.random.default_rng(5678)
t1 = np.sort(random.uniform(0, 10, 200))
label1 = np.zeros_like(t1, dtype=int)
t2 = np.sort(random.uniform(0, 10, 300))
label2 = np.ones_like(t2, dtype=int)
X = (np.append(t1, t2), np.append(label1, label2))

gp = tinygp.GaussianProcess(kernel, X, diag=1e-5)
y = gp.sample(jax.random.PRNGKey(1234))

subset = np.append(
    random.integers(len(t1), size=50),
    len(t1) + random.integers(len(t2), size=15),
)
X_obs = (X[0][subset], X[1][subset])
y_obs = y[subset] + 0.1 * random.normal(size=len(subset))

offset = 2.5

plt.axhline(0.5 * offset, color="k", lw=1)
plt.axhline(-0.5 * offset, color="k", lw=1)

plt.plot(t1, y[: len(t1)] + 0.5 * offset, label="class 0")
plt.plot(t2, y[len(t1) :] - 0.5 * offset, label="class 1")

plt.plot(X_obs[0], y_obs + offset * (0.5 - X_obs[1]), ".k", label="measured")

plt.xlim(0, 10)
plt.ylim(-offset, offset)
plt.xlabel("t")
plt.ylabel("y + offset")
_ = plt.legend(bbox_to_anchor=(1.01, 1), loc="upper left")
../_images/kernels_14_0.png

Then, we fit the simulated data, by optimizing for the maximum likelihood kernel parameters using scipy:

from scipy.optimize import minimize


def build_gp(params):
    latent_kernel = tinygp.kernels.ExpSquared(
        scale=jnp.exp(params[0])
    ) * tinygp.kernels.ExpSineSquared(
        period=jnp.exp(params[1]), gamma=params[2]
    )
    kernel = DerivativeKernel(latent_kernel, params[3:5], params[5:7])
    return tinygp.GaussianProcess(kernel, X_obs, diag=jnp.exp(params[7]))


@jax.jit
@jax.value_and_grad
def loss(params):
    gp = build_gp(params)
    return -gp.condition(y_obs)


init = jnp.array(
    [np.log(1.5), np.log(2.5), 0.5, 1.0, 0.5, -0.1, 0.3, np.log(0.1)]
)
print(f"Initial negative log likelihood: {loss(init)[0]}")
soln = minimize(loss, init, jac=True)
print(f"Final negative log likelihood: {soln.fun}")
Initial negative log likelihood: 18.547339999959853
Final negative log likelihood: -20.014003748544688

And plot the resulting inference. Of particular note here, even for the sparsely sampled “class 1” dataset, we get robust predictions for the expected process, since we have more finely sampled observations in “class 0”.

gp = build_gp(soln.x)
mu, var = gp.predict(y_obs, X, return_var=True)

plt.axhline(0.5 * offset, color="k", lw=1)
plt.axhline(-0.5 * offset, color="k", lw=1)

plt.plot(t1, y[: len(t1)] + 0.5 * offset, "k", label="truth")
plt.plot(t2, y[len(t1) :] - 0.5 * offset, "k")

for c in [0, 1]:
    delta = offset * (0.5 - c)
    m = X[1] == c
    # plt.plot(X[0][m], delta + mu[m], color=f"C{c}")
    plt.fill_between(
        X[0][m],
        delta + mu[m] + 2 * np.sqrt(var[m]),
        delta + mu[m] - 2 * np.sqrt(var[m]),
        color=f"C{c}",
        alpha=0.5,
        label=f"inferred class {c}",
    )

plt.plot(X_obs[0], y_obs + offset * (0.5 - X_obs[1]), ".k", label="measured")

plt.xlim(0, 10)
plt.ylim(-offset, offset)
plt.xlabel("t")
plt.ylabel("y + offset")
_ = plt.legend(bbox_to_anchor=(1.01, 1), loc="upper left")
../_images/kernels_18_0.png