Source code for data_juicer.ops.mapper.python_lambda_mapper

import ast

from ..base_op import OPERATORS, Mapper

OP_NAME = "python_lambda_mapper"


[docs] @OPERATORS.register_module(OP_NAME) class PythonLambdaMapper(Mapper): """Mapper for applying a Python lambda function to data samples. This operator allows users to define a custom transformation using a Python lambda function. The lambda function is applied to each sample, and the result must be a dictionary. If the `batched` parameter is set to True, the lambda function will process a batch of samples at once. If no lambda function is provided, the identity function is used, which returns the input sample unchanged. The operator validates the lambda function to ensure it has exactly one argument and compiles it safely."""
[docs] def __init__(self, lambda_str: str = "", batched: bool = False, **kwargs): """ Initialization method. :param lambda_str: A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. :param batched: A boolean indicating whether to process input data in batches. :param kwargs: Additional keyword arguments passed to the parent class. """ self._batched_op = bool(batched) super().__init__(**kwargs) # Parse and validate the lambda function if not lambda_str: self.lambda_func = lambda sample: sample else: self.lambda_func = self._create_lambda(lambda_str)
def _create_lambda(self, lambda_str: str): # Parse input string into an AST and check for a valid lambda function try: node = ast.parse(lambda_str, mode="eval") # Check if the body of the expression is a lambda if not isinstance(node.body, ast.Lambda): raise ValueError("Input string must be a valid lambda function.") # Check that the lambda has exactly one argument if len(node.body.args.args) != 1: raise ValueError("Lambda function must have exactly one argument.") # Compile the AST to code compiled_code = compile(node, "<string>", "eval") # Safely evaluate the compiled code allowing built-in functions func = eval(compiled_code, {"__builtins__": __builtins__}) return func except Exception as e: raise ValueError(f"Invalid lambda function: {e}")
[docs] def process_single(self, sample): # Process the input through the lambda function and return the result result = self.lambda_func(sample) # Check if the result is a valid if not isinstance(result, dict): raise ValueError(f"Lambda function must return a dictionary, " f"got {type(result).__name__} instead.") return result
[docs] def process_batched(self, samples): # Process the input through the lambda function and return the result result = self.lambda_func(samples) # Check if the result is a valid if not isinstance(result, dict): raise ValueError(f"Lambda function must return a dictionary, " f"got {type(result).__name__} instead.") return result