from abc import ABC, abstractmethod
from typing import Dict, Tuple
import torch
from trinity.algorithm.utils import masked_mean
from trinity.utils.registry import Registry
ENTROPY_LOSS_FN = Registry("entropy_loss_fn")
[docs]
class EntropyLossFn(ABC):
"""
Entropy loss function.
"""
@abstractmethod
def __call__(
self,
entropy: torch.Tensor,
action_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""
Args:
entropy (`torch.Tensor`): The entropy generated by the policy model.
action_mask (`torch.Tensor`): The action mask.
Returns:
`torch.Tensor`: The calculated entropy loss.
`Dict`: The metrics for logging
"""
[docs]
@classmethod
def default_args(cls) -> Dict:
"""
Returns:
`Dict`: The default arguments for the entropy loss function.
"""
return {"entropy_coef": 0.0}
[docs]
@ENTROPY_LOSS_FN.register_module("default")
class DefaultEntropyLossFn(EntropyLossFn):
"""
Basic entropy loss function.
"""
[docs]
def __init__(self, entropy_coef: float):
self.entropy_coef = entropy_coef
def __call__(
self,
entropy: torch.Tensor,
action_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
entropy_loss = masked_mean(entropy, action_mask)
return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()}
[docs]
@ENTROPY_LOSS_FN.register_module("none")
class DummyEntropyLossFn(EntropyLossFn):
"""
Dummy entropy loss function.
"""
[docs]
def __init__(self, entropy_coef: float):
self.entropy_coef = entropy_coef
def __call__(
self,
entropy: torch.Tensor,
action_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
return torch.tensor(0.0), {}