Source code for trinity.algorithm.entropy_loss_fn.entropy_loss_fn

from abc import ABC, abstractmethod
from typing import Dict, Tuple

import torch

from trinity.algorithm.utils import masked_mean
from trinity.utils.registry import Registry

ENTROPY_LOSS_FN = Registry("entropy_loss_fn")


[docs] class EntropyLossFn(ABC): """ Entropy loss function. """ @abstractmethod def __call__( self, entropy: torch.Tensor, action_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: """ Args: entropy (`torch.Tensor`): The entropy generated by the policy model. action_mask (`torch.Tensor`): The action mask. Returns: `torch.Tensor`: The calculated entropy loss. `Dict`: The metrics for logging """
[docs] @classmethod def default_args(cls) -> Dict: """ Returns: `Dict`: The default arguments for the entropy loss function. """ return {"entropy_coef": 0.0}
[docs] @ENTROPY_LOSS_FN.register_module("default") class DefaultEntropyLossFn(EntropyLossFn): """ Basic entropy loss function. """
[docs] def __init__(self, entropy_coef: float): self.entropy_coef = entropy_coef
def __call__( self, entropy: torch.Tensor, action_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: entropy_loss = masked_mean(entropy, action_mask) return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}
[docs] @ENTROPY_LOSS_FN.register_module("none") class DummyEntropyLossFn(EntropyLossFn): """ Dummy entropy loss function. """
[docs] def __init__(self, entropy_coef: float): self.entropy_coef = entropy_coef
def __call__( self, entropy: torch.Tensor, action_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: return torch.tensor(0.0), {}