from jax.config import config
config.update('jax_enable_x64', True)
import jax.numpy as jnp
from jax.scipy.special import gammaln
from jax.scipy import special
from jax.scipy.stats import norm
import numpy as np
import scipy.stats as osp_stats
import logging
log = logging.getLogger(__name__)
class _BasicPoisson(object):
def __init__(self, rate):
self.rate = rate
def sample(self, sample_shape):
# TODO: Support other dtypes
return jnp.asarray(
osp_stats.poisson(self.rate).rvs(size=sample_shape + self.rate.shape),
dtype=jnp.float64,
)
def log_prob(self, value):
tensorlib = jax_backend()
return tensorlib.poisson_logpdf(value, self.rate)
class _BasicNormal(object):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
def sample(self, sample_shape):
# TODO: Support other dtypes
return jnp.asarray(
osp_stats.norm(self.loc, self.scale).rvs(
size=sample_shape + self.loc.shape
),
dtype=jnp.float64,
)
def log_prob(self, value):
tensorlib = jax_backend()
return tensorlib.normal_logpdf(value, self.loc, self.scale)
[docs]class jax_backend(object):
"""JAX backend for pyhf"""
__slots__ = ['name', 'precision', 'dtypemap', 'default_do_grad']
[docs] def __init__(self, **kwargs):
self.name = 'jax'
self.precision = kwargs.get('precision', '64b')
self.dtypemap = {
'float': jnp.float64 if self.precision == '64b' else jnp.float32,
'int': jnp.int64 if self.precision == '64b' else jnp.int32,
'bool': jnp.bool_,
}
self.default_do_grad = True
def _setup(self):
"""
Run any global setups for the jax lib.
"""
[docs] def clip(self, tensor_in, min_value, max_value):
"""
Clips (limits) the tensor values to be within a specified min and max.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
>>> pyhf.tensorlib.clip(a, -1, 1)
DeviceArray([-1., -1., 0., 1., 1.], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object
min_value (:obj:`scalar` or :obj:`tensor` or :obj:`None`): The minimum value to be cliped to
max_value (:obj:`scalar` or :obj:`tensor` or :obj:`None`): The maximum value to be cliped to
Returns:
JAX ndarray: A clipped `tensor`
"""
return jnp.clip(tensor_in, min_value, max_value)
[docs] def erf(self, tensor_in):
"""
The error function of complex argument.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
>>> pyhf.tensorlib.erf(a)
DeviceArray([-0.99532227, -0.84270079, 0. , 0.84270079,
0.99532227], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object
Returns:
JAX ndarray: The values of the error function at the given points.
"""
return special.erf(tensor_in)
[docs] def erfinv(self, tensor_in):
"""
The inverse of the error function of complex argument.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
>>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a))
DeviceArray([-2., -1., 0., 1., 2.], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object
Returns:
JAX ndarray: The values of the inverse of the error function at the given points.
"""
return special.erfinv(tensor_in)
[docs] def tile(self, tensor_in, repeats):
"""
Repeat tensor data along a specific dimension
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([[1.0], [2.0]])
>>> pyhf.tensorlib.tile(a, (1, 2))
DeviceArray([[1., 1.],
[2., 2.]], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The tensor to be repeated
repeats (:obj:`tensor`): The tuple of multipliers for each dimension
Returns:
JAX ndarray: The tensor with repeated axes
"""
return jnp.tile(tensor_in, repeats)
[docs] def conditional(self, predicate, true_callable, false_callable):
"""
Runs a callable conditional on the boolean value of the evaulation of a predicate
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensorlib = pyhf.tensorlib
>>> a = tensorlib.astensor([4])
>>> b = tensorlib.astensor([5])
>>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b)
DeviceArray([9.], dtype=float64)
Args:
predicate (:obj:`scalar`): The logical condition that determines which callable to evaluate
true_callable (:obj:`callable`): The callable that is evaluated when the :code:`predicate` evalutes to :code:`true`
false_callable (:obj:`callable`): The callable that is evaluated when the :code:`predicate` evalutes to :code:`false`
Returns:
JAX ndarray: The output of the callable that was evaluated
"""
return true_callable() if predicate else false_callable()
[docs] def tolist(self, tensor_in):
try:
return np.asarray(tensor_in).tolist()
except AttributeError:
if isinstance(tensor_in, list):
return tensor_in
raise
[docs] def outer(self, tensor_in_1, tensor_in_2):
return jnp.outer(tensor_in_1, tensor_in_2)
[docs] def gather(self, tensor, indices):
return tensor[indices]
[docs] def boolean_mask(self, tensor, mask):
return tensor[mask]
[docs] def isfinite(self, tensor):
return jnp.isfinite(tensor)
[docs] def astensor(self, tensor_in, dtype='float'):
"""
Convert to a JAX ndarray.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
DeviceArray([[1., 2., 3.],
[4., 5., 6.]], dtype=float64)
>>> type(tensor)
<class 'jax.interpreters.xla._DeviceArray'>
Args:
tensor_in (Number or Tensor): Tensor object
Returns:
`jax.interpreters.xla._DeviceArray`: A multi-dimensional, fixed-size homogenous array.
"""
try:
dtype = self.dtypemap[dtype]
except KeyError:
log.error('Invalid dtype: dtype must be float, int, or bool.')
raise
return jnp.asarray(tensor_in, dtype=dtype)
[docs] def sum(self, tensor_in, axis=None):
return jnp.sum(tensor_in, axis=axis)
[docs] def product(self, tensor_in, axis=None):
return jnp.prod(tensor_in, axis=axis)
[docs] def abs(self, tensor):
return jnp.abs(tensor)
[docs] def ones(self, shape):
return jnp.ones(shape)
[docs] def zeros(self, shape):
return jnp.zeros(shape)
[docs] def power(self, tensor_in_1, tensor_in_2):
return jnp.power(tensor_in_1, tensor_in_2)
[docs] def sqrt(self, tensor_in):
return jnp.sqrt(tensor_in)
[docs] def divide(self, tensor_in_1, tensor_in_2):
return jnp.divide(tensor_in_1, tensor_in_2)
[docs] def log(self, tensor_in):
return jnp.log(tensor_in)
[docs] def exp(self, tensor_in):
return jnp.exp(tensor_in)
[docs] def stack(self, sequence, axis=0):
return jnp.stack(sequence, axis=axis)
[docs] def where(self, mask, tensor_in_1, tensor_in_2):
return jnp.where(mask, tensor_in_1, tensor_in_2)
[docs] def concatenate(self, sequence, axis=0):
"""
Join a sequence of arrays along an existing axis.
Args:
sequence: sequence of tensors
axis: dimension along which to concatenate
Returns:
output: the concatenated tensor
"""
return jnp.concatenate(sequence, axis=axis)
[docs] def simple_broadcast(self, *args):
"""
Broadcast a sequence of 1 dimensional arrays.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.simple_broadcast(
... pyhf.tensorlib.astensor([1]),
... pyhf.tensorlib.astensor([2, 3, 4]),
... pyhf.tensorlib.astensor([5, 6, 7]))
[DeviceArray([1., 1., 1.], dtype=float64), DeviceArray([2., 3., 4.], dtype=float64), DeviceArray([5., 6., 7.], dtype=float64)]
Args:
args (Array of Tensors): Sequence of arrays
Returns:
list of Tensors: The sequence broadcast together.
"""
return jnp.broadcast_arrays(*args)
[docs] def shape(self, tensor):
return tensor.shape
[docs] def reshape(self, tensor, newshape):
return jnp.reshape(tensor, newshape)
[docs] def einsum(self, subscripts, *operands):
"""
Evaluates the Einstein summation convention on the operands.
Using the Einstein summation convention, many common multi-dimensional
array operations can be represented in a simple fashion. This function
provides a way to compute such summations. The best way to understand
this function is to try the examples below, which show how many common
NumPy functions can be implemented as calls to einsum.
Args:
subscripts: str, specifies the subscripts for summation
operands: list of array_like, these are the tensors for the operation
Returns:
tensor: the calculation based on the Einstein summation convention
"""
# return contract(subscripts,*operands)
return jnp.einsum(subscripts, *operands)
[docs] def poisson_logpdf(self, n, lam):
n = jnp.asarray(n)
lam = jnp.asarray(lam)
return n * jnp.log(lam) - lam - gammaln(n + 1.0)
[docs] def poisson(self, n, lam):
r"""
The continous approximation, using :math:`n! = \Gamma\left(n+1\right)`,
to the probability mass function of the Poisson distribution evaluated
at :code:`n` given the parameter :code:`lam`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.poisson(5., 6.)
DeviceArray(0.16062314, dtype=float64)
>>> values = pyhf.tensorlib.astensor([5., 9.])
>>> rates = pyhf.tensorlib.astensor([6., 8.])
>>> pyhf.tensorlib.poisson(values, rates)
DeviceArray([0.16062314, 0.12407692], dtype=float64)
Args:
n (:obj:`tensor` or :obj:`float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f.
(the observed number of events)
lam (:obj:`tensor` or :obj:`float`): The mean of the Poisson distribution p.m.f.
(the expected number of events)
Returns:
JAX ndarray: Value of the continous approximation to Poisson(n|lam)
"""
n = jnp.asarray(n)
lam = jnp.asarray(lam)
return jnp.exp(n * jnp.log(lam) - lam - gammaln(n + 1.0))
[docs] def normal_logpdf(self, x, mu, sigma):
# this is much faster than
# norm.logpdf(x, loc=mu, scale=sigma)
# https://codereview.stackexchange.com/questions/69718/fastest-computation-of-n-likelihoods-on-normal-distributions
root2 = jnp.sqrt(2)
root2pi = jnp.sqrt(2 * jnp.pi)
prefactor = -jnp.log(sigma * root2pi)
summand = -jnp.square(jnp.divide((x - mu), (root2 * sigma)))
return prefactor + summand
# def normal_logpdf(self, x, mu, sigma):
# return norm.logpdf(x, loc=mu, scale=sigma)
[docs] def normal(self, x, mu, sigma):
r"""
The probability density function of the Normal distribution evaluated
at :code:`x` given parameters of mean of :code:`mu` and standard deviation
of :code:`sigma`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.normal(0.5, 0., 1.)
DeviceArray(0.35206533, dtype=float64)
>>> values = pyhf.tensorlib.astensor([0.5, 2.0])
>>> means = pyhf.tensorlib.astensor([0., 2.3])
>>> sigmas = pyhf.tensorlib.astensor([1., 0.8])
>>> pyhf.tensorlib.normal(values, means, sigmas)
DeviceArray([0.35206533, 0.46481887], dtype=float64)
Args:
x (:obj:`tensor` or :obj:`float`): The value at which to evaluate the Normal distribution p.d.f.
mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution
sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution
Returns:
JAX ndarray: Value of Normal(x|mu, sigma)
"""
return norm.pdf(x, loc=mu, scale=sigma)
[docs] def normal_cdf(self, x, mu=0, sigma=1):
"""
The cumulative distribution function for the Normal distribution
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.normal_cdf(0.8)
DeviceArray(0.7881446, dtype=float64)
>>> values = pyhf.tensorlib.astensor([0.8, 2.0])
>>> pyhf.tensorlib.normal_cdf(values)
DeviceArray([0.7881446 , 0.97724987], dtype=float64)
Args:
x (:obj:`tensor` or :obj:`float`): The observed value of the random variable to evaluate the CDF for
mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution
sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution
Returns:
JAX ndarray: The CDF
"""
return norm.cdf(x, loc=mu, scale=sigma)
[docs] def poisson_dist(self, rate):
r"""
The Poisson distribution with rate parameter :code:`rate`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> rates = pyhf.tensorlib.astensor([5, 8])
>>> values = pyhf.tensorlib.astensor([4, 9])
>>> poissons = pyhf.tensorlib.poisson_dist(rates)
>>> poissons.log_prob(values)
DeviceArray([-1.74030218, -2.0868536 ], dtype=float64)
Args:
rate (:obj:`tensor` or :obj:`float`): The mean of the Poisson distribution (the expected number of events)
Returns:
Poisson distribution: The Poisson distribution class
"""
return _BasicPoisson(rate)
[docs] def normal_dist(self, mu, sigma):
r"""
The Normal distribution with mean :code:`mu` and standard deviation :code:`sigma`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> means = pyhf.tensorlib.astensor([5, 8])
>>> stds = pyhf.tensorlib.astensor([1, 0.5])
>>> values = pyhf.tensorlib.astensor([4, 9])
>>> normals = pyhf.tensorlib.normal_dist(means, stds)
>>> normals.log_prob(values)
DeviceArray([-1.41893853, -2.22579135], dtype=float64)
Args:
mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution
sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution
Returns:
Normal distribution: The Normal distribution class
"""
return _BasicNormal(mu, sigma)