trinity.trainer package#

Subpackages#

Submodules#

Module contents#

class trinity.trainer.Trainer(config: Config)[源代码]#

基类:object

Consume the experience and train the model.

__init__(config: Config) None[源代码]#
classmethod get_actor(config: Config)[源代码]#

Get a Ray actor for the trainer.

async is_alive() bool[源代码]#

Check if the trainer is alive.

need_save() bool[源代码]#

Whether to save the checkpoint.

async need_sync() bool[源代码]#

Whether to sync the model weight.

async prepare() None[源代码]#

Prepare the trainer.

save_checkpoint(block_until_saved: bool = False, save_as_hf: bool = False) Dict[源代码]#
async shutdown() None[源代码]#
async sync_weight() Dict[源代码]#

Sync the model weight.

async train() str[源代码]#

Train the model.

async train_step(exps: List[Experience]) Dict[源代码]#

Train one step.

返回:

Whether to continue training. Dict: Metrics of the training step.

返回类型:

bool

property train_step_num: int#

Get the current training step number.

class trinity.trainer.TrainEngineWrapper[源代码]#

基类:ABC

A wrapper class to wrap various training engines.

abstractmethod async prepare() None[源代码]#

Do some preparation before training started.

abstractmethod save_checkpoint(block_until_saved: bool = False, save_as_hf: bool = False) None[源代码]#

Save the checkpoint.

abstractmethod save_state_dict() None[源代码]#

Only save the model state dict for Synchronizer.

abstractmethod sync_weight() None[源代码]#

Sync the model weight.

abstractmethod async train_step(batch_exps: List[Experience]) Dict[源代码]#

Training one step.

参数:

batch_exps (List[Experience]) -- A batch of experiences to train.

返回:

Metrics of the training step.

返回类型:

Dict

abstract property train_step_num: int#

Get the current training step number.

abstractmethod upload_state_dict() None[源代码]#

Upload the state dict to Synchronizer.

trinity.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[源代码]#

Get a trainer wrapper.