Source code for nnsa.keras.callbacks

"""
This module contains custom Keras callbacks that might be used for training.
"""
import numpy as np
from tensorflow import keras
from tensorflow.keras import backend as K


[docs]class BetaUpdateCallback(keras.callbacks.Callback): """ Callback to update a non trainable parameter `beta` after each epoch, increasing from initial to final value in a predefined number of steps. Useful for e.g. applying a warmup strategy, increasing the loss of a regularizer as training progresses. Args: initial_value (float): initial value of beta. final_value (float): final value of beta. steps (int): number of epochs in which to increase beta from initial to final value. verbose (int): verbose level (if 1, prints the updates to beta). Example: >>> beta_callback = BetaUpdateCallback(initial_value=0, final_value=1/100, steps=5) >>> loss = loss_term_1 + beta_callback.beta * loss_term_2 """ def __init__(self, wait=0, initial_value=0.0, final_value=1.0, steps=5, verbose=0): super().__init__() self.initial_value = float(initial_value) self.final_value = float(final_value) self.steps = int(steps) self.wait = int(wait) self.verbose = verbose self._beta = K.variable(value=initial_value) self._beta._trainable = False @property def beta(self): return self._beta @beta.setter def beta(self, new_value): # Set beta by setting the value in the TF variable. K.set_value(self._beta, new_value)
[docs] def on_epoch_begin(self, epoch, logs=None): if self.verbose: print('On begin epoch {}'.format(epoch)) print('Beta before update = {}'.format(self.beta)) # Determine beta based on epoch number. f_final = np.clip(((epoch - self.wait) / self.steps), 0, 1) new_value = f_final*self.final_value + (1 - f_final)*self.initial_value self.beta = new_value if self.verbose: print('Beta after update = {}'.format(self.beta))
[docs]class LambdaBetaUpdateCallback(keras.callbacks.Callback): """ Callback to update a non trainable parameter `beta` at the end of each epoch, by executing a function: on_epoch_end(epoch, logs). Args: fun (function): a function that takes two inputs: epoch and logs (see on_epoch_begin) and returns a float (the new value for beta). initial_value (float): initial value for beta. verbose (int): verbose level (if 1, prints the updates to lambda). Example: >>> lambda_callback = LambdaBetaUpdateCallback(initial_value=0, fun=lambda epoch, logs: epoch) """ def __init__(self, fun, initial_value=0.0, verbose=0): super().__init__() self.fun = fun self.verbose = verbose self._beta = K.variable(value=float(initial_value)) self._beta._trainable = False @property def beta(self): return self._beta @beta.setter def beta(self, new_value): # Set beta by setting the value in the TF variable. K.set_value(self._beta, new_value)
[docs] def on_epoch_end(self, epoch, logs=None): if self.verbose: print('On end epoch {}'.format(epoch)) print('Logs:', logs) print('Beta before update = {}'.format(self.beta)) # Determine beta based on function. new_value = self.fun(epoch, logs) self.beta = new_value if self.verbose: print('Beta after update = {}'.format(self.beta))
[docs]class StopOnNanLoss(keras.callbacks.Callback): """ Callback that checks the loss after each epoch stops training if the loss is nan. Args: which (str or list): which loss to check for nan, e.g. 'loss', or ['loss, 'val_loss']. Defaults to ['loss', 'val_loss']. raise_error (bool): if True, raises an error when loss is nan. If False, stops training without raising error. """ def __init__(self, which=None, raise_error=False): super().__init__() if which is None: which = ['loss', 'val_loss'] elif isinstance(which, str): which = [which] self.which = which self.raise_error = raise_error self.stopped_epoch = -1 self.stopped_batch = -1
[docs] def on_train_begin(self, logs=None): # The epoch the training stops at. self.stopped_epoch = -1 self.stopped_batch = -1
[docs] def on_train_batch_end(self, batch, logs=None): for which in self.which: loss = logs.get(which, 0) if np.isnan(loss) or np.isinf(loss): self.stopped_batch = batch self.model.stop_training = True
[docs] def on_epoch_end(self, epoch, logs=None): for which in self.which: loss = logs.get(which, 0) if np.isnan(loss) or np.isinf(loss): self.stopped_epoch = epoch self.model.stop_training = True
[docs] def on_train_end(self, logs=None): msg = "{} epoch {}: The loss is {}. Stopped training.".format( self.__class__.__name__, self.stopped_epoch, logs) if self.stopped_epoch >= 0: if self.raise_error: raise ValueError(msg) else: print(msg)
[docs]class MyEarlyStopping(keras.callbacks.EarlyStopping): """ Early stopping with extra parameter to start monitoring from `min_epochs` epochs. """ def __init__(self, *args, min_epochs=-1, **kwargs): super(MyEarlyStopping, self).__init__(*args, **kwargs) self.min_epochs = min_epochs
[docs] def on_epoch_end(self, epoch, logs=None): if epoch > self.min_epochs: super().on_epoch_end(epoch, logs)