Source code for data_juicer.ops.mapper.video_captioning_from_summarizer_mapper

import copy
from typing import Dict, Optional

from pydantic import PositiveInt

from data_juicer.utils.constant import Fields
from data_juicer.utils.lazy_loader import AUTOINSTALL
from data_juicer.utils.mm_utils import SpecialTokens, remove_special_tokens
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper

NAME = 'video_captioning_from_summarizer_mapper'


[docs]@OPERATORS.register_module(NAME) class VideoCaptioningFromSummarizerMapper(Mapper): """ Mapper to generate video captions by summarizing several kinds of generated texts (captions from video/audio/frames, tags from audio/frames, ...) """ _accelerator = 'cuda' _batched_op = True
[docs] def __init__(self, hf_summarizer: str = None, trust_remote_code: bool = False, consider_video_caption_from_video: bool = True, consider_video_caption_from_audio: bool = True, consider_video_caption_from_frames: bool = True, consider_video_tags_from_audio: bool = True, consider_video_tags_from_frames: bool = True, vid_cap_from_vid_args: Optional[Dict] = None, vid_cap_from_frm_args: Optional[Dict] = None, vid_tag_from_aud_args: Optional[Dict] = None, vid_tag_from_frm_args: Optional[Dict] = None, keep_tag_num: PositiveInt = 5, keep_original_sample: bool = True, *args, **kwargs): """ Initialization method. :param hf_summarizer: the summarizer model used to summarize texts generated by other methods. :param consider_video_caption_from_video: whether to consider the video caption generated from video directly in the summarization process. Default: True. :param consider_video_caption_from_audio: whether to consider the video caption generated from audio streams in the video in the summarization process. Default: True. :param consider_video_caption_from_frames: whether to consider the video caption generated from sampled frames from the video in the summarization process. Default: True. :param consider_video_tags_from_audio: whether to consider the video tags generated from audio streams in the video in the summarization process. Default: True. :param consider_video_tags_from_frames: whether to consider the video tags generated from sampled frames from the video in the summarization process. Default: True. :param vid_cap_from_vid_args: the arg dict for video captioning from video directly with keys are the arg names and values are the arg values. Default: None. :param vid_cap_from_frm_args: the arg dict for video captioning from sampled frames from the video with keys are the arg names and values are the arg values. Default: None. :param vid_tag_from_aud_args: the arg dict for video tagging from audio streams in the video with keys are the arg names and values are the arg values. Default: None. :param vid_tag_from_frm_args: the arg dict for video tagging from sampled frames from the video with keys are the arg names and values are the arg values. Default: None. :param keep_tag_num: max number N of tags from sampled frames to keep. Too many tags might bring negative influence to summarized text, so we consider to only keep the N most frequent tags. Default: 5. :param keep_original_sample: whether to keep the original sample. If it's set to False, there will be only summarized captions in the final datasets and the original captions will be removed. It's True in default. :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) AUTOINSTALL.check([ 'torch', 'transformers', 'transformers_stream_generator', 'einops', 'accelerate', 'tiktoken', # by audio caption 'torchaudio', # by audio tag ]) self.keep_original_sample = keep_original_sample self.extra_args = kwargs # prepare summarizer self._hf_summarizer = hf_summarizer if hf_summarizer else 'mrm8488/flan-t5-large-finetuned-openai-summarize_from_feedback' # noqa: E501 self.model_key = prepare_model( model_type='huggingface', pretrained_model_name_or_path=self._hf_summarizer, trust_remote_code=trust_remote_code) # prepare input texts ops if vid_cap_from_vid_args is None: vid_cap_from_vid_args = {} if vid_cap_from_frm_args is None: vid_cap_from_frm_args = {} if vid_tag_from_aud_args is None: vid_tag_from_aud_args = {} if vid_tag_from_frm_args is None: vid_tag_from_frm_args = {} self.FIXED_ARGS = { 'caption_num': 1, 'keep_candidate_mode': 'random_any', 'keep_original_sample': False, } self.cap_op_list = [] self.tag_op_list = [] if consider_video_caption_from_video: from .video_captioning_from_video_mapper import \ VideoCaptioningFromVideoMapper self.cap_op_list.append( VideoCaptioningFromVideoMapper(**self._prepare_op_args( VideoCaptioningFromVideoMapper, vid_cap_from_vid_args))) if consider_video_caption_from_audio: from .video_captioning_from_audio_mapper import \ VideoCaptioningFromAudioMapper self.cap_op_list.append( VideoCaptioningFromAudioMapper(**self._prepare_op_args( VideoCaptioningFromAudioMapper, {}))) if consider_video_caption_from_frames: from .video_captioning_from_frames_mapper import \ VideoCaptioningFromFramesMapper self.cap_op_list.append( VideoCaptioningFromFramesMapper(**self._prepare_op_args( VideoCaptioningFromFramesMapper, vid_cap_from_frm_args))) if consider_video_tags_from_audio: from .video_tagging_from_audio_mapper import \ VideoTaggingFromAudioMapper self.tag_op_list.append( VideoTaggingFromAudioMapper(**self._prepare_op_args( VideoTaggingFromAudioMapper, vid_tag_from_aud_args))) if consider_video_tags_from_frames: from .video_tagging_from_frames_mapper import \ VideoTaggingFromFramesMapper self.tag_op_list.append( VideoTaggingFromFramesMapper(**self._prepare_op_args( VideoTaggingFromFramesMapper, vid_tag_from_frm_args))) self.keep_tag_num = keep_tag_num
def _prepare_op_args(self, op_class, args_dict): required_args = set(op_class.__init__.__code__.co_varnames) args_dict.update(self.FIXED_ARGS) temp_args = copy.deepcopy(args_dict) for key in temp_args: if key not in required_args: args_dict.pop(key) args_dict['accelerator'] = self.accelerator return args_dict def _process_single_sample(self, sample, rank=None): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: return [] # there is no activated ops if len(self.cap_op_list) == 0 and len(self.tag_op_list) == 0: return [] # get paths of all video(s) loaded_video_keys = sample[self.video_key] # get models model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) captioned_sample = copy.deepcopy(sample) # generate for each video chunk by chunk captioned_texts = '' offset = 0 for chunk in sample[self.text_key].split(SpecialTokens.eoc): # skip empty chunks if not chunk.strip(): continue vid_count = chunk.count(SpecialTokens.video) if vid_count == 0: # add special tokens captioned_texts += f'{chunk}{SpecialTokens.eoc}' continue # make a temporary sample temp_sample = { self.text_key: chunk, self.video_key: loaded_video_keys[offset:offset + vid_count], } captioned_text_list = [] # tag ops for op in self.tag_op_list: temp_sample = op.process(temp_sample, rank=rank) if Fields.video_audio_tags in temp_sample: captioned_text_list.extend( temp_sample[Fields.video_audio_tags]) if Fields.video_frame_tags in temp_sample: for tag_list in temp_sample[Fields.video_frame_tags]: captioned_text_list.extend(tag_list[self.keep_tag_num]) # cap ops for op in self.cap_op_list: captioned_text_list.append( remove_special_tokens( op._process_single_sample(temp_sample, rank=rank)[0]['text'])) # summarization all_texts = ', '.join(captioned_text_list) input_ids = tokenizer(all_texts, return_tensors='pt').input_ids.to( model.device) outputs = model.generate(input_ids, max_new_tokens=128) summarized_text = tokenizer.decode(outputs[0], skip_special_tokens=True) offset += vid_count captioned_text = f'{SpecialTokens.video * vid_count} ' \ f'{summarized_text}' # add special tokens captioned_texts += f'{captioned_text}{SpecialTokens.eoc}' captioned_sample[self.text_key] = captioned_texts return [captioned_sample]
[docs] def process_batched(self, samples, rank=None): # reconstruct samples from "dict of lists" to "list of dicts" reconstructed_samples = [] for i in range(len(samples[self.text_key])): reconstructed_samples.append( {key: samples[key][i] for key in samples}) samples_after_split = [] # do split for each sample within the batch for ori_sample in reconstructed_samples: if self.keep_original_sample: samples_after_split.append(ori_sample) generated_samples = self._process_single_sample(ori_sample, rank=rank) if len(generated_samples) != 0: samples_after_split.extend(generated_samples) # reconstruct samples from "list of dicts" to "dict of lists" keys = samples_after_split[0].keys() res_samples = {} for key in keys: res_samples[key] = [s[key] for s in samples_after_split] return res_samples