Source code for nnsa.feature_extraction.feature_sets.base

import warnings

import h5py
import numpy as np

from nnsa.feature_extraction.result import ResultBase
from nnsa.utils import check_directory_exists
from nnsa.utils.objects import convert_to_nnsa_class_callable
from nnsa.utils.other import enumerate_label


[docs]class FeatureSetResult(ResultBase): """ High-level interface for processing features sets with multiple features, channels, segments. Args: features (np.ndarray): 3D array with feature values per channel per segment (segments, channels, features). feature_labels (list): list with labels of the features, corresponding to the rows of `features`. algorithm_parameters (nnsa.Parameters): see ResultBase. channel_labels (list of str, optional): labels of the channels corresponding to the second axis in `features`. If None, default labels will be created. Default is None. data_info (str, optional): see ResultBase. segment_start_times (np.ndarray, optional): see ResultBase. segment_end_times (np.ndarray, optional): see ResultBase. fs (flaot, optional): see ResultBase. """ def __init__(self, features, feature_labels, algorithm_parameters, channel_labels=None, data_info=None, segment_start_times=None, segment_end_times=None, fs=None): super().__init__(algorithm_parameters=algorithm_parameters, data_info=data_info, segment_start_times=segment_start_times, segment_end_times=segment_end_times, fs=fs) # Input check. data_shape = features.shape if len(data_shape) == 2: # Assume 1 channel, convert to shape (n_segments, 1, n_features). features = features[:, np.newaxis, :] data_shape = features.shape if len(feature_labels) != data_shape[2]: raise ValueError('Length of feature_labels ({}) does not correspond to the shape of the data {}.' .format(len(feature_labels), data_shape)) if channel_labels is None: channel_labels = enumerate_label(data_shape[1], label='Channel') elif len(channel_labels) != data_shape[1]: raise ValueError('Length of channel_labels ({}) does not correspond to the shape of the data {}.' .format(len(channel_labels), data_shape)) # Store variables that are not already stored by the parent class (ResultBase). self.features = features self.feature_labels = feature_labels self.channel_labels = channel_labels @property def num_segments(self): """ Return the number of segments. Returns: (int): number of segments. """ return self.features.shape[0] def _extract_epoch(self, mask): """ Extracts the segments for which `mask` is True (inplace). Does not return anything. Notes: Do not merge the time arrays, this happens in self.extract_epoch(). """ self.features = self.features[mask, :, :] def _get_feature_indices(self, feature_labels): """ Return the indices of the features with feature_labels in the features array. Args: feature_labels (list or str): list with feature labels or one feature label. Returns: indices (list): list with indices of the feature labels in the features array. """ if isinstance(feature_labels, str): feature_labels = [feature_labels] indices = [] for label in feature_labels: index = self.feature_labels.index(label) indices.append(index) return indices def _merge(self, other, index=None): """ See ResultBase. """ # Check if the feature labels of self and other are the same. if self.feature_labels != other.feature_labels: raise ValueError('Cannot merge results with different feature labels.') # Check if the channel labels of self and other are the same. if self.channel_labels != other.channel_labels: raise ValueError('Cannot merge results with different channel labels.') if index is not None: n_segments, n_channels, n_features = self.features.shape if index < n_segments: # Cut piece off. msg = 'Overwriting data while merging.' warnings.warn(msg) self.features = self.features[:index, :, :] else: # Add nans. self.features = np.concatenate([ self.features, np.full((index-n_segments, n_channels, n_features), fill_value=np.nan) ], axis=-1) self.features = np.concatenate((self.features, other.features), axis=0) @staticmethod def _read_from_csv(filepath): raise NotImplementedError('Deprecated') @staticmethod def _read_from_hdf5(filepath): """ Read result from hdf5 file into a KirubinFeaturesResult class. Args: filepath (str): see ResultBase._read_from_hdf5(). Returns: result (nnsa.FeatureSetResult): instance of FeatureSetResult containing the result. """ # Read standard hdf5 header (use the ResultBase method). class_name, algorithm_parameters, data_info, segment_start_times, segment_end_times, fs, time_offset =\ ResultBase._read_hdf5_header(filepath) # Re-open the file and read the rest of the file. with h5py.File(filepath, 'r') as f: # Read array data. features = f['features'][:] # Read non-array data. feature_labels = [label.decode() for label in f['features'].attrs['feature_labels']] if 'channel_labels' in f['features'].attrs: channel_labels = [label.decode() for label in f['features'].attrs['channel_labels']] else: channel_labels = None # Create a result object. result_class = convert_to_nnsa_class_callable(class_name) result = result_class(features=features, feature_labels=feature_labels, algorithm_parameters=algorithm_parameters, channel_labels=channel_labels, data_info=data_info, segment_start_times=segment_start_times, segment_end_times=segment_end_times, fs=fs) return result def _write_to_csv(self, filepath): raise NotImplementedError('Deprecated') def _write_to_hdf5(self, filepath): """ Write the contents of the object to an hdf5 file. Args: filepath (str): see ResultBase._write_to_hdf5(). """ check_directory_exists(filepath=filepath) # Write standard hdf5 header (use the ResultBase method). self._write_hdf5_header(filepath) # Append attributes to the hdf5 file. with h5py.File(filepath, 'a') as f: # Write array data. f.create_dataset('features', data=self.features) # Write non-array data as attributes to the 'features' group. # Convert strings to np.string_ type as recommended for compatibility. f['features'].attrs['feature_labels'] = [np.string_(label) for label in self.feature_labels] f['features'].attrs['channel_labels'] = [np.string_(label) for label in self.channel_labels]