fromtypingimportListimportnumpyasnpfromloguruimportloggerfromdata_juicer.ops.base_opimportOP,OPERATORS,Filter,Mapperfromdata_juicer.ops.loadimportload_opsfromdata_juicer.utils.constantimportFields,InterVarsfromdata_juicer.utils.registryimportRegistry# Type of intermediate vars# textINTER_LINES=Registry(InterVars.lines)INTER_WORDS=Registry(InterVars.words)# imagesLOADED_IMAGES=Registry(InterVars.loaded_images)# audiosLOADED_AUDIOS=Registry(InterVars.loaded_audios)# videosLOADED_VIDEOS=Registry(InterVars.loaded_videos)INTER_SAMPLED_FRAMES=Registry(InterVars.sampled_frames)# allALL_INTER_VARS=[INTER_LINES,INTER_WORDS,LOADED_AUDIOS,LOADED_IMAGES,LOADED_VIDEOS,INTER_SAMPLED_FRAMES]# supported fusion strategiesFUSION_STRATEGIES={"greedy","probe"}
[文档]deffuse_operators(ops,probe_res=None):""" Fuse the input ops list and return the fused ops list. :param ops: the corresponding list of op objects. :param probe_res: the probed speed for each OP from Monitor. :return: a list of fused op objects. """ifprobe_resisNone:probe_res=[Nonefor_inrange(len(ops))]# detect filter groups and try to fuse themfused_ops=[]filter_group=[]forop,op_probeinzip(ops,probe_res):ifisinstance(op,Filter):filter_group.append((op,op_probe))else:iffilter_group:# got a filter group, try to fuse themfused_ops.extend(fuse_filter_group(filter_group))filter_group=[]# and add the current non-filter op into fused_opsfused_ops.append(op)# the final filter group, try to fuse themiffilter_group:fused_ops.extend(fuse_filter_group(filter_group))returnfused_ops
[文档]deffuse_filter_group(original_filter_group):""" Fuse single filter group and return the fused filter group. :param original_filter_group: the original filter group, including op definitions and objects. :return: the fused definitions and objects of the input filter group. """fused_group=[]group_speed=[]all_intermediate_vars=ALL_INTER_VARSall_fused_filters={inter_vars:[]forinter_varsinall_intermediate_vars}# group these filters by their intermediate varsforop,probe_resinoriginal_filter_group:op_name=op._nameforinter_varsinall_intermediate_vars:ifop_nameininter_vars.modules:all_fused_filters[inter_vars].append((op,probe_res))breakelse:# first apply other filters to decrease the number of samples, so# we add them into the fused_group list directlyfused_group.append(op)group_speed.append(probe_res["speed"]ifprobe_reselse0)# try to fuse ops for each type of intermediate varsforinter_varsinall_intermediate_vars:inter_vars_filter=all_fused_filters[inter_vars]iflen(inter_vars_filter)==0:# no ops include this type of intermediate varpasseliflen(inter_vars_filter)>1:# more than 1 ops share the same intermediate var, try to fuse themops,probe_res_list=zip(*inter_vars_filter)# new definition: new name and a definition list of fused op listfused_filter_name="OpFusion:(%s)"%",".join([op._nameforopinops])logger.info(f"Ops are fused into one op "f"{fused_filter_name}.")# use these ops to create a FusedFilter object, and add the fused# definition and op into the fused groupfused_filter=FusedFilter(fused_filter_name,ops)fused_filter._op_cfg={fused_filter_name:[op._op_cfgforopinops]}fused_filter_speed=sum([1.0/probe_res["speed"]forprobe_resinprobe_res_listifprobe_res])iffused_filter_speed>0:fused_filter_speed=1.0/fused_filter_speedfused_group.append(fused_filter)group_speed.append(fused_filter_speed)else:# only 1 op for this type of intermediate var, add it to the fused# group directly without fusionfused_group.append(inter_vars_filter[0][0])probe_res=inter_vars_filter[0][1]group_speed.append(probe_res["speed"]ifprobe_reselse0)# reorder according to the probed speed results in group_speed# 'greedy': all speed data in group_speed will be 0, which will keep the# current order of fused group# 'probe': OPs in fused group will be reordered according to the speed data# in group_speed in descending orderfused_group=[opforop,_insorted(zip(fused_group,group_speed),key=lambdait:it[1],reverse=True)]returnfused_group
[文档]classFusedFilter(Filter):"""A fused operator for filters."""_batched_op=True
[文档]def__init__(self,name:str,fused_filters:List):""" Initialization method. :param fused_filters: a list of filters to be fused. """self._name=namesuper().__init__()self.fused_filters=fused_filters# set accelerator to 'cuda' if there exists any ops whose accelerator# is 'cuda'accelerator_methods=set([op.acceleratorforopinself.fused_filters])if"cuda"inaccelerator_methods:self.accelerator="cuda"# update num_proc with the min num_proc of all fusible filtersself.num_proc=min([op.runtime_np()foropinself.fused_filters])
[文档]defcompute_stats_batched(self,samples,rank=None):importav# context for the intermediate varsnum_samples=len(samples[Fields.stats])samples[Fields.context]=[{}for_inrange(num_samples)]foropinself.fused_filters:# open the context for these fused opsifop.accelerator=="cuda":samples=op.compute_stats_batched(samples,rank=rank,context=True)else:samples=op.compute_stats_batched(samples,context=True)# clean up the contexts after processing# check if there are containers that need to be closedforctxinsamples[Fields.context]:forcontext_keyinctx:ifisinstance(ctx[context_key],av.container.InputContainer):ctx[context_key].streams.video[0].close()ctx[context_key].close()_=samples.pop(Fields.context)returnsamples
[文档]defprocess_batched(self,samples):# Only return True when all filters return Trueres=Noneforopinself.fused_filters:this_res=np.array(list(op.process_batched(samples)))ifresisnotNone:res=np.logical_and(res,this_res)else:res=this_resreturnres
[文档]@OPERATORS.register_module("general_fused_op")classGeneralFusedOP(OP):"""An explicitly fused operator designed to execute multiple sequential operations (OPs) on the same batch, enabling fine-grained control over data processing."""_batched_op=True
[文档]def__init__(self,batch_size:int=1,fused_op_list:List=None,*args,**kwargs):super().__init__(*args,**kwargs)self.batch_size=batch_sizeiffused_op_listisNone:fused_op_list=[]self.fused_ops=load_ops(fused_op_list)self._name="GeneralFusedOP:(%s)"%",".join([op._nameforopinself.fused_ops])# set accelerator to 'cuda' if there exists any ops whose accelerator# is 'cuda'accelerator_methods=set([op.acceleratorforopinself.fused_ops])if"cuda"inaccelerator_methods:self.accelerator="cuda"# update num_proc with the min num_proc of all fusible filtersself.num_proc=min([op.runtime_np()foropinself.fused_ops])ifself.fused_opselse1
[文档]defprocess_batched(self,samples,rank=None):foropinself.fused_ops:process_args={"rank":rank}ifop.accelerator=="cuda"else{}ifisinstance(op,Mapper):samples=op.process_batched(samples,**process_args)elifisinstance(op,Filter):samples=op.compute_stats_batched(samples,**process_args)indicators=list(op.process_batched(samples))new_samples={}forkeyinsamples:new_samples[key]=[valforval,indicatorinzip(samples[key],indicators)ifindicator]samples=new_sampleselse:raiseNotImplementedError(f"FusedOP does not support OP {op._name} of type "f"{type(op)} and only supports Mapper and Filter now.")returnsamples
[文档]defrun(self,dataset,*,exporter=None,tracer=None):# prepare the datasetfromdata_juicer.core.dataimportNestedDatasetifnotisinstance(dataset,NestedDataset):dataset=NestedDataset(dataset)ifnotself.fused_ops:returndataset# initialize for different kinds of datasetsforopinself.fused_ops:dataset=OP.run(op,dataset)new_dataset=dataset.map(self.process_batched,num_proc=self.num_proc,with_rank=self.use_cuda(),batch_size=self.batch_size,desc=self._name+"_process",)returnnew_dataset