trinity.trainer.trainer module#

Trainer Class

class trinity.trainer.trainer.Trainer(config: Config)[source]#

Bases: object

Consume the experience and train the model.

__init__(config: Config) None[source]#
prepare() None[source]#

Prepare the trainer.

async train() str[source]#

Train the model.

async train_step() bool[source]#

Train one step.

Returns:

Whether to continue training.

Return type:

bool

need_sync() bool[source]#

Whether to sync the model weight.

sync_weight() None[source]#

Sync the model weight.

save_checkpoint(block_until_saved: bool = False) None[source]#
async shutdown() None[source]#
property train_step_num: int#

Get the current training step number.

is_alive() bool[source]#

Check if the trainer is alive.

classmethod get_actor(config: Config)[source]#

Get a Ray actor for the trainer.

class trinity.trainer.trainer.TrainEngineWrapper[source]#

Bases: ABC

A wrapper class to wrap various training engines.

abstract prepare() None[source]#

Do some preparation before training started.

abstract property train_step_num: int#

Get the current training step number.

abstract train_step(batch: Experiences) Tuple[bool, Dict][source]#

Training one step.

Parameters:

batch (Experiences) – A batch of experiences to train.

Returns:

Whether to continue training. Dict: Metrics of the training step.

Return type:

bool

abstract save_checkpoint(block_until_saved: bool = False) None[source]#

Save the checkpoint.

abstract sync_weight() None[source]#

Sync the model weight.

abstract upload_state_dict() None[source]#

Upload the state dict to Synchronizer.

abstract save_state_dict() None[source]#

Only save the model state dict for Synchronizer.

trinity.trainer.trainer.get_trainer_wrapper(config: Config) TrainEngineWrapper[source]#

Get a trainer wrapper.