Source code for trinity.common.models.mm_utils
from typing import Any, Dict
[docs]
def build_multi_modal_inputs(
prompt: str,
raw_mm_data: Dict[str, Any],
processor: Any,
**kwargs,
) -> Dict[str, Any]:
"""
Preprocess multi-modal data and build multi-modal inputs
Adapted from: verl/utils/dataset/rl_dataset.py
"""
from verl.utils.dataset.vision_utils import process_image, process_video
if prompt is None:
raise ValueError("Prompt is required for build multi-modal inputs")
raw_images, raw_videos = None, None
if "image" in raw_mm_data:
raw_images = raw_mm_data["image"]
if "video" in raw_mm_data:
raw_videos = raw_mm_data["video"]
multi_modal_data = {}
images, videos = None, None
if raw_images is not None:
images = [process_image(image) for image in raw_images]
multi_modal_data["image"] = images
if raw_videos is not None:
videos = [process_video(video) for video in raw_videos]
multi_modal_data["video"] = [video.numpy() for video in videos]
model_inputs = processor(text=[prompt], images=images, videos=videos, return_tensors="pt")
model_inputs.pop("input_ids", None) # TODO: check
model_inputs.pop("attention_mask", None)
# There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
multi_modal_inputs = dict(model_inputs)
return {
"prompt": prompt,
"multi_modal_inputs": multi_modal_inputs,
"multi_modal_data": multi_modal_data,
}
[docs]
def attach_images_to_messages(messages, raw_mm_data):
new_msgs = [dict(m) for m in messages]
imgs = (raw_mm_data or {}).get("image") or []
if not imgs:
return new_msgs
for i in range(len(new_msgs) - 1, -1, -1):
if new_msgs[i].get("role") == "user":
content = new_msgs[i].get("content", "")
items = []
if isinstance(content, str):
text = content.replace("<image>", "").replace("<|image_pad|>", "").strip()
if text:
items.append({"type": "text", "text": text})
elif isinstance(content, list):
for c in content:
if isinstance(c, str):
t = c.replace("<image>", "").replace("<|image_pad|>", "").strip()
if t:
items.append({"type": "text", "text": t})
elif isinstance(c, dict):
items.append(c)
for img in imgs:
items.append({"type": "image", "image": img})
new_msgs[i]["content"] = items
break
return new_msgs