Source code for data_juicer.utils.logger_utils

# Some codes here are adapted from

# Copyright 2021 Megvii, Base Detection
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import inspect
import os
import sys
from io import StringIO

from loguru import logger
from loguru._file_sink import FileSink
from tabulate import tabulate

from data_juicer.utils.file_utils import add_suffix_to_filename


[docs] def get_caller_name(depth=0): """ Get caller name by depth. :param depth: depth of caller context, use 0 for caller depth. :return: module name of the caller """ # the following logic is a little bit faster than inspect.stack() logic frame = inspect.currentframe().f_back for _ in range(depth): frame = frame.f_back return frame.f_globals['__name__']
[docs] class StreamToLoguru: """Stream object that redirects writes to a logger instance."""
[docs] def __init__(self, level='INFO', caller_names=('datasets', 'logging')): """ Initialization method. :param level: log level string of loguru. Default value: "INFO". :param caller_names: caller names of redirected module. Default value: (apex, pycocotools). """ self.level = level self.caller_names = caller_names self.buffer = StringIO() self.BUFFER_SIZE = 1024 * 1024
[docs] def write(self, buf): full_name = get_caller_name(depth=1) module_name = full_name.rsplit('.', maxsplit=-1)[0] self.buffer.write(buf) if module_name in self.caller_names: for line in buf.rstrip().splitlines(): # use caller level log logger.opt(depth=2).log(self.level, line.rstrip()) else: # sys.__stdout__.write(buf) logger.opt(raw=True).info(buf) self.buffer.truncate(self.BUFFER_SIZE)
[docs] def getvalue(self): return self.buffer.getvalue()
[docs] def flush(self): self.buffer.flush()
[docs] def redirect_sys_output(log_level='INFO'): """ Redirect stdout/stderr to loguru with log level. :param log_level: log level string of loguru. Default value: "INFO". """ redirect_logger = StreamToLoguru(level=log_level) sys.stderr = redirect_logger sys.stdout = redirect_logger
[docs] def get_log_file_path(): """ Get the path to the location of the log file. :return: a location of log file. """ for _, handler in logger._core.handlers.items(): if isinstance(handler._sink, FileSink): return
[docs] def setup_logger(save_dir, distributed_rank=0, filename='log.txt', mode='o', level='INFO', redirect=True): """ Setup logger for training and testing. :param save_dir: location to save log file :param distributed_rank: device rank when multi-gpu environment :param filename: log file name to save :param mode: log file write mode, `append` or `override`. default is `o`. :param level: log severity level. It's "INFO" in default. :param redirect: whether to redirect system output :return: logger instance. """ global LOGGER_SETUP if LOGGER_SETUP: return loguru_format = ( '<green>{time:YYYY-MM-DD HH:mm:ss}</green> | ' '<level>{level: <8}</level> | ' '<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>') logger.remove() save_file = os.path.join(save_dir, filename) if mode == 'o' and os.path.exists(save_file): os.remove(save_file) # only keep logger in rank0 process if distributed_rank == 0: logger.add( sys.stderr, format=loguru_format, level=level, enqueue=True, ) logger.add(save_file) # for interest of levels: debug, error, warning logger.add( add_suffix_to_filename(save_file, '_DEBUG'), level='DEBUG', filter=lambda x: 'DEBUG' == x['level'].name, format=loguru_format, enqueue=True, serialize=True, ) logger.add( add_suffix_to_filename(save_file, '_ERROR'), level='ERROR', filter=lambda x: 'ERROR' == x['level'].name, format=loguru_format, enqueue=True, serialize=True, ) logger.add( add_suffix_to_filename(save_file, '_WARNING'), level='WARNING', filter=lambda x: 'WARNING' == x['level'].name, format=loguru_format, enqueue=True, serialize=True, ) # redirect stdout/stderr to loguru if redirect: redirect_sys_output(level) LOGGER_SETUP = True
[docs] def make_log_summarization(max_show_item=10): error_pattern = r'^An error occurred in (.*?) when ' \ r'processing samples? \"(.*?)\" -- (.*?): (.*?)$' log_file = get_log_file_path() error_log_file = add_suffix_to_filename(log_file, '_ERROR') warning_log_file = add_suffix_to_filename(log_file, '_WARNING') import jsonlines as jl import regex as re # make error summarization error_dict = {} with as reader: for error_log in reader: error_msg = error_log['record']['message'] find_res = re.findall(error_pattern, error_msg) if len(find_res) > 0: op_name, sample, error_type, error_msg = find_res[0] error = (op_name, error_type, error_msg) error_dict.setdefault(error, 0) error_dict[error] += 1 total_error_count = sum(error_dict.values()) # make warning summarization warning_count = 0 with as reader: for _ in reader: warning_count += 1 # make summary log summary = f'Processing finished with:\n' \ f'<yellow>Warnings</yellow>: {warning_count}\n' \ f'<red>Errors</red>: {total_error_count}\n' error_items = list(error_dict.items()) error_items.sort(key=lambda it: it[1], reverse=True) error_items = error_items[:max_show_item] # convert error items to a table if len(error_items) > 0: error_table = [] table_header = [ 'OP/Method', 'Error Type', 'Error Message', 'Error Count' ] for key, num in error_items: op_name, error_type, error_msg = key error_table.append([op_name, error_type, error_msg, num]) table = tabulate(error_table, table_header, tablefmt='fancy_grid') summary += table summary += f'\nError/Warning details can be found in the log file ' \ f'[{log_file}] and its related log files.' logger.opt(ansi=True).info(summary)
[docs] class HiddenPrints: """Define a range that hide the outputs within this range.""" def __enter__(self): """ Store the original standard output and redirect the standard output to null when entering this range. """ self._original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') def __exit__(self, exc_type, exc_val, exc_tb): """ Close the redirected standard output and restore it when exiting from this range. """ sys.stdout.close() sys.stdout = self._original_stdout