data_juicer.ops.mapper.download_file_mapper 源代码

import asyncio
import copy
import os
import os.path as osp
from typing import List, Union
from urllib.parse import urlparse

import aiohttp
from loguru import logger

from data_juicer.utils.file_utils import download_file, is_remote_path

from ..base_op import OPERATORS, Mapper

OP_NAME = "download_file_mapper"


[文档] @OPERATORS.register_module(OP_NAME) class DownloadFileMapper(Mapper): """Mapper to download url files to local files.""" _batched_op = True
[文档] def __init__( self, save_dir: str = None, download_field: str = None, timeout: int = 30, max_concurrent: int = 10, *args, **kwargs, ): """ Initialization method. :param save_dir: The directory to save downloaded files. :param download_field: The filed name to get the url to download. :param max_concurrent: Maximum concurrent downloads. :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) self.download_field = download_field self.save_dir = save_dir os.makedirs(self.save_dir, exist_ok=True) self.timeout = timeout self.max_concurrent = max_concurrent
[文档] def download_files_async(self, urls, save_dir, **kwargs): async def _download_file( session: aiohttp.ClientSession, semaphore: asyncio.Semaphore, idx: int, url: str, save_dir, **kwargs ) -> dict: try: filename = os.path.basename(urlparse(url).path) save_path = osp.join(save_dir, filename) status = "success" if os.path.exists(save_path): return idx, save_path, status, None async with semaphore: response = await download_file(session, url, save_path, timeout=self.timeout, **kwargs) except Exception as e: status = "failed" response = str(e) save_path = None return idx, save_path, status, response async def run_downloads(urls, save_dir, **kwargs): semaphore = asyncio.Semaphore(self.max_concurrent) async with aiohttp.ClientSession() as session: tasks = [ _download_file(session, semaphore, idx, url, save_dir, **kwargs) for idx, url in enumerate(urls) ] return await asyncio.gather(*tasks) results = asyncio.run(run_downloads(urls, save_dir, **kwargs)) results.sort(key=lambda x: x[0]) return results
[文档] def download_nested_urls(self, nested_urls: List[Union[str, List[str]]], save_dir: str): flat_urls = [] structure_info = [] # save as original index, sub index for idx, urls in enumerate(nested_urls): if isinstance(urls, list): for sub_idx, url in enumerate(urls): if is_remote_path(url): flat_urls.append(url) structure_info.append((idx, sub_idx)) else: if is_remote_path(urls): flat_urls.append(urls) structure_info.append((idx, -1)) # -1 means single str element download_results = self.download_files_async( flat_urls, save_dir, ) keep_failed_url = True if keep_failed_url: reconstructed = copy.deepcopy(nested_urls) else: reconstructed = [] for item in nested_urls: if isinstance(item, list): reconstructed.append([None] * len(item)) else: reconstructed.append(None) failed_info = "" for i, (idx, save_path, status, response) in enumerate(download_results): orig_idx, sub_idx = structure_info[i] if status != "success": save_path = flat_urls[i] failed_info += "\n" + str(response) # TODO: add download stats if sub_idx == -1: reconstructed[orig_idx] = save_path else: reconstructed[orig_idx][sub_idx] = save_path return reconstructed, failed_info
[文档] def process_batched(self, samples): if self.download_field not in samples or not samples[self.download_field]: return samples batch_nested_urls = samples[self.download_field] reconstructed, failed_info = self.download_nested_urls(batch_nested_urls, self.save_dir) samples[self.download_field] = reconstructed if len(failed_info): logger.error(f"Failed files:\n{failed_info}") return samples