Source code for nnsa.models.losses

"""
Module containing loss functions for varying (parts of) models.
"""

import numpy as np

from nnsa.utils.mathematics import abs_der

__all__ = [
    'regularization_loss',
]


def regularization_der(theta, regulizer, lamb):
    """
    Derivative of te regularization loss with repect to theta.

    Args:
        theta (np.ndarray): parameter vector.
        regulizer (str): type of regulizer. Choose from
            'L1' (or 'lasso'): loss = lamb * sum(abs(theta))
            'L2' (or 'ridge'): loss = lamb * 1/2*sum(theta**2)
        lamb (float): lambda parameter (scale factor for the regularization loss).

    Returns:
        der (np.ndarray): vector with partial derivatives of the regularization loss to the element in `theta`.
    """
    if regulizer.lower() == 'l1':
        der = lamb * abs_der(theta)
    elif regulizer.lower() == 'l2':
        der = lamb * theta
    else:
        raise ValueError('Invalid regulizer="{}". Choose from {}.'
                         .format(regulizer, ['L1', 'L2']))
    return der


[docs]def regularization_loss(theta, regulizer, lamb): """ Return the regularization loss of the parameter vector `theta`. Args: theta (np.ndarray): parameter vector. regulizer (str): type of regulizer. Choose from 'L1' (or 'lasso'): loss = lamb * sum(abs(theta)) 'L2' (or 'ridge'): loss = lamb * 1/2*sum(theta**2) lamb (float): lambda parameter (scale factor for the regularization loss). Returns: reg_loss (float): regularization loss for the given parameter vector `theta`. """ if regulizer.lower() in ['l1', 'lasso']: reg_loss = lamb * np.sum(np.abs(theta)) elif regulizer.lower() in ['l2', 'ridge']: reg_loss = lamb / 2 * np.dot(theta, theta) else: raise ValueError('Invalid regulizer="{}". Choose from {}.' .format(regulizer, ['L1', 'L2'])) return reg_loss