import copy
from typing import Dict, Optional
from pydantic import PositiveInt
from data_juicer.utils.constant import Fields
from data_juicer.utils.lazy_loader import AUTOINSTALL
from data_juicer.utils.mm_utils import SpecialTokens, remove_special_tokens
from data_juicer.utils.model_utils import get_model, prepare_model
from ..base_op import OPERATORS, Mapper
NAME = 'video_captioning_from_summarizer_mapper'
[docs]
@OPERATORS.register_module(NAME)
class VideoCaptioningFromSummarizerMapper(Mapper):
"""
Mapper to generate video captions by summarizing several kinds of generated
texts (captions from video/audio/frames, tags from audio/frames, ...)
"""
_accelerator = 'cuda'
_batched_op = True
[docs]
def __init__(self,
hf_summarizer: str = None,
trust_remote_code: bool = False,
consider_video_caption_from_video: bool = True,
consider_video_caption_from_audio: bool = True,
consider_video_caption_from_frames: bool = True,
consider_video_tags_from_audio: bool = True,
consider_video_tags_from_frames: bool = True,
vid_cap_from_vid_args: Optional[Dict] = None,
vid_cap_from_frm_args: Optional[Dict] = None,
vid_tag_from_aud_args: Optional[Dict] = None,
vid_tag_from_frm_args: Optional[Dict] = None,
keep_tag_num: PositiveInt = 5,
keep_original_sample: bool = True,
*args,
**kwargs):
"""
Initialization method.
:param hf_summarizer: the summarizer model used to summarize texts
generated by other methods.
:param consider_video_caption_from_video: whether to consider the video
caption generated from video directly in the summarization process.
Default: True.
:param consider_video_caption_from_audio: whether to consider the video
caption generated from audio streams in the video in the
summarization process. Default: True.
:param consider_video_caption_from_frames: whether to consider the
video caption generated from sampled frames from the video in the
summarization process. Default: True.
:param consider_video_tags_from_audio: whether to consider the video
tags generated from audio streams in the video in the summarization
process. Default: True.
:param consider_video_tags_from_frames: whether to consider the video
tags generated from sampled frames from the video in the
summarization process. Default: True.
:param vid_cap_from_vid_args: the arg dict for video captioning from
video directly with keys are the arg names and values are the arg
values. Default: None.
:param vid_cap_from_frm_args: the arg dict for video captioning from
sampled frames from the video with keys are the arg names and
values are the arg values. Default: None.
:param vid_tag_from_aud_args: the arg dict for video tagging from audio
streams in the video with keys are the arg names and values are the
arg values. Default: None.
:param vid_tag_from_frm_args: the arg dict for video tagging from
sampled frames from the video with keys are the arg names and
values are the arg values. Default: None.
:param keep_tag_num: max number N of tags from sampled frames to keep.
Too many tags might bring negative influence to summarized text, so
we consider to only keep the N most frequent tags. Default: 5.
:param keep_original_sample: whether to keep the original sample. If
it's set to False, there will be only summarized captions in the
final datasets and the original captions will be removed. It's True
in default.
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '40GB')
super().__init__(*args, **kwargs)
AUTOINSTALL.check([
'torch',
'transformers',
'transformers_stream_generator',
'einops',
'accelerate',
'tiktoken', # by audio caption
'torchaudio', # by audio tag
])
self.keep_original_sample = keep_original_sample
self.extra_args = kwargs
# prepare summarizer
self._hf_summarizer = hf_summarizer if hf_summarizer else 'mrm8488/flan-t5-large-finetuned-openai-summarize_from_feedback' # noqa: E501
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=self._hf_summarizer,
trust_remote_code=trust_remote_code)
# prepare input texts ops
if vid_cap_from_vid_args is None:
vid_cap_from_vid_args = {}
if vid_cap_from_frm_args is None:
vid_cap_from_frm_args = {}
if vid_tag_from_aud_args is None:
vid_tag_from_aud_args = {}
if vid_tag_from_frm_args is None:
vid_tag_from_frm_args = {}
self.FIXED_ARGS = {
'caption_num': 1,
'keep_candidate_mode': 'random_any',
'keep_original_sample': False,
}
self.cap_op_list = []
self.tag_op_list = []
if consider_video_caption_from_video:
from .video_captioning_from_video_mapper import \
VideoCaptioningFromVideoMapper
self.cap_op_list.append(
VideoCaptioningFromVideoMapper(**self._prepare_op_args(
VideoCaptioningFromVideoMapper, vid_cap_from_vid_args)))
if consider_video_caption_from_audio:
from .video_captioning_from_audio_mapper import \
VideoCaptioningFromAudioMapper
self.cap_op_list.append(
VideoCaptioningFromAudioMapper(**self._prepare_op_args(
VideoCaptioningFromAudioMapper, {})))
if consider_video_caption_from_frames:
from .video_captioning_from_frames_mapper import \
VideoCaptioningFromFramesMapper
self.cap_op_list.append(
VideoCaptioningFromFramesMapper(**self._prepare_op_args(
VideoCaptioningFromFramesMapper, vid_cap_from_frm_args)))
if consider_video_tags_from_audio:
from .video_tagging_from_audio_mapper import \
VideoTaggingFromAudioMapper
self.tag_op_list.append(
VideoTaggingFromAudioMapper(**self._prepare_op_args(
VideoTaggingFromAudioMapper, vid_tag_from_aud_args)))
if consider_video_tags_from_frames:
from .video_tagging_from_frames_mapper import \
VideoTaggingFromFramesMapper
self.tag_op_list.append(
VideoTaggingFromFramesMapper(**self._prepare_op_args(
VideoTaggingFromFramesMapper, vid_tag_from_frm_args)))
self.keep_tag_num = keep_tag_num
def _prepare_op_args(self, op_class, args_dict):
required_args = set(op_class.__init__.__code__.co_varnames)
args_dict.update(self.FIXED_ARGS)
temp_args = copy.deepcopy(args_dict)
for key in temp_args:
if key not in required_args:
args_dict.pop(key)
args_dict['accelerator'] = self.accelerator
return args_dict
def _process_single_sample(self, sample, rank=None):
# there is no video in this sample
if self.video_key not in sample or not sample[self.video_key]:
return []
# there is no activated ops
if len(self.cap_op_list) == 0 and len(self.tag_op_list) == 0:
return []
# get paths of all video(s)
loaded_video_keys = sample[self.video_key]
# get models
model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
captioned_sample = copy.deepcopy(sample)
# generate for each video chunk by chunk
captioned_texts = ''
offset = 0
for chunk in sample[self.text_key].split(SpecialTokens.eoc):
# skip empty chunks
if not chunk.strip():
continue
vid_count = chunk.count(SpecialTokens.video)
if vid_count == 0:
# add special tokens
captioned_texts += f'{chunk}{SpecialTokens.eoc}'
continue
# make a temporary sample
temp_sample = {
self.text_key: chunk,
self.video_key: loaded_video_keys[offset:offset + vid_count],
}
captioned_text_list = []
# tag ops
for op in self.tag_op_list:
temp_sample = op.process(temp_sample, rank=rank)
if Fields.video_audio_tags in temp_sample:
captioned_text_list.extend(
temp_sample[Fields.video_audio_tags])
if Fields.video_frame_tags in temp_sample:
for tag_list in temp_sample[Fields.video_frame_tags]:
captioned_text_list.extend(tag_list[self.keep_tag_num])
# cap ops
for op in self.cap_op_list:
captioned_text_list.append(
remove_special_tokens(
op._process_single_sample(temp_sample,
rank=rank)[0]['text']))
# summarization
all_texts = ', '.join(captioned_text_list)
input_ids = tokenizer(all_texts, return_tensors='pt').input_ids.to(
model.device)
outputs = model.generate(input_ids, max_new_tokens=128)
summarized_text = tokenizer.decode(outputs[0],
skip_special_tokens=True)
offset += vid_count
captioned_text = f'{SpecialTokens.video * vid_count} ' \
f'{summarized_text}'
# add special tokens
captioned_texts += f'{captioned_text}{SpecialTokens.eoc}'
captioned_sample[self.text_key] = captioned_texts
return [captioned_sample]
[docs]
def process_batched(self, samples, rank=None):
# reconstruct samples from "dict of lists" to "list of dicts"
reconstructed_samples = []
for i in range(len(samples[self.text_key])):
reconstructed_samples.append(
{key: samples[key][i]
for key in samples})
samples_after_split = []
# do split for each sample within the batch
for ori_sample in reconstructed_samples:
if self.keep_original_sample:
samples_after_split.append(ori_sample)
generated_samples = self._process_single_sample(ori_sample,
rank=rank)
if len(generated_samples) != 0:
samples_after_split.extend(generated_samples)
# reconstruct samples from "list of dicts" to "dict of lists"
keys = samples_after_split[0].keys()
res_samples = {}
for key in keys:
res_samples[key] = [s[key] for s in samples_after_split]
return res_samples