[docs]@OPERATORS.register_module(OP_NAME)classPythonLambdaMapper(Mapper):"""Mapper for applying a Python lambda function to data samples. This operator allows users to define a custom transformation using a Python lambda function. The lambda function is applied to each sample, and the result must be a dictionary. If the `batched` parameter is set to True, the lambda function will process a batch of samples at once. If no lambda function is provided, the identity function is used, which returns the input sample unchanged. The operator validates the lambda function to ensure it has exactly one argument and compiles it safely."""
[docs]def__init__(self,lambda_str:str="",batched:bool=False,**kwargs):""" Initialization method. :param lambda_str: A string representation of the lambda function to be executed on data samples. If empty, the identity function is used. :param batched: A boolean indicating whether to process input data in batches. :param kwargs: Additional keyword arguments passed to the parent class. """self._batched_op=bool(batched)super().__init__(**kwargs)# Parse and validate the lambda functionifnotlambda_str:self.lambda_func=lambdasample:sampleelse:self.lambda_func=self._create_lambda(lambda_str)
def_create_lambda(self,lambda_str:str):# Parse input string into an AST and check for a valid lambda functiontry:node=ast.parse(lambda_str,mode="eval")# Check if the body of the expression is a lambdaifnotisinstance(node.body,ast.Lambda):raiseValueError("Input string must be a valid lambda function.")# Check that the lambda has exactly one argumentiflen(node.body.args.args)!=1:raiseValueError("Lambda function must have exactly one argument.")# Compile the AST to codecompiled_code=compile(node,"<string>","eval")# Safely evaluate the compiled code allowing built-in functionsfunc=eval(compiled_code,{"__builtins__":__builtins__})returnfuncexceptExceptionase:raiseValueError(f"Invalid lambda function: {e}")
[docs]defprocess_single(self,sample):# Process the input through the lambda function and return the resultresult=self.lambda_func(sample)# Check if the result is a validifnotisinstance(result,dict):raiseValueError(f"Lambda function must return a dictionary, "f"got {type(result).__name__} instead.")returnresult
[docs]defprocess_batched(self,samples):# Process the input through the lambda function and return the resultresult=self.lambda_func(samples)# Check if the result is a validifnotisinstance(result,dict):raiseValueError(f"Lambda function must return a dictionary, "f"got {type(result).__name__} instead.")returnresult