Source code for pyhf.tensor.jax_backend

from jax.config import config

config.update('jax_enable_x64', True)

import jax.numpy as jnp
from jax.scipy.special import gammaln
from jax.scipy import special
from jax.scipy.stats import norm
import numpy as np
import scipy.stats as osp_stats
import logging

log = logging.getLogger(__name__)


class _BasicPoisson:
    def __init__(self, rate):
        self.rate = rate

    def sample(self, sample_shape):
        # TODO: Support other dtypes
        return jnp.asarray(
            osp_stats.poisson(self.rate).rvs(size=sample_shape + self.rate.shape),
            dtype=jnp.float64,
        )

    def log_prob(self, value):
        tensorlib = jax_backend()
        return tensorlib.poisson_logpdf(value, self.rate)


class _BasicNormal:
    def __init__(self, loc, scale):
        self.loc = loc
        self.scale = scale

    def sample(self, sample_shape):
        # TODO: Support other dtypes
        return jnp.asarray(
            osp_stats.norm(self.loc, self.scale).rvs(
                size=sample_shape + self.loc.shape
            ),
            dtype=jnp.float64,
        )

    def log_prob(self, value):
        tensorlib = jax_backend()
        return tensorlib.normal_logpdf(value, self.loc, self.scale)


[docs]class jax_backend: """JAX backend for pyhf""" __slots__ = ['name', 'precision', 'dtypemap', 'default_do_grad']
[docs] def __init__(self, **kwargs): self.name = 'jax' self.precision = kwargs.get('precision', '64b') self.dtypemap = { 'float': jnp.float64 if self.precision == '64b' else jnp.float32, 'int': jnp.int64 if self.precision == '64b' else jnp.int32, 'bool': jnp.bool_, } self.default_do_grad = True
[docs] def _setup(self): """ Run any global setups for the jax lib. """
[docs] def clip(self, tensor_in, min_value, max_value): """ Clips (limits) the tensor values to be within a specified min and max. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2]) >>> pyhf.tensorlib.clip(a, -1, 1) DeviceArray([-1., -1., 0., 1., 1.], dtype=float64) Args: tensor_in (:obj:`tensor`): The input tensor object min_value (:obj:`scalar` or :obj:`tensor` or :obj:`None`): The minimum value to be cliped to max_value (:obj:`scalar` or :obj:`tensor` or :obj:`None`): The maximum value to be cliped to Returns: JAX ndarray: A clipped `tensor` """ return jnp.clip(tensor_in, min_value, max_value)
[docs] def erf(self, tensor_in): """ The error function of complex argument. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.]) >>> pyhf.tensorlib.erf(a) DeviceArray([-0.99532227, -0.84270079, 0. , 0.84270079, 0.99532227], dtype=float64) Args: tensor_in (:obj:`tensor`): The input tensor object Returns: JAX ndarray: The values of the error function at the given points. """ return special.erf(tensor_in)
[docs] def erfinv(self, tensor_in): """ The inverse of the error function of complex argument. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.]) >>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a)) DeviceArray([-2., -1., 0., 1., 2.], dtype=float64) Args: tensor_in (:obj:`tensor`): The input tensor object Returns: JAX ndarray: The values of the inverse of the error function at the given points. """ return special.erfinv(tensor_in)
[docs] def tile(self, tensor_in, repeats): """ Repeat tensor data along a specific dimension Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([[1.0], [2.0]]) >>> pyhf.tensorlib.tile(a, (1, 2)) DeviceArray([[1., 1.], [2., 2.]], dtype=float64) Args: tensor_in (:obj:`tensor`): The tensor to be repeated repeats (:obj:`tensor`): The tuple of multipliers for each dimension Returns: JAX ndarray: The tensor with repeated axes """ return jnp.tile(tensor_in, repeats)
[docs] def conditional(self, predicate, true_callable, false_callable): """ Runs a callable conditional on the boolean value of the evaulation of a predicate Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> tensorlib = pyhf.tensorlib >>> a = tensorlib.astensor([4]) >>> b = tensorlib.astensor([5]) >>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b) DeviceArray([9.], dtype=float64) Args: predicate (:obj:`scalar`): The logical condition that determines which callable to evaluate true_callable (:obj:`callable`): The callable that is evaluated when the :code:`predicate` evalutes to :code:`true` false_callable (:obj:`callable`): The callable that is evaluated when the :code:`predicate` evalutes to :code:`false` Returns: JAX ndarray: The output of the callable that was evaluated """ return true_callable() if predicate else false_callable()
[docs] def tolist(self, tensor_in): try: return np.asarray(tensor_in).tolist() except AttributeError: if isinstance(tensor_in, list): return tensor_in raise
[docs] def outer(self, tensor_in_1, tensor_in_2): return jnp.outer(tensor_in_1, tensor_in_2)
[docs] def gather(self, tensor, indices): return tensor[indices]
[docs] def boolean_mask(self, tensor, mask): return tensor[mask]
[docs] def isfinite(self, tensor): return jnp.isfinite(tensor)
[docs] def astensor(self, tensor_in, dtype='float'): """ Convert to a JAX ndarray. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> tensor DeviceArray([[1., 2., 3.], [4., 5., 6.]], dtype=float64) >>> type(tensor) <class 'jax.interpreters.xla._DeviceArray'> Args: tensor_in (Number or Tensor): Tensor object Returns: `jax.interpreters.xla._DeviceArray`: A multi-dimensional, fixed-size homogenous array. """ try: dtype = self.dtypemap[dtype] except KeyError: log.error( 'Invalid dtype: dtype must be float, int, or bool.', exc_info=True ) raise return jnp.asarray(tensor_in, dtype=dtype)
[docs] def sum(self, tensor_in, axis=None): return jnp.sum(tensor_in, axis=axis)
[docs] def product(self, tensor_in, axis=None): return jnp.prod(tensor_in, axis=axis)
[docs] def abs(self, tensor): return jnp.abs(tensor)
[docs] def ones(self, shape): return jnp.ones(shape)
[docs] def zeros(self, shape): return jnp.zeros(shape)
[docs] def power(self, tensor_in_1, tensor_in_2): return jnp.power(tensor_in_1, tensor_in_2)
[docs] def sqrt(self, tensor_in): return jnp.sqrt(tensor_in)
[docs] def divide(self, tensor_in_1, tensor_in_2): return jnp.divide(tensor_in_1, tensor_in_2)
[docs] def log(self, tensor_in): return jnp.log(tensor_in)
[docs] def exp(self, tensor_in): return jnp.exp(tensor_in)
[docs] def stack(self, sequence, axis=0): return jnp.stack(sequence, axis=axis)
[docs] def where(self, mask, tensor_in_1, tensor_in_2): return jnp.where(mask, tensor_in_1, tensor_in_2)
[docs] def concatenate(self, sequence, axis=0): """ Join a sequence of arrays along an existing axis. Args: sequence: sequence of tensors axis: dimension along which to concatenate Returns: output: the concatenated tensor """ return jnp.concatenate(sequence, axis=axis)
[docs] def simple_broadcast(self, *args): """ Broadcast a sequence of 1 dimensional arrays. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.simple_broadcast( ... pyhf.tensorlib.astensor([1]), ... pyhf.tensorlib.astensor([2, 3, 4]), ... pyhf.tensorlib.astensor([5, 6, 7])) [DeviceArray([1., 1., 1.], dtype=float64), DeviceArray([2., 3., 4.], dtype=float64), DeviceArray([5., 6., 7.], dtype=float64)] Args: args (Array of Tensors): Sequence of arrays Returns: list of Tensors: The sequence broadcast together. """ return jnp.broadcast_arrays(*args)
[docs] def shape(self, tensor): return tensor.shape
[docs] def reshape(self, tensor, newshape): return jnp.reshape(tensor, newshape)
[docs] def ravel(self, tensor): """ Return a flattened view of the tensor, not a copy. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> pyhf.tensorlib.ravel(tensor) DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float64) Args: tensor (Tensor): Tensor object Returns: `jax.interpreters.xla._DeviceArray`: A flattened array. """ return jnp.ravel(tensor)
[docs] def einsum(self, subscripts, *operands): """ Evaluates the Einstein summation convention on the operands. Using the Einstein summation convention, many common multi-dimensional array operations can be represented in a simple fashion. This function provides a way to compute such summations. The best way to understand this function is to try the examples below, which show how many common NumPy functions can be implemented as calls to einsum. Args: subscripts: str, specifies the subscripts for summation operands: list of array_like, these are the tensors for the operation Returns: tensor: the calculation based on the Einstein summation convention """ # return contract(subscripts,*operands) return jnp.einsum(subscripts, *operands)
[docs] def poisson_logpdf(self, n, lam): n = jnp.asarray(n) lam = jnp.asarray(lam) return n * jnp.log(lam) - lam - gammaln(n + 1.0)
[docs] def poisson(self, n, lam): r""" The continuous approximation, using :math:`n! = \Gamma\left(n+1\right)`, to the probability mass function of the Poisson distribution evaluated at :code:`n` given the parameter :code:`lam`. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.poisson(5., 6.) DeviceArray(0.16062314, dtype=float64) >>> values = pyhf.tensorlib.astensor([5., 9.]) >>> rates = pyhf.tensorlib.astensor([6., 8.]) >>> pyhf.tensorlib.poisson(values, rates) DeviceArray([0.16062314, 0.12407692], dtype=float64) Args: n (:obj:`tensor` or :obj:`float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f. (the observed number of events) lam (:obj:`tensor` or :obj:`float`): The mean of the Poisson distribution p.m.f. (the expected number of events) Returns: JAX ndarray: Value of the continuous approximation to Poisson(n|lam) """ n = jnp.asarray(n) lam = jnp.asarray(lam) return jnp.exp(n * jnp.log(lam) - lam - gammaln(n + 1.0))
[docs] def normal_logpdf(self, x, mu, sigma): # this is much faster than # norm.logpdf(x, loc=mu, scale=sigma) # https://codereview.stackexchange.com/questions/69718/fastest-computation-of-n-likelihoods-on-normal-distributions root2 = jnp.sqrt(2) root2pi = jnp.sqrt(2 * jnp.pi) prefactor = -jnp.log(sigma * root2pi) summand = -jnp.square(jnp.divide((x - mu), (root2 * sigma))) return prefactor + summand
# def normal_logpdf(self, x, mu, sigma): # return norm.logpdf(x, loc=mu, scale=sigma)
[docs] def normal(self, x, mu, sigma): r""" The probability density function of the Normal distribution evaluated at :code:`x` given parameters of mean of :code:`mu` and standard deviation of :code:`sigma`. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.normal(0.5, 0., 1.) DeviceArray(0.35206533, dtype=float64) >>> values = pyhf.tensorlib.astensor([0.5, 2.0]) >>> means = pyhf.tensorlib.astensor([0., 2.3]) >>> sigmas = pyhf.tensorlib.astensor([1., 0.8]) >>> pyhf.tensorlib.normal(values, means, sigmas) DeviceArray([0.35206533, 0.46481887], dtype=float64) Args: x (:obj:`tensor` or :obj:`float`): The value at which to evaluate the Normal distribution p.d.f. mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution Returns: JAX ndarray: Value of Normal(x|mu, sigma) """ return norm.pdf(x, loc=mu, scale=sigma)
[docs] def normal_cdf(self, x, mu=0, sigma=1): """ The cumulative distribution function for the Normal distribution Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.normal_cdf(0.8) DeviceArray(0.7881446, dtype=float64) >>> values = pyhf.tensorlib.astensor([0.8, 2.0]) >>> pyhf.tensorlib.normal_cdf(values) DeviceArray([0.7881446 , 0.97724987], dtype=float64) Args: x (:obj:`tensor` or :obj:`float`): The observed value of the random variable to evaluate the CDF for mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution Returns: JAX ndarray: The CDF """ return norm.cdf(x, loc=mu, scale=sigma)
[docs] def poisson_dist(self, rate): r""" The Poisson distribution with rate parameter :code:`rate`. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> rates = pyhf.tensorlib.astensor([5, 8]) >>> values = pyhf.tensorlib.astensor([4, 9]) >>> poissons = pyhf.tensorlib.poisson_dist(rates) >>> poissons.log_prob(values) DeviceArray([-1.74030218, -2.0868536 ], dtype=float64) Args: rate (:obj:`tensor` or :obj:`float`): The mean of the Poisson distribution (the expected number of events) Returns: Poisson distribution: The Poisson distribution class """ return _BasicPoisson(rate)
[docs] def normal_dist(self, mu, sigma): r""" The Normal distribution with mean :code:`mu` and standard deviation :code:`sigma`. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> means = pyhf.tensorlib.astensor([5, 8]) >>> stds = pyhf.tensorlib.astensor([1, 0.5]) >>> values = pyhf.tensorlib.astensor([4, 9]) >>> normals = pyhf.tensorlib.normal_dist(means, stds) >>> normals.log_prob(values) DeviceArray([-1.41893853, -2.22579135], dtype=float64) Args: mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution Returns: Normal distribution: The Normal distribution class """ return _BasicNormal(mu, sigma)
[docs] def to_numpy(self, tensor_in): """ Convert the TensorFlow tensor to a :class:`numpy.ndarray`. Example: >>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> tensor DeviceArray([[1., 2., 3.], [4., 5., 6.]], dtype=float64) >>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor) >>> numpy_ndarray array([[1., 2., 3.], [4., 5., 6.]]) >>> type(numpy_ndarray) <class 'numpy.ndarray'> Args: tensor_in (:obj:`tensor`): The input tensor object. Returns: :class:`numpy.ndarray`: The tensor converted to a NumPy ``ndarray``. """ return np.asarray(tensor_in, dtype=tensor_in.dtype)