"""
This module contains custom Keras callbacks that might be used for training.
"""
import numpy as np
from tensorflow import keras
from tensorflow.keras import backend as K
[docs]class BetaUpdateCallback(keras.callbacks.Callback):
"""
Callback to update a non trainable parameter `beta` after each epoch,
increasing from initial to final value in a predefined number of steps.
Useful for e.g. applying a warmup strategy, increasing the loss of a regularizer as training progresses.
Args:
initial_value (float): initial value of beta.
final_value (float): final value of beta.
steps (int): number of epochs in which to increase beta from initial to final value.
verbose (int): verbose level (if 1, prints the updates to beta).
Example:
>>> beta_callback = BetaUpdateCallback(initial_value=0, final_value=1/100, steps=5)
>>> loss = loss_term_1 + beta_callback.beta * loss_term_2
"""
def __init__(self, wait=0, initial_value=0.0, final_value=1.0, steps=5, verbose=0):
super().__init__()
self.initial_value = float(initial_value)
self.final_value = float(final_value)
self.steps = int(steps)
self.wait = int(wait)
self.verbose = verbose
self._beta = K.variable(value=initial_value)
self._beta._trainable = False
@property
def beta(self):
return self._beta
@beta.setter
def beta(self, new_value):
# Set beta by setting the value in the TF variable.
K.set_value(self._beta, new_value)
[docs] def on_epoch_begin(self, epoch, logs=None):
if self.verbose:
print('On begin epoch {}'.format(epoch))
print('Beta before update = {}'.format(self.beta))
# Determine beta based on epoch number.
f_final = np.clip(((epoch - self.wait) / self.steps), 0, 1)
new_value = f_final*self.final_value + (1 - f_final)*self.initial_value
self.beta = new_value
if self.verbose:
print('Beta after update = {}'.format(self.beta))
[docs]class LambdaBetaUpdateCallback(keras.callbacks.Callback):
"""
Callback to update a non trainable parameter `beta` at the end of each epoch,
by executing a function: on_epoch_end(epoch, logs).
Args:
fun (function): a function that takes two inputs: epoch and logs (see on_epoch_begin) and returns a float
(the new value for beta).
initial_value (float): initial value for beta.
verbose (int): verbose level (if 1, prints the updates to lambda).
Example:
>>> lambda_callback = LambdaBetaUpdateCallback(initial_value=0, fun=lambda epoch, logs: epoch)
"""
def __init__(self, fun, initial_value=0.0, verbose=0):
super().__init__()
self.fun = fun
self.verbose = verbose
self._beta = K.variable(value=float(initial_value))
self._beta._trainable = False
@property
def beta(self):
return self._beta
@beta.setter
def beta(self, new_value):
# Set beta by setting the value in the TF variable.
K.set_value(self._beta, new_value)
[docs] def on_epoch_end(self, epoch, logs=None):
if self.verbose:
print('On end epoch {}'.format(epoch))
print('Logs:', logs)
print('Beta before update = {}'.format(self.beta))
# Determine beta based on function.
new_value = self.fun(epoch, logs)
self.beta = new_value
if self.verbose:
print('Beta after update = {}'.format(self.beta))
[docs]class StopOnNanLoss(keras.callbacks.Callback):
"""
Callback that checks the loss after each epoch stops training if the loss is nan.
Args:
which (str or list): which loss to check for nan, e.g. 'loss', or ['loss, 'val_loss'].
Defaults to ['loss', 'val_loss'].
raise_error (bool): if True, raises an error when loss is nan. If False, stops training without raising error.
"""
def __init__(self, which=None, raise_error=False):
super().__init__()
if which is None:
which = ['loss', 'val_loss']
elif isinstance(which, str):
which = [which]
self.which = which
self.raise_error = raise_error
self.stopped_epoch = -1
self.stopped_batch = -1
[docs] def on_train_begin(self, logs=None):
# The epoch the training stops at.
self.stopped_epoch = -1
self.stopped_batch = -1
[docs] def on_train_batch_end(self, batch, logs=None):
for which in self.which:
loss = logs.get(which, 0)
if np.isnan(loss) or np.isinf(loss):
self.stopped_batch = batch
self.model.stop_training = True
[docs] def on_epoch_end(self, epoch, logs=None):
for which in self.which:
loss = logs.get(which, 0)
if np.isnan(loss) or np.isinf(loss):
self.stopped_epoch = epoch
self.model.stop_training = True
[docs] def on_train_end(self, logs=None):
msg = "{} epoch {}: The loss is {}. Stopped training.".format(
self.__class__.__name__, self.stopped_epoch, logs)
if self.stopped_epoch >= 0:
if self.raise_error:
raise ValueError(msg)
else:
print(msg)
[docs]class MyEarlyStopping(keras.callbacks.EarlyStopping):
"""
Early stopping with extra parameter to start monitoring from `min_epochs` epochs.
"""
def __init__(self, *args, min_epochs=-1, **kwargs):
super(MyEarlyStopping, self).__init__(*args, **kwargs)
self.min_epochs = min_epochs
[docs] def on_epoch_end(self, epoch, logs=None):
if epoch > self.min_epochs:
super().on_epoch_end(epoch, logs)