data_juicer.ops.mapper.annotation.annotation_mapper 源代码

import time
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Dict, List, Optional

from loguru import logger

from data_juicer.utils.lazy_loader import LazyLoader

from ...base_op import Mapper
from ...mixins import EventDrivenMixin, NotificationMixin

label_studio_sdk = LazyLoader('label_studio_sdk')

# Common annotation event types
ANNOTATION_EVENTS = {
    'TASK_CREATED': 'task_created',
    'BATCH_CREATED': 'batch_created',
    'ANNOTATION_COMPLETED': 'annotation_completed',
    'BATCH_ANNOTATION_COMPLETED': 'batch_annotation_completed',
    'ERROR_OCCURRED': 'error_occurred'
}


[文档] class BaseAnnotationMapper(EventDrivenMixin, NotificationMixin, Mapper, ABC): """Base class for annotation operations with event-driven capabilities. This class provides functionality for creating annotation tasks, waiting for annotations to be completed, and handling notification events for various stages of the annotation process. It supports sending notifications with annotation platform URLs included, making it easy for annotators to directly access the annotation interface when they receive a notification. This is particularly useful for time-se- -nsitive annotation tasks or for coordinating with remote annotation teams. Notifications can be sent via email, Slack, or DingTalk, and can be configured to trigger on various events such as task creation, batch creation, annotation completion, and error conditions. URL inclusion in notifications can be enabled or disabled using the `include_urls_in_notifications` parameter. """ _batched_op = True # Mark this as a batched operator
[文档] def __init__(self, project_name_prefix: str = 'DataJuicer_Annotation', wait_for_annotations: bool = False, timeout: int = 3600, poll_interval: int = 60, samples_per_task: int = 1, max_tasks_per_batch: int = 100, project_id: Optional[int] = None, notification_config: Optional[Dict] = None, notification_events: Optional[Dict[str, bool]] = None, **kwargs): """Initialize the base annotation operation Args: project_name_prefix: Prefix for the project name project_id: ID of existing project (if None, creates new project) wait_for_annotations: Whether to wait for annotations to complete timeout: Maximum time to wait for annotations in seconds poll_interval: Time between annotation status checks in seconds samples_per_task: Number of samples in each annotation task max_tasks_per_batch: Maximum number of tasks in a single batch notification_config: Configuration for notifications (email, slack) notification_events: Events that should trigger notifications """ # Ensure notification_config is passed to kwargs for NotificationMixin kwargs['notification_config'] = notification_config or {} # Initialize parent classes super().__init__(**kwargs) # Store configuration self.project_name = project_name_prefix + '_' + str(uuid.uuid4())[:6] self.project_id = project_id self.wait_for_annotations = wait_for_annotations self.timeout = timeout self.poll_interval = poll_interval self.samples_per_task = samples_per_task self.max_tasks_per_batch = max_tasks_per_batch # Configure which events should trigger notifications self.notification_events = notification_events or { 'task_created': False, 'batch_created': True, 'annotation_completed': False, 'batch_annotation_completed': True, 'error_occurred': True } # Track task IDs and sample mappings self.sample_to_task_id = {} # Maps sample ID to task ID self.task_to_samples = defaultdict(list) # Maps task ID to sample IDs self.processed_annotations = set() self.batch_counter = 0 # Set up event handlers self._setup_event_handlers()
def _setup_event_handlers(self): """Set up default event handlers""" # Register event handlers with notification support self.register_event_handler(ANNOTATION_EVENTS['TASK_CREATED'], self._handle_task_created) self.register_event_handler(ANNOTATION_EVENTS['BATCH_CREATED'], self._handle_batch_created) self.register_event_handler(ANNOTATION_EVENTS['ANNOTATION_COMPLETED'], self._handle_annotation_completed) self.register_event_handler( ANNOTATION_EVENTS['BATCH_ANNOTATION_COMPLETED'], self._handle_batch_annotation_completed) self.register_event_handler(ANNOTATION_EVENTS['ERROR_OCCURRED'], self._handle_error) def _handle_task_created(self, data): """Handle task created event with notification""" task_id = data.get('task_id') sample_ids = data.get('sample_ids', []) logger.debug(f'Task {task_id} created with {len(sample_ids)} samples') # Send notification if configured if self.notification_events.get('task_created', False): self.send_notification( f'Task {task_id} created with {len(sample_ids)} ' f'samples for annotation in project {self.project_name}', subject=f'New Annotation Task Created - {self.project_name}') def _handle_batch_created(self, data): """Handle batch created event with notification""" batch_id = data.get('batch_id') task_count = data.get('task_count', 0) sample_count = data.get('sample_count', 0) logger.info( f'Batch {batch_id} created with {task_count} tasks containing ' f'{sample_count} samples') # Send notification if configured if self.notification_events.get('batch_created', False): # Build message with or without URL message = ( f'Batch {batch_id} created with {task_count} tasks containing ' f'{sample_count} samples in project {self.project_name}') project_url = self._get_project_url() if project_url: message += f'\n\nAnnotation URL: {project_url}' self.send_notification( message, subject=f'New Annotation Batch Created - {self.project_name}') def _handle_annotation_completed(self, data): """Handle annotation completed event with notification""" task_id = data.get('task_id') annotation_id = data.get('annotation_id') logger.debug( f'Annotation {annotation_id} completed for task {task_id}') # Mark this task as processed self.processed_annotations.add(task_id) # Send notification if configured if self.notification_events.get('annotation_completed', False): sample_count = len(self.task_to_samples.get(task_id, [])) # Build message with or without URL message = ( f'Annotation {annotation_id} completed for task {task_id} ' f'with {sample_count} samples in project {self.project_name}') # Add URL if configured to include it task_url = self._get_task_url(task_id) if task_url: message += f'\n\nTask URL: {task_url}' self.send_notification( message, subject=f'Annotation Completed - {self.project_name}') def _handle_batch_annotation_completed(self, data): """Handle batch annotation completed event with notification""" batch_id = data.get('batch_id') task_count = data.get('task_count', 0) sample_count = data.get('sample_count', 0) if self.notification_events.get('batch_annotation_completed', False): # Build message with or without URL message = ( f'Batch {batch_id} done with {task_count} tasks containing ' f'{sample_count} samples in project {self.project_name}') project_url = self._get_project_url() if project_url: message += f'\n\nProject URL: {project_url}' self.send_notification( message, subject=f'Batch Annotation Completed - {self.project_name}') def _handle_error(self, data): """Handle error event with notification""" error_message = data.get('message', 'Unknown error') task_id = data.get('task_id', 'Unknown task') logger.error(f'Error in annotation task {task_id}: {error_message}') # Send notification if configured if self.notification_events.get('error_occurred', False): self.send_notification( f'Error in annotation task {task_id}: {error_message}', subject=f'Annotation Error - {self.project_name}', notification_type='email' # Always send errors via email ) @abstractmethod def _create_tasks_batch(self, tasks_data: List[Dict], sample_ids: List[Any]) -> List[int]: """Create multiple tasks in the annotation platform Args: tasks_data: List of task data sample_ids: List of sample IDs corresponding to each task Returns: List[int]: List of created task IDs """ pass @abstractmethod def _format_task(self, samples: List[Dict]) -> Dict: """Format samples as an annotation task Args: samples: List of samples to include in the task Returns: Dict: Formatted task data """ pass @abstractmethod def _get_task_annotation(self, task_id: int) -> Optional[Dict]: """Get annotation for a task if available""" pass @abstractmethod def _process_annotation_result(self, annotation: Dict, sample: Dict) -> Dict: """Process annotation result and update the sample Args: annotation: The annotation result from the annotation platform sample: The original sample that was annotated Returns: Dict: The updated sample with annotation results """ pass @abstractmethod def _check_annotation_status(self, task_ids): """Check the status of annotations for the given task IDs. Args: task_ids: List of task IDs to check Returns: Tuple[bool, Dict]: (has_changes, completed_tasks_dict) - has_changes: If there are new annotations since last check - completed_tasks_dict: Dictionary for task IDs to annotations """ pass
[文档] def process_batched(self, samples): """Process a batch of samples by creating annotation tasks Args: samples: Dictionary of samples to process (column-oriented) Returns: Dict: Processed samples (column-oriented) """ # Get dimensions of the data keys = list(samples.keys()) num_samples = len(samples[keys[0]]) # Step 1: Convert to row-oriented format (list of sample dictionaries) sample_list = [{key: samples[key][i] for key in keys} for i in range(num_samples)] # Step 2: Retrieve or generate unique IDs for each sample # TODO: Use ID from database or metadata if not present if 'id' not in keys: sample_ids = [str(uuid.uuid4()) for _ in range(num_samples)] else: sample_ids = [sample['id'] for sample in sample_list] # Step 3: Process samples in batches for batch_start in range(0, num_samples, self.max_tasks_per_batch): batch_end = min(batch_start + self.max_tasks_per_batch, num_samples) # Prepare tasks for this batch tasks_data = [] task_sample_ids = [] for i in range(batch_start, batch_end, self.samples_per_task): end_idx = min(i + self.samples_per_task, batch_end) batch_samples = sample_list[i:end_idx] batch_ids = sample_ids[i:end_idx] # Format the samples as a task task_data = self._format_task(batch_samples) tasks_data.append(task_data) task_sample_ids.append(batch_ids) # Create and process this batch of tasks self._create_and_process_batch(tasks_data, task_sample_ids) # Step 4: Update samples with annotation results processed_count = 0 for i, sample_id in enumerate(sample_ids): if sample_id in self.sample_to_task_id: task_id = self.sample_to_task_id[sample_id] # If waiting for annotations and they're available, add them if (self.wait_for_annotations and task_id in self.processed_annotations): annotation = self._get_task_annotation(task_id) if annotation: # Process the annotation result sample_list[i] = self._process_annotation_result( annotation, sample_list[i]) processed_count += 1 # Step 5: Convert back to column-oriented format efficiently # Find all keys that exist in any sample all_keys = set().union(*(sample.keys() for sample in sample_list)) # Create the result dictionary with all keys result = {} for key in all_keys: # Use a list comprehension for better performance result[key] = [sample.get(key) for sample in sample_list] logger.info(f'Processed {num_samples} samples with {processed_count} ' 'annotations') return result
def _create_and_process_batch(self, tasks_data, task_sample_ids): """Create a batch of tasks and process the results""" if not tasks_data: return # Generate a batch ID batch_id = f'batch_{self.batch_counter}_{int(time.time())}' self.batch_counter += 1 # Flatten the sample IDs for this batch all_batch_sample_ids = [id for ids in task_sample_ids for id in ids] try: # Create the tasks in the annotation platform task_ids = self._create_tasks_batch(tasks_data, all_batch_sample_ids) # Map sample IDs to task IDs for task_id, sample_id_list in zip(task_ids, task_sample_ids): for sample_id in sample_id_list: self.sample_to_task_id[sample_id] = task_id self.task_to_samples[task_id] = sample_id_list # Trigger task created event self.trigger_event(ANNOTATION_EVENTS['TASK_CREATED'], { 'task_id': task_id, 'sample_ids': sample_id_list }) # Trigger batch created event self.trigger_event( ANNOTATION_EVENTS['BATCH_CREATED'], { 'batch_id': batch_id, 'task_count': len(task_ids), 'sample_count': len(all_batch_sample_ids) }) # If waiting for annotations, wait for them directly if self.wait_for_annotations: try: # Wait for all tasks in this batch to be annotated logger.info( f'Waiting for annotations for batch {batch_id}') completed_tasks = self._wait_for_batch_annotations( task_ids) logger.info( f'Completed {len(completed_tasks)}/{len(task_ids)} ' f'annotations for batch {batch_id}') except TimeoutError as e: # Trigger error event but continue processing self.trigger_event( ANNOTATION_EVENTS['ERROR_OCCURRED'], { 'message': (f'Timeout waiting for batch annotations: {str(e)}' ), 'batch_id': batch_id }) # Trigger batch created event self.trigger_event( ANNOTATION_EVENTS['BATCH_ANNOTATION_COMPLETED'], { 'batch_id': batch_id, 'task_count': len(task_ids), 'sample_count': len(all_batch_sample_ids) }) except Exception as e: # Trigger error event self.trigger_event( ANNOTATION_EVENTS['ERROR_OCCURRED'], { 'message': f'Failed to create batch: {str(e)}', 'batch_id': batch_id }) logger.error(f'Error creating batch {batch_id}: {str(e)}') def _wait_for_batch_annotations(self, task_ids=None): """Wait for all tasks in a batch to be annotated. Args: task_ids: List of task IDs to wait for Returns: Dict: Mapping of task IDs to their annotations """ if not self.wait_for_annotations or not task_ids: return {} logger.info(f'Waiting for annotations for {len(task_ids)} tasks') start_time = time.time() completed_tasks = {} task_id_set = set(task_ids) remaining_tasks = (task_id_set - set(completed_tasks.keys()) - self.processed_annotations) # Track the last time we saw a change in annotations last_change_time = time.time() # Use efficient polling with platform-specific status checks while time.time() - start_time < self.timeout and remaining_tasks: try: # Check for new annotations using the platform-specific method has_changes, new_completed_tasks = \ self._check_annotation_status(list(remaining_tasks)) # Update our completed tasks completed_tasks.update(new_completed_tasks) # Update remaining tasks and trigger events for completed tasks for task_id in new_completed_tasks: if task_id in remaining_tasks: remaining_tasks.remove(task_id) self.processed_annotations.add(task_id) # Trigger annotation completed event annotation = new_completed_tasks[task_id] self.trigger_event( ANNOTATION_EVENTS['ANNOTATION_COMPLETED'], { 'task_id': task_id, 'annotation_id': annotation.get( 'id', 'unknown'), 'annotation': annotation, 'sample_ids': self.task_to_samples.get( task_id, []) }) # If no changes for a while, do a full check poll_timeout = max(self.poll_interval * 5, 30) if not has_changes and time.time( ) - last_change_time > poll_timeout: logger.info('No new annotations detected for a while, ' 'performing full check') last_change_time = time.time() # Reset the timer elif has_changes: last_change_time = time.time() # Log progress periodically logger.info( f'Completed {len(completed_tasks)}/{len(task_ids)} ' f'annotations, {len(remaining_tasks)} remaining') # If all tasks are annotated, we're done if not remaining_tasks: return completed_tasks # Sleep before checking again time.sleep(self.poll_interval) except Exception as e: logger.error(f'Error polling for annotations: {e}') time.sleep(self.poll_interval) # If we get here, we timed out or completed all annotations if remaining_tasks: logger.warning( f'Timed out waiting for {len(remaining_tasks)} annotations') return completed_tasks def _get_project_url(self): """Get URL to the annotation project (abstract method) Returns: str: URL to the annotation project, or None if not available """ return None def _get_task_url(self, task_id): """Get URL to a specific annotation task (abstract method) Args: task_id: ID of the task Returns: str: URL to the task, or None if not available """ return None
[文档] class LabelStudioAnnotationMapper(BaseAnnotationMapper, ABC): """Operation for annotating data using Label Studio"""
[文档] def __init__(self, api_url: str = None, api_key: str = None, label_config: Optional[str] = None, **kwargs): """Initialize the Label Studio annotation operation Args: api_url: Base URL for Label Studio API api_key: API key for authentication label_config: XML configuration for the labeling interface **kwargs: Additional parameters passed to BaseAnnotationMapper """ # Initialize parent classes super().__init__(**kwargs) # Make sure samples_per_task is 1 # Label studio only supports 1 sample per task if self.samples_per_task != 1: logger.warning( 'Label Studio Annotation Mapper only supports 1 sample ' 'per task, but samples_per_task is set to ' f'{self.samples_per_task}. Setting samples_per_task to 1.') self.samples_per_task = 1 # Store Label Studio specific configuration self.api_url = api_url self.api_key = api_key self.label_config = label_config # Initialize Label Studio client try: self.client = label_studio_sdk.Client(url=self.api_url, api_key=self.api_key) logger.info(f'Connected to Label Studio at {self.api_url}') except ImportError: logger.error( 'Failed to import label_studio_sdk. Please install it with: ' 'pip install label-studio-sdk') raise ImportError( 'label-studio-sdk is required for LabelStudioAnnotationMapper') # Initialize project if self.project_id is None: self.project = self.setup_project() self.project_id = self.project.id else: try: self.project = self.client.get_project(self.project_id) logger.info( f'Using existing project with ID: {self.project_id}') except Exception as e: logger.error( f'Failed to get project with ID {self.project_id}: {e}') raise
[文档] def setup_project(self): """Create a new project or use existing one""" try: # Create new project logger.info( f'Creating new Label Studio project: {self.project_name}') project = self.client.create_project( title=self.project_name, description='Created by Data Juicer', label_config=self.label_config) logger.info( f'Created new Label Studio project with ID: {project.id}') return project except Exception as e: logger.error(f'Failed to create project: {e}') raise
def _create_tasks_batch(self, tasks_data: List[Dict], sample_ids: List[Any]) -> List[int]: """Create multiple tasks in Label Studio Args: tasks_data: List of task data sample_ids: List of sample IDs corresponding to each task Returns: List[int]: List of created task IDs """ try: # Create tasks in the project created_tasks = self.project.import_tasks(tasks_data) logger.debug(f'Created {len(created_tasks)} tasks in project ' f'{self.project_id}') return created_tasks except Exception as e: logger.error(f'Failed to create tasks: {e}') raise def _check_annotation_status(self, task_ids=None): """Check the status of annotations for the given task IDs Args: task_ids: List of task IDs to check. If None, uses all in batch Returns: Tuple[bool, Dict]: (has_changes, completed_tasks_dict) """ # Handle the case where task_ids is not provided if task_ids is None: task_ids = list(self.task_to_samples.keys()) if not task_ids: return False, {} # Initialize tracking variables has_changes = False completed_tasks = {} # Filter out tasks we already know are completed remaining_tasks = [ tid for tid in task_ids if tid not in self.processed_annotations ] # If all tasks are already processed, return immediately if not remaining_tasks: return False, {} logger.debug( f'Checking {len(remaining_tasks)} tasks (skipping ' f'{len(task_ids) - len(remaining_tasks)} already processed)') try: # Use filters to get completed tasks efficiently completed_task_data = self.project.get_tasks( filters={ 'conjunction': 'and', 'items': [{ 'filter': 'filter:tasks:completed_at', 'operator': 'empty', 'value': False, 'type': 'Datetime' }] }) # Process newly completed tasks for task in completed_task_data: task_id = task['id'] has_changes = True # Get the annotations for this task if 'annotations' in task and task['annotations']: # Use the latest annotation annotation = task['annotations'][-1] completed_tasks[task_id] = annotation else: # If task is marked as completed but has no annotations, # fetch the full task to get annotations full_task = self.project.get_task(task_id) if 'annotations' in full_task and full_task['annotations']: annotation = full_task['annotations'][-1] completed_tasks[task_id] = annotation else: # Task is completed but has no annotations (unusual) logger.warning( f'Task {task_id} is marked as completed but has ' 'no annotations') completed_tasks[task_id] = { 'id': 'no_annotation', 'result': [] } logger.debug(f'Task {task_id} is newly completed') # Update our class-level cache of processed annotations for task_id in completed_tasks: self.processed_annotations.add(task_id) except Exception as e: logger.error(f'Error checking annotation status: {e}') return has_changes, completed_tasks def _get_task_annotation(self, task_id: int) -> Optional[Dict]: """Get annotation for a task if available Args: task_id: ID of the task Returns: Optional[Dict]: Annotation data or None if not yet annotated """ try: # Get task with annotations task = self.project.get_task(task_id) # Check if task has annotations if task and 'annotations' in task and task['annotations']: return task['annotations'][0] # Return the first annotation return None except Exception as e: logger.error(f'Failed to get task annotation: {e}') return None
[文档] def get_all_annotations(self) -> Dict[int, Dict]: """Get all annotations for tasks created by this operation Returns: Dict[int, Dict]: Dictionary mapping task IDs to annotations """ task_ids = list(self.task_to_samples.keys()) annotations = {} for task_id in task_ids: annotation = self._get_task_annotation(task_id) if annotation: annotations[task_id] = annotation return annotations
# for pickling def __getstate__(self): """Control how the object is pickled""" state = self.__dict__.copy() # Remove unpicklable attributes if 'client' in state: del state['client'] # Remove Label Studio client if 'project' in state: del state['project'] # Remove project reference return state # for unpickling def __setstate__(self, state): """Control how the object is unpickled""" self.__dict__.update(state) # Reconnect to Label Studio if needed if hasattr(self, 'api_url') and hasattr(self, 'api_key'): try: self.client = label_studio_sdk.Client(url=self.api_url, api_key=self.api_key) if hasattr(self, 'project_id'): self.project = self.client.get_project(self.project_id) except ImportError: pass def _get_project_url(self): """Get URL to the Label Studio project Returns: str: URL to the Label Studio project, or None if not available """ if not hasattr(self, 'api_url') or not self.api_url or not self.project_id: return None # Remove trailing slash if present base_url = self.api_url.rstrip('/') # Label Studio web UI URL format (remove /api if present in the URL) if '/api' in base_url: base_url = base_url.replace('/api', '') return f'{base_url}/projects/{self.project_id}' def _get_task_url(self, task_id): """Get URL to a specific Label Studio task Args: task_id: ID of the task Returns: str: URL to the task in Label Studio, or None if not available """ if not hasattr(self, 'api_url') or not self.api_url or not self.project_id: return None # Remove trailing slash if present base_url = self.api_url.rstrip('/') # Label Studio web UI URL format (remove /api if present in the URL) if '/api' in base_url: base_url = base_url.replace('/api', '') return f'{base_url}/projects/{self.project_id}/data?task={task_id}'