import base64
import datetime
import io
import os
import re
import shutil
from typing import List, Optional, Tuple, Union
import av
import numpy as np
from datasets import Audio, Image
from loguru import logger
from pydantic import PositiveInt
from data_juicer.utils.constant import DEFAULT_PREFIX, Fields
from data_juicer.utils.file_utils import add_suffix_to_filename
from data_juicer.utils.lazy_loader import LazyLoader
cv2 = LazyLoader('cv2', 'cv2')
# suppress most warnings from av
av.logging.set_level(av.logging.PANIC)
# A class to keep special tokens for multimodal information in the texts
# The tokens in this class can be updated by corresponding arguments in config
[docs]
class SpecialTokens(object):
# modality
image = f'<{DEFAULT_PREFIX}image>'
audio = f'<{DEFAULT_PREFIX}audio>'
video = f'<{DEFAULT_PREFIX}video>'
# others
eoc = f'<|{DEFAULT_PREFIX}eoc|>'
AV_STREAM_THREAD_TYPE = 'AUTO'
"""
av stream thread type support "SLICE", "FRAME", "AUTO".
"SLICE": Decode more than one part of a single frame at once
"FRAME": Decode more than one frame at once
"AUTO": Using both "FRAME" and "SLICE"
AUTO is faster when there are no video latency.
"""
[docs]
def get_special_tokens():
special_token_dict = {
key: value
for key, value in SpecialTokens.__dict__.items()
if not key.startswith('__')
}
return special_token_dict
[docs]
def remove_special_tokens(text):
for value in get_special_tokens().values():
text = text.replace(value, '').strip()
return text
[docs]
def remove_non_special_tokens(text):
special_tokens = get_special_tokens().values()
patterns = '|'.join(re.escape(token) for token in special_tokens)
special_tokens_found = re.findall(patterns, text)
text_with_only_special_tokens = ''.join(special_tokens_found)
return text_with_only_special_tokens
[docs]
def load_data_with_context(sample, context, loaded_data_keys, load_func):
"""
The unified loading function with contexts for multimodal data.
"""
data = {}
for loaded_data_key in loaded_data_keys:
if context and loaded_data_key in sample[Fields.context]:
# load from context
data[loaded_data_key] = sample[Fields.context][loaded_data_key]
else:
if loaded_data_key not in data:
# avoid load the same data
data_item = load_func(loaded_data_key)
data[loaded_data_key] = data_item
if context:
# store the data into context
sample[Fields.context][loaded_data_key] = data_item
return sample, data
# Images
[docs]
def load_images(paths):
return [load_image(path) for path in paths]
[docs]
def load_images_byte(paths):
return [load_image_byte(path) for path in paths]
[docs]
def load_image(path):
img_feature = Image()
img = img_feature.decode_example(img_feature.encode_example(path))
img = img.convert('RGB')
return img
[docs]
def load_image_byte(path):
with open(path, 'rb') as image_file:
image_data = image_file.read()
return image_data
[docs]
def image_path_to_base64(image_path):
with open(image_path, 'rb') as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
[docs]
def image_byte_to_base64(image_byte):
return base64.b64encode(image_byte).decode('utf-8')
[docs]
def pil_to_opencv(pil_image):
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
numpy_image = np.array(pil_image)
# RGB to BGR
opencv_image = numpy_image[:, :, ::-1]
return opencv_image
[docs]
def detect_faces(image, detector, **extra_kwargs):
img = pil_to_opencv(image)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
dets = detector.detectMultiScale(gray, **extra_kwargs)
rectified_dets = []
for (x, y, w, h) in dets:
x = max(x, 0)
y = max(y, 0)
w = min(w, image.width - x)
h = min(h, image.height - y)
rectified_dets.append([x, y, w, h])
return rectified_dets
[docs]
def get_file_size(path):
import os
return os.path.getsize(path)
[docs]
def iou(box1, box2):
x1_min, y1_min, x1_max, y1_max = box1
x2_min, y2_min, x2_max, y2_max = box2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
ix_min = max(x1_min, x2_min)
ix_max = min(x1_max, x2_max)
iy_min = max(y1_min, y2_min)
iy_max = min(y1_max, y2_max)
intersection = max(0, (ix_max - ix_min) * (iy_max - iy_min))
union = area1 + area2 - intersection
return 1.0 * intersection / union
[docs]
def calculate_resized_dimensions(
original_size: Tuple[PositiveInt, PositiveInt],
target_size: Union[PositiveInt, Tuple[PositiveInt, PositiveInt]],
max_length: Optional[int] = None,
divisible: PositiveInt = 1) -> Tuple[int, int]:
"""
Resize dimensions based on specified constraints.
:param original_size: The original dimensions as (height, width).
:param target_size: Desired target size; can be a single integer
(short edge) or a tuple (height, width).
:param max_length: Maximum allowed length for the longer edge.
:param divisible: The number that the dimensions must be divisible by.
:return: Resized dimensions as (height, width).
"""
height, width = original_size
short_edge, long_edge = sorted((width, height))
# Normalize target_size to a tuple
if isinstance(target_size, int):
target_size = (target_size, )
# Initialize new dimensions
if target_size:
if len(target_size) == 1: # Only the smaller edge is specified
new_short_edge = target_size[0]
new_long_edge = int(new_short_edge * long_edge / short_edge)
else: # Both dimensions are specified
new_short_edge = min(target_size)
new_long_edge = max(target_size)
else: # No change
new_short_edge, new_long_edge = short_edge, long_edge
# Enforce maximum length constraint
if max_length is not None and new_long_edge > max_length:
scaling_factor = max_length / new_long_edge
new_short_edge = int(new_short_edge * scaling_factor)
new_long_edge = max_length
# Determine final dimensions based on original orientation
resized_dimensions = ((new_short_edge,
new_long_edge) if width <= height else
(new_long_edge, new_short_edge))
# Ensure final dimensions are divisible by the specified value
resized_dimensions = tuple(
int(dim / divisible) * divisible for dim in resized_dimensions)
return resized_dimensions
# Audios
[docs]
def load_audios(paths):
return [load_audio(path) for path in paths]
[docs]
def load_audio(path, sampling_rate=None):
aud_feature = Audio(sampling_rate)
aud = aud_feature.decode_example(aud_feature.encode_example(path))
return aud['array'], aud['sampling_rate']
# Videos
[docs]
def load_videos(paths):
return [load_video(path) for path in paths]
[docs]
def load_video(path, mode='r'):
"""
Load a video using its path.
:param path: the path to this video.
:param mode: the loading mode. It's "r" in default.
:return: a container object form PyAv library, which contains all streams
in this video (video/audio/...) and can be used to decode these streams
to frames.
"""
if not os.path.exists(path) and 'r' in mode:
raise FileNotFoundError(f'Video [{path}] does not exist!')
container = av.open(path, mode)
return container
[docs]
def get_video_duration(input_video: Union[str, av.container.InputContainer],
video_stream_index: int = 0):
"""
Get the video's duration from the container
:param input_video: the container object form PyAv library, which
contains all streams in this video (video/audio/...) and can be used
to decode these streams to frames.
:param video_stream_index: the video stream index to decode,
default set to 0.
:return: duration of the video in second
"""
if isinstance(input_video, str):
container = load_video(input_video)
elif isinstance(input_video, av.container.InputContainer):
container = input_video
else:
raise ValueError(f'Unsupported type of input_video. Should be one of '
f'[str, av.container.InputContainer], but given '
f'[{type(input_video)}].')
input_video_stream = container.streams.video[video_stream_index]
duration = input_video_stream.duration * input_video_stream.time_base
return float(duration)
[docs]
def get_decoded_frames_from_video(
input_video: Union[str, av.container.InputContainer],
video_stream_index: int = 0):
"""
Get the video's frames from the container
:param input_video: the container object form PyAv library, which
contains all streams in this video (video/audio/...) and can be used
to decode these streams to frames.
:param video_stream_index: the video stream index to decode,
default set to 0.
:return: an iterator of all the frames of the video
"""
if isinstance(input_video, str):
container = load_video(input_video)
elif isinstance(input_video, av.container.InputContainer):
container = input_video
stream = container.streams.video[video_stream_index]
# use "AUTO" thread_type for faster decode
stream.thread_type = AV_STREAM_THREAD_TYPE
return container.decode(stream)
[docs]
def cut_video_by_seconds(
input_video: Union[str, av.container.InputContainer],
output_video: str,
start_seconds: float,
end_seconds: Optional[float] = None,
):
"""
Cut a video into several segments by times in second.
:param input_video: the path to input video or the video container.
:param output_video: the path to output video.
:param start_seconds: the start time in second.
:param end_seconds: the end time in second. If it's None, this function
will cut the video from the start_seconds to the end of the video.
:return: a boolean flag indicating whether the video was successfully
cut or not.
"""
# open the original video
if isinstance(input_video, str):
container = load_video(input_video)
else:
container = input_video
# create the output video
if output_video:
output_container = load_video(output_video, 'w')
else:
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format='mp4')
# add the video stream into the output video according to input video
input_video_stream = container.streams.video[0]
codec_name = input_video_stream.codec_context.name
fps = input_video_stream.base_rate
output_video_stream = output_container.add_stream(codec_name, rate=fps)
output_video_stream.width = input_video_stream.codec_context.width
output_video_stream.height = input_video_stream.codec_context.height
output_video_stream.pix_fmt = input_video_stream.codec_context.pix_fmt
# add the audio stream into the output video with template of input audio
if len(container.streams.audio) == 0:
input_audio_stream = None
else:
input_audio_stream = container.streams.audio[0]
output_container.add_stream(template=input_audio_stream)
# seek to the start time, time must be in microsecond if no
# stream is specified
container.seek(int(start_seconds * 1000000),
any_frame=False,
backward=True)
# copy the video and audio streams until the end time
# NOTICE: for different streams, the time have to be converted to be
# in the corresponding time base.
video_at_the_end = False
# compute the start/end pts for video/audio streams
video_start_pts = int(start_seconds / input_video_stream.time_base)
video_end_pts = (end_seconds / input_video_stream.time_base
if end_seconds else input_video_stream.duration)
if input_audio_stream is not None:
audio_start_pts = int(start_seconds / input_audio_stream.time_base)
audio_end_pts = (end_seconds / input_audio_stream.time_base
if end_seconds else input_audio_stream.duration)
for packet in container.demux(input_video_stream, input_audio_stream):
if packet.stream.type == 'video':
for frame in packet.decode():
if frame.pts < video_start_pts:
continue
if frame.pts > video_end_pts:
# continue to check until the next P/I frame
if frame.pict_type in {'P', 'I'}:
video_at_the_end = True
break
continue
frame.pts -= video_start_pts # timestamp alignment
for inter_packet in output_video_stream.encode(frame):
output_container.mux(inter_packet)
elif packet.stream.type == 'audio':
if packet.pts is None or packet.dts is None:
continue
if packet.pts < audio_start_pts or packet.pts > audio_end_pts:
continue
packet.pts -= audio_start_pts
packet.dts -= audio_start_pts
output_container.mux(packet)
if video_at_the_end:
break
# flush all packets
for packet in output_video_stream.encode():
output_container.mux(packet)
# close the output videos
if isinstance(input_video, str):
close_video(container)
close_video(output_container)
if not output_video:
output_buffer.seek(0)
return output_buffer
if not os.path.exists(output_video):
logger.warning(f'This video could not be successfully cut in '
f'[{start_seconds}, {end_seconds}] seconds. '
f'Please set more accurate parameters.')
return os.path.exists(output_video)
[docs]
def process_each_frame(input_video: Union[str, av.container.InputContainer],
output_video: str, frame_func):
"""
Process each frame in video by replacing each frame by
`frame_func(frame)`.
:param input_video: the path to input video or the video container.
:param output_video: the path to output video.
:param frame_func: a function which inputs a frame and outputs another
frame.
"""
frame_modified = False
# open the original video
if isinstance(input_video, str):
container = load_video(input_video)
else:
container = input_video
# create the output video
output_container = load_video(output_video, 'w')
# add the audio stream into the output video with template of input audio
for input_audio_stream in container.streams.audio:
output_container.add_stream(template=input_audio_stream)
# add the video stream into the output video according to input video
for input_video_stream in container.streams.video:
# search from the beginning
container.seek(0, backward=False, any_frame=True)
codec_name = input_video_stream.codec_context.name
fps = input_video_stream.base_rate
output_video_stream = output_container.add_stream(codec_name, rate=fps)
output_video_stream.pix_fmt = input_video_stream.codec_context.pix_fmt
output_video_stream.width = input_video_stream.codec_context.width
output_video_stream.height = input_video_stream.codec_context.height
for packet in container.demux(input_video_stream):
for frame in packet.decode():
new_frame = frame_func(frame)
if new_frame != frame:
frame_modified = True
# for resize cases
output_video_stream.width = new_frame.width
output_video_stream.height = new_frame.height
for inter_packet in output_video_stream.encode(new_frame):
output_container.mux(inter_packet)
# flush all packets
for packet in output_video_stream.encode():
output_container.mux(packet)
# close the output videos
if isinstance(input_video, str):
close_video(container)
close_video(output_container)
if frame_modified:
return output_video
else:
shutil.rmtree(output_video, ignore_errors=True)
return (input_video
if isinstance(input_video, str) else input_video.name)
[docs]
def get_key_frame_seconds(input_video: Union[str,
av.container.InputContainer]):
"""
Get seconds of key frames in the input video.
"""
key_frames = extract_key_frames(input_video)
ts = [float(f.pts * f.time_base) for f in key_frames]
ts.sort()
return ts
# Others
[docs]
def size_to_bytes(size):
alphabets_list = [char for char in size if char.isalpha()]
numbers_list = [char for char in size if char.isdigit()]
if len(numbers_list) == 0:
raise ValueError(f'Your input `size` does not contain numbers: {size}')
size_numbers = int(float(''.join(numbers_list)))
if len(alphabets_list) == 0:
# by default, if users do not specify the units, the number will be
# regarded as in bytes
return size_numbers
suffix = ''.join(alphabets_list).lower()
if suffix == 'kb' or suffix == 'kib':
return size_numbers << 10
elif suffix == 'mb' or suffix == 'mib':
return size_numbers << 20
elif suffix == 'gb' or suffix == 'gib':
return size_numbers << 30
elif suffix == 'tb' or suffix == 'tib':
return size_numbers << 40
elif suffix == 'pb' or suffix == 'pib':
return size_numbers << 50
elif suffix == 'eb' or suffix == 'eib':
return size_numbers << 60
elif suffix == 'zb' or suffix == 'zib':
return size_numbers << 70
elif suffix == 'yb' or suffix == 'yib':
return size_numbers << 80
else:
raise ValueError(f'You specified unidentifiable unit: {suffix}, '
f'expected in [KB, MB, GB, TB, PB, EB, ZB, YB, '
f'KiB, MiB, GiB, TiB, PiB, EiB, ZiB, YiB], '
f'(case insensitive, counted by *Bytes*).')
[docs]
def insert_texts_after_placeholders(original_string,
placeholders,
new_texts,
delimiter_in_insert_pos=' '):
if len(placeholders) != len(new_texts):
raise ValueError(
'The number of placeholders and new_texts must be equal')
modified_string = original_string
for placeholder, new_text in zip(placeholders, new_texts):
# Find the index of the next occurrence of the placeholder
index = modified_string.find(placeholder)
if index == -1:
raise ValueError(
f"Placeholder '{placeholder}' not found in the string")
# Insert new_text at the found index position
modified_string = \
modified_string[:index + len(placeholder)] + \
delimiter_in_insert_pos + \
new_text + \
delimiter_in_insert_pos + \
modified_string[index + len(placeholder):]
return modified_string
[docs]
def timecode_string_to_seconds(timecode: str):
"""
Convert a timecode string to the float seconds.
:param timecode: the input timecode string. Must in "HH:MM:SS.fff(fff)"
format.
"""
# parse the timecode string
dt = datetime.datetime.strptime(timecode, '%H:%M:%S.%f')
# compute the start/end time in second
pts = dt.hour * 3600 + dt.minute * 60 + dt.second + dt.microsecond / 1e6
return pts
[docs]
def parse_string_to_roi(roi_string, roi_type='pixel'):
"""
Convert a roi string to four number x1, y1, x2, y2 stand for the region.
When the type is 'pixel', (x1, y1), (x2, y2) are the locations of pixels
in the top left corner and the bottom right corner respectively. If the
roi_type is 'ratio', the coordinates are normalized by wights and
heights.
:param roi_string: the roi string
:patam roi_type: the roi string type
return tuple of (x1, y1, x2, y2) if roi_string is valid, else None
"""
if not roi_string:
return None
pattern = r'^\s*[\[\(]?\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\s*[\]\)]?\s*$' # noqa: E501
match = re.match(pattern, roi_string)
if match:
if roi_type == 'pixel':
return tuple(int(num) for num in match.groups())
elif roi_type == 'ratio':
return tuple(min(1.0, float(num)) for num in match.groups())
else:
logger.warning('The roi_type must be "pixel" or "ratio".')
return None
else:
logger.warning(
'The roi_string must be four no negative numbers in the '
'format of "x1, y1, x2, y2", "(x1, y1, x2, y2)", or '
'"[x1, y1, x2, y2]".')
return None
[docs]
def close_video(container: av.container.InputContainer):
"""
Close the video stream and container to avoid memory leak.
:param container: the video container.
"""
for video_stream in container.streams.video:
video_stream.close(strict=False)
container.close()