Source code for trinity.explorer.proxy.recorder
from typing import List
from sqlalchemy.orm import sessionmaker
from trinity.buffer.schema import init_engine
from trinity.buffer.utils import retry_session
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
# TODO: implement an async version in the future
[docs]
class HistoryRecorder:
"""Record chat history into the database."""
[docs]
def __init__(self, db_url: str, table_name: str):
self.logger = get_logger()
self.engine, self.table_model_cls = init_engine(
db_url=db_url,
table_name=table_name,
schema_type="experience",
)
self.logger.info(f"Init SQL storage at {db_url}")
self.session = sessionmaker(bind=self.engine)
[docs]
def record_history(self, experiences: List[Experience]) -> None:
"""Save experience to the database."""
with retry_session(self.session) as db:
exps = [self.table_model_cls.from_experience(exp) for exp in experiences]
db.add_all(exps)
[docs]
def update_reward(
self, reward: float, msg_ids: list, run_id: int, task_id: str
) -> List[Experience]:
"""Update reward for given response IDs and return the updated experiences.
Args:
reward (float): The reward value to be updated.
msg_ids (list): List of message IDs to update.
run_id (int): The run ID associated with the experiences.
task_id (str): The task ID associated with the experiences.
Returns:
List[Experience]: List of updated experiences.
Note:
Only experiences that have not been consumed (consumed == 0) will be returned.
For example, if you call this method multiple times with the same msg_ids, only
the first call will return the updated experiences; subsequent calls will return
an empty list.
"""
with retry_session(self.session) as db:
# Lock and retrieve records that have not been consumed yet.
records = (
db.query(self.table_model_cls)
.filter(
self.table_model_cls.msg_id.in_(msg_ids),
self.table_model_cls.consumed == 0,
)
.with_for_update()
.all()
)
if not records:
return []
# Update records in memory
for record in records:
record.reward = reward
record.run_id = run_id
record.task_id = task_id
record.consumed += 1
# The session commit is handled by the `retry_session` context manager.
updated_experiences = [record.to_experience() for record in records]
return updated_experiences