Source code for trinity.explorer.api.api
import traceback
import httpx
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
app = FastAPI()
# Forward openAI requests to a model instance
[docs]
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
# Currently, we do not support streaming chat completions
body = await request.json()
url = await request.app.state.service.allocate_model()
try:
async with httpx.AsyncClient(timeout=request.app.state.inference_timeout) as client:
resp = await client.post(f"{url}/v1/chat/completions", json=body)
except Exception:
return Response(
status_code=500,
content=f"Error forwarding request to model at {url}: {traceback.format_exc()}",
)
resp_data = resp.json()
await request.app.state.service.record_experience(resp_data)
return JSONResponse(content=resp_data)
[docs]
@app.get("/v1/models")
async def show_available_models(request: Request):
body = await request.json()
url = await request.app.state.service.allocate_model(increase_count=False)
async with httpx.AsyncClient() as client:
resp = await client.get(f"{url}/v1/models", json=body)
return JSONResponse(content=resp.json())
[docs]
@app.get("/health")
async def health(request: Request) -> Response:
"""Health check."""
return Response(status_code=200)
[docs]
@app.get("/metrics")
async def metrics(request: Request):
"""Get the metrics of the service."""
metrics = request.app.state.service.collect_metrics()
metrics["explore_step_num"] = request.app.state.service.explorer.explore_step_num
return JSONResponse(content=metrics)
[docs]
async def serve_http(app: FastAPI, host: str, port: int = None):
config = uvicorn.Config(app, host=host, port=port)
server = uvicorn.Server(config)
await server.serve()
[docs]
async def run_app(service, listen_address: str, port: int = None) -> FastAPI:
app.state.service = service
app.state.inference_timeout = service.explorer.config.synchronizer.sync_timeout
print(f"API server running on {listen_address}:{port}")
await serve_http(app, listen_address, port)