"""
The algorithms implemented in this subpackage are are mostly based on `Eidelman
& Gohberg (1999) <https://link.springer.com/article/10.1007%2FBF01300581>`_ and
`Foreman-Mackey et al. (2017) <https://arxiv.org/abs/1703.09710>`_.
"""
from __future__ import annotations
__all__ = [
"DiagQSM",
"StrictLowerTriQSM",
"StrictUpperTriQSM",
"LowerTriQSM",
"UpperTriQSM",
"SquareQSM",
"SymmQSM",
]
import dataclasses
from abc import abstractmethod
from collections.abc import Callable
from functools import wraps
from typing import Any
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.linalg import block_diag
from tinygp.helpers import JAXArray
from tinygp.solvers.quasisep.block import ensure_dense
def handle_matvec_shapes(
func: Callable[..., JAXArray],
) -> Callable[..., JAXArray]:
@wraps(func)
def wrapped(self: Any, x: JAXArray, **kwargs: Any) -> JAXArray:
output_shape = x.shape
result = func(self, jnp.reshape(x, (output_shape[0], -1)), **kwargs)
return jnp.reshape(result, output_shape)
return wrapped
[docs]
class QSM(eqx.Module):
"""The base class for all square quasiseparable matrices
This class has blanket implementations of the standard operations that are
implemented for all QSMs, like addtion, subtraction, multiplication, and
matrix multiplication.
"""
# Must be higher than jax's
__array_priority__ = 2000
[docs]
@abstractmethod
def transpose(self) -> Any:
"""The matrix transpose as a QSM"""
raise NotImplementedError
[docs]
@abstractmethod
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
"""The dot product of this matrix with a dense vector or matrix
Args:
x (n, ...): A matrix or vector with leading dimension matching this
matrix.
parallel: If ``True``, use a parallel associative-scan algorithm
instead of the default sequential scan.
"""
raise NotImplementedError
[docs]
@abstractmethod
def scale(self, other: JAXArray) -> QSM:
"""The multiplication of this matrix times a scalar, as a QSM"""
raise NotImplementedError
@property
def T(self) -> Any:
return self.transpose()
[docs]
def to_dense(self) -> JAXArray:
"""Render this representation to a dense matrix
This implementation is not optimized and should really only ever be used
for testing purposes.
"""
return self.matmul(jnp.eye(self.shape[0]))
@property
def shape(self) -> tuple[int, int]:
"""The shape of the matrix"""
n = self.diag.shape[0] # type: ignore
return (n, n)
def __iter__(self): # type: ignore
return (getattr(self, f.name) for f in dataclasses.fields(self))
def __sub__(self, other: Any) -> Any:
return self.__add__(-other)
def __add__(self, other: Any) -> Any:
from tinygp.solvers.quasisep.ops import elementwise_add
return elementwise_add(self, other)
def __mul__(self, other: Any) -> Any:
if isinstance(other, QSM):
from tinygp.solvers.quasisep.ops import elementwise_mul
return elementwise_mul(self, other)
else:
assert jnp.ndim(other) <= 1
return self.scale(other)
def __rmul__(self, other: Any) -> Any:
assert not isinstance(other, QSM)
assert jnp.ndim(other) <= 1
return self.scale(other)
def __matmul__(self, other: Any) -> Any:
if isinstance(other, QSM):
from tinygp.solvers.quasisep.ops import qsm_mul
return qsm_mul(self, other)
else:
return self.matmul(other)
def __rmatmul__(self, other: Any) -> Any:
assert not isinstance(other, QSM)
return (self.transpose() @ other.transpose()).transpose()
[docs]
class DiagQSM(QSM):
"""A diagonal quasiseparable matrix
Args:
d (n,): The diagonal entries of the matrix as a 1-D array.
"""
d: JAXArray
@property
def shape(self) -> tuple[int, int]:
n = self.d.shape[0]
return (n, n)
[docs]
def transpose(self) -> DiagQSM:
return self
[docs]
@handle_matvec_shapes
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
del parallel
return self.d[:, None] * x
[docs]
def scale(self, other: JAXArray) -> DiagQSM:
return DiagQSM(d=self.d * other)
[docs]
def self_add(self, other: DiagQSM) -> DiagQSM:
"""The sum of two :class:`DiagQSM` matrices"""
return DiagQSM(d=self.d + other.d)
[docs]
def self_mul(self, other: DiagQSM) -> DiagQSM:
"""The elementwise product of two :class:`DiagQSM` matrices"""
return DiagQSM(d=self.d * other.d)
def __neg__(self) -> DiagQSM:
return DiagQSM(d=-self.d)
[docs]
class StrictLowerTriQSM(QSM):
"""A strictly lower triangular order ``m`` quasiseparable matrix
Args:
p (n, m): The left quasiseparable elements.
q (n, m): The right quasiseparable elements.
a (n, m, m): The transition matrices.
"""
p: JAXArray
q: JAXArray
a: JAXArray
@property
def shape(self) -> tuple[int, int]:
n = self.p.shape[0]
return (n, n)
[docs]
def transpose(self) -> StrictUpperTriQSM:
return StrictUpperTriQSM(p=self.p, q=self.q, a=self.a)
[docs]
@handle_matvec_shapes
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
from tinygp.solvers.quasisep.ops import lower_matmul, lower_matmul_parallel
impl = lower_matmul_parallel if parallel else lower_matmul
return impl(self.p, self.q, self.a, x)
[docs]
def scale(self, other: JAXArray) -> StrictLowerTriQSM:
return StrictLowerTriQSM(p=self.p * other, q=self.q, a=self.a)
[docs]
def self_add(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM:
"""The sum of two :class:`StrictLowerTriQSM` matrices"""
@jax.vmap
def impl(
self: StrictLowerTriQSM, other: StrictLowerTriQSM
) -> StrictLowerTriQSM:
p1, q1, a1 = self
p2, q2, a2 = other
return StrictLowerTriQSM(
p=jnp.concatenate((p1, p2)),
q=jnp.concatenate((q1, q2)),
a=block_diag(ensure_dense(a1), ensure_dense(a2)),
)
return impl(self, other)
[docs]
def self_mul(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM:
"""The elementwise product of two :class:`StrictLowerTriQSM` matrices"""
# vmap is needed because a batched Block has 3D block arrays that
# block_diag (used by to_dense) cannot handle without unbatching.
self_a = jax.vmap(ensure_dense)(self.a)
other_a = jax.vmap(ensure_dense)(other.a)
i, j = np.meshgrid(np.arange(self.p.shape[1]), np.arange(other.p.shape[1]))
i = i.flatten()
j = j.flatten()
return StrictLowerTriQSM(
p=self.p[:, i] * other.p[:, j],
q=self.q[:, i] * other.q[:, j],
a=self_a[:, i[:, None], i[None, :]] * other_a[:, j[:, None], j[None, :]],
)
def __neg__(self) -> StrictLowerTriQSM:
return StrictLowerTriQSM(p=-self.p, q=self.q, a=self.a)
[docs]
class StrictUpperTriQSM(QSM):
"""A strictly upper triangular order ``m`` quasiseparable matrix
The notation here is somewhat different from that in `Eidelman & Gohberg
(1999) <https://link.springer.com/article/10.1007%2FBF01300581>`_, because
we wanted to map ``StrictLowerTriQSM.transpose() -> StrictUpperTriQSM``
while retaining the same names for each component. Therefore, our ``p`` is
their ``h``, and our ``a`` is their ``b.T``.
Args:
p (n, m): The right quasiseparable elements.
q (n, m): The left quasiseparable elements.
a (n, m, m): The transition matrices.
"""
p: JAXArray
q: JAXArray
a: JAXArray
@property
def shape(self) -> tuple[int, int]:
n = self.p.shape[0]
return (n, n)
[docs]
def transpose(self) -> StrictLowerTriQSM:
return StrictLowerTriQSM(p=self.p, q=self.q, a=self.a)
[docs]
@handle_matvec_shapes
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
from tinygp.solvers.quasisep.ops import upper_matmul, upper_matmul_parallel
impl = upper_matmul_parallel if parallel else upper_matmul
return impl(self.p, self.q, self.a, x)
[docs]
def scale(self, other: JAXArray) -> StrictUpperTriQSM:
return StrictUpperTriQSM(p=self.p, q=self.q * other, a=self.a)
[docs]
def self_add(self, other: StrictUpperTriQSM) -> StrictUpperTriQSM:
"""The sum of two :class:`StrictUpperTriQSM` matrices"""
return self.transpose().self_add(other.transpose()).transpose()
[docs]
def self_mul(self, other: StrictUpperTriQSM) -> StrictUpperTriQSM:
"""The elementwise product of two :class:`StrictUpperTriQSM` matrices"""
return self.transpose().self_mul(other.transpose()).transpose()
def __neg__(self) -> StrictUpperTriQSM:
return StrictUpperTriQSM(p=-self.p, q=self.q, a=self.a)
[docs]
class LowerTriQSM(QSM):
"""A lower triangular quasiseparable matrix
Args:
diag: The diagonal elements.
lower: The strictly lower triangular elements.
"""
diag: DiagQSM
lower: StrictLowerTriQSM
[docs]
def transpose(self) -> UpperTriQSM:
return UpperTriQSM(diag=self.diag, upper=self.lower.transpose())
[docs]
@handle_matvec_shapes
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
return self.diag.matmul(x) + self.lower.matmul(x, parallel=parallel)
[docs]
def scale(self, other: JAXArray) -> LowerTriQSM:
return LowerTriQSM(diag=self.diag.scale(other), lower=self.lower.scale(other))
def inv(self) -> LowerTriQSM:
(d,) = self.diag
p, q, a = self.lower
g = 1 / d
u = -g[:, None] * p
v = g[:, None] * q
b = a - jax.vmap(jnp.outer)(v, p)
return LowerTriQSM(diag=DiagQSM(g), lower=StrictLowerTriQSM(p=u, q=v, a=b))
[docs]
@handle_matvec_shapes
def solve(self, y: JAXArray, *, parallel: bool = False) -> JAXArray:
"""Solve a linear system with this matrix
If this matrix is called ``L``, this solves ``L @ x = y`` for ``x``
given ``y``, using forward substitution.
Args:
y (n, ...): A matrix or vector with leading dimension matching this
matrix.
parallel: If ``True``, use a parallel associative-scan algorithm.
"""
from tinygp.solvers.quasisep.ops import lower_solve, lower_solve_parallel
(d,) = self.diag
p, q, a = self.lower
impl = lower_solve_parallel if parallel else lower_solve
return impl(d, p, q, a, y)
def __neg__(self) -> LowerTriQSM:
return LowerTriQSM(diag=-self.diag, lower=-self.lower)
[docs]
class UpperTriQSM(QSM):
"""A upper triangular quasiseparable matrix
Args:
diag: The diagonal elements.
upper: The strictly upper triangular elements.
"""
diag: DiagQSM
upper: StrictUpperTriQSM
[docs]
def transpose(self) -> LowerTriQSM:
return LowerTriQSM(diag=self.diag, lower=self.upper.transpose())
[docs]
@handle_matvec_shapes
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
return self.diag.matmul(x) + self.upper.matmul(x, parallel=parallel)
[docs]
def scale(self, other: JAXArray) -> UpperTriQSM:
return UpperTriQSM(diag=self.diag.scale(other), upper=self.upper.scale(other))
def inv(self) -> UpperTriQSM:
return self.transpose().inv().transpose()
[docs]
@handle_matvec_shapes
def solve(self, y: JAXArray, *, parallel: bool = False) -> JAXArray:
"""Solve a linear system with this matrix
If this matrix is called ``U``, this solves ``U @ x = y`` for ``x``
given ``y``, using backward substitution.
Args:
y (n, ...): A matrix or vector with leading dimension matching this
matrix.
parallel: If ``True``, use a parallel associative-scan algorithm.
"""
from tinygp.solvers.quasisep.ops import upper_solve, upper_solve_parallel
(d,) = self.diag
p, q, a = self.upper
impl = upper_solve_parallel if parallel else upper_solve
return impl(d, p, q, a, y)
def __neg__(self) -> UpperTriQSM:
return UpperTriQSM(diag=-self.diag, upper=-self.upper)
[docs]
class SquareQSM(QSM):
"""A general square order ``(m1, m2)`` quasiseparable matrix
Args:
diag: The diagonal elements.
lower: The strictly lower triangular elements with order ``m1``.
upper: The strictly upper triangular elements with order ``m2``.
"""
diag: DiagQSM
lower: StrictLowerTriQSM
upper: StrictUpperTriQSM
[docs]
def transpose(self) -> SquareQSM:
return SquareQSM(
diag=self.diag,
lower=self.upper.transpose(),
upper=self.lower.transpose(),
)
[docs]
@handle_matvec_shapes
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
return (
self.diag.matmul(x)
+ self.lower.matmul(x, parallel=parallel)
+ self.upper.matmul(x, parallel=parallel)
)
[docs]
def scale(self, other: JAXArray) -> SquareQSM:
return SquareQSM(
diag=self.diag.scale(other),
lower=self.lower.scale(other),
upper=self.upper.scale(other),
)
[docs]
def gram(self) -> SymmQSM:
"""The inner product of this matrix with itself
If this matrix is called ``A``, the Gram matrix is ``A.T @ A``, and
that's what this method computes. The result is a :class:`SymmQSM`.
"""
# We know that this must result in symmetric matrix, but that won't be
# enforced; we make it so! It might be possible to make this more
# efficient, but perhaps jax is clever enough?
M = self.transpose() @ self
return SymmQSM(diag=M.diag, lower=M.lower)
[docs]
@jax.jit
def inv(self) -> SquareQSM:
"""The inverse of this matrix"""
(d,) = self.diag
p, q, a = self.lower
h, g, b = self.upper
def forward(carry, data): # type: ignore
f = carry
dk, pk, qk, ak, gk, hk, bk = data
fhk = f @ hk
fbk = f @ bk.T
left = qk - ak @ fhk
right = gk - pk @ fbk
igk = 1 / (dk - pk @ fhk)
sk = igk * left
ellk = ak - jnp.outer(sk, pk)
vk = igk * right
delk = bk - jnp.outer(vk, hk)
fk = ak @ fbk + igk * jnp.outer(left, right)
return fk, (igk, sk, ellk, vk, delk)
init = jnp.zeros_like(jnp.outer(q[0], g[0]))
ig, s, ell, v, del_ = jax.lax.scan(forward, init, (d, p, q, a, g, h, b))[1]
def backward(carry, data): # type: ignore
z = carry
igk, pk, ak, hk, bk, sk, vk = data
zsk = z @ sk
zak = z @ ak
lk = igk + vk @ zsk
tk = vk @ zak - lk * pk
uk = bk.T @ zsk - lk * hk
zk = bk.T @ zak - jnp.outer(uk + lk * hk, pk) - jnp.outer(hk, tk)
return zk, (lk, tk, uk)
init = jnp.zeros_like(jnp.outer(h[-1], p[-1]))
args = (ig, p, a, h, b, s, v)
lam, t, u = jax.lax.scan(backward, init, args, reverse=True)[1]
return SquareQSM(
diag=DiagQSM(d=lam),
lower=StrictLowerTriQSM(p=t, q=s, a=ell),
upper=StrictUpperTriQSM(p=u, q=v, a=del_),
)
def __neg__(self) -> SquareQSM:
return SquareQSM(diag=-self.diag, lower=-self.lower, upper=-self.upper)
[docs]
class SymmQSM(QSM):
"""A symmetric order ``m`` quasiseparable matrix
Args:
diag: The diagonal elements.
lower: The strictly lower triangular elements with order ``m``.
"""
diag: DiagQSM
lower: StrictLowerTriQSM
[docs]
def transpose(self) -> SymmQSM:
return self
[docs]
@handle_matvec_shapes
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
return (
self.diag.matmul(x)
+ self.lower.matmul(x, parallel=parallel)
+ self.lower.transpose().matmul(x, parallel=parallel)
)
[docs]
def scale(self, other: JAXArray) -> SymmQSM:
return SymmQSM(diag=self.diag.scale(other), lower=self.lower.scale(other))
[docs]
def inv(self, *, parallel: bool = False) -> SymmQSM:
"""The inverse of this matrix
Args:
parallel: If ``True``, use a parallel associative-scan algorithm.
"""
from tinygp.solvers.quasisep.ops import symm_inv, symm_inv_parallel
(d,) = self.diag
p, q, a = self.lower
impl = symm_inv_parallel if parallel else symm_inv
lam, t, s, ell = impl(d, p, q, a)
return SymmQSM(diag=DiagQSM(d=lam), lower=StrictLowerTriQSM(p=t, q=s, a=ell))
[docs]
def cholesky(self, *, parallel: bool = False) -> LowerTriQSM:
"""The Cholesky decomposition of this matrix
If this matrix is called ``A``, this method returns the
:class:`LowerTriQSM` ``L`` such that ``L @ L.T = A``.
Args:
parallel: If ``True``, use a parallel associative-scan algorithm.
"""
from tinygp.solvers.quasisep.ops import cholesky, cholesky_parallel
(d,) = self.diag
p, q, a = self.lower
impl = cholesky_parallel if parallel else cholesky
c, w = impl(d, p, q, a)
return LowerTriQSM(diag=DiagQSM(c), lower=StrictLowerTriQSM(p=p, q=w, a=a))
def __neg__(self) -> SymmQSM:
return SymmQSM(diag=-self.diag, lower=-self.lower)