import json
import os
import random
from typing import Dict, Optional
from PIL import Image
import data_juicer
from data_juicer.ops.load import load_ops
from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE
from data_juicer.utils.constant import Fields
from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_IMAGES
OP_NAME = "detect_character_attributes_mapper"
[docs]
@UNFORKABLE.register_module(OP_NAME)
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class DetectCharacterAttributesMapper(Mapper):
"""Takes an image, a caption, and main character names as input to extract the characters' attributes."""
_accelerator = "cuda"
[docs]
def __init__(
self,
detect_character_locations_mapper_args: Optional[Dict] = {},
*args,
**kwargs,
):
"""
Initialization method.
:param detect_character_locations_mapper_args: Arguments for detect_character_locations_mapper_args.
Controls the threshold for locating the main character.
Default empty dict will use fixed values: default mllm_mapper_args,
default image_text_matching_filter_args, yoloe_path="yoloe-11l-seg.pt",
iou_threshold=0.7, matching_score_threshold=0.4,
"""
super().__init__(*args, **kwargs)
self.FIXED_ARGS = {}
self.FIXED_ARGS["detect_character_locations_mapper"] = {
"mllm_mapper_args": {
"max_new_tokens": 256,
"temperature": 0.2,
"top_p": None,
"num_beams": 1,
"hf_model": "llava-hf/llava-v1.6-vicuna-7b-hf",
},
"image_text_matching_filter_args": {
"min_score": 0,
"max_score": 1.0,
"hf_blip": "Salesforce/blip-itm-base-coco",
"num_proc": 1,
},
"yoloe_path": "yoloe-11l-seg.pt",
"iou_threshold": 0.7,
"matching_score_threshold": 0.4,
}
self.detect_character_locations_mapper_args = self._prepare_op_args(
"detect_character_locations_mapper", detect_character_locations_mapper_args
)
self.fused_op_list = [{"detect_character_locations_mapper": self.detect_character_locations_mapper_args}]
self.fused_ops = load_ops(self.fused_op_list)
accelerator_methods = set([op.accelerator for op in self.fused_ops])
if "cuda" in accelerator_methods:
self.accelerator = "cuda"
# update num_proc with the min num_proc of all fusible filters
self.num_proc = min([op.runtime_np() for op in self.fused_ops]) if self.fused_ops else 1
def _prepare_op_args(self, op_name, args_dict):
for key in self.FIXED_ARGS[op_name]:
if key not in args_dict:
args_dict[key] = self.FIXED_ARGS[op_name][key]
args_dict["accelerator"] = self.accelerator
return args_dict
[docs]
def process_single(self, samples, rank=None):
if Fields.meta not in samples:
samples[Fields.meta] = {}
detect_location_dataset = data_juicer.core.NestedDataset.from_list(
[{"main_character_list": samples["main_character_list"], "images": samples["images"]}]
)
character_locations = detect_location_dataset.map(
self.fused_ops[0].process, num_proc=1, with_rank=True
).to_list()
character_locations = character_locations[0][Fields.meta]["main_character_locations_list"]
character_to_characteristics = {}
character_to_cls = {}
for temp_character in samples["main_character_list"]:
# detect class
prompt = (
'Please classify the character "'
+ temp_character
+ "\" into the following categories: ['object', 'animal', 'person', 'text', 'other']. Only reply with the most fitting single category."
)
mllm_sample = {"text": prompt, "images": samples["images"]}
output_text = self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
character_to_cls[temp_character] = output_text
# detect feature
prompt = (
'I will provide you with the corresponding description of an image, as follows: "'
+ samples["text"]
+ "\" Please extract all descriptions of the features related to '"
+ temp_character
+ '\' from this text, which may include color, material, action, and other typical features, and compile them into a list of phrase string. Formatted like: ["in a blue shirt", "sitting on a nearby fence", "with flame decals"]. Return only the phrase string list.'
)
mllm_sample = {"text": prompt, "images": samples["images"]}
output_text = self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
try:
character_to_characteristics[temp_character] = json.loads(output_text)
except json.JSONDecodeError:
character_to_characteristics[temp_character] = [output_text]
image = Image.open(samples["images"][0])
valid_character_in_bbox_dict = {}
for temp_character_with_bbox_idx, temp_character_with_bbox in enumerate(character_locations):
crop_img = image.crop(temp_character_with_bbox["bbox"])
cache_img_name = (
"temp_"
+ str(random.randint(0, 9999))
+ "_"
+ str(temp_character_with_bbox_idx)
+ samples["images"][0].split("/")[-1]
)
cache_img_path = os.path.join(
DATA_JUICER_ASSETS_CACHE,
cache_img_name,
)
crop_img.save(cache_img_path)
try:
temp_character_cls = character_to_cls[temp_character_with_bbox["main_character"]]
except Exception:
os.remove(cache_img_path)
continue
if "object" in temp_character_cls:
prompt = (
"Please analyze the key characteristics of the main object in this image, specifically the '"
+ temp_character_with_bbox["main_character"]
+ "', which may include color, material, shape, and other typical features. Currently identified characteristics include \""
+ str(temp_character_cls)
+ '". Please expand this list and respond in an identically formatted phrase string list.'
)
mllm_sample = {"text": prompt, "images": [cache_img_path]}
output_text = (
self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
)
elif "animal" in temp_character_cls:
prompt = (
"Please analyze the key characteristics of the primary animal in this image, specifically the '"
+ temp_character_with_bbox["main_character"]
+ "', which may include color, action, and other typical features. Currently identified characteristics include \""
+ str(temp_character_cls)
+ '". Please expand this list and respond in an identically formatted phrase string list.'
)
mllm_sample = {"text": prompt, "images": [cache_img_path]}
output_text = (
self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
)
elif "person" in temp_character_cls:
prompt = (
"Please analyze the key characteristics of the primary person in this image, specifically the '"
+ temp_character_with_bbox["main_character"]
+ "', which may include clothing, ages, and other typical features. Currently identified characteristics include \""
+ str(temp_character_cls)
+ '". Please expand this list and respond in an identically formatted phrase string list.'
)
mllm_sample = {"text": prompt, "images": [cache_img_path]}
output_text = (
self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
)
else:
prompt = (
"Please analyze the key characteristics of the primary character in this image, specifically the '"
+ temp_character_with_bbox["main_character"]
+ "'. Currently identified characteristics include \""
+ str(temp_character_cls)
+ '". Please expand this list and respond in an identically formatted phrase string list.'
)
mllm_sample = {"text": prompt, "images": [cache_img_path]}
output_text = (
self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
)
final_characteristic_list = []
# filter
try:
characteristic_list = json.loads(output_text)
except json.JSONDecodeError:
characteristic_list = output_text
if isinstance(characteristic_list, list):
if len(characteristic_list) == 1:
characteristic_list = characteristic_list[0].replace("_", " ").split(", ")
try:
for temp_characteristic in characteristic_list:
prompt = (
'Please analyze the main character in this image, specifically the "'
+ temp_character_with_bbox["main_character"]
+ '". Is "'
+ temp_characteristic
+ "\" one of its features? Only respond with 'yes' if it is a perfect match. Please only respond with 'yes' or 'no'."
)
mllm_sample = {"text": prompt, "images": [cache_img_path]}
output_text = (
self.fused_ops[0]
.fused_ops[0]
.process(mllm_sample)["text"][0]
.split("ASSISTANT:")[-1]
.strip()
)
if "yes" in output_text:
final_characteristic_list.append(temp_characteristic)
except Exception:
os.remove(cache_img_path)
continue
else:
try:
characteristic_list = output_text.split("\n")
if len(characteristic_list) == 1:
characteristic_list = characteristic_list[0].replace("_", " ").split(", ")
for temp_characteristic in characteristic_list:
prompt = (
'Please analyze the main character in this image, specifically the "'
+ temp_character_with_bbox["main_character"]
+ '". Is "'
+ temp_characteristic
+ "\" one of its features? Only respond with 'yes' if it is a perfect match. Please only respond with 'yes' or 'no'."
)
mllm_sample = {"text": prompt, "images": [cache_img_path]}
output_text = (
self.fused_ops[0]
.fused_ops[0]
.process(mllm_sample)["text"][0]
.split("ASSISTANT:")[-1]
.strip()
)
if "yes" in output_text:
final_characteristic_list.append(temp_characteristic)
except Exception:
os.remove(cache_img_path)
continue
valid_character_in_bbox_dict[temp_character_with_bbox["main_character"]] = {}
valid_character_in_bbox_dict[temp_character_with_bbox["main_character"]]["bbox"] = temp_character_with_bbox[
"bbox"
]
valid_character_in_bbox_dict[temp_character_with_bbox["main_character"]][
"final_characteristic_list"
] = final_characteristic_list
os.remove(cache_img_path)
new_character_list = []
for temp_character in samples["main_character_list"]:
temp_character_json = {}
temp_character_json["main_character"] = temp_character
if temp_character in valid_character_in_bbox_dict:
temp_character_json["bbox"] = valid_character_in_bbox_dict[temp_character]["bbox"]
if len(valid_character_in_bbox_dict[temp_character]["final_characteristic_list"]) == 0:
temp_character_json["characteristic_list"] = character_to_characteristics[temp_character]
else:
temp_character_json["characteristic_list"] = valid_character_in_bbox_dict[temp_character][
"final_characteristic_list"
]
else:
temp_character_json["bbox"] = []
temp_character_json["characteristic_list"] = character_to_characteristics[temp_character]
new_character_list.append(temp_character_json)
samples[Fields.meta]["main_character_attributes_list"] = new_character_list
return samples