Kernel Transforms

Hide code cell content
try:
    import tinygp
except ImportError:
    %pip install -q tinygp

try:
    import flax
except ImportError:
    %pip install -q flax

try:
    import optax
except ImportError:
    %pip install -q optax

Kernel Transforms#

tinygp is designed to make it easy to implement new kernels (see Custom Kernels for an example), but a particular set of customizations that tinygp supports with a high-level interface are coordinate transforms. The basic idea here is that you may want to pass your input coordinates through a linear or non-linear transformation before evaluating one of the standard kernels in that transformed space. This is particularly useful for multivariate inputs where, for example, you may want to capture the different units, or prior covariances between dimensions.

Example: Deep kernel lerning#

The Deep Kernel Learning model is an example of a more complicated kernel transform, and since tinygp integrates well with libraries like flax (see Modeling Frameworks) the implementation of such a model is fairly straightforward. To demonstrate, let’s start by sampling a simulated dataset from a step function, a model that a GP would typically struggle to model:

import matplotlib.pyplot as plt
import numpy as np

random = np.random.default_rng(567)

noise = 0.1

x = np.sort(random.uniform(-1, 1, 100))
y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
t = np.linspace(-1.5, 1.5, 500)

plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
plt.plot(x, y, ".k", label="data")
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3)
plt.xlabel("x")
plt.ylabel("y")
_ = plt.legend()
../_images/108d5ef503df0ac3c4f20f760b4e2e9386ff6ee252278063c398c811236456cf.png

Then we will fit this model using a model similar to the one described in Optimization with flax & optax, except our kernel will include a custom tinygp.kernels.Transform that will pass the input coordinates through a (small) neural network before passing them into a tinygp.kernels.Matern32 kernel. Otherwise, the model and optimization procedure are similar to the ones used in Optimization with flax & optax.

We compare the performance of the Deep Matern-3/2 kernel (a tinygp.kernels.Matern32 kernel, with custom neural network transform) to the performance of the same kernel without the transform. The untransformed model doesn’t have the capacity to capture our simulated step function, but our transformed model does. In our transformed model, the hyperparameters of our kernel now include the weights of our neural network transform, and we learn those simultaneously with the length scale and amplitude of the Matern32 kernel.

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.linen.initializers import zeros

from tinygp import GaussianProcess, kernels, transforms
Hide code cell content
class Matern32Loss(nn.Module):
    @nn.compact
    def __call__(self, x, y, t):
        # Set up a typical Matern-3/2 kernel
        log_sigma = self.param("log_sigma", zeros, ())
        log_rho = self.param("log_rho", zeros, ())
        log_jitter = self.param("log_jitter", zeros, ())
        base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho))

        # Evaluate and return the GP negative log likelihood as usual
        gp = GaussianProcess(
            base_kernel, x[:, None], diag=noise**2 + jnp.exp(2 * log_jitter)
        )
        log_prob, gp_cond = gp.condition(y, t[:, None])
        return -log_prob, (gp_cond.loc, gp_cond.variance)
class Transformer(nn.Module):
    """A small neural network used to non-linearly transform the input data"""

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=15)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x


class DeepLoss(nn.Module):
    @nn.compact
    def __call__(self, x, y, t):
        # Set up a typical Matern-3/2 kernel
        log_sigma = self.param("log_sigma", zeros, ())
        log_rho = self.param("log_rho", zeros, ())
        log_jitter = self.param("log_jitter", zeros, ())
        base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho))

        # Define a custom transform to pass the input coordinates through our
        # `Transformer` network from above
        # Note: with recent version of flax, you can't directly vmap modules,
        # but we can get around that by explicitly constructing the init and
        # apply functions. Ref:
        # https://flax.readthedocs.io/en/latest/advanced_topics/lift.html
        transform = Transformer()
        transform_params = self.param("transform", transform.init, x[:1])
        apply_fn = lambda x: transform.apply(transform_params, x)
        kernel = transforms.Transform(apply_fn, base_kernel)

        # Evaluate and return the GP negative log likelihood as usual with the
        # transformed features
        gp = GaussianProcess(
            kernel, x[:, None], diag=noise**2 + jnp.exp(2 * log_jitter)
        )
        log_prob, gp_cond = gp.condition(y, t[:, None])

        # We return the loss, the conditional mean and variance, and the
        # transformed input parameters
        return (
            -log_prob,
            (gp_cond.loc, gp_cond.variance),
            (transform(x[:, None]), transform(t[:, None])),
        )


# Define and train the model
def loss_func(model):
    def loss(params):
        return model.apply(params, x, y, t)[0]

    return loss


models_list, params_list = [], []
loss_vals = {}
# Plot the results and compare to the true model
fig, ax = plt.subplots(ncols=2, sharey=True, figsize=(9, 3))
for it, (model_name, model) in enumerate(
    zip(
        ["Deep", "Matern32"],
        [DeepLoss(), Matern32Loss()],
    )
):
    loss_vals[it] = []
    params = model.init(jax.random.PRNGKey(1234), x, y, t)
    tx = optax.sgd(learning_rate=1e-4)
    opt_state = tx.init(params)

    loss = loss_func(model)
    loss_grad_fn = jax.jit(jax.value_and_grad(loss))
    for i in range(1000):
        loss_val, grads = loss_grad_fn(params)
        updates, opt_state = tx.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        loss_vals[it].append(loss_val)

    mu, var = model.apply(params, x, y, t)[1]
    ax[it].plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
    ax[it].plot(x, y, ".k", label="data")
    ax[it].plot(t, mu)
    ax[it].fill_between(
        t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label="model"
    )
    ax[it].set_xlim(-1.5, 1.5)
    ax[it].set_ylim(-1.3, 1.3)
    ax[it].set_xlabel("x")
    ax[it].set_ylabel("y")
    ax[it].set_title(model_name)
    _ = ax[it].legend()

    models_list.append(model)
    params_list.append(params)
../_images/b6195018ae8e62b74a67d148b2a12ce280eeff5a4f5ada885ee30570a759b69f.png

The untransformed Matern32 model suffers from over-smoothing at the discontinuity, and poor extrapolation performance. The Deep model extrapolates well and captures the discontinuity reliably.

We can compare the training loss (negative log likelihood) traces for these two models:

fig = plt.plot()
plt.plot(loss_vals[0], label="Deep")
plt.plot(loss_vals[1], label="Matern32")
plt.ylabel("Loss")
plt.xlabel("Training Iterations")
_ = plt.legend()
../_images/57ad3da65cb78333e515bd6d7c854695891a0ce29ad78f3158941658ede4a3e3.png

To inspect what the transformed model is doing under the hood, we can plot the functional form of the transformation, as well as the transformed values of our input coordinates:

x_transform, t_transform = models_list[0].apply(params_list[0], x, y, t)[2]

fig = plt.figure()
plt.plot(t, t_transform, "k")
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3)
plt.xlabel("input data; x")
plt.ylabel("transformed data; x'")

fig, ax = plt.subplots(ncols=2, sharey=True, figsize=(9, 3))
for it, (fig_title, feature_input, x_label) in enumerate(
    zip(["Input Data", "Transformed Data"], [x, x_transform], ["x", "x'"])
):
    ax[it].plot(feature_input, y, ".k")
    ax[it].set_xlim(-1.5, 1.5)
    ax[it].set_ylim(-1.3, 1.3)
    ax[it].set_title(fig_title)
    ax[it].set_xlabel(x_label)
    ax[it].set_ylabel("y")
../_images/3ef660d42996bd89bd4ed9e066a388690e05ffb42370b42fdd668e243aecf04f.png ../_images/337250af3c1ef0670dd591dfd6fcc1893e0efcf3ea8112071a2eefb10094b8c7.png

The neural network transforms the input feature into a step function like data (as shown in the figures above) before feeding to the base kernel, making it better suited than the baseline model for this data.