import argparse
from typing import List
import streamlit as st
import streamlit.components.v1 as components
from sqlalchemy.orm import sessionmaker
from transformers import AutoTokenizer
from trinity.buffer.schema import init_engine
from trinity.common.config import StorageConfig
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
[docs]
class SQLExperienceViewer:
[docs]
def __init__(self, config: StorageConfig) -> None:
self.logger = get_logger(f"sql_{config.name}", in_ray_actor=True)
if not config.path:
raise ValueError("`path` is required for SQL storage type.")
self.engine, self.table_model_cls = init_engine(
db_url=config.path,
table_name=config.name,
schema_type=config.schema_type,
)
self.session = sessionmaker(bind=self.engine)
[docs]
def get_experiences(self, offset: int, limit: int = 10) -> List[Experience]:
self.logger.info(f"Viewing experiences from offset {offset} with limit {limit}.")
with self.session() as session:
query = session.query(self.table_model_cls).offset(offset).limit(limit)
results = query.all()
exps = [self.table_model_cls.to_experience(row) for row in results]
return exps
[docs]
def total_experiences(self) -> int:
with self.session() as session:
count = session.query(self.table_model_cls).count()
return count
st.set_page_config(page_title="Trinity-RFT Experience Visualizer", layout="wide")
[docs]
def get_color_for_action_mask(action_mask_value: int) -> str:
"""Return color based on action_mask value"""
if action_mask_value == 1:
return "#c8e6c9"
else:
return "#ffcdd2"
[docs]
def render_experience(exp: Experience, exp_index: int, tokenizer):
"""Render a single experience sequence in Streamlit."""
token_ids = exp.tokens
logprobs = exp.logprobs
action_mask = exp.action_mask
prompt_length = exp.prompt_length
prompt_token_ids = token_ids[:prompt_length] # type: ignore [index]
response_token_ids = token_ids[prompt_length:] # type: ignore [index]
# Decode tokens
prompt_text = (
tokenizer.decode(prompt_token_ids)
if hasattr(tokenizer, "decode")
else "".join([str(tid) for tid in prompt_token_ids])
)
response_text = (
tokenizer.decode(response_token_ids)
if hasattr(tokenizer, "decode")
else "".join([str(tid) for tid in response_token_ids])
)
# Get each response token text
response_tokens = []
for tid in response_token_ids:
if hasattr(tokenizer, "decode"):
token_text = tokenizer.decode([tid])
else:
token_text = f"[{tid}]"
response_tokens.append(token_text)
# HTML escape function
def html_escape(text):
return (
text.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace('"', """)
.replace("'", "'")
)
# Build full HTML (with CSS)
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
* {{
margin: 0;
padding: 0;
box-sizing: border-box;
}}
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
padding: 10px;
}}
.sequence-container {{
border: 2px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
background-color: #f8f9fa;
height: auto;
}}
.sequence-header {{
font-size: 18px;
font-weight: bold;
margin-bottom: 15px;
color: #333;
}}
.text-section {{
background-color: white;
padding: 15px;
border-radius: 5px;
margin-bottom: 15px;
border-left: 4px solid #4CAF50;
}}
.prompt-section {{
background-color: #e3f2fd;
padding: 10px;
border-radius: 5px;
margin-bottom: 10px;
font-family: 'Courier New', monospace;
white-space: pre-wrap;
word-wrap: break-word;
}}
.response-section {{
background-color: #fff3e0;
padding: 10px;
border-radius: 5px;
font-family: 'Courier New', monospace;
white-space: pre-wrap;
word-wrap: break-word;
}}
.token-container {{
display: flex;
flex-wrap: wrap;
gap: 5px;
padding: 15px;
background-color: white;
border-radius: 5px;
}}
.token-box {{
display: inline-flex;
flex-direction: column;
align-items: center;
padding: 8px 12px;
border-radius: 5px;
border: 1px solid #ddd;
min-width: 60px;
transition: transform 0.2s, box-shadow 0.2s;
}}
.token-box:hover {{
transform: scale(1.5);
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
z-index: 10;
}}
.token-text {{
font-family: 'Courier New', monospace;
font-size: 14px;
font-weight: bold;
margin-bottom: 5px;
text-align: center;
word-break: break-all;
max-width: 100px;
}}
.token-logprob {{
font-size: 11px;
color: #555;
font-family: 'Courier New', monospace;
text-align: center;
}}
.label-text {{
font-weight: bold;
color: #1976d2;
margin-bottom: 5px;
margin-top: 10px;
}}
.section-divider {{
margin: 20px 0;
border-top: 2px dashed #ccc;
}}
</style>
</head>
<body>
<div class="sequence-container">
<div class="sequence-header">Experience {exp_index + 1}</div>
<div class="text-section">
<div class="label-text">📝 Prompt:</div>
<div class="prompt-section">{html_escape(prompt_text)}</div>
<div class="label-text">💬 Response:</div>
<div class="response-section">{html_escape(response_text)}</div>
</div>
<div class="section-divider"></div>
<div class="label-text">🔍 Response Tokens Detail:</div>
<div class="token-container">
"""
# Add each response token
for i, (token_text, logprob, mask) in enumerate(zip(response_tokens, logprobs, action_mask)): # type: ignore [arg-type]
bg_color = get_color_for_action_mask(mask)
# Handle special character display
token_display = token_text.replace(" ", "␣").replace("\n", "↵").replace("\t", "⇥")
token_display = html_escape(token_display)
html += f"""
<div class="token-box" style="background-color: {bg_color};">
<div class="token-text">{token_display}</div>
<div class="token-logprob">{logprob:.3f}</div>
</div>
"""
html += """
</div>
</div>
</body>
</html>
"""
# Use components.html instead of st.markdown
components.html(html, height=1200, scrolling=True)
[docs]
def parse_args():
parser = argparse.ArgumentParser(description="Experience Visualizer")
parser.add_argument(
"--db-url",
type=str,
help="Path to the experience database.",
)
parser.add_argument(
"--table",
type=str,
help="Name of the experience table.",
)
parser.add_argument(
"--tokenizer",
type=str,
help="Path to the tokenizer.",
)
return parser.parse_args()
[docs]
def main():
args = parse_args()
# Initialize SQLExperienceViewer
config = StorageConfig(
name=args.table,
path=args.db_url,
schema_type="experience",
storage_type="sql",
)
viewer = SQLExperienceViewer(config)
st.title("🎯 Trinity-RFT Experience Visualizer")
if "page" not in st.session_state:
st.session_state.page = 1
# Add instructions
with st.expander("ℹ️ Instructions"):
st.markdown(
"""
- **Green background**: action_mask = 1
- **Red background**: action_mask = 0
- **Top**: Token text (special characters: space=␣, newline=↵, tab=⇥)
- **Bottom**: Logprob value of the token
- Hover over token to zoom in
"""
)
# Get total sequence number
total_seq_num = viewer.total_experiences()
# Sidebar configuration
st.sidebar.header("⚙️ Settings")
# Pagination settings
experiences_per_page = st.sidebar.slider(
"Experiences per page", min_value=1, max_value=20, value=5
)
# Calculate total pages
total_pages = (total_seq_num + experiences_per_page - 1) // experiences_per_page
# Page selection (sidebar)
current_page = st.sidebar.number_input(
"Select page",
min_value=1,
max_value=max(1, total_pages),
step=1,
value=st.session_state.page,
)
if current_page != st.session_state.page:
st.session_state.page = current_page
st.rerun()
# Show statistics
st.sidebar.markdown("---")
st.sidebar.metric("Total experiences", total_seq_num)
st.sidebar.metric("Total pages", total_pages)
st.sidebar.metric("Current page", f"{st.session_state.page}/{total_pages}")
# Calculate offset
offset = (st.session_state.page - 1) * experiences_per_page
# Get experiences for current page
experiences = viewer.get_experiences(offset, experiences_per_page)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
# Render experiences
if experiences:
for i, exp in enumerate(experiences):
render_experience(exp, offset + i, tokenizer)
else:
st.warning("No experience data found")
# Pagination navigation
st.markdown("---")
col1, col2, col3 = st.columns([1, 2, 1])
with col1:
if st.session_state.page > 1:
if st.button("⬅️ Previous Page"):
st.session_state.page = st.session_state.page - 1
st.rerun()
with col2:
st.markdown(
f"<center>Page {st.session_state.page} / {total_pages}</center>", unsafe_allow_html=True
)
with col3:
if st.session_state.page < total_pages:
if st.button("Next Page ➡️"):
st.session_state.page = st.session_state.page + 1
st.rerun()
if __name__ == "__main__":
main()