Source code for data_juicer.ops.mapper.download_file_mapper

import asyncio
import copy
import os
import os.path as osp
from typing import List, Union
from urllib.parse import urlparse

import aiohttp
from loguru import logger

from data_juicer.utils.file_utils import download_file, is_remote_path

from ..base_op import OPERATORS, Mapper

OP_NAME = "download_file_mapper"


[docs] @OPERATORS.register_module(OP_NAME) class DownloadFileMapper(Mapper): """Mapper to download url files to local files or load them into memory.""" _batched_op = True
[docs] def __init__( self, download_field: str = None, save_dir: str = None, save_field: str = None, resume_download: bool = False, timeout: int = 30, max_concurrent: int = 10, *args, **kwargs, ): """ Initialization method. :param save_dir: The directory to save downloaded files. :param download_field: The filed name to get the url to download. :param save_field: The filed name to save the downloaded file content. :param resume_download: Whether to resume download. if True, skip the sample if it exists. :param max_concurrent: Maximum concurrent downloads. :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) self.download_field = download_field self.save_dir = save_dir self.save_field = save_field self.resume_download = resume_download if not (self.save_dir or self.save_field): logger.warning( "Both `save_dir` and `save_field` are not specified. Use the default `image_bytes` key to " "save the downloaded contents." ) self.save_field = self.image_bytes_key if self.save_dir: os.makedirs(self.save_dir, exist_ok=True) self.timeout = timeout self.max_concurrent = max_concurrent
[docs] def download_files_async(self, urls, return_contents, save_dir=None, **kwargs): async def _download_file( session: aiohttp.ClientSession, semaphore: asyncio.Semaphore, idx: int, url: str, save_dir=None, return_content=False, **kwargs, ) -> dict: try: status, response, content, save_path = "success", None, None, None # local file if not is_remote_path(url): if return_content: with open(url, "rb") as f: content = f.read() if save_dir: save_path = url return idx, save_path, status, response, content # skip already downloaded files if not save_dir and not return_content: return idx, save_path, status, response, content if save_dir: filename = os.path.basename(urlparse(url).path) save_path = osp.join(save_dir, filename) if os.path.exists(save_path): if return_content: with open(save_path, "rb") as f: content = f.read() return idx, save_path, status, response, content async with semaphore: response, content = await download_file( session, url, save_path, return_content=True, timeout=self.timeout, **kwargs ) except Exception as e: status = "failed" response = str(e) save_path = None content = None return idx, save_path, status, response, content async def run_downloads(urls, return_contents, save_dir=None, **kwargs): semaphore = asyncio.Semaphore(self.max_concurrent) async with aiohttp.ClientSession() as session: tasks = [ _download_file(session, semaphore, idx, url, save_dir, return_contents[idx], **kwargs) for idx, url in enumerate(urls) ] return await asyncio.gather(*tasks) results = asyncio.run(run_downloads(urls, return_contents, save_dir, **kwargs)) results.sort(key=lambda x: x[0]) return results
def _flat_urls(self, nested_urls): flat_urls = [] structure_info = [] # save as original index, sub index for idx, urls in enumerate(nested_urls): if isinstance(urls, list): for sub_idx, url in enumerate(urls): flat_urls.append(url) structure_info.append((idx, sub_idx)) else: flat_urls.append(urls) structure_info.append((idx, -1)) # -1 means single str element return flat_urls, structure_info def _create_path_struct(self, nested_urls, keep_failed_url=True) -> str: if keep_failed_url: reconstructed = copy.deepcopy(nested_urls) else: reconstructed = [] for item in nested_urls: if isinstance(item, list): reconstructed.append([None] * len(item)) else: reconstructed.append(None) return reconstructed def _create_save_field_struct(self, nested_urls, save_field_contents=None) -> str: if save_field_contents is None: save_field_contents = [] for item in nested_urls: if isinstance(item, list): save_field_contents.append([None] * len(item)) else: save_field_contents.append(None) else: # check whether the save_field_contents format is correct and correct it automatically for i, item in enumerate(nested_urls): if isinstance(item, list): if not save_field_contents[i] or len(save_field_contents[i]) != len(item): save_field_contents[i] = [None] * len(item) return save_field_contents
[docs] def download_nested_urls(self, nested_urls: List[Union[str, List[str]]], save_dir=None, save_field_contents=None): flat_urls, structure_info = self._flat_urls(nested_urls) if save_field_contents is None: # not save contents, set return_contents to False return_contents = [False] * len(flat_urls) else: # if original content None, set bool value to True to get content else False to skip reload it return_contents = [not c for sublist in save_field_contents for c in sublist] download_results = self.download_files_async( flat_urls, return_contents, save_dir, ) if self.save_dir: reconstructed_path = self._create_path_struct(nested_urls) else: reconstructed_path = None failed_info = "" for i, (idx, save_path, status, response, content) in enumerate(download_results): orig_idx, sub_idx = structure_info[i] if status != "success": save_path = flat_urls[i] failed_info += "\n" + str(response) if save_field_contents is not None: if return_contents[i]: if sub_idx == -1: save_field_contents[orig_idx] = content else: save_field_contents[orig_idx][sub_idx] = content if self.save_dir: # TODO: add download stats if sub_idx == -1: reconstructed_path[orig_idx] = save_path else: reconstructed_path[orig_idx][sub_idx] = save_path return save_field_contents, reconstructed_path, failed_info
[docs] def process_batched(self, samples): if self.download_field not in samples or not samples[self.download_field]: return samples batch_nested_urls = samples[self.download_field] if self.save_field: if not self.resume_download: if self.save_field in samples: raise ValueError( f"{self.save_field} is already in samples. '\ 'If you want to resume download, please set `resume_download=True`" ) save_field_contents = self._create_save_field_struct(batch_nested_urls) else: if self.save_field not in samples: save_field_contents = self._create_save_field_struct(batch_nested_urls) else: save_field_contents = self._create_save_field_struct(batch_nested_urls, samples[self.save_field]) else: save_field_contents = None save_field_contents, reconstructed_path, failed_info = self.download_nested_urls( batch_nested_urls, save_dir=self.save_dir, save_field_contents=save_field_contents ) if self.save_dir: samples[self.download_field] = reconstructed_path if self.save_field: samples[self.save_field] = save_field_contents if len(failed_info): logger.error(f"Failed files:\n{failed_info}") return samples