import warnings
import h5py
import numpy as np
from nnsa.feature_extraction.result import ResultBase
from nnsa.utils import check_directory_exists
from nnsa.utils.objects import convert_to_nnsa_class_callable
from nnsa.utils.other import enumerate_label
[docs]class FeatureSetResult(ResultBase):
"""
High-level interface for processing features sets with multiple features, channels, segments.
Args:
features (np.ndarray): 3D array with feature values per channel per segment (segments, channels, features).
feature_labels (list): list with labels of the features, corresponding to the rows of `features`.
algorithm_parameters (nnsa.Parameters): see ResultBase.
channel_labels (list of str, optional): labels of the channels corresponding to the second axis in `features`.
If None, default labels will be created.
Default is None.
data_info (str, optional): see ResultBase.
segment_start_times (np.ndarray, optional): see ResultBase.
segment_end_times (np.ndarray, optional): see ResultBase.
fs (flaot, optional): see ResultBase.
"""
def __init__(self, features, feature_labels,
algorithm_parameters, channel_labels=None, data_info=None,
segment_start_times=None, segment_end_times=None, fs=None):
super().__init__(algorithm_parameters=algorithm_parameters, data_info=data_info,
segment_start_times=segment_start_times, segment_end_times=segment_end_times, fs=fs)
# Input check.
data_shape = features.shape
if len(data_shape) == 2:
# Assume 1 channel, convert to shape (n_segments, 1, n_features).
features = features[:, np.newaxis, :]
data_shape = features.shape
if len(feature_labels) != data_shape[2]:
raise ValueError('Length of feature_labels ({}) does not correspond to the shape of the data {}.'
.format(len(feature_labels), data_shape))
if channel_labels is None:
channel_labels = enumerate_label(data_shape[1], label='Channel')
elif len(channel_labels) != data_shape[1]:
raise ValueError('Length of channel_labels ({}) does not correspond to the shape of the data {}.'
.format(len(channel_labels), data_shape))
# Store variables that are not already stored by the parent class (ResultBase).
self.features = features
self.feature_labels = feature_labels
self.channel_labels = channel_labels
@property
def num_segments(self):
"""
Return the number of segments.
Returns:
(int): number of segments.
"""
return self.features.shape[0]
def _extract_epoch(self, mask):
"""
Extracts the segments for which `mask` is True (inplace). Does not return anything.
Notes:
Do not merge the time arrays, this happens in self.extract_epoch().
"""
self.features = self.features[mask, :, :]
def _get_feature_indices(self, feature_labels):
"""
Return the indices of the features with feature_labels in the features array.
Args:
feature_labels (list or str): list with feature labels or one feature label.
Returns:
indices (list): list with indices of the feature labels in the features array.
"""
if isinstance(feature_labels, str):
feature_labels = [feature_labels]
indices = []
for label in feature_labels:
index = self.feature_labels.index(label)
indices.append(index)
return indices
def _merge(self, other, index=None):
"""
See ResultBase.
"""
# Check if the feature labels of self and other are the same.
if self.feature_labels != other.feature_labels:
raise ValueError('Cannot merge results with different feature labels.')
# Check if the channel labels of self and other are the same.
if self.channel_labels != other.channel_labels:
raise ValueError('Cannot merge results with different channel labels.')
if index is not None:
n_segments, n_channels, n_features = self.features.shape
if index < n_segments:
# Cut piece off.
msg = 'Overwriting data while merging.'
warnings.warn(msg)
self.features = self.features[:index, :, :]
else:
# Add nans.
self.features = np.concatenate([
self.features,
np.full((index-n_segments, n_channels, n_features), fill_value=np.nan)
], axis=-1)
self.features = np.concatenate((self.features, other.features), axis=0)
@staticmethod
def _read_from_csv(filepath):
raise NotImplementedError('Deprecated')
@staticmethod
def _read_from_hdf5(filepath):
"""
Read result from hdf5 file into a KirubinFeaturesResult class.
Args:
filepath (str): see ResultBase._read_from_hdf5().
Returns:
result (nnsa.FeatureSetResult): instance of FeatureSetResult containing the result.
"""
# Read standard hdf5 header (use the ResultBase method).
class_name, algorithm_parameters, data_info, segment_start_times, segment_end_times, fs, time_offset =\
ResultBase._read_hdf5_header(filepath)
# Re-open the file and read the rest of the file.
with h5py.File(filepath, 'r') as f:
# Read array data.
features = f['features'][:]
# Read non-array data.
feature_labels = [label.decode() for label in f['features'].attrs['feature_labels']]
if 'channel_labels' in f['features'].attrs:
channel_labels = [label.decode() for label in f['features'].attrs['channel_labels']]
else:
channel_labels = None
# Create a result object.
result_class = convert_to_nnsa_class_callable(class_name)
result = result_class(features=features, feature_labels=feature_labels,
algorithm_parameters=algorithm_parameters,
channel_labels=channel_labels,
data_info=data_info,
segment_start_times=segment_start_times,
segment_end_times=segment_end_times,
fs=fs)
return result
def _write_to_csv(self, filepath):
raise NotImplementedError('Deprecated')
def _write_to_hdf5(self, filepath):
"""
Write the contents of the object to an hdf5 file.
Args:
filepath (str): see ResultBase._write_to_hdf5().
"""
check_directory_exists(filepath=filepath)
# Write standard hdf5 header (use the ResultBase method).
self._write_hdf5_header(filepath)
# Append attributes to the hdf5 file.
with h5py.File(filepath, 'a') as f:
# Write array data.
f.create_dataset('features', data=self.features)
# Write non-array data as attributes to the 'features' group.
# Convert strings to np.string_ type as recommended for compatibility.
f['features'].attrs['feature_labels'] = [np.string_(label) for label in self.feature_labels]
f['features'].attrs['channel_labels'] = [np.string_(label) for label in self.channel_labels]