Source code for pyhf.workspace

import logging
import jsonpatch
from . import exceptions
from . import utils
from .pdf import Model
from .mixins import _ChannelSummaryMixin

logging.basicConfig()
log = logging.getLogger(__name__)


[docs]class Workspace(_ChannelSummaryMixin, dict): """ A JSON-serializable object that is built from an object that follows the `workspace.json` schema. """
[docs] def __init__(self, spec, **config_kwargs): super(Workspace, self).__init__(spec, channels=spec['channels']) self.schema = config_kwargs.pop('schema', 'workspace.json') self.version = config_kwargs.pop('version', None) # run jsonschema validation of input specification against the (provided) schema log.info("Validating spec against schema: {0:s}".format(self.schema)) utils.validate(self, self.schema, version=self.version) self.measurement_names = [] for measurement in self.get('measurements', []): self.measurement_names.append(measurement['name']) self.observations = {} for obs in self['observations']: self.observations[obs['name']] = obs['data']
def __eq__(self, other): if not isinstance(other, Workspace): return False return dict(self) == dict(other) def __ne__(self, other): return not self == other def __repr__(self): return object.__repr__(self) # NB: this is a wrapper function to validate the returned measurement object against the spec
[docs] def get_measurement(self, **config_kwargs): """ Get (or create) a measurement object using the following logic: 1. if the poi name is given, create a measurement object for that poi 2. if the measurement name is given, find the measurement for the given name 3. if the measurement index is given, return the measurement at that index 4. if there are measurements but none of the above have been specified, return the 0th measurement Raises: ~pyhf.exceptions.InvalidMeasurement: If the measurement was not found Args: poi_name (str): The name of the parameter of interest to create a new measurement from measurement_name (str): The name of the measurement to use measurement_index (int): The index of the measurement to use Returns: :obj:`dict`: A measurement object adhering to the schema defs.json#/definitions/measurement """ m = self._get_measurement(**config_kwargs) utils.validate(m, 'measurement.json', self.version) return m
def _get_measurement(self, **config_kwargs): """ See `Workspace::get_measurement`. """ poi_name = config_kwargs.get('poi_name') if poi_name: return { 'name': 'NormalMeasurement', 'config': {'poi': poi_name, 'parameters': []}, } if self.measurement_names: measurement_name = config_kwargs.get('measurement_name') if measurement_name: if measurement_name not in self.measurement_names: log.debug( 'measurements defined:\n\t{0:s}'.format( '\n\t'.join(self.measurement_names) ) ) raise exceptions.InvalidMeasurement( 'no measurement by name \'{0:s}\' was found in the workspace, pick from one of the valid ones above'.format( measurement_name ) ) return self['measurements'][ self.measurement_names.index(measurement_name) ] measurement_index = config_kwargs.get('measurement_index') if measurement_index: return self['measurements'][measurement_index] if len(self.measurement_names) > 1: log.warning( 'multiple measurements defined. Taking the first measurement.' ) return self['measurements'][0] raise exceptions.InvalidMeasurement( "A measurement was not given to create the Model." )
[docs] def model(self, **config_kwargs): """ Create a model object with/without patches applied. Args: patches: A list of JSON patches to apply to the model specification Returns: ~pyhf.pdf.Model: A model object adhering to the schema model.json """ measurement = self.get_measurement(**config_kwargs) log.debug( 'model being created for measurement {0:s}'.format(measurement['name']) ) patches = config_kwargs.get('patches', []) modelspec = { 'channels': self['channels'], 'parameters': measurement['config']['parameters'], } for patch in patches: modelspec = jsonpatch.JsonPatch(patch).apply(modelspec) return Model(modelspec, poiname=measurement['config']['poi'], **config_kwargs)
[docs] def data(self, model, with_aux=True): """ Return the data for the supplied model with or without auxiliary data from the model. The model is needed as the order of the data depends on the order of the channels in the model. Raises: KeyError: Invalid or missing channel Args: model (~pyhf.pdf.Model): A model object adhering to the schema model.json with_aux (bool): Whether to include auxiliary data from the model or not Returns: :obj:`list`: data """ try: observed_data = sum( (self.observations[c] for c in model.config.channels), [] ) except KeyError: log.error( "Invalid channel: the workspace does not have observation data for one of the channels in the model." ) raise if with_aux: observed_data += model.config.auxdata return observed_data
def _prune_and_rename( self, prune_modifiers=[], prune_modifier_types=[], prune_samples=[], prune_channels=[], prune_measurements=[], rename_modifiers={}, rename_samples={}, rename_channels={}, rename_measurements={}, ): """ Return a new, pruned, renamed workspace specification. This will not modify the original workspace. Pruning removes pieces of the workspace whose name or type matches the user-provided lists. The pruned, renamed workspace must also be a valid workspace. A workspace is composed of many named components, such as channels and samples, as well as types of systematics (e.g. `histosys`). Components can be removed (pruned away) filtering on name or be renamed according to the provided :obj:`dict` mapping. Additionally, modifiers of specific types can be removed (pruned away). This function also handles specific peculiarities, such as renaming/removing a channel which needs to rename/remove the corresponding `observation`. Args: prune_modifiers: A :obj:`str` or a :obj:`list` of modifiers to prune. prune_modifier_types: A :obj:`str` or a :obj:`list` of modifier types to prune. prune_samples: A :obj:`str` or a :obj:`list` of samples to prune. prune_channels: A :obj:`str` or a :obj:`list` of channels to prune. prune_measurements: A :obj:`str` or a :obj:`list` of measurements to prune. rename_modifiers: A :obj:`dict` mapping old modifier name to new modifier name. rename_samples: A :obj:`dict` mapping old sample name to new sample name. rename_channels: A :obj:`dict` mapping old channel name to new channel name. rename_measurements: A :obj:`dict` mapping old measurement name to new measurement name. Returns: ~pyhf.workspace.Workspace: A new workspace object with the specified components removed or renamed """ newspec = { 'channels': [ { 'name': rename_channels.get(channel['name'], channel['name']), 'samples': [ { 'name': rename_samples.get(sample['name'], sample['name']), 'data': sample['data'], 'modifiers': [ dict( modifier, name=rename_modifiers.get( modifier['name'], modifier['name'] ), ) for modifier in sample['modifiers'] if modifier['name'] not in prune_modifiers and modifier['type'] not in prune_modifier_types ], } for sample in channel['samples'] if sample['name'] not in prune_samples ], } for channel in self['channels'] if channel['name'] not in prune_channels ], 'measurements': [ { 'name': rename_measurements.get( measurement['name'], measurement['name'] ), 'config': { 'parameters': [ dict( parameter, name=rename_modifiers.get( parameter['name'], parameter['name'] ), ) for parameter in measurement['config']['parameters'] if parameter['name'] not in prune_modifiers ], 'poi': rename_modifiers.get( measurement['config']['poi'], measurement['config']['poi'] ), }, } for measurement in self['measurements'] if measurement['name'] not in prune_measurements ], 'observations': [ dict( observation, name=rename_channels.get(observation['name'], observation['name']), ) for observation in self['observations'] if observation['name'] not in prune_channels ], 'version': self['version'], } return Workspace(newspec)
[docs] def prune( self, modifiers=[], modifier_types=[], samples=[], channels=[], measurements=[] ): """ Return a new, pruned workspace specification. This will not modify the original workspace. The pruned workspace must also be a valid workspace. Args: modifiers: A :obj:`str` or a :obj:`list` of modifiers to prune. modifier_types: A :obj:`str` or a :obj:`list` of modifier types to prune. samples: A :obj:`str` or a :obj:`list` of samples to prune. channels: A :obj:`str` or a :obj:`list` of channels to prune. measurements: A :obj:`str` or a :obj:`list` of measurements to prune. Returns: ~pyhf.workspace.Workspace: A new workspace object with the specified components removed """ return self._prune_and_rename( prune_modifiers=modifiers, prune_modifier_types=modifier_types, prune_samples=samples, prune_channels=channels, prune_measurements=measurements, )
[docs] def rename(self, modifiers={}, samples={}, channels={}, measurements={}): """ Return a new workspace specification with certain elements renamed. This will not modify the original workspace. The renamed workspace must also be a valid workspace. Args: modifiers: A :obj:`dict` mapping old modifier name to new modifier name. samples: A :obj:`dict` mapping old sample name to new sample name. channels: A :obj:`dict` mapping old channel name to new channel name. measurements: A :obj:`dict` mapping old measurement name to new measurement name. Returns: ~pyhf.workspace.Workspace: A new workspace object with the specified components renamed """ return self._prune_and_rename( rename_modifiers=modifiers, rename_samples=samples, rename_channels=channels, rename_measurements=measurements, )
[docs] @classmethod def combine(cls, left, right): """ Return a new workspace specification that is the combination of the two workspaces. The new workspace must also be a valid workspace. A combination of workspaces is done by combining the set of: - channels, - observations, and - measurements between the two workspaces. If the two workspaces have modifiers that follow the same naming convention, then correlations across the two workspaces may be possible. In particular, the `lumi` modifier will be fully-correlated. If the two workspaces have the same measurement (with the same POI), those measurements will get merged. Raises: ~pyhf.exceptions.InvalidWorkspaceOperation: The workspaces have common channel names, incompatible measurements, or incompatible schema versions. Args: left (~pyhf.workspace.Workspace): A workspace right (~pyhf.workspace.Workspace): Another workspace Returns: ~pyhf.workspace.Workspace: A new combined workspace object """ common_channels = set(left.channels).intersection(right.channels) if common_channels: raise exceptions.InvalidWorkspaceOperation( "Workspaces cannot have any channels in common: {}".format( common_channels ) ) common_measurements = set(left.measurement_names).intersection( right.measurement_names ) incompatible_poi = [ left.get_measurement(measurement_name=m)['config']['poi'] != right.get_measurement(measurement_name=m)['config']['poi'] for m in common_measurements ] if any(incompatible_poi): raise exceptions.InvalidWorkspaceOperation( "Workspaces cannot have any measurements with incompatible POI: {}".format( [ m for m, i in zip(common_measurements, incompatible_poi) if incompatible_poi ] ) ) if left.version != right.version: raise exceptions.InvalidWorkspaceOperation( "Workspaces of different versions cannot be combined: {} != {}".format( left.version, right.version ) ) left_measurements = [ left.get_measurement(measurement_name=m) for m in set(left.measurement_names) - set(common_measurements) ] right_measurements = [ right.get_measurement(measurement_name=m) for m in set(right.measurement_names) - set(common_measurements) ] merged_measurements = [ dict( name=m, config=dict( poi=left.get_measurement(measurement_name=m)['config']['poi'], parameters=( left.get_measurement(measurement_name=m)['config']['parameters'] + right.get_measurement(measurement_name=m)['config'][ 'parameters' ] ), ), ) for m in common_measurements ] newspec = { 'channels': left['channels'] + right['channels'], 'measurements': ( left_measurements + right_measurements + merged_measurements ), 'observations': left['observations'] + right['observations'], 'version': left['version'], } return Workspace(newspec)