{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0",
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"try:\n",
" import tinygp\n",
"except ImportError:\n",
" %pip install -q tinygp\n",
"\n",
"try:\n",
" import jaxopt\n",
"except ImportError:\n",
" %pip install -q jaxopt"
]
},
{
"cell_type": "markdown",
"id": "1",
"metadata": {},
"source": [
"(quasisep)=\n",
"\n",
"# Scalable GPs with Quasiseparable Kernels\n",
"\n",
"````{admonition} Warning\n",
":class: warning\n",
"\n",
"The algorithms described in this section are inherently serial, and you will probably see extremely degraded performance if you turn on GPU acceleration.\n",
"````\n",
"\n",
"\n",
"Starting with `v0.2`, `tinygp` includes an experimental pure-`jax` implementation of the algorithms behind the [celerite package](https://celerite.readthedocs.io).\n",
"The [celerite2 package](https://celerite2.readthedocs.io) already had support for `jax`, but since it doesn't depend on any extra compiled code, the implementation here in `tinygp` might be a little easier to get up and running, and it is significantly more flexible.\n",
"Similarly, even though it is implemented directly in `jax`, instead of highly-optimized C++ code, the `tinygp` implementation has similar performance to the `celerite2` version (see {ref}`benchmarks`).\n",
"\n",
"All this being said, this performance doesn't come for free.\n",
"In particular, this solver can only be used with data with sortable inputs, and specific types of kernels.\n",
"In practice this generally means that you'll need 1-D input data (e.g. a time series) and you'll need to build your kernel using the members of the {ref}`api-kernels-quasisep`.\n",
"But, if your problem has this form, you may see several orders of magnitude improvement in the runtime of you model.\n",
"\n",
"As a demonstration, let's use the same sample dataset as we used in {ref}`modeling` tutorial:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"random = np.random.default_rng(42)\n",
"\n",
"t = np.sort(\n",
" np.append(\n",
" random.uniform(0, 3.8, 28),\n",
" random.uniform(5.5, 10, 18),\n",
" )\n",
")\n",
"yerr = random.uniform(0.08, 0.22, len(t))\n",
"y = (\n",
" 0.2 * (t - 5)\n",
" + np.sin(3 * t + 0.1 * (t - 5) ** 2)\n",
" + yerr * random.normal(size=len(t))\n",
")\n",
"\n",
"true_t = np.linspace(0, 10, 100)\n",
"true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)\n",
"\n",
"plt.plot(true_t, true_y, \"k\", lw=1.5, alpha=0.3)\n",
"plt.errorbar(t, y, yerr=yerr, fmt=\".k\", capsize=0)\n",
"plt.xlabel(\"x [day]\")\n",
"plt.ylabel(\"y [ppm]\")\n",
"plt.xlim(0, 10)\n",
"plt.ylim(-2.5, 2.5)\n",
"_ = plt.title(\"simulated data\")"
]
},
{
"cell_type": "markdown",
"id": "3",
"metadata": {},
"source": [
"Then we can set up our scalable GP model.\n",
"This looks (perhaps deceivingly) similar to the model set up that we would normally use, but all the kernels that we're using are defined in `tinygp.kernels.quasisep`, instead of `tinygp.kernels`.\n",
"These kernels do, however, still support addition, multiplication, and scaling to build expressive models.\n",
"That being said, it's important to point out that the computational cost of these methods scales poorly with the number of kernels that you add or (worse!) multiply."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from tinygp import kernels, GaussianProcess\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"\n",
"def build_gp(params):\n",
" kernel = kernels.quasisep.SHO(\n",
" sigma=jnp.exp(params[\"log_sigma1\"]),\n",
" omega=jnp.exp(params[\"log_omega\"]),\n",
" quality=jnp.exp(params[\"log_quality\"]),\n",
" )\n",
" kernel += jnp.exp(2 * params[\"log_sigma2\"]) * kernels.quasisep.Matern32(\n",
" scale=jnp.exp(params[\"log_scale\"])\n",
" )\n",
" return GaussianProcess(\n",
" kernel,\n",
" t,\n",
" diag=yerr**2 + jnp.exp(params[\"log_jitter\"]),\n",
" mean=params[\"mean\"],\n",
" )\n",
"\n",
"\n",
"@jax.jit\n",
"def loss(params):\n",
" gp = build_gp(params)\n",
" return -gp.log_probability(y)\n",
"\n",
"\n",
"params = {\n",
" \"mean\": 0.0,\n",
" \"log_jitter\": 0.0,\n",
" \"log_sigma1\": 0.0,\n",
" \"log_omega\": np.log(2 * np.pi),\n",
" \"log_quality\": 0.0,\n",
" \"log_sigma2\": 0.0,\n",
" \"log_scale\": 0.0,\n",
"}\n",
"loss(params)"
]
},
{
"cell_type": "markdown",
"id": "5",
"metadata": {},
"source": [
"Good - we got a value for our loss function.\n",
"We can check that this was actually using the scalable solver defined in {class}`tinygp.solvers.quasisep.solver.QuasisepSolver` by checking the type of the `solver` property of our GP:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"type(build_gp(params).solver)"
]
},
{
"cell_type": "markdown",
"id": "7",
"metadata": {},
"source": [
"Now we can minimize the loss:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"import jaxopt\n",
"\n",
"solver = jaxopt.ScipyMinimize(fun=loss)\n",
"soln = solver.run(jax.tree_util.tree_map(jnp.asarray, params))\n",
"print(f\"Final negative log likelihood: {soln.state.fun_val}\")"
]
},
{
"cell_type": "markdown",
"id": "9",
"metadata": {},
"source": [
"And plot our results:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"_, cond = build_gp(soln.params).condition(y, true_t)\n",
"\n",
"mu = cond.loc\n",
"std = np.sqrt(cond.variance)\n",
"\n",
"plt.plot(true_t, true_y, \"k\", lw=1.5, alpha=0.3, label=\"truth\")\n",
"plt.errorbar(t, y, yerr=yerr, fmt=\".k\", capsize=0)\n",
"plt.plot(true_t, mu, label=\"max likelihood model\")\n",
"plt.fill_between(true_t, mu + std, mu - std, color=\"C0\", alpha=0.3)\n",
"plt.xlabel(\"x [day]\")\n",
"plt.ylabel(\"y [ppm]\")\n",
"plt.xlim(0, 10)\n",
"plt.ylim(-2.5, 2.5)\n",
"plt.legend()\n",
"_ = plt.title(\"maximum likelihood\")"
]
},
{
"cell_type": "markdown",
"id": "11",
"metadata": {},
"source": [
"This all looks pretty good!\n",
"\n",
"Before closing out this tutorial, here are some technical details to keep in mind when using this solver:\n",
"\n",
"1. This implementation is new, and it hasn't yet been pushed to its limits. If you run into problems, please [open issues or pull requests](https://github.com/dfm/tinygp/issues).\n",
"\n",
"2. The computation of the general conditional model with these kernels is not (yet!) as fast as we might want, and it may be somewhat memory heavy. For very large datasets, it is sometimes sufficient to (a) just compute the conditional at the input points (by omitting the `X_test` parameter in {func}`tinygp.GaussianProcess.condition`), (b) only compute the mean prediction, which should be fast, or (c) only predict at a few test points.\n",
"\n",
"3. For more technical details about these methods, check out the API docs for the {ref}`api-kernels-quasisep`, and the {ref}`api-solvers-quasisep`, as well as the links therein.\n",
"\n",
"4. It should be possible to implement more flexible models using this interface than those supported by `celerite` or `celerite2`, so stay tuned for more tutorials!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}