"""Schema for SQLAlchemy models."""
from typing import Any, Optional, Union
from sqlalchemy import Column, Float, Integer, LargeBinary, String
from sqlalchemy.ext.declarative import declarative_base
from trinity.common.constants import AlgorithmType
from trinity.common.experience import Experience
from trinity.common.models.utils import tokenize_and_mask_messages_hf
Base = declarative_base()
[docs]
class TaskModel(Base): # type: ignore
"""Model for storing tasks in SQLAlchemy."""
__abstract__ = True
__table_args__ = {
"keep_existing": True,
}
id = Column(Integer, primary_key=True, autoincrement=True)
task_desc = Column(String, nullable=True)
workflow_type = Column(String, nullable=True)
reward_type = Column(String, nullable=True)
[docs]
class ExperienceModel(Base): # type: ignore
"""SQLAlchemy model for Experience."""
__abstract__ = True
__table_args__ = {
"keep_existing": True,
}
id = Column(Integer, primary_key=True, autoincrement=True)
serialized_exp = Column(LargeBinary, nullable=True)
prompt = Column(String, nullable=True)
response = Column(String, nullable=True)
reward = Column(Float, nullable=True)
consumed = Column(Integer, default=0)
priority = Column(Float, default=0.0)
[docs]
def to_experience(self) -> Experience:
"""Load the experience from the database."""
return Experience.deserialize(self.serialized_exp)
[docs]
@classmethod
def from_experience(cls, experience: Experience):
"""Save the experience to database."""
return cls(
serialized_exp=experience.serialize(),
reward=experience.reward,
prompt=experience.prompt_text,
response=experience.response_text,
)
[docs]
class SFTDataModel(Base): # type: ignore
"""SQLAlchemy model for SFT data."""
__abstract__ = True
__table_args__ = {
"keep_existing": True,
}
id = Column(Integer, primary_key=True, autoincrement=True)
serialized_exp = Column(LargeBinary, nullable=True)
messages = Column(String, nullable=True)
consumed = Column(Integer, default=0)
[docs]
def to_experience(self) -> Experience:
"""Load the experience from the database."""
return Experience.deserialize(self.serialized_exp)
[docs]
@classmethod
def from_messages(
cls,
messages: list[dict],
tokenizer: Any,
chat_template: Optional[str] = None,
) -> "SFTDataModel":
"""Convert a list of messages into a single instance of SFT data."""
token_ids, action_mask = tokenize_and_mask_messages_hf(
tokenizer=tokenizer,
messages=messages,
chat_template=chat_template,
)
exp = Experience(
tokens=token_ids,
prompt_length=0,
action_mask=action_mask,
info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])},
)
return cls(
serialized_exp=exp.serialize(),
messages=messages,
)
[docs]
class DPODataModel(Base): # type: ignore
"""SQLAlchemy model for DPO data."""
__abstract__ = True
__table_args__ = {
"keep_existing": True,
}
id = Column(Integer, primary_key=True, autoincrement=True)
serialized_exp = Column(LargeBinary, nullable=True)
chosen = Column(LargeBinary, nullable=True)
rejected = Column(LargeBinary, nullable=True)
consumed = Column(Integer, default=0)
[docs]
def to_experience(self) -> Experience:
"""Load the experience from the database."""
exp = Experience.deserialize(self.serialized_exp)
exp.chosen = Experience.deserialize(self.chosen)
exp.rejected = Experience.deserialize(self.rejected)
return exp
SCHEMA_MAPPING = {
None: TaskModel,
AlgorithmType.SFT: SFTDataModel,
AlgorithmType.PPO: ExperienceModel,
AlgorithmType.GRPO: ExperienceModel,
AlgorithmType.OPMD: ExperienceModel,
AlgorithmType.DPO: DPODataModel,
}
[docs]
def create_dynamic_table(algorithm_type: Union[AlgorithmType | None], table_name: str) -> Any:
"""Create a dynamic table based on the provided algorithm type and table name."""
if algorithm_type not in SCHEMA_MAPPING:
raise ValueError(f"Unknown schema: {algorithm_type}")
base_class = SCHEMA_MAPPING[algorithm_type]
table_attrs = {
"__tablename__": table_name,
}
return type(table_name, (base_class,), table_attrs)