jax_backend

class pyhf.tensor.jax_backend.jax_backend(**kwargs)[source]

Bases: object

JAX backend for pyhf

__init__(**kwargs)[source]

Initialize self. See help(type(self)) for accurate signature.

Attributes

default_do_grad
dtypemap
name
precision

Methods

_setup()[source]

Run any global setups for the jax lib.

abs(tensor)[source]
astensor(tensor_in, dtype='float')[source]

Convert to a JAX ndarray.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
DeviceArray([[1., 2., 3.],
             [4., 5., 6.]], dtype=float64)
>>> type(tensor)
<class 'jax.interpreters.xla._DeviceArray'>
Parameters

tensor_in (Number or Tensor) – Tensor object

Returns

A multi-dimensional, fixed-size homogenous array.

Return type

jax.interpreters.xla._DeviceArray

boolean_mask(tensor, mask)[source]
clip(tensor_in, min_value, max_value)[source]

Clips (limits) the tensor values to be within a specified min and max.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
>>> pyhf.tensorlib.clip(a, -1, 1)
DeviceArray([-1., -1.,  0.,  1.,  1.], dtype=float64)
Parameters
  • tensor_in (tensor) – The input tensor object

  • min_value (scalar or tensor or None) – The minimum value to be cliped to

  • max_value (scalar or tensor or None) – The maximum value to be cliped to

Returns

A clipped tensor

Return type

JAX ndarray

concatenate(sequence, axis=0)[source]

Join a sequence of arrays along an existing axis.

Parameters
  • sequence – sequence of tensors

  • axis – dimension along which to concatenate

Returns

the concatenated tensor

Return type

output

conditional(predicate, true_callable, false_callable)[source]

Runs a callable conditional on the boolean value of the evaulation of a predicate

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensorlib = pyhf.tensorlib
>>> a = tensorlib.astensor([4])
>>> b = tensorlib.astensor([5])
>>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b)
DeviceArray([9.], dtype=float64)
Parameters
  • predicate (scalar) – The logical condition that determines which callable to evaluate

  • true_callable (callable) – The callable that is evaluated when the predicate evalutes to true

  • false_callable (callable) – The callable that is evaluated when the predicate evalutes to false

Returns

The output of the callable that was evaluated

Return type

JAX ndarray

divide(tensor_in_1, tensor_in_2)[source]
einsum(subscripts, *operands)[source]

Evaluates the Einstein summation convention on the operands.

Using the Einstein summation convention, many common multi-dimensional array operations can be represented in a simple fashion. This function provides a way to compute such summations. The best way to understand this function is to try the examples below, which show how many common NumPy functions can be implemented as calls to einsum.

Parameters
  • subscripts – str, specifies the subscripts for summation

  • operands – list of array_like, these are the tensors for the operation

Returns

the calculation based on the Einstein summation convention

Return type

tensor

erf(tensor_in)[source]

The error function of complex argument.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
>>> pyhf.tensorlib.erf(a)
DeviceArray([-0.99532227, -0.84270079,  0.        ,  0.84270079,
              0.99532227], dtype=float64)
Parameters

tensor_in (tensor) – The input tensor object

Returns

The values of the error function at the given points.

Return type

JAX ndarray

erfinv(tensor_in)[source]

The inverse of the error function of complex argument.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
>>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a))
DeviceArray([-2., -1.,  0.,  1.,  2.], dtype=float64)
Parameters

tensor_in (tensor) – The input tensor object

Returns

The values of the inverse of the error function at the given points.

Return type

JAX ndarray

exp(tensor_in)[source]
gather(tensor, indices)[source]
isfinite(tensor)[source]
log(tensor_in)[source]
normal(x, mu, sigma)[source]

The probability density function of the Normal distribution evaluated at x given parameters of mean of mu and standard deviation of sigma.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.normal(0.5, 0., 1.)
DeviceArray(0.35206533, dtype=float64)
>>> values = pyhf.tensorlib.astensor([0.5, 2.0])
>>> means = pyhf.tensorlib.astensor([0., 2.3])
>>> sigmas = pyhf.tensorlib.astensor([1., 0.8])
>>> pyhf.tensorlib.normal(values, means, sigmas)
DeviceArray([0.35206533, 0.46481887], dtype=float64)
Parameters
  • x (tensor or float) – The value at which to evaluate the Normal distribution p.d.f.

  • mu (tensor or float) – The mean of the Normal distribution

  • sigma (tensor or float) – The standard deviation of the Normal distribution

Returns

Value of Normal(x|mu, sigma)

Return type

JAX ndarray

normal_cdf(x, mu=0, sigma=1)[source]

The cumulative distribution function for the Normal distribution

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.normal_cdf(0.8)
DeviceArray(0.7881446, dtype=float64)
>>> values = pyhf.tensorlib.astensor([0.8, 2.0])
>>> pyhf.tensorlib.normal_cdf(values)
DeviceArray([0.7881446 , 0.97724987], dtype=float64)
Parameters
  • x (tensor or float) – The observed value of the random variable to evaluate the CDF for

  • mu (tensor or float) – The mean of the Normal distribution

  • sigma (tensor or float) – The standard deviation of the Normal distribution

Returns

The CDF

Return type

JAX ndarray

normal_dist(mu, sigma)[source]

The Normal distribution with mean mu and standard deviation sigma.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> means = pyhf.tensorlib.astensor([5, 8])
>>> stds = pyhf.tensorlib.astensor([1, 0.5])
>>> values = pyhf.tensorlib.astensor([4, 9])
>>> normals = pyhf.tensorlib.normal_dist(means, stds)
>>> normals.log_prob(values)
DeviceArray([-1.41893853, -2.22579135], dtype=float64)
Parameters
  • mu (tensor or float) – The mean of the Normal distribution

  • sigma (tensor or float) – The standard deviation of the Normal distribution

Returns

The Normal distribution class

Return type

Normal distribution

normal_logpdf(x, mu, sigma)[source]
ones(shape)[source]
outer(tensor_in_1, tensor_in_2)[source]
poisson(n, lam)[source]

The continuous approximation, using \(n! = \Gamma\left(n+1\right)\), to the probability mass function of the Poisson distribution evaluated at n given the parameter lam.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.poisson(5., 6.)
DeviceArray(0.16062314, dtype=float64)
>>> values = pyhf.tensorlib.astensor([5., 9.])
>>> rates = pyhf.tensorlib.astensor([6., 8.])
>>> pyhf.tensorlib.poisson(values, rates)
DeviceArray([0.16062314, 0.12407692], dtype=float64)
Parameters
  • n (tensor or float) – The value at which to evaluate the approximation to the Poisson distribution p.m.f. (the observed number of events)

  • lam (tensor or float) – The mean of the Poisson distribution p.m.f. (the expected number of events)

Returns

Value of the continuous approximation to Poisson(n|lam)

Return type

JAX ndarray

poisson_dist(rate)[source]

The Poisson distribution with rate parameter rate.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> rates = pyhf.tensorlib.astensor([5, 8])
>>> values = pyhf.tensorlib.astensor([4, 9])
>>> poissons = pyhf.tensorlib.poisson_dist(rates)
>>> poissons.log_prob(values)
DeviceArray([-1.74030218, -2.0868536 ], dtype=float64)
Parameters

rate (tensor or float) – The mean of the Poisson distribution (the expected number of events)

Returns

The Poisson distribution class

Return type

Poisson distribution

poisson_logpdf(n, lam)[source]
power(tensor_in_1, tensor_in_2)[source]
product(tensor_in, axis=None)[source]
ravel(tensor)[source]

Return a flattened view of the tensor, not a copy.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> pyhf.tensorlib.ravel(tensor)
DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float64)
Parameters

tensor (Tensor) – Tensor object

Returns

A flattened array.

Return type

jax.interpreters.xla._DeviceArray

reshape(tensor, newshape)[source]
shape(tensor)[source]
simple_broadcast(*args)[source]

Broadcast a sequence of 1 dimensional arrays.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.simple_broadcast(
...   pyhf.tensorlib.astensor([1]),
...   pyhf.tensorlib.astensor([2, 3, 4]),
...   pyhf.tensorlib.astensor([5, 6, 7]))
[DeviceArray([1., 1., 1.], dtype=float64), DeviceArray([2., 3., 4.], dtype=float64), DeviceArray([5., 6., 7.], dtype=float64)]
Parameters

args (Array of Tensors) – Sequence of arrays

Returns

The sequence broadcast together.

Return type

list of Tensors

sqrt(tensor_in)[source]
stack(sequence, axis=0)[source]
sum(tensor_in, axis=None)[source]
tile(tensor_in, repeats)[source]

Repeat tensor data along a specific dimension

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([[1.0], [2.0]])
>>> pyhf.tensorlib.tile(a, (1, 2))
DeviceArray([[1., 1.],
             [2., 2.]], dtype=float64)
Parameters
  • tensor_in (tensor) – The tensor to be repeated

  • repeats (tensor) – The tuple of multipliers for each dimension

Returns

The tensor with repeated axes

Return type

JAX ndarray

to_numpy(tensor_in)[source]

Convert the TensorFlow tensor to a numpy.ndarray.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
DeviceArray([[1., 2., 3.],
             [4., 5., 6.]], dtype=float64)
>>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor)
>>> numpy_ndarray
array([[1., 2., 3.],
       [4., 5., 6.]])
>>> type(numpy_ndarray)
<class 'numpy.ndarray'>
Parameters

tensor_in (tensor) – The input tensor object.

Returns

The tensor converted to a NumPy ndarray.

Return type

numpy.ndarray

tolist(tensor_in)[source]
where(mask, tensor_in_1, tensor_in_2)[source]
zeros(shape)[source]