Source code for trinity.buffer.viewer

import argparse
from typing import List

import streamlit as st
import streamlit.components.v1 as components
from sqlalchemy.orm import sessionmaker
from transformers import AutoTokenizer

from trinity.buffer.schema import init_engine
from trinity.common.config import StorageConfig
from trinity.common.experience import Experience
from trinity.utils.log import get_logger


[docs] class SQLExperienceViewer:
[docs] def __init__(self, config: StorageConfig) -> None: self.logger = get_logger(f"sql_{config.name}", in_ray_actor=True) if not config.path: raise ValueError("`path` is required for SQL storage type.") self.engine, self.table_model_cls = init_engine( db_url=config.path, table_name=config.name, schema_type=config.schema_type, ) self.session = sessionmaker(bind=self.engine)
[docs] def get_experiences(self, offset: int, limit: int = 10) -> List[Experience]: self.logger.info(f"Viewing experiences from offset {offset} with limit {limit}.") with self.session() as session: query = session.query(self.table_model_cls).offset(offset).limit(limit) results = query.all() exps = [self.table_model_cls.to_experience(row) for row in results] return exps
[docs] def total_experiences(self) -> int: with self.session() as session: count = session.query(self.table_model_cls).count() return count
st.set_page_config(page_title="Trinity-RFT Experience Visualizer", layout="wide")
[docs] def get_color_for_action_mask(action_mask_value: int) -> str: """Return color based on action_mask value""" if action_mask_value == 1: return "#c8e6c9" else: return "#ffcdd2"
[docs] def render_experience(exp: Experience, exp_index: int, tokenizer): """Render a single experience sequence in Streamlit.""" token_ids = exp.tokens logprobs = exp.logprobs action_mask = exp.action_mask prompt_length = exp.prompt_length prompt_token_ids = token_ids[:prompt_length] # type: ignore [index] response_token_ids = token_ids[prompt_length:] # type: ignore [index] # Decode tokens prompt_text = ( tokenizer.decode(prompt_token_ids) if hasattr(tokenizer, "decode") else "".join([str(tid) for tid in prompt_token_ids]) ) response_text = ( tokenizer.decode(response_token_ids) if hasattr(tokenizer, "decode") else "".join([str(tid) for tid in response_token_ids]) ) # Get each response token text response_tokens = [] for tid in response_token_ids: if hasattr(tokenizer, "decode"): token_text = tokenizer.decode([tid]) else: token_text = f"[{tid}]" response_tokens.append(token_text) # HTML escape function def html_escape(text): return ( text.replace("&", "&amp;") .replace("<", "&lt;") .replace(">", "&gt;") .replace('"', "&quot;") .replace("'", "&#39;") ) # Build full HTML (with CSS) html = f""" <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <style> * {{ margin: 0; padding: 0; box-sizing: border-box; }} body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif; padding: 10px; }} .sequence-container {{ border: 2px solid #e0e0e0; border-radius: 10px; padding: 20px; background-color: #f8f9fa; height: auto; }} .sequence-header {{ font-size: 18px; font-weight: bold; margin-bottom: 15px; color: #333; }} .text-section {{ background-color: white; padding: 15px; border-radius: 5px; margin-bottom: 15px; border-left: 4px solid #4CAF50; }} .prompt-section {{ background-color: #e3f2fd; padding: 10px; border-radius: 5px; margin-bottom: 10px; font-family: 'Courier New', monospace; white-space: pre-wrap; word-wrap: break-word; }} .response-section {{ background-color: #fff3e0; padding: 10px; border-radius: 5px; font-family: 'Courier New', monospace; white-space: pre-wrap; word-wrap: break-word; }} .token-container {{ display: flex; flex-wrap: wrap; gap: 5px; padding: 15px; background-color: white; border-radius: 5px; }} .token-box {{ display: inline-flex; flex-direction: column; align-items: center; padding: 8px 12px; border-radius: 5px; border: 1px solid #ddd; min-width: 60px; transition: transform 0.2s, box-shadow 0.2s; }} .token-box:hover {{ transform: scale(1.5); box-shadow: 0 4px 8px rgba(0,0,0,0.2); z-index: 10; }} .token-text {{ font-family: 'Courier New', monospace; font-size: 14px; font-weight: bold; margin-bottom: 5px; text-align: center; word-break: break-all; max-width: 100px; }} .token-logprob {{ font-size: 11px; color: #555; font-family: 'Courier New', monospace; text-align: center; }} .label-text {{ font-weight: bold; color: #1976d2; margin-bottom: 5px; margin-top: 10px; }} .section-divider {{ margin: 20px 0; border-top: 2px dashed #ccc; }} </style> </head> <body> <div class="sequence-container"> <div class="sequence-header">Experience {exp_index + 1}</div> <div class="text-section"> <div class="label-text">📝 Prompt:</div> <div class="prompt-section">{html_escape(prompt_text)}</div> <div class="label-text">💬 Response:</div> <div class="response-section">{html_escape(response_text)}</div> </div> <div class="section-divider"></div> <div class="label-text">🔍 Response Tokens Detail:</div> <div class="token-container"> """ # Add each response token for i, (token_text, logprob, mask) in enumerate(zip(response_tokens, logprobs, action_mask)): # type: ignore [arg-type] bg_color = get_color_for_action_mask(mask) # Handle special character display token_display = token_text.replace(" ", "␣").replace("\n", "↵").replace("\t", "⇥") token_display = html_escape(token_display) html += f""" <div class="token-box" style="background-color: {bg_color};"> <div class="token-text">{token_display}</div> <div class="token-logprob">{logprob:.3f}</div> </div> """ html += """ </div> </div> </body> </html> """ # Use components.html instead of st.markdown components.html(html, height=1200, scrolling=True)
[docs] def parse_args(): parser = argparse.ArgumentParser(description="Experience Visualizer") parser.add_argument( "--db-url", type=str, help="Path to the experience database.", ) parser.add_argument( "--table", type=str, help="Name of the experience table.", ) parser.add_argument( "--tokenizer", type=str, help="Path to the tokenizer.", ) return parser.parse_args()
[docs] def main(): args = parse_args() # Initialize SQLExperienceViewer config = StorageConfig( name=args.table, path=args.db_url, schema_type="experience", storage_type="sql", ) viewer = SQLExperienceViewer(config) st.title("🎯 Trinity-RFT Experience Visualizer") if "page" not in st.session_state: st.session_state.page = 1 # Add instructions with st.expander("ℹ️ Instructions"): st.markdown( """ - **Green background**: action_mask = 1 - **Red background**: action_mask = 0 - **Top**: Token text (special characters: space=␣, newline=↵, tab=⇥) - **Bottom**: Logprob value of the token - Hover over token to zoom in """ ) # Get total sequence number total_seq_num = viewer.total_experiences() # Sidebar configuration st.sidebar.header("⚙️ Settings") # Pagination settings experiences_per_page = st.sidebar.slider( "Experiences per page", min_value=1, max_value=20, value=5 ) # Calculate total pages total_pages = (total_seq_num + experiences_per_page - 1) // experiences_per_page # Page selection (sidebar) current_page = st.sidebar.number_input( "Select page", min_value=1, max_value=max(1, total_pages), step=1, value=st.session_state.page, ) if current_page != st.session_state.page: st.session_state.page = current_page st.rerun() # Show statistics st.sidebar.markdown("---") st.sidebar.metric("Total experiences", total_seq_num) st.sidebar.metric("Total pages", total_pages) st.sidebar.metric("Current page", f"{st.session_state.page}/{total_pages}") # Calculate offset offset = (st.session_state.page - 1) * experiences_per_page # Get experiences for current page experiences = viewer.get_experiences(offset, experiences_per_page) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) # Render experiences if experiences: for i, exp in enumerate(experiences): render_experience(exp, offset + i, tokenizer) else: st.warning("No experience data found") # Pagination navigation st.markdown("---") col1, col2, col3 = st.columns([1, 2, 1]) with col1: if st.session_state.page > 1: if st.button("⬅️ Previous Page"): st.session_state.page = st.session_state.page - 1 st.rerun() with col2: st.markdown( f"<center>Page {st.session_state.page} / {total_pages}</center>", unsafe_allow_html=True ) with col3: if st.session_state.page < total_pages: if st.button("Next Page ➡️"): st.session_state.page = st.session_state.page + 1 st.rerun()
if __name__ == "__main__": main()