Tutorial: Kernel Transforms

try:
    import tinygp
except ImportError:
    %pip install -q tinygp

try:
    import flax
except ImportError:
    %pip install -q flax
    
try:
    import optax
except ImportError:
    %pip install -q optax
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Tutorial: Kernel Transforms

tinygp is designed to make it easy to implement new kernels (see Tutorial: Custom Kernels & Pytree Data for an example), but a particular set of customizations that tinygp supports with a high-level interface are coordinate transforms. The basic idea here is that you may want to pass your input coordinates through a linear or non-linear transformation before evaluating one of the standard kernels in that transformed space. This is particularly useful for multivariate inputs where, for example, you may want to capture the different units, or prior covariances between dimensions.

Example: Deep kernel lerning

The Deep Kernel Learning model is an example of a more complicated kernel transform, and since tinygp integrates well with libraries like flax (see Tutorial: Modeling Frameworks) the implementation of such a model is fairly straightforward. To demonstrate, let’s start by sampling a simulated dataset from a step function, a model that a GP would typically struggle to model:

import numpy as np
import matplotlib.pyplot as plt

random = np.random.default_rng(567)

noise = 0.1

x = np.sort(random.uniform(-1, 1, 100))
y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
t = np.linspace(-1.5, 1.5, 500)

plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
plt.plot(x, y, ".k", label="data")
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3)
plt.xlabel("x")
plt.ylabel("y")
_ = plt.legend()
../_images/transforms_3_0.png

Then we will fit this model using a model similar to the one described in Optimization with flax & optax, except our kernel will include a custom tinygp.kernels.Transform that will pass the input coordinates through a (small) neural network before passing them into a tinygp.kernels.Matern32 kernel. Otherwise, the model and optimization procedure are similar to the ones used in Optimization with flax & optax.

import jax
import optax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros
from tinygp import kernels, transforms, GaussianProcess


# Define a small neural network used to non-linearly transform the input data in our model
class Transformer(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=15)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x


class GPLoss(nn.Module):
    @nn.compact
    def __call__(self, x, y, t):
        # Set up a typical Matern-3/2 kernel
        log_sigma = self.param("log_sigma", zeros, ())
        log_rho = self.param("log_rho", zeros, ())
        log_jitter = self.param("log_jitter", zeros, ())
        base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(
            jnp.exp(log_rho)
        )

        # Define a custom transform to pass the input coordinates through our `Transformer`
        # network from above
        transform = Transformer()
        kernel = transforms.Transform(transform, base_kernel)

        # Evaluate and return the GP negative log likelihood as usual
        gp = GaussianProcess(
            kernel, x[:, None], diag=noise ** 2 + jnp.exp(2 * log_jitter)
        )
        return -gp.condition(y), gp.predict(y, t[:, None], return_var=True)


# Define and train the model
def loss(params):
    return model.apply(params, x, y, t)[0]


model = GPLoss()
params = model.init(jax.random.PRNGKey(1234), x, y, t)
tx = optax.sgd(learning_rate=1e-4)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))
for i in range(1000):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

# Plot the results and compare to the true model
plt.figure()
mu, var = model.apply(params, x, y, t)[1]
plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
plt.plot(x, y, ".k", label="data")
plt.plot(t, mu)
plt.fill_between(
    t, mu + np.sqrt(var), mu - np.sqrt(var), alpha=0.5, label="model"
)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.3, 1.3)
plt.xlabel("x")
plt.ylabel("y")
_ = plt.legend()
../_images/transforms_5_0.png