jax_backend
- class pyhf.tensor.jax_backend.jax_backend(**kwargs)[source]
Bases:
object
JAX backend for pyhf
Attributes
- name
- precision
- dtypemap
- default_do_grad
- array_subtype
The array content type for jax
- array_type
The array type for jax
Methods
- astensor(tensor_in, dtype='float')[source]
Convert to a JAX ndarray.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> tensor DeviceArray([[1., 2., 3.], [4., 5., 6.]], dtype=float64) >>> type(tensor) <class '...DeviceArray'>
- Parameters
tensor_in (Number or Tensor) – Tensor object
- Returns
A multi-dimensional, fixed-size homogeneous array.
- Return type
jaxlib.xla_extension.DeviceArray
- clip(tensor_in, min_value, max_value)[source]
Clips (limits) the tensor values to be within a specified min and max.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2]) >>> pyhf.tensorlib.clip(a, -1, 1) DeviceArray([-1., -1., 0., 1., 1.], dtype=float64)
- concatenate(sequence, axis=0)[source]
Join a sequence of arrays along an existing axis.
- Parameters
sequence – sequence of tensors
axis – dimension along which to concatenate
- Returns
the concatenated tensor
- Return type
output
- conditional(predicate, true_callable, false_callable)[source]
Runs a callable conditional on the boolean value of the evaluation of a predicate
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensorlib = pyhf.tensorlib >>> a = tensorlib.astensor([4]) >>> b = tensorlib.astensor([5]) >>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b) DeviceArray([9.], dtype=float64)
- Parameters
- Returns
The output of the callable that was evaluated
- Return type
JAX ndarray
- einsum(subscripts, *operands)[source]
Evaluates the Einstein summation convention on the operands.
Using the Einstein summation convention, many common multi-dimensional array operations can be represented in a simple fashion. This function provides a way to compute such summations. The best way to understand this function is to try the examples below, which show how many common NumPy functions can be implemented as calls to einsum.
- Parameters
subscripts – str, specifies the subscripts for summation
operands – list of array_like, these are the tensors for the operation
- Returns
the calculation based on the Einstein summation convention
- Return type
tensor
- erf(tensor_in)[source]
The error function of complex argument.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.]) >>> pyhf.tensorlib.erf(a) DeviceArray([-0.99532227, -0.84270079, 0. , 0.84270079, 0.99532227], dtype=float64)
- Parameters
tensor_in (
tensor
) – The input tensor object- Returns
The values of the error function at the given points.
- Return type
JAX ndarray
- erfinv(tensor_in)[source]
The inverse of the error function of complex argument.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.]) >>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a)) DeviceArray([-2., -1., 0., 1., 2.], dtype=float64)
- Parameters
tensor_in (
tensor
) – The input tensor object- Returns
The values of the inverse of the error function at the given points.
- Return type
JAX ndarray
- normal(x, mu, sigma)[source]
The probability density function of the Normal distribution evaluated at
x
given parameters of mean ofmu
and standard deviation ofsigma
.Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.normal(0.5, 0., 1.) DeviceArray(0.35206533, dtype=float64, weak_type=True) >>> values = pyhf.tensorlib.astensor([0.5, 2.0]) >>> means = pyhf.tensorlib.astensor([0., 2.3]) >>> sigmas = pyhf.tensorlib.astensor([1., 0.8]) >>> pyhf.tensorlib.normal(values, means, sigmas) DeviceArray([0.35206533, 0.46481887], dtype=float64)
- normal_cdf(x, mu=0, sigma=1)[source]
The cumulative distribution function for the Normal distribution
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.normal_cdf(0.8) DeviceArray(0.7881446, dtype=float64) >>> values = pyhf.tensorlib.astensor([0.8, 2.0]) >>> pyhf.tensorlib.normal_cdf(values) DeviceArray([0.7881446 , 0.97724987], dtype=float64)
- normal_dist(mu, sigma)[source]
The Normal distribution with mean
mu
and standard deviationsigma
.Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> means = pyhf.tensorlib.astensor([5, 8]) >>> stds = pyhf.tensorlib.astensor([1, 0.5]) >>> values = pyhf.tensorlib.astensor([4, 9]) >>> normals = pyhf.tensorlib.normal_dist(means, stds) >>> normals.log_prob(values) DeviceArray([-1.41893853, -2.22579135], dtype=float64)
- percentile(tensor_in, q, axis=None, interpolation='linear')[source]
Compute the \(q\)-th percentile of the tensor along the specified axis.
Example
>>> import pyhf >>> import jax.numpy as jnp >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]]) >>> pyhf.tensorlib.percentile(a, 50) DeviceArray(3.5, dtype=float64) >>> pyhf.tensorlib.percentile(a, 50, axis=1) DeviceArray([7., 2.], dtype=float64)
- Parameters
tensor_in (tensor) – The tensor containing the data
q (
float
or tensor) – The \(q\)-th percentile to computeaxis (number or tensor) – The dimensions along which to compute
interpolation (
str
) –The interpolation method to use when the desired percentile lies between two data points
i < j
:'linear'
:i + (j - i) * fraction
, wherefraction
is the fractional part of the index surrounded byi
andj
.'lower'
:i
.'higher'
:j
.'midpoint'
:(i + j) / 2
.'nearest'
:i
orj
, whichever is nearest.
- Returns
The value of the \(q\)-th percentile of the tensor along the specified axis.
- Return type
JAX ndarray
New in version 0.7.0.
- poisson(n, lam)[source]
The continuous approximation, using \(n! = \Gamma\left(n+1\right)\), to the probability mass function of the Poisson distribution evaluated at
n
given the parameterlam
.Note
Though the p.m.f of the Poisson distribution is not defined for \(\lambda = 0\), the limit as \(\lambda \to 0\) is still defined, which gives a degenerate p.m.f. of
\[\begin{split}\lim_{\lambda \to 0} \,\mathrm{Pois}(n | \lambda) = \left\{\begin{array}{ll} 1, & n = 0,\\ 0, & n > 0 \end{array}\right.\end{split}\]Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.poisson(5., 6.) DeviceArray(0.16062314, dtype=float64, weak_type=True) >>> values = pyhf.tensorlib.astensor([5., 9.]) >>> rates = pyhf.tensorlib.astensor([6., 8.]) >>> pyhf.tensorlib.poisson(values, rates) DeviceArray([0.16062314, 0.12407692], dtype=float64)
- Parameters
- Returns
Value of the continuous approximation to Poisson(n|lam)
- Return type
JAX ndarray
- poisson_dist(rate)[source]
The Poisson distribution with rate parameter
rate
.Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> rates = pyhf.tensorlib.astensor([5, 8]) >>> values = pyhf.tensorlib.astensor([4, 9]) >>> poissons = pyhf.tensorlib.poisson_dist(rates) >>> poissons.log_prob(values) DeviceArray([-1.74030218, -2.0868536 ], dtype=float64)
- Parameters
rate (
tensor
orfloat
) – The mean of the Poisson distribution (the expected number of events)- Returns
The Poisson distribution class
- Return type
Poisson distribution
- ravel(tensor)[source]
Return a flattened view of the tensor, not a copy.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> pyhf.tensorlib.ravel(tensor) DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float64)
- Parameters
tensor (Tensor) – Tensor object
- Returns
A flattened array.
- Return type
jaxlib.xla_extension.DeviceArray
- simple_broadcast(*args)[source]
Broadcast a sequence of 1 dimensional arrays.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.simple_broadcast( ... pyhf.tensorlib.astensor([1]), ... pyhf.tensorlib.astensor([2, 3, 4]), ... pyhf.tensorlib.astensor([5, 6, 7])) [DeviceArray([1., 1., 1.], dtype=float64), DeviceArray([2., 3., 4.], dtype=float64), DeviceArray([5., 6., 7.], dtype=float64)]
- Parameters
args (Array of Tensors) – Sequence of arrays
- Returns
The sequence broadcast together.
- Return type
list of Tensors
- tile(tensor_in, repeats)[source]
Repeat tensor data along a specific dimension
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([[1.0], [2.0]]) >>> pyhf.tensorlib.tile(a, (1, 2)) DeviceArray([[1., 1.], [2., 2.]], dtype=float64)
- Parameters
tensor_in (
tensor
) – The tensor to be repeatedrepeats (
tensor
) – The tuple of multipliers for each dimension
- Returns
The tensor with repeated axes
- Return type
JAX ndarray
- to_numpy(tensor_in)[source]
Convert the JAX tensor to a
numpy.ndarray
.Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> tensor DeviceArray([[1., 2., 3.], [4., 5., 6.]], dtype=float64) >>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor) >>> numpy_ndarray array([[1., 2., 3.], [4., 5., 6.]]) >>> type(numpy_ndarray) <class 'numpy.ndarray'>
- Parameters
tensor_in (
tensor
) – The input tensor object.- Returns
The tensor converted to a NumPy
ndarray
.- Return type
- transpose(tensor_in)[source]
Transpose the tensor.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> tensor DeviceArray([[1., 2., 3.], [4., 5., 6.]], dtype=float64) >>> pyhf.tensorlib.transpose(tensor) DeviceArray([[1., 4.], [2., 5.], [3., 6.]], dtype=float64)
- Parameters
tensor_in (
tensor
) – The input tensor object.- Returns
The transpose of the input tensor.
- Return type
JAX ndarray
New in version 0.7.0.