Source code for trinity.explorer.proxy.app
import traceback
from contextlib import asynccontextmanager
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response
http_client: httpx.AsyncClient = None
[docs]
@asynccontextmanager
async def lifespan(app: FastAPI):
global http_client
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(300.0, connect=10.0),
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
yield
await http_client.aclose()
app = FastAPI(lifespan=lifespan)
# Forward OpenAI requests to a model instance
[docs]
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
# Currently, we do not support streaming chat completions
try:
request_data = await request.json()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
forward_headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in ["host", "content-length", "transfer-encoding"]
}
# for experience data recording, we need to return token ids and logprobs
request_data["return_token_ids"] = True
request_data["logprobs"] = True
# temperature must be set from config, ignore user's input
request_data["temperature"] = request.app.state.temperature
url, model_version = await request.app.state.service.allocate_model()
try:
async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client:
resp = await client.post(
f"{url}/v1/chat/completions", json=request_data, headers=forward_headers
)
except Exception:
return Response(
status_code=500,
content=f"Error forwarding request to model at {url}: {traceback.format_exc()}",
)
resp_data = resp.json()
await request.app.state.service.record_experience(resp_data, model_version)
return JSONResponse(content=resp_data)
[docs]
@app.get("/v1/models")
async def show_available_models(request: Request):
if hasattr(request.app.state, "models"):
return JSONResponse(content=request.app.state.models)
url, _ = await request.app.state.service.allocate_model(increase_count=False)
async with httpx.AsyncClient() as client:
print(f"Fetching models from {url}/v1/models")
resp = await client.get(f"{url}/v1/models")
request.app.state.models = resp.json()
return JSONResponse(content=resp.json())
[docs]
@app.get("/health")
async def health(request: Request) -> Response:
"""Health check."""
return Response(status_code=200)
[docs]
@app.get("/metrics")
async def metrics(request: Request):
"""Get the metrics of the service."""
metrics = request.app.state.service.collect_metrics()
metrics["explore_step_num"] = request.app.state.service.explorer.explore_step_num
return JSONResponse(content=metrics)
[docs]
@app.post("/feedback")
async def feedback(request: Request):
"""Receive feedback for the current session."""
body = await request.json()
reward = body.get("reward")
msg_ids = body.get("msg_ids")
task_id = body.get("task_id")
run_id = body.get("run_id", 0)
if msg_ids is None or reward is None:
return JSONResponse(status_code=400, content={"error": "msg_ids and reward are required"})
if not isinstance(msg_ids, list) or not isinstance(reward, (int, float)):
return JSONResponse(
status_code=400, content={"error": "msg_ids must be a list and reward must be a number"}
)
await request.app.state.service.record_feedback(
reward=reward, msg_ids=msg_ids, task_id=task_id, run_id=run_id
)
return JSONResponse(content={"status": "success"})
[docs]
@app.post("/commit")
async def commit(request: Request):
"""Commit the current experiences."""
await request.app.state.service.submit_experiences()
return JSONResponse(content={"status": "success"})
[docs]
async def serve_http(app: FastAPI, host: str, port: int) -> None:
config = uvicorn.Config(app, host=host, port=port)
server = uvicorn.Server(config)
await server.serve()
[docs]
async def run_app(service, listen_address: str, port: int) -> None:
app.state.service = service
app.state.temperature = service.explorer.config.model.temperature
app.state.inference_timeout = service.explorer.config.synchronizer.sync_timeout
await serve_http(app, listen_address, port)