Source code for trinity.common.schema

# -*- coding: utf-8 -*-
"""Schema for different types of data."""
from typing import Any, Optional, Type

from sqlalchemy import JSON, Column, DateTime, Float, Integer, LargeBinary, String, Text
from sqlalchemy.ext.declarative import declarative_base

from trinity.common.experience import Experience
from trinity.common.models.utils import tokenize_and_mask_messages_hf

Base: Type = declarative_base()

# TODO: create db engine and all tables in a factory class


[docs] class RftDatasetModel(Base): """SQLAlchemy model for RftDataset.""" __tablename__ = "rft_dataset" # lineage id = Column(Integer, primary_key=True, autoincrement=True) consumed_cnt = Column(Integer, default=0) last_modified_date = Column(DateTime, nullable=True) from_id = Column(Integer, nullable=True) from_model = Column(Text, nullable=True) from_recipe = Column(Text, nullable=True) # content prompt = Column(Text, nullable=True) response = Column(Text, nullable=True) solution = Column(Text, nullable=True) reward = Column(Float, nullable=True) chosen = Column(Text, nullable=True) rejected = Column(Text, nullable=True) label = Column(Text, nullable=True) # extra info quality_score = Column(Float, default=0.0) quality_score_detail = Column(JSON, nullable=True) difficulty_score = Column(Float, default=0.0) difficulty_score_detail = Column(JSON, nullable=True) diversity_score = Column(Float, default=0.0) diversity_score_detail = Column(JSON, nullable=True) priority = Column(Float, default=0.0) # downstream reward_fn = Column(Text, nullable=True) workflow = Column(Text, nullable=True)
[docs] def to_dict(self) -> dict: return {key: val for key, val in self.__dict__.items() if not key.startswith("_")}
[docs] class TaskModel(Base): """SQLAlchemy model for Task.""" # TODO: Add more fields __tablename__ = "task_buffer" id = Column(Integer, primary_key=True, autoincrement=True) task_desc = Column(String, nullable=True) workflow_type = Column(String, nullable=True) reward_type = Column(String, nullable=True)
[docs] class ExperienceModel(Base): """SQLAlchemy model for Experience.""" __tablename__ = "experience_buffer" id = Column(Integer, primary_key=True, autoincrement=True) serialized_exp = Column(LargeBinary, nullable=True) prompt = Column(String, nullable=True) response = Column(String, nullable=True) reward = Column(Float, nullable=True) consumed = Column(Integer, default=0) priority = Column(Float, default=0.0)
[docs] def to_experience(self) -> Experience: """Load the experience from the database.""" return Experience.deserialize(self.serialized_exp)
[docs] @staticmethod def from_experience(experience: Experience): """Save the experience to database.""" return ExperienceModel( serialized_exp=experience.serialize(), reward=experience.reward, prompt=experience.prompt_text, response=experience.response_text, )
[docs] class SFTDataModel(Base): """SQLAlchemy model for SFT data.""" __tablename__ = "sft_data_buffer" id = Column(Integer, primary_key=True, autoincrement=True) serialized_exp = Column(LargeBinary, nullable=True) messages = Column(String, nullable=True) consumed = Column(Integer, default=0)
[docs] def to_experience(self) -> Experience: """Load the experience from the database.""" return Experience.deserialize(self.serialized_exp)
[docs] @classmethod def from_messages( cls, messages: list[dict], tokenizer: Any, chat_template: Optional[str] = None, ) -> "SFTDataModel": """Convert a list of messages into a single instance of SFT data.""" token_ids, action_mask = tokenize_and_mask_messages_hf( tokenizer=tokenizer, messages=messages, chat_template=chat_template, ) exp = Experience( tokens=token_ids, prompt_length=0, action_mask=action_mask, info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])}, ) return cls( serialized_exp=exp.serialize(), messages=messages, )
[docs] class DPODataModel(Base): """SQLAlchemy model for DPO data.""" __tablename__ = "dpo_data_buffer" id = Column(Integer, primary_key=True, autoincrement=True) serialized_exp = Column(LargeBinary, nullable=True) chosen = Column(LargeBinary, nullable=True) rejected = Column(LargeBinary, nullable=True) consumed = Column(Integer, default=0)
[docs] def to_experience(self) -> Experience: """Load the experience from the database.""" exp = Experience.deserialize(self.serialized_exp) exp.chosen = Experience.deserialize(self.chosen) exp.rejected = Experience.deserialize(self.rejected) return exp