Source code for pyhf.optimize.opt_jax

"""JAX Optimizer Backend."""

from .. import get_backend, default_backend
from ..tensor.common import _TensorViewer
from .autodiff import AutoDiffOptimizerMixin
import jax


def _final_objective(pars, data, fixed_vals, model, objective, fixed_idx, variable_idx):
    tensorlib, _ = get_backend()
    tv = _TensorViewer([fixed_idx, variable_idx])
    pars = tensorlib.astensor(pars)
    constrained_pars = tv.stitch([fixed_vals, pars])
    return objective(constrained_pars, data, model)[0]


_jitted_objective_and_grad = jax.jit(
    jax.value_and_grad(_final_objective), static_argnums=(3, 4, 5, 6)
)


[docs]class jax_optimizer(AutoDiffOptimizerMixin): """JAX Optimizer Backend."""
[docs] def setup_minimize( self, objective, data, pdf, init_pars, par_bounds, fixed_vals=None ): """ Prepare Minimization for AutoDiff-Optimizer. Args: objective: objective function data: observed data pdf: model init_pars: initial parameters par_bounds: parameter boundaries fixed_vals: fixed parameter values """ tensorlib, _ = get_backend() all_idx = default_backend.astensor(range(pdf.config.npars), dtype='int') all_init = default_backend.astensor(init_pars) fixed_vals = fixed_vals or [] fixed_values = [x[1] for x in fixed_vals] fixed_idx = [x[0] for x in fixed_vals] variable_idx = [x for x in all_idx if x not in fixed_idx] variable_init = all_init[variable_idx] variable_bounds = [par_bounds[i] for i in variable_idx] tv = _TensorViewer([fixed_idx, variable_idx]) data = tensorlib.astensor(data) fixed_values_tensor = tensorlib.astensor(fixed_values, dtype='float') def func(pars): # need to conver to tuple to make args hashable return _jitted_objective_and_grad( pars, data, fixed_values_tensor, pdf, objective, tuple(fixed_idx), tuple(variable_idx), ) return tv, fixed_values_tensor, func, variable_init, variable_bounds