funasr.register1
funasr.models39
bat
branchformer
conformer
conformer_rwkv
e_branchformer
sanm
funasr.utils15
datadir_writer
export_utils
install_model_requirements
funasr.bin6
compute_audio_cmvn
export
inference
train
train_ds

API Reference

250 entries · auto-generated from source
Click source code to expand. Links point to latest GitHub code.

function

is_npu_available()

funasr.auto.auto_model · View on GitHub ↗

检查NPU是否可用。

📄 Source code
def is_npu_available():
    """检查NPU是否可用。"""
    try:
        import torch_npu

        return torch_npu.npu.is_available()
    except ImportError:
        return False
function

prepare_data_iterator(data_in, input_len, data_type, key)

funasr.auto.auto_model · View on GitHub ↗

No documentation yet.

📄 Source code
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
    """ """
    data_list = []
    key_list = []
    filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]

    chars = string.ascii_letters + string.digits
    if isinstance(data_in, str):
        if data_in.startswith("http://") or data_in.startswith("https://"):  # url
            data_in = download_from_url(data_in)

    if isinstance(data_in, str) and os.path.exists(
        data_in
    ):  # wav_path; filelist: wav.scp, file.jsonl;text.txt;
        _, file_extension = os.path.splitext(data_in)
        file_extension = file_extension.lower()
        if file_extension in filelist:  # filelist: wav.scp, file.jsonl;text.txt;
            with open(data_in, encoding="utf-8") as fin:
                for line in fin:
                    key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                    if data_in.endswith(".jsonl"):  # file.jsonl: json.dumps({"source": data})
                        lines = json.loads(line.strip())
                        data = lines["source"]
                        key = lines.get("key", key)
                    else:  # filelist, wav.scp, text.txt: id \t data or data
                        lines = line.strip().split(maxsplit=1)
                        data = lines[1] if len(lines) > 1 else lines[0]
                        key = lines[0] if len(lines) > 1 else key

                    data_list.append(data)
View full source on GitHub →
class

AutoModel

funasr.auto.auto_model · View on GitHub ↗

No documentation yet.

📄 Source code
class AutoModel:

    def __init__(self, **kwargs):
        """Initialize AutoModel with ASR model and optional sub-models.

        Args:
            model (str): Model name (hub alias or full ID) or local path.
            device (str): Device for inference. "cuda:0", "cpu", "mps", "npu:0".
                Falls back to CPU if specified device is unavailable.
            vad_model (str, optional): VAD model for long audio segmentation.
                Enables processing of any-length audio.
            vad_kwargs (dict, optional): VAD config, e.g. {"max_single_segment_time": 60000}.
            punc_model (str, optional): Punctuation restoration model.
                Not needed for Fun-ASR-Nano/SenseVoice/Qwen3-ASR (they output punctuation natively).
            spk_model (str, optional): Speaker model for diarization ("cam++" or full model ID).
                Requires vad_model. For Qwen3-ASR, also requires forced_aligner.
            spk_mode (str, optional): Speaker diarization mode. "punc_segment" (default) or "vad_segment".
            hub (str): Model hub. "ms" (ModelScope, default) or "hf" (HuggingFace).
            ncpu (int): CPU threads (default: 4).
            disable_update (bool): Skip version check on startup.
            disable_pbar (bool): Disable tqdm progress bars.
            **kwargs: Additional model-specific parameters (passed to config.yaml overrides).

        Examples:
            >>> model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc")
            >>> model = AutoModel(model="FunAudioLLM/Fun-ASR-Nano-2512", trust_remote_code=True,
            ...                   remote_code="./model.py", vad_model="fsmn-vad", spk_model="cam++", hub="hf")
        """
        try:
            from funasr.utils.version_checker import check_for_update
View full source on GitHub →

Methods

.build_model(**kwargs) L275

Download model from hub, build all components, and load pretrained weights.

This method handles the full model construction pipeline:

1. Download model files from ModelScope/HuggingFace (if not local)

2. Parse config.yaml to determine model class, tokenizer, frontend

3. Instantiate tokenizer, frontend, and model via the registry

4. Load pretrained weights from model.pt

Args:

  • **kwargs — Must include 'model' (str). All other config.yaml fields can be overridden.

Returns:

  • tuple — (model, kwargs) where model is the instantiated nn.Module and
  • kwargs contains the resolved configuration.
📄 Source
    def build_model(**kwargs):
        """Download model from hub, build all components, and load pretrained weights.

        This method handles the full model construction pipeline:
        1. Download model files from ModelScope/HuggingFace (if not local)
        2. Parse config.yaml to determine model class, tokenizer, frontend
        3. Instantiate tokenizer, frontend, and model via the registry
        4. Load pretrained weights from model.pt

        Args:
            **kwargs: Must include 'model' (str). All other config.yaml fields can be overridden.

        Returns:
            tuple: (model, kwargs) where model is the instantiated nn.Module and
                kwargs contains the resolved configuration.
        """
        assert "model" in kwargs
        if "model_conf" not in kwargs:
            logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
            kwargs = download_model(**kwargs)

        set_all_random_seed(kwargs.get("seed", 0))

        device = kwargs.get("device", "cuda")
        if (
            (device.startswith("cuda") and not torch.cuda.is_available())
            or (device.startswith("xpu") and not torch.xpu.is_available())
            or (device.startswith("mps") and not torch.backends.mps.is_available())
            or (device.startswith("npu") and not is_npu_available())
            or kwargs.get("ngpu", 1) == 0
View full source →
.generate(input, input_len, progress_callback, **cfg) L436

Run speech recognition on input audio.

This is the primary user-facing method. It automatically routes to:

  • inference() if no vad_model is configured (single utterance)
  • inference_with_vad() if vad_model is configured (long audio with segmentation)

Args:

  • input — Audio input. Accepts:
  • File path (str): "audio.wav", "audio.mp3"
  • URL (str): "https://..."
  • numpy array: raw audio samples (float32, 16kHz)
  • list: batch of file paths or arrays
  • bytes: raw audio bytes
  • input_len (tensor, optional) — Length of each input sample.
  • progress_callback (callable, optional) — fn(current, total) called during processing.
  • **cfg — Runtime parameters:
  • cache (dict): State cache for streaming mode. Pass {} for first call.
  • hotword (str/list): Keywords to boost recognition accuracy.
  • language (str): Language hint ("auto", "zh", "en", "Chinese", etc.)
  • batch_size_s (int): Dynamic batch total duration in seconds.
  • is_final (bool): Last chunk flag for streaming mode.
  • return_spk_res (bool): Return speaker diarization results.
  • sentence_timestamp (bool): Return sentence-level timestamps.
  • use_itn (bool): Apply inverse text normalization (SenseVoice).

Returns:

list[dict]: Results for each input sample. Common fields:

  • "key" (str): Sample identifier
  • "text" (str): Recognized text
  • "timestamp" (list): [[start_ms, end_ms], ...] per character/word
  • "sentence_info" (list): [{text, start, end, spk, timestamp}, ...] when spk enabled
📄 Source
    def generate(self, input, input_len=None, progress_callback=None, **cfg):
        """Run speech recognition on input audio.

        This is the primary user-facing method. It automatically routes to:
        - inference() if no vad_model is configured (single utterance)
        - inference_with_vad() if vad_model is configured (long audio with segmentation)

        Args:
            input: Audio input. Accepts:
                - File path (str): "audio.wav", "audio.mp3"
                - URL (str): "https://..."
                - numpy array: raw audio samples (float32, 16kHz)
                - list: batch of file paths or arrays
                - bytes: raw audio bytes
            input_len (tensor, optional): Length of each input sample.
            progress_callback (callable, optional): fn(current, total) called during processing.
            **cfg: Runtime parameters:
                - cache (dict): State cache for streaming mode. Pass {} for first call.
                - hotword (str/list): Keywords to boost recognition accuracy.
                - language (str): Language hint ("auto", "zh", "en", "Chinese", etc.)
                - batch_size_s (int): Dynamic batch total duration in seconds.
                - is_final (bool): Last chunk flag for streaming mode.
                - return_spk_res (bool): Return speaker diarization results.
                - sentence_timestamp (bool): Return sentence-level timestamps.
                - use_itn (bool): Apply inverse text normalization (SenseVoice).

        Returns:
            list[dict]: Results for each input sample. Common fields:
                - "key" (str): Sample identifier
                - "text" (str): Recognized text
View full source →
.inference(input, input_len, model, kwargs, key, progress_callback, **cfg) L490

Run model inference on input data (internal method).

Handles batching, timing, and progress reporting. Called by generate()

and inference_with_vad(). Typically not called directly by users.

Args:

  • input — Audio data, file path, or text (for punc model).
  • input_len (tensor, optional) — Input lengths for batch.
  • model (nn.Module, optional) — Override model (used for VAD/PUNC/SPK sub-models).
  • kwargs (dict, optional) — Override kwargs (used for sub-model configs).
  • key (list, optional) — Sample identifiers.
  • progress_callback (callable, optional) — Progress reporting function.
  • **cfg — Additional config merged into kwargs.

Returns:

list[dict]: Model inference results.

📄 Source
    def inference(
        self,
        input,
        input_len=None,
        model=None,
        kwargs=None,
        key=None,
        progress_callback=None,
        **cfg,
    ):
        """Run model inference on input data (internal method).

        Handles batching, timing, and progress reporting. Called by generate()
        and inference_with_vad(). Typically not called directly by users.

        Args:
            input: Audio data, file path, or text (for punc model).
            input_len (tensor, optional): Input lengths for batch.
            model (nn.Module, optional): Override model (used for VAD/PUNC/SPK sub-models).
            kwargs (dict, optional): Override kwargs (used for sub-model configs).
            key (list, optional): Sample identifiers.
            progress_callback (callable, optional): Progress reporting function.
            **cfg: Additional config merged into kwargs.

        Returns:
            list[dict]: Model inference results.
        """
        if kwargs is None:
            self._reset_runtime_configs()
        kwargs = self.kwargs if kwargs is None else kwargs
View full source →
.inference_with_vad(input, input_len, **cfg) L592

Run ASR with VAD segmentation, punctuation, and optional speaker diarization.

Pipeline:

1. VAD: Segment audio into speech regions

2. ASR: Recognize each segment (sorted by length for efficient batching)

3. Timestamp merge: Combine per-segment timestamps with VAD offsets

4. Punctuation: Add punctuation to combined text (if punc_model configured)

5. Speaker diarization: Cluster speaker embeddings and assign labels (if spk_model configured)

Args:

  • input — Audio file path, URL, or numpy array.
  • input_len — Not used (kept for interface consistency).
  • **cfg — Runtime parameters (same as generate()).

Returns:

list[dict]: Results with fields: key, text, timestamp, sentence_info, raw_text.

📄 Source
    def inference_with_vad(self, input, input_len=None, **cfg):
        """Run ASR with VAD segmentation, punctuation, and optional speaker diarization.

        Pipeline:
        1. VAD: Segment audio into speech regions
        2. ASR: Recognize each segment (sorted by length for efficient batching)
        3. Timestamp merge: Combine per-segment timestamps with VAD offsets
        4. Punctuation: Add punctuation to combined text (if punc_model configured)
        5. Speaker diarization: Cluster speaker embeddings and assign labels (if spk_model configured)

        Args:
            input: Audio file path, URL, or numpy array.
            input_len: Not used (kept for interface consistency).
            **cfg: Runtime parameters (same as generate()).

        Returns:
            list[dict]: Results with fields: key, text, timestamp, sentence_info, raw_text.
        """
        self._reset_runtime_configs()
        if self.spk_model is not None and "output_timestamp" not in cfg:
            cfg["output_timestamp"] = True
            cfg["return_time_stamps"] = True
        kwargs = self.kwargs
        # step.1: compute the vad model
        deep_update(self.vad_kwargs, cfg)
        beg_vad = time.time()
        res = self.inference(
            input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
        )
        end_vad = time.time()
View full source →
.export(input, **cfg) L911

Export model to ONNX format.

Creates a deep copy of the model to isolate ONNX operator monkey-patching,

then runs torch.onnx.export. The original model remains usable after export.

Args:

  • input — Sample input for tracing (auto-generated if None).
  • **cfg — Export parameters:
  • type (str): Export format, "onnx" (default).
  • quantize (bool): Whether to quantize the model.
  • device (str): Device for export.

Returns:

  • str — Path to the exported model directory.
📄 Source
    def export(self, input=None, **cfg):
        """Export model to ONNX format.

        Creates a deep copy of the model to isolate ONNX operator monkey-patching,
        then runs torch.onnx.export. The original model remains usable after export.

        Args:
            input: Sample input for tracing (auto-generated if None).
            **cfg: Export parameters:
                - type (str): Export format, "onnx" (default).
                - quantize (bool): Whether to quantize the model.
                - device (str): Device for export.

        Returns:
            str: Path to the exported model directory.
        """
        """

        :param input:
        :param type:
        :param quantize:
        :param fallback_num:
        :param calib_num:
        :param opset_version:
        :param cfg:
        :return:
        """

        device = cfg.get("device", "cpu")
        
View full source →
method

AutoModel.build_model(**kwargs)

funasr.auto.auto_model.AutoModel · View on GitHub ↗

Download model from hub, build all components, and load pretrained weights.

This method handles the full model construction pipeline:

1. Download model files from ModelScope/HuggingFace (if not local)

2. Parse config.yaml to determine model class, tokenizer, frontend

3. Instantiate tokenizer, frontend, and model via the registry

4. Load pretrained weights from model.pt

Args:

  • **kwargs — Must include 'model' (str). All other config.yaml fields can be overridden.

Returns:

  • tuple — (model, kwargs) where model is the instantiated nn.Module and
  • kwargs contains the resolved configuration.
📄 Source code
    def build_model(**kwargs):
        """Download model from hub, build all components, and load pretrained weights.

        This method handles the full model construction pipeline:
        1. Download model files from ModelScope/HuggingFace (if not local)
        2. Parse config.yaml to determine model class, tokenizer, frontend
        3. Instantiate tokenizer, frontend, and model via the registry
        4. Load pretrained weights from model.pt

        Args:
            **kwargs: Must include 'model' (str). All other config.yaml fields can be overridden.

        Returns:
            tuple: (model, kwargs) where model is the instantiated nn.Module and
                kwargs contains the resolved configuration.
        """
        assert "model" in kwargs
        if "model_conf" not in kwargs:
            logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
            kwargs = download_model(**kwargs)

        set_all_random_seed(kwargs.get("seed", 0))

        device = kwargs.get("device", "cuda")
        if (
            (device.startswith("cuda") and not torch.cuda.is_available())
            or (device.startswith("xpu") and not torch.xpu.is_available())
            or (device.startswith("mps") and not torch.backends.mps.is_available())
            or (device.startswith("npu") and not is_npu_available())
            or kwargs.get("ngpu", 1) == 0
View full source on GitHub →
method

AutoModel.generate(input, input_len, progress_callback, **cfg)

funasr.auto.auto_model.AutoModel · View on GitHub ↗

Run speech recognition on input audio.

This is the primary user-facing method. It automatically routes to:

  • inference() if no vad_model is configured (single utterance)
  • inference_with_vad() if vad_model is configured (long audio with segmentation)

Args:

  • input — Audio input. Accepts:
  • File path (str): "audio.wav", "audio.mp3"
  • URL (str): "https://..."
  • numpy array: raw audio samples (float32, 16kHz)
  • list: batch of file paths or arrays
  • bytes: raw audio bytes
  • input_len (tensor, optional) — Length of each input sample.
  • progress_callback (callable, optional) — fn(current, total) called during processing.
  • **cfg — Runtime parameters:
  • cache (dict): State cache for streaming mode. Pass {} for first call.
  • hotword (str/list): Keywords to boost recognition accuracy.
  • language (str): Language hint ("auto", "zh", "en", "Chinese", etc.)
  • batch_size_s (int): Dynamic batch total duration in seconds.
  • is_final (bool): Last chunk flag for streaming mode.
  • return_spk_res (bool): Return speaker diarization results.
  • sentence_timestamp (bool): Return sentence-level timestamps.
  • use_itn (bool): Apply inverse text normalization (SenseVoice).

Returns:

list[dict]: Results for each input sample. Common fields:

  • "key" (str): Sample identifier
  • "text" (str): Recognized text
  • "timestamp" (list): [[start_ms, end_ms], ...] per character/word
  • "sentence_info" (list): [{text, start, end, spk, timestamp}, ...] when spk enabled
📄 Source code
    def generate(self, input, input_len=None, progress_callback=None, **cfg):
        """Run speech recognition on input audio.

        This is the primary user-facing method. It automatically routes to:
        - inference() if no vad_model is configured (single utterance)
        - inference_with_vad() if vad_model is configured (long audio with segmentation)

        Args:
            input: Audio input. Accepts:
                - File path (str): "audio.wav", "audio.mp3"
                - URL (str): "https://..."
                - numpy array: raw audio samples (float32, 16kHz)
                - list: batch of file paths or arrays
                - bytes: raw audio bytes
            input_len (tensor, optional): Length of each input sample.
            progress_callback (callable, optional): fn(current, total) called during processing.
            **cfg: Runtime parameters:
                - cache (dict): State cache for streaming mode. Pass {} for first call.
                - hotword (str/list): Keywords to boost recognition accuracy.
                - language (str): Language hint ("auto", "zh", "en", "Chinese", etc.)
                - batch_size_s (int): Dynamic batch total duration in seconds.
                - is_final (bool): Last chunk flag for streaming mode.
                - return_spk_res (bool): Return speaker diarization results.
                - sentence_timestamp (bool): Return sentence-level timestamps.
                - use_itn (bool): Apply inverse text normalization (SenseVoice).

        Returns:
            list[dict]: Results for each input sample. Common fields:
                - "key" (str): Sample identifier
                - "text" (str): Recognized text
View full source on GitHub →
method

AutoModel.inference(input, input_len, model, kwargs, key, progress_callback, **cfg)

funasr.auto.auto_model.AutoModel · View on GitHub ↗

Run model inference on input data (internal method).

Handles batching, timing, and progress reporting. Called by generate()

and inference_with_vad(). Typically not called directly by users.

Args:

  • input — Audio data, file path, or text (for punc model).
  • input_len (tensor, optional) — Input lengths for batch.
  • model (nn.Module, optional) — Override model (used for VAD/PUNC/SPK sub-models).
  • kwargs (dict, optional) — Override kwargs (used for sub-model configs).
  • key (list, optional) — Sample identifiers.
  • progress_callback (callable, optional) — Progress reporting function.
  • **cfg — Additional config merged into kwargs.

Returns:

list[dict]: Model inference results.

📄 Source code
    def inference(
        self,
        input,
        input_len=None,
        model=None,
        kwargs=None,
        key=None,
        progress_callback=None,
        **cfg,
    ):
        """Run model inference on input data (internal method).

        Handles batching, timing, and progress reporting. Called by generate()
        and inference_with_vad(). Typically not called directly by users.

        Args:
            input: Audio data, file path, or text (for punc model).
            input_len (tensor, optional): Input lengths for batch.
            model (nn.Module, optional): Override model (used for VAD/PUNC/SPK sub-models).
            kwargs (dict, optional): Override kwargs (used for sub-model configs).
            key (list, optional): Sample identifiers.
            progress_callback (callable, optional): Progress reporting function.
            **cfg: Additional config merged into kwargs.

        Returns:
            list[dict]: Model inference results.
        """
        if kwargs is None:
            self._reset_runtime_configs()
        kwargs = self.kwargs if kwargs is None else kwargs
View full source on GitHub →
method

AutoModel.inference_with_vad(input, input_len, **cfg)

funasr.auto.auto_model.AutoModel · View on GitHub ↗

Run ASR with VAD segmentation, punctuation, and optional speaker diarization.

Pipeline:

1. VAD: Segment audio into speech regions

2. ASR: Recognize each segment (sorted by length for efficient batching)

3. Timestamp merge: Combine per-segment timestamps with VAD offsets

4. Punctuation: Add punctuation to combined text (if punc_model configured)

5. Speaker diarization: Cluster speaker embeddings and assign labels (if spk_model configured)

Args:

  • input — Audio file path, URL, or numpy array.
  • input_len — Not used (kept for interface consistency).
  • **cfg — Runtime parameters (same as generate()).

Returns:

list[dict]: Results with fields: key, text, timestamp, sentence_info, raw_text.

📄 Source code
    def inference_with_vad(self, input, input_len=None, **cfg):
        """Run ASR with VAD segmentation, punctuation, and optional speaker diarization.

        Pipeline:
        1. VAD: Segment audio into speech regions
        2. ASR: Recognize each segment (sorted by length for efficient batching)
        3. Timestamp merge: Combine per-segment timestamps with VAD offsets
        4. Punctuation: Add punctuation to combined text (if punc_model configured)
        5. Speaker diarization: Cluster speaker embeddings and assign labels (if spk_model configured)

        Args:
            input: Audio file path, URL, or numpy array.
            input_len: Not used (kept for interface consistency).
            **cfg: Runtime parameters (same as generate()).

        Returns:
            list[dict]: Results with fields: key, text, timestamp, sentence_info, raw_text.
        """
        self._reset_runtime_configs()
        if self.spk_model is not None and "output_timestamp" not in cfg:
            cfg["output_timestamp"] = True
            cfg["return_time_stamps"] = True
        kwargs = self.kwargs
        # step.1: compute the vad model
        deep_update(self.vad_kwargs, cfg)
        beg_vad = time.time()
        res = self.inference(
            input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
        )
        end_vad = time.time()
View full source on GitHub →
method

AutoModel.export(input, **cfg)

funasr.auto.auto_model.AutoModel · View on GitHub ↗

Export model to ONNX format.

Creates a deep copy of the model to isolate ONNX operator monkey-patching,

then runs torch.onnx.export. The original model remains usable after export.

Args:

  • input — Sample input for tracing (auto-generated if None).
  • **cfg — Export parameters:
  • type (str): Export format, "onnx" (default).
  • quantize (bool): Whether to quantize the model.
  • device (str): Device for export.

Returns:

  • str — Path to the exported model directory.
📄 Source code
    def export(self, input=None, **cfg):
        """Export model to ONNX format.

        Creates a deep copy of the model to isolate ONNX operator monkey-patching,
        then runs torch.onnx.export. The original model remains usable after export.

        Args:
            input: Sample input for tracing (auto-generated if None).
            **cfg: Export parameters:
                - type (str): Export format, "onnx" (default).
                - quantize (bool): Whether to quantize the model.
                - device (str): Device for export.

        Returns:
            str: Path to the exported model directory.
        """
        """

        :param input:
        :param type:
        :param quantize:
        :param fallback_num:
        :param calib_num:
        :param opset_version:
        :param cfg:
        :return:
        """

        device = cfg.get("device", "cpu")
        
View full source on GitHub →
class

RegisterTables

funasr.register.register · View on GitHub ↗

Registry system for classes.

📄 Source code
class RegisterTables:
    """Registry system for classes."""

    model_classes = {}
    frontend_classes = {}
    specaug_classes = {}
    normalize_classes = {}
    encoder_classes = {}
    decoder_classes = {}
    joint_network_classes = {}
    predictor_classes = {}
    stride_conv_classes = {}
    tokenizer_classes = {}
    dataloader_classes = {}
    batch_sampler_classes = {}
    dataset_classes = {}
    index_ds_classes = {}

    def print(self, key: str = None) -> None:
        """Print registered classes."""
        print("\ntables: \n")
        fields = vars(self)
        headers = ["register name", "class name", "class location"]
        for classes_key, classes_dict in fields.items():
            if classes_key.endswith("_meta") and (key is None or key in classes_key):
                print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
                metas = []
                for register_key, meta in classes_dict.items():
                    metas.append(meta)
                metas.sort(key=lambda x: x[0])
View full source on GitHub →

Methods

.print(key) L26

Print registered classes.

📄 Source
    def print(self, key: str = None) -> None:
        """Print registered classes."""
        print("\ntables: \n")
        fields = vars(self)
        headers = ["register name", "class name", "class location"]
        for classes_key, classes_dict in fields.items():
            if classes_key.endswith("_meta") and (key is None or key in classes_key):
                print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
                metas = []
                for register_key, meta in classes_dict.items():
                    metas.append(meta)
                metas.sort(key=lambda x: x[0])
                data = [headers] + metas
                col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]

                for row in data:
                    print(
                        "| "
                        + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
                        + " |"
                    )
        print("\n")
.register(register_tables_key, key) L49

Decorator to register a class.

📄 Source
    def register(self, register_tables_key: str, key: str = None) -> callable:
        """Decorator to register a class."""

        def decorator(target_class):
            """Decorator.
            
                Args:
                    target_class: TODO.
                """
            if not hasattr(self, register_tables_key):
                setattr(self, register_tables_key, {})
                logging.debug(f"New registry table added: {register_tables_key}")

            registry = getattr(self, register_tables_key)
            registry_key = key if key is not None else target_class.__name__

            if registry_key in registry:
                logging.debug(
                    f"Key {registry_key} already exists in {register_tables_key}, re-register"
                )

            registry[registry_key] = target_class

            register_tables_key_meta = register_tables_key + "_meta"
            if not hasattr(self, register_tables_key_meta):
                setattr(self, register_tables_key_meta, {})
            registry_meta = getattr(self, register_tables_key_meta)

            class_file = inspect.getfile(target_class)
            class_line = inspect.getsourcelines(target_class)[1]
View full source →
method

RegisterTables.print(key)

funasr.register.register.RegisterTables · View on GitHub ↗

Print registered classes.

📄 Source code
    def print(self, key: str = None) -> None:
        """Print registered classes."""
        print("\ntables: \n")
        fields = vars(self)
        headers = ["register name", "class name", "class location"]
        for classes_key, classes_dict in fields.items():
            if classes_key.endswith("_meta") and (key is None or key in classes_key):
                print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
                metas = []
                for register_key, meta in classes_dict.items():
                    metas.append(meta)
                metas.sort(key=lambda x: x[0])
                data = [headers] + metas
                col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]

                for row in data:
                    print(
                        "| "
                        + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
                        + " |"
                    )
        print("\n")
method

RegisterTables.register(register_tables_key, key)

funasr.register.register.RegisterTables · View on GitHub ↗

Decorator to register a class.

📄 Source code
    def register(self, register_tables_key: str, key: str = None) -> callable:
        """Decorator to register a class."""

        def decorator(target_class):
            """Decorator.
            
                Args:
                    target_class: TODO.
                """
            if not hasattr(self, register_tables_key):
                setattr(self, register_tables_key, {})
                logging.debug(f"New registry table added: {register_tables_key}")

            registry = getattr(self, register_tables_key)
            registry_key = key if key is not None else target_class.__name__

            if registry_key in registry:
                logging.debug(
                    f"Key {registry_key} already exists in {register_tables_key}, re-register"
                )

            registry[registry_key] = target_class

            register_tables_key_meta = register_tables_key + "_meta"
            if not hasattr(self, register_tables_key_meta):
                setattr(self, register_tables_key_meta, {})
            registry_meta = getattr(self, register_tables_key_meta)

            class_file = inspect.getfile(target_class)
            class_line = inspect.getsourcelines(target_class)[1]
View full source on GitHub →
class

BAT

funasr.models.bat · View on GitHub ↗
  • BAT (Boundary-Aware Transducer) — Low-latency RNN-T model with boundary detection.

Inherits from Transducer. Designed for streaming ASR with reduced latency

by predicting token boundaries explicitly.

📄 Source code
class BAT(Transducer):
    """BAT (Boundary-Aware Transducer): Low-latency RNN-T model with boundary detection.

    Inherits from Transducer. Designed for streaming ASR with reduced latency
    by predicting token boundaries explicitly.
    """

    pass
class

BiCifParaformer

funasr.models.bicif_paraformer · View on GitHub ↗
  • BiCifParaformer — Paraformer with Bidirectional CIF for Timestamp Prediction.

Extends Paraformer with a second CIF predictor that provides accurate

  • character — level timestamp prediction alongside ASR. Uses bidirectional

information flow for better alignment between audio frames and text tokens.

Reference:

  • FunASR: A Fundamental End-to-End Speech Recognition Toolkit (https://arxiv.org/abs/2305.11013)
  • Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model

(https://arxiv.org/abs/2301.12343)

Output:

{"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}

  • Author — Speech Lab of DAMO Academy, Alibaba Group
📄 Source code
class BiCifParaformer(Paraformer):
    """BiCifParaformer: Paraformer with Bidirectional CIF for Timestamp Prediction.

    Extends Paraformer with a second CIF predictor that provides accurate
    character-level timestamp prediction alongside ASR. Uses bidirectional
    information flow for better alignment between audio frames and text tokens.

    Reference:
        - FunASR: A Fundamental End-to-End Speech Recognition Toolkit (https://arxiv.org/abs/2305.11013)
        - Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
          (https://arxiv.org/abs/2301.12343)

    Output:
        {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}

    Author: Speech Lab of DAMO Academy, Alibaba Group
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):
        """Initialize BiCifParaformer.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)
View full source on GitHub →

Methods

.calc_predictor(encoder_out, encoder_out_lens) L162

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source
    def calc_predictor(self, encoder_out, encoder_out_lens):
        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = (
            self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
        )
        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num) L177

Calc predictor timestamp.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • token_num — TODO.
📄 Source
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        """Calc predictor timestamp.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                token_num: TODO.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
            encoder_out, encoder_out_mask, token_num
        )
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
.forward(speech, speech_lengths, text, text_lengths, **kwargs) L193

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()

        # decoder: CTC branch
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L271

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        # init beamsearch
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

View full source →
.export(**kwargs) L430

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
method

BiCifParaformer.calc_predictor(encoder_out, encoder_out_lens)

funasr.models.bicif_paraformer.BiCifParaformer · View on GitHub ↗

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source code
    def calc_predictor(self, encoder_out, encoder_out_lens):
        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = (
            self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
        )
        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
method

BiCifParaformer.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num)

funasr.models.bicif_paraformer.BiCifParaformer · View on GitHub ↗

Calc predictor timestamp.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • token_num — TODO.
📄 Source code
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        """Calc predictor timestamp.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                token_num: TODO.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
            encoder_out, encoder_out_mask, token_num
        )
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
method

BiCifParaformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.bicif_paraformer.BiCifParaformer · View on GitHub ↗

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()

        # decoder: CTC branch
View full source on GitHub →
method

BiCifParaformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.bicif_paraformer.BiCifParaformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        # init beamsearch
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

View full source on GitHub →
method

BiCifParaformer.export(**kwargs)

funasr.models.bicif_paraformer.BiCifParaformer · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
class

Branchformer

funasr.models.branchformer · View on GitHub ↗
  • Branchformer — Parallel branch encoder architecture.

Uses parallel branches of self-attention and convolution that are

merged via concatenation. Alternative to Conformer with similar accuracy.

Inherits Transformer pipeline for training and inference.

📄 Source code
class Branchformer(Transformer):
    """Branchformer: Parallel branch encoder architecture.

    Uses parallel branches of self-attention and convolution that are
    merged via concatenation. Alternative to Conformer with similar accuracy.

    Inherits Transformer pipeline for training and inference.
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):

        """Initialize Branchformer.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)
class

CAMPPlus

funasr.models.campplus · View on GitHub ↗

CAM++ Speaker Verification Model.

Extracts fixed-dimensional speaker embeddings from variable-length audio.

Used for speaker verification and speaker diarization pipelines.

  • Output — 192-dimensional speaker embedding per utterance.
📄 Source code
class CAMPPlus(torch.nn.Module):
    """CAM++ Speaker Verification Model.

    Extracts fixed-dimensional speaker embeddings from variable-length audio.
    Used for speaker verification and speaker diarization pipelines.

    Output: 192-dimensional speaker embedding per utterance.
    """

    def __init__(
        self,
        feat_dim=80,
        embedding_size=192,
        growth_rate=32,
        bn_size=4,
        init_channels=128,
        config_str="batchnorm-relu",
        memory_efficient=True,
        output_level="segment",
        **kwargs,
    ):
        """Initialize CAMPPlus.
        
            Args:
                feat_dim: Size/dimension parameter.
                embedding_size: Size/dimension parameter.
                growth_rate: TODO.
                bn_size: Size/dimension parameter.
                init_channels: TODO.
                config_str: TODO.
View full source on GitHub →

Methods

.forward(x) L141

Extract speaker embedding from fbank features.

Args:

  • x (Tensor) — Input fbank features, shape (batch, time, feat_dim).

Returns:

  • Tensor — Speaker embedding, shape (batch, embedding_size) for segment level,
  • or (batch, time, channels) for frame level.
📄 Source
    def forward(self, x):
        """Extract speaker embedding from fbank features.

        Args:
            x (Tensor): Input fbank features, shape (batch, time, feat_dim).

        Returns:
            Tensor: Speaker embedding, shape (batch, embedding_size) for segment level,
                or (batch, time, channels) for frame level.
        """
        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
        x = self.head(x)
        x = self.xvector(x)
        if self.output_level == "frame":
            x = x.transpose(1, 2)
        return x
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L158

Run speaker embedding extraction on audio input.

Args:

  • data_in — Audio input (file path, numpy array, or list).
  • data_lengths — Not used.
  • key (list) — Sample identifiers.
  • tokenizer — Not used.
  • frontend — Not used.
  • **kwargs — Must include 'device' (str) and optional 'fs' (int, default 16000).

Returns:

  • tuple — (results, meta_data) where results is
  • [{"spk_embedding": Tensor of shape (1, 192)}]
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run speaker embedding extraction on audio input.

        Args:
            data_in: Audio input (file path, numpy array, or list).
            data_lengths: Not used.
            key (list): Sample identifiers.
            tokenizer: Not used.
            frontend: Not used.
            **kwargs: Must include 'device' (str) and optional 'fs' (int, default 16000).

        Returns:
            tuple: (results, meta_data) where results is
                [{"spk_embedding": Tensor of shape (1, 192)}]
        """
        # extract fbank feats
        meta_data = {}
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(
            data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
        )
        time2 = time.perf_counter()
View full source →
method

CAMPPlus.forward(x)

funasr.models.campplus.CAMPPlus · View on GitHub ↗

Extract speaker embedding from fbank features.

Args:

  • x (Tensor) — Input fbank features, shape (batch, time, feat_dim).

Returns:

  • Tensor — Speaker embedding, shape (batch, embedding_size) for segment level,
  • or (batch, time, channels) for frame level.
📄 Source code
    def forward(self, x):
        """Extract speaker embedding from fbank features.

        Args:
            x (Tensor): Input fbank features, shape (batch, time, feat_dim).

        Returns:
            Tensor: Speaker embedding, shape (batch, embedding_size) for segment level,
                or (batch, time, channels) for frame level.
        """
        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
        x = self.head(x)
        x = self.xvector(x)
        if self.output_level == "frame":
            x = x.transpose(1, 2)
        return x
method

CAMPPlus.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.campplus.CAMPPlus · View on GitHub ↗

Run speaker embedding extraction on audio input.

Args:

  • data_in — Audio input (file path, numpy array, or list).
  • data_lengths — Not used.
  • key (list) — Sample identifiers.
  • tokenizer — Not used.
  • frontend — Not used.
  • **kwargs — Must include 'device' (str) and optional 'fs' (int, default 16000).

Returns:

  • tuple — (results, meta_data) where results is
  • [{"spk_embedding": Tensor of shape (1, 192)}]
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run speaker embedding extraction on audio input.

        Args:
            data_in: Audio input (file path, numpy array, or list).
            data_lengths: Not used.
            key (list): Sample identifiers.
            tokenizer: Not used.
            frontend: Not used.
            **kwargs: Must include 'device' (str) and optional 'fs' (int, default 16000).

        Returns:
            tuple: (results, meta_data) where results is
                [{"spk_embedding": Tensor of shape (1, 192)}]
        """
        # extract fbank feats
        meta_data = {}
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(
            data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
        )
        time2 = time.perf_counter()
View full source on GitHub →
class

Conformer

funasr.models.conformer · View on GitHub ↗
  • Conformer — CTC-attention hybrid encoder-decoder model.

Combines convolution and self-attention in the encoder for better

local and global context modeling. Inherits full Transformer pipeline

(CTC + attention decoder + beam search).

  • Output — {"key": str, "text": str}
📄 Source code
class Conformer(Transformer):
    """Conformer: CTC-attention hybrid encoder-decoder model.

    Combines convolution and self-attention in the encoder for better
    local and global context modeling. Inherits full Transformer pipeline
    (CTC + attention decoder + beam search).

    Output: {"key": str, "text": str}
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):

        """Initialize Conformer.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)
class

Conformer

funasr.models.conformer_rwkv · View on GitHub ↗
  • CTC — attention hybrid Encoder-Decoder model
📄 Source code
class Conformer(Transformer):
    """CTC-attention hybrid Encoder-Decoder model"""

    def __init__(
        self,
        *args,
        **kwargs,
    ):

        """Initialize Conformer.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)
class

ContextualParaformer

funasr.models.contextual_paraformer · View on GitHub ↗
  • ContextualParaformer — Paraformer with hotword/context biasing.

Extends Paraformer with a context encoder that incorporates user-defined

hotwords/keywords to boost recognition of domain-specific terms.

  • Usage — Pass hotwords via generate(hotword='term1 term2').
📄 Source code
class ContextualParaformer(Paraformer):
    """ContextualParaformer: Paraformer with hotword/context biasing.

    Extends Paraformer with a context encoder that incorporates user-defined
    hotwords/keywords to boost recognition of domain-specific terms.

    Usage: Pass hotwords via generate(hotword='term1 term2').
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):
        """Initialize ContextualParaformer.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)

        self.target_buffer_length = kwargs.get("target_buffer_length", -1)
        inner_dim = kwargs.get("inner_dim", 256)
        bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
        use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
        crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
        crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
        bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)

View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L96

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        text_lengths = text_lengths.squeeze()
        speech_lengths = speech_lengths.squeeze()

        batch_size = speech.shape[0]

        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        # dha_pad = kwargs.get("dha_pad")

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None

View full source →
.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info) L268

Sampler.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
  • contextual_info — TODO.
📄 Source
    def sampler(
        self,
        encoder_out,
        encoder_out_lens,
        ys_pad,
        ys_pad_lens,
        pre_acoustic_embeds,
        contextual_info,
    ):
        """Sampler.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
                contextual_info: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out,
View full source →
.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list, clas_scale) L331

Cal decoder with predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • hw_list — TODO.
  • clas_scale — TODO.
📄 Source
    def cal_decoder_with_predictor(
        self,
        encoder_out,
        encoder_out_lens,
        sematic_embeds,
        ys_pad_lens,
        hw_list=None,
        clas_scale=1.0,
    ):
        """Cal decoder with predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
                hw_list: TODO.
                clas_scale: TODO.
            """
        if hw_list is None:
            hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
            hw_list_pad = pad_list(hw_list, 0)
            if self.use_decoder_embedding:
                hw_embed = self.decoder.embed(hw_list_pad)
            else:
                hw_embed = self.bias_embed(hw_list_pad)
            hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
            hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
        else:
            hw_lengths = [len(i) for i in hw_list]
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L387

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        # init beamsearch

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

View full source →
.generate_hotwords_list(hotword_list_or_file, tokenizer, frontend) L534

Generate hotwords list.

Args:

  • hotword_list_or_file — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
📄 Source
    def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
        """Generate hotwords list.
        
            Args:
                hotword_list_or_file: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
            """
        def load_seg_dict(seg_dict_file):
            """Load seg dict.
            
                Args:
                    seg_dict_file: TODO.
                """
            seg_dict = {}
            assert isinstance(seg_dict_file, str)
            with open(seg_dict_file, "r", encoding="utf8") as f:
                lines = f.readlines()
                for line in lines:
                    s = line.strip().split()
                    key = s[0]
                    value = s[1:]
                    seg_dict[key] = " ".join(value)
            return seg_dict

        def seg_tokenize(txt, seg_dict):
            """Seg tokenize.
            
                Args:
                    txt: TODO.
View full source →
.export(**kwargs) L659

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(
        self,
        **kwargs,
    ):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
method

ContextualParaformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.contextual_paraformer.ContextualParaformer · View on GitHub ↗

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        text_lengths = text_lengths.squeeze()
        speech_lengths = speech_lengths.squeeze()

        batch_size = speech.shape[0]

        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        # dha_pad = kwargs.get("dha_pad")

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None

View full source on GitHub →
method

ContextualParaformer.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info)

funasr.models.contextual_paraformer.ContextualParaformer · View on GitHub ↗

Sampler.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
  • contextual_info — TODO.
📄 Source code
    def sampler(
        self,
        encoder_out,
        encoder_out_lens,
        ys_pad,
        ys_pad_lens,
        pre_acoustic_embeds,
        contextual_info,
    ):
        """Sampler.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
                contextual_info: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out,
View full source on GitHub →
method

ContextualParaformer.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list, clas_scale)

funasr.models.contextual_paraformer.ContextualParaformer · View on GitHub ↗

Cal decoder with predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • hw_list — TODO.
  • clas_scale — TODO.
📄 Source code
    def cal_decoder_with_predictor(
        self,
        encoder_out,
        encoder_out_lens,
        sematic_embeds,
        ys_pad_lens,
        hw_list=None,
        clas_scale=1.0,
    ):
        """Cal decoder with predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
                hw_list: TODO.
                clas_scale: TODO.
            """
        if hw_list is None:
            hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
            hw_list_pad = pad_list(hw_list, 0)
            if self.use_decoder_embedding:
                hw_embed = self.decoder.embed(hw_list_pad)
            else:
                hw_embed = self.bias_embed(hw_list_pad)
            hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
            hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
        else:
            hw_lengths = [len(i) for i in hw_list]
View full source on GitHub →
method

ContextualParaformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.contextual_paraformer.ContextualParaformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        # init beamsearch

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

View full source on GitHub →
method

ContextualParaformer.generate_hotwords_list(hotword_list_or_file, tokenizer, frontend)

funasr.models.contextual_paraformer.ContextualParaformer · View on GitHub ↗

Generate hotwords list.

Args:

  • hotword_list_or_file — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
📄 Source code
    def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
        """Generate hotwords list.
        
            Args:
                hotword_list_or_file: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
            """
        def load_seg_dict(seg_dict_file):
            """Load seg dict.
            
                Args:
                    seg_dict_file: TODO.
                """
            seg_dict = {}
            assert isinstance(seg_dict_file, str)
            with open(seg_dict_file, "r", encoding="utf8") as f:
                lines = f.readlines()
                for line in lines:
                    s = line.strip().split()
                    key = s[0]
                    value = s[1:]
                    seg_dict[key] = " ".join(value)
            return seg_dict

        def seg_tokenize(txt, seg_dict):
            """Seg tokenize.
            
                Args:
                    txt: TODO.
View full source on GitHub →
method

ContextualParaformer.export(**kwargs)

funasr.models.contextual_paraformer.ContextualParaformer · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(
        self,
        **kwargs,
    ):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
class

CTTransformer

funasr.models.ct_transformer · View on GitHub ↗
  • CT — Transformer: Punctuation Restoration Model.

Adds punctuation (comma, period, question mark) to unpunctuated text.

Supports Chinese and English. Used as punc_model in the ASR pipeline.

  • Output — {"key": "...", "text": "punctuated text", "punc_array": Tensor}

punc_array encoding: 1=none, 2=comma(,), 3=period(。), 4=question(?)

  • Note — Not needed for Fun-ASR-Nano/SenseVoice/Qwen3-ASR (they output punctuation natively).

Only required for Paraformer models.

  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • CT — Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection

https://arxiv.org/pdf/2003.01309.pdf

📄 Source code
class CTTransformer(torch.nn.Module):
    """CT-Transformer: Punctuation Restoration Model.

    Adds punctuation (comma, period, question mark) to unpunctuated text.
    Supports Chinese and English. Used as punc_model in the ASR pipeline.

    Output: {"key": "...", "text": "punctuated text", "punc_array": Tensor}
    punc_array encoding: 1=none, 2=comma(,), 3=period(。), 4=question(?)

    Note: Not needed for Fun-ASR-Nano/SenseVoice/Qwen3-ASR (they output punctuation natively).
    Only required for Paraformer models.

    Author: Speech Lab of DAMO Academy, Alibaba Group
    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
    https://arxiv.org/pdf/2003.01309.pdf
    """

    def __init__(
        self,
        encoder: str = None,
        encoder_conf: dict = None,
        vocab_size: int = -1,
        punc_list: list = None,
        punc_weight: list = None,
        embed_unit: int = 128,
        att_unit: int = 256,
        dropout_rate: float = 0.5,
        ignore_id: int = -1,
        sos: int = 1,
        eos: int = 2,
View full source on GitHub →

Methods

.punc_forward(text, text_lengths, **kwargs) L113

Compute loss value from buffer sequences.

Args:

  • input (torch.Tensor) — Input ids. (batch, len)
  • hidden (torch.Tensor) — Target ids. (batch, len)
📄 Source
    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
        """Compute loss value from buffer sequences.

        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)

        """
        x = self.embed(text)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y, None
.with_vad() L127

With vad.

📄 Source
    def with_vad(self):
        """With vad."""
        return False
.score(y, state, x) L131

Score new token.

Args:

  • y (torch.Tensor) — 1D torch.int64 prefix tokens.
  • state — Scorer state for prefix tokens
  • x (torch.Tensor) — encoder feature that generates ys.

Returns:

tuple[torch.Tensor, Any]: Tuple of

torch.float32 scores for next token (vocab_size)

and next state for ys

📄 Source
    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        """Score new token.

        Args:
            y (torch.Tensor): 1D torch.int64 prefix tokens.
            state: Scorer state for prefix tokens
            x (torch.Tensor): encoder feature that generates ys.

        Returns:
            tuple[torch.Tensor, Any]: Tuple of
                torch.float32 scores for next token (vocab_size)
                and next state for ys

        """
        y = y.unsqueeze(0)
        h, _, cache = self.encoder.forward_one_step(
            self.embed(y), self._target_mask(y), cache=state
        )
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1).squeeze(0)
        return logp, cache
.batch_score(ys, states, xs) L153

Score new token batch.

Args:

  • ys (torch.Tensor) — torch.int64 prefix tokens (n_batch, ylen).
  • states (List[Any]) — Scorer states for prefix tokens.

xs (torch.Tensor):

The encoder feature that generates ys (n_batch, xlen, n_feat).

Returns:

tuple[torch.Tensor, List[Any]]: Tuple of

batchfied scores for next token with shape of `(n_batch, vocab_size)`

and next state list for ys.

📄 Source
    def batch_score(
        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.

        Args:
            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
            states (List[Any]): Scorer states for prefix tokens.
            xs (torch.Tensor):
                The encoder feature that generates ys (n_batch, xlen, n_feat).

        Returns:
            tuple[torch.Tensor, List[Any]]: Tuple of
                batchfied scores for next token with shape of `(n_batch, vocab_size)`
                and next state list for ys.

        """
        # merge states
        n_batch = len(ys)
        n_layers = len(self.encoder.encoders)
        if states[0] is None:
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
            ]

        # batch decoding
        h, _, states = self.encoder.forward_one_step(
View full source →
.nll(text, punc, text_lengths, punc_lengths, max_length, vad_indexes, vad_indexes_lengths) L192

Compute negative log likelihood(nll)

Normally, this function is called in batchify_nll.

Args:

  • text — (Batch, Length)
  • punc — (Batch, Length)
  • text_lengths — (Batch,)
  • max_lengths — int
📄 Source
    def nll(
        self,
        text: torch.Tensor,
        punc: torch.Tensor,
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        max_length: Optional[int] = None,
        vad_indexes: Optional[torch.Tensor] = None,
        vad_indexes_lengths: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute negative log likelihood(nll)

        Normally, this function is called in batchify_nll.
        Args:
            text: (Batch, Length)
            punc: (Batch, Length)
            text_lengths: (Batch,)
            max_lengths: int
        """
        batch_size = text.size(0)
        # For data parallel
        if max_length is None:
            text = text[:, : text_lengths.max()]
            punc = punc[:, : text_lengths.max()]
        else:
            text = text[:, :max_length]
            punc = punc[:, :max_length]

        if self.with_vad():
            # Should be VadRealtimeTransformer
View full source →
.forward(text, punc, text_lengths, punc_lengths, vad_indexes, vad_indexes_lengths) L262

Forward pass for training.

Args:

  • text — Text tensor or string input.
  • punc — TODO.
  • text_lengths — Length of each text sample.
  • punc_lengths — Lengths of punc.
  • vad_indexes — TODO.
  • vad_indexes_lengths — Lengths of vad_indexes.
📄 Source
    def forward(
        self,
        text: torch.Tensor,
        punc: torch.Tensor,
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        vad_indexes: Optional[torch.Tensor] = None,
        vad_indexes_lengths: Optional[torch.Tensor] = None,
    ):
        """Forward pass for training.
        
            Args:
                text: Text tensor or string input.
                punc: TODO.
                text_lengths: Length of each text sample.
                punc_lengths: Lengths of punc.
                vad_indexes: TODO.
                vad_indexes_lengths: Lengths of vad_indexes.
            """
        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L290

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        assert len(data_in) == 1
        if not data_in[0] or (isinstance(data_in[0], str) and not data_in[0].strip()):
            meta_data = {"batch_data_time": -1}
            return [{"key": key[0] if key else "", "text": "", "punc_array": None}], meta_data
        text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
        vad_indexes = kwargs.get("vad_indexes", None)
        # text = data_in[0]
        # text_lengths = data_lengths[0] if data_lengths is not None else None
        split_size = kwargs.get("split_size", 20)

        tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
View full source →
.export(**kwargs) L471

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):

        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
method

CTTransformer.punc_forward(text, text_lengths, **kwargs)

funasr.models.ct_transformer.CTTransformer · View on GitHub ↗

Compute loss value from buffer sequences.

Args:

  • input (torch.Tensor) — Input ids. (batch, len)
  • hidden (torch.Tensor) — Target ids. (batch, len)
📄 Source code
    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
        """Compute loss value from buffer sequences.

        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)

        """
        x = self.embed(text)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y, None
method

CTTransformer.with_vad()

funasr.models.ct_transformer.CTTransformer · View on GitHub ↗

With vad.

📄 Source code
    def with_vad(self):
        """With vad."""
        return False
method

CTTransformer.score(y, state, x)

funasr.models.ct_transformer.CTTransformer · View on GitHub ↗

Score new token.

Args:

  • y (torch.Tensor) — 1D torch.int64 prefix tokens.
  • state — Scorer state for prefix tokens
  • x (torch.Tensor) — encoder feature that generates ys.

Returns:

tuple[torch.Tensor, Any]: Tuple of

torch.float32 scores for next token (vocab_size)

and next state for ys

📄 Source code
    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        """Score new token.

        Args:
            y (torch.Tensor): 1D torch.int64 prefix tokens.
            state: Scorer state for prefix tokens
            x (torch.Tensor): encoder feature that generates ys.

        Returns:
            tuple[torch.Tensor, Any]: Tuple of
                torch.float32 scores for next token (vocab_size)
                and next state for ys

        """
        y = y.unsqueeze(0)
        h, _, cache = self.encoder.forward_one_step(
            self.embed(y), self._target_mask(y), cache=state
        )
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1).squeeze(0)
        return logp, cache
method

CTTransformer.batch_score(ys, states, xs)

funasr.models.ct_transformer.CTTransformer · View on GitHub ↗

Score new token batch.

Args:

  • ys (torch.Tensor) — torch.int64 prefix tokens (n_batch, ylen).
  • states (List[Any]) — Scorer states for prefix tokens.

xs (torch.Tensor):

The encoder feature that generates ys (n_batch, xlen, n_feat).

Returns:

tuple[torch.Tensor, List[Any]]: Tuple of

batchfied scores for next token with shape of `(n_batch, vocab_size)`

and next state list for ys.

📄 Source code
    def batch_score(
        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.

        Args:
            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
            states (List[Any]): Scorer states for prefix tokens.
            xs (torch.Tensor):
                The encoder feature that generates ys (n_batch, xlen, n_feat).

        Returns:
            tuple[torch.Tensor, List[Any]]: Tuple of
                batchfied scores for next token with shape of `(n_batch, vocab_size)`
                and next state list for ys.

        """
        # merge states
        n_batch = len(ys)
        n_layers = len(self.encoder.encoders)
        if states[0] is None:
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
            ]

        # batch decoding
        h, _, states = self.encoder.forward_one_step(
View full source on GitHub →
method

CTTransformer.nll(text, punc, text_lengths, punc_lengths, max_length, vad_indexes, vad_indexes_lengths)

funasr.models.ct_transformer.CTTransformer · View on GitHub ↗

Compute negative log likelihood(nll)

Normally, this function is called in batchify_nll.

Args:

  • text — (Batch, Length)
  • punc — (Batch, Length)
  • text_lengths — (Batch,)
  • max_lengths — int
📄 Source code
    def nll(
        self,
        text: torch.Tensor,
        punc: torch.Tensor,
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        max_length: Optional[int] = None,
        vad_indexes: Optional[torch.Tensor] = None,
        vad_indexes_lengths: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute negative log likelihood(nll)

        Normally, this function is called in batchify_nll.
        Args:
            text: (Batch, Length)
            punc: (Batch, Length)
            text_lengths: (Batch,)
            max_lengths: int
        """
        batch_size = text.size(0)
        # For data parallel
        if max_length is None:
            text = text[:, : text_lengths.max()]
            punc = punc[:, : text_lengths.max()]
        else:
            text = text[:, :max_length]
            punc = punc[:, :max_length]

        if self.with_vad():
            # Should be VadRealtimeTransformer
View full source on GitHub →
method

CTTransformer.forward(text, punc, text_lengths, punc_lengths, vad_indexes, vad_indexes_lengths)

funasr.models.ct_transformer.CTTransformer · View on GitHub ↗

Forward pass for training.

Args:

  • text — Text tensor or string input.
  • punc — TODO.
  • text_lengths — Length of each text sample.
  • punc_lengths — Lengths of punc.
  • vad_indexes — TODO.
  • vad_indexes_lengths — Lengths of vad_indexes.
📄 Source code
    def forward(
        self,
        text: torch.Tensor,
        punc: torch.Tensor,
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        vad_indexes: Optional[torch.Tensor] = None,
        vad_indexes_lengths: Optional[torch.Tensor] = None,
    ):
        """Forward pass for training.
        
            Args:
                text: Text tensor or string input.
                punc: TODO.
                text_lengths: Length of each text sample.
                punc_lengths: Lengths of punc.
                vad_indexes: TODO.
                vad_indexes_lengths: Lengths of vad_indexes.
            """
        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())

        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
        return loss, stats, weight
method

CTTransformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.ct_transformer.CTTransformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        assert len(data_in) == 1
        if not data_in[0] or (isinstance(data_in[0], str) and not data_in[0].strip()):
            meta_data = {"batch_data_time": -1}
            return [{"key": key[0] if key else "", "text": "", "punc_array": None}], meta_data
        text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
        vad_indexes = kwargs.get("vad_indexes", None)
        # text = data_in[0]
        # text_lengths = data_lengths[0] if data_lengths is not None else None
        split_size = kwargs.get("split_size", 20)

        tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
View full source on GitHub →
method

CTTransformer.export(**kwargs)

funasr.models.ct_transformer.CTTransformer · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):

        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
class

CTTransformerStreaming

funasr.models.ct_transformer_streaming · View on GitHub ↗
  • CT — Transformer Streaming: Online punctuation restoration.

Processes text incrementally with a sliding window, maintaining cache

of previous context for consistent punctuation decisions across chunks.

Used as punc_model in streaming ASR pipelines.

Supports VAD-aware punctuation: uses VAD boundaries to improve sentence segmentation.

  • Reference — https://arxiv.org/pdf/2003.01309.pdf
  • Output — {"key": str, "text": str, "punc_array": Tensor}
  • Author — Speech Lab of DAMO Academy, Alibaba Group
📄 Source code
class CTTransformerStreaming(CTTransformer):
    """CT-Transformer Streaming: Online punctuation restoration.

    Processes text incrementally with a sliding window, maintaining cache
    of previous context for consistent punctuation decisions across chunks.
    Used as punc_model in streaming ASR pipelines.

    Supports VAD-aware punctuation: uses VAD boundaries to improve sentence segmentation.

    Reference: https://arxiv.org/pdf/2003.01309.pdf
    Output: {"key": str, "text": str, "punc_array": Tensor}

    Author: Speech Lab of DAMO Academy, Alibaba Group
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):
        """Initialize CTTransformerStreaming.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)

    def punc_forward(
        self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs
View full source on GitHub →

Methods

.punc_forward(text, text_lengths, vad_indexes, **kwargs) L61

Compute loss value from buffer sequences.

Args:

  • input (torch.Tensor) — Input ids. (batch, len)
  • hidden (torch.Tensor) — Target ids. (batch, len)
📄 Source
    def punc_forward(
        self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs
    ):
        """Compute loss value from buffer sequences.

        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)

        """
        x = self.embed(text)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths, vad_indexes=vad_indexes)
        y = self.decoder(h)
        return y, None
.with_vad() L77

With vad.

📄 Source
    def with_vad(self):
        """With vad."""
        return True
.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L81

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        assert len(data_in) == 1

        if len(cache) == 0:
            cache["pre_text"] = []
        text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
        text = "".join(cache["pre_text"]) + " " + text

View full source →
.export(**kwargs) L221

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):

        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
method

CTTransformerStreaming.punc_forward(text, text_lengths, vad_indexes, **kwargs)

funasr.models.ct_transformer_streaming.CTTransformerStreaming · View on GitHub ↗

Compute loss value from buffer sequences.

Args:

  • input (torch.Tensor) — Input ids. (batch, len)
  • hidden (torch.Tensor) — Target ids. (batch, len)
📄 Source code
    def punc_forward(
        self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs
    ):
        """Compute loss value from buffer sequences.

        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)

        """
        x = self.embed(text)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths, vad_indexes=vad_indexes)
        y = self.decoder(h)
        return y, None
method

CTTransformerStreaming.with_vad()

funasr.models.ct_transformer_streaming.CTTransformerStreaming · View on GitHub ↗

With vad.

📄 Source code
    def with_vad(self):
        """With vad."""
        return True
method

CTTransformerStreaming.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)

funasr.models.ct_transformer_streaming.CTTransformerStreaming · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        assert len(data_in) == 1

        if len(cache) == 0:
            cache["pre_text"] = []
        text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
        text = "".join(cache["pre_text"]) + " " + text

View full source on GitHub →
method

CTTransformerStreaming.export(**kwargs)

funasr.models.ct_transformer_streaming.CTTransformerStreaming · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):

        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
class

Transformer

funasr.models.ctc · View on GitHub ↗
  • CTC — attention hybrid Encoder-Decoder model
📄 Source code
class Transformer(nn.Module):
    """CTC-attention hybrid Encoder-Decoder model"""

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        ctc_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        length_normalized_loss: bool = False,
        **kwargs,
    ):

        """Initialize Transformer.
        
            Args:
                specaug: TODO.
                specaug_conf: Configuration dict for specaug.
                normalize: TODO.
                normalize_conf: Configuration dict for normalize.
                encoder: TODO.
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L89

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        stats = dict()

View full source →
.encode(speech, speech_lengths, **kwargs) L134

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """

        # Data augmentation
        if self.specaug is not None and self.training:
            speech, speech_lengths = self.specaug(speech, speech_lengths)

        # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
        if self.normalize is not None:
            speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)

        return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L189

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
View full source →
method

Transformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.ctc.Transformer · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        stats = dict()

View full source on GitHub →
method

Transformer.encode(speech, speech_lengths, **kwargs)

funasr.models.ctc.Transformer · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """

        # Data augmentation
        if self.specaug is not None and self.training:
            speech, speech_lengths = self.specaug(speech, speech_lengths)

        # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
        if self.normalize is not None:
            speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)

        return encoder_out, encoder_out_lens
method

Transformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.ctc.Transformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
View full source on GitHub →
class

EBranchformer

funasr.models.e_branchformer · View on GitHub ↗
  • E — Branchformer: Enhanced Branchformer with improved merging.

Uses element-wise merging instead of concatenation for parallel branches,

resulting in better parameter efficiency.

Inherits Transformer pipeline.

📄 Source code
class EBranchformer(Transformer):
    """E-Branchformer: Enhanced Branchformer with improved merging.

    Uses element-wise merging instead of concatenation for parallel branches,
    resulting in better parameter efficiency.

    Inherits Transformer pipeline.
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):

        """Initialize EBranchformer.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)
class

EParaformer

funasr.models.e_paraformer · View on GitHub ↗
  • E — Paraformer: Enhanced Paraformer with streaming support.

Extended Paraformer supporting both offline and streaming modes

through dynamic masking in the encoder. Used for 2-pass decoding

where first pass provides streaming results and second pass refines.

  • Output — {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}
  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • Paraformer — Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition

https://arxiv.org/abs/2206.08317

  • Author — Kun Zou, chinazoukun@gmail.com
  • E — Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition

https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf

📄 Source code
class EParaformer(torch.nn.Module):
    """E-Paraformer: Enhanced Paraformer with streaming support.

    Extended Paraformer supporting both offline and streaming modes
    through dynamic masking in the encoder. Used for 2-pass decoding
    where first pass provides streaming results and second pass refines.

    Output: {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}

    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    Author: Kun Zou, chinazoukun@gmail.com
    E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
    https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
    """

    def __init__(
        self,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        decoder: str = None,
        decoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        predictor: str = None,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L220

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()

        # decoder: CTC branch
View full source →
.encode(speech, speech_lengths, **kwargs) L292

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
.calc_predictor(encoder_out, encoder_out_lens) L321

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source
    def calc_predictor(self, encoder_out, encoder_out_lens):

        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
            encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
        )
        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) L337

Cal decoder with predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source
    def cal_decoder_with_predictor(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
    ):

        """Cal decoder with predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds) L425

Sampler.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
📄 Source
    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):

        """Sampler.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
            )
            decoder_out, _ = decoder_outs[0], decoder_outs[1]
            pred_tokens = decoder_out.argmax(-1)
            nonpad_positions = ys_pad.ne(self.ignore_id)
            seq_lens = (nonpad_positions).sum(1)
            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
            input_mask = torch.ones_like(nonpad_positions)
            bsz, seq_len = ys_pad.size()
View full source →
.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds) L474

Sampler with grad.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
📄 Source
    def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):

        """Sampler with grad.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
        )
        decoder_out, _ = decoder_outs[0], decoder_outs[1]
        pred_tokens = decoder_out.argmax(-1)
        nonpad_positions = ys_pad.ne(self.ignore_id)
        seq_lens = (nonpad_positions).sum(1)
        same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
        input_mask = torch.ones_like(nonpad_positions)
        bsz, seq_len = ys_pad.size()
        for li in range(bsz):
View full source →
.init_beam_search(**kwargs) L548

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.paraformer.search import BeamSearchPara
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

        weights = dict(
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L600

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        # init beamsearch
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        pred_timestamp = kwargs.get("pred_timestamp", False)
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

View full source →
.export(**kwargs) L765

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
method

EParaformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()

        # decoder: CTC branch
View full source on GitHub →
method

EParaformer.encode(speech, speech_lengths, **kwargs)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
method

EParaformer.calc_predictor(encoder_out, encoder_out_lens)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source code
    def calc_predictor(self, encoder_out, encoder_out_lens):

        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
            encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
        )
        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
method

EParaformer.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Cal decoder with predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source code
    def cal_decoder_with_predictor(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
    ):

        """Cal decoder with predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
method

EParaformer.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Sampler.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
📄 Source code
    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):

        """Sampler.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
            )
            decoder_out, _ = decoder_outs[0], decoder_outs[1]
            pred_tokens = decoder_out.argmax(-1)
            nonpad_positions = ys_pad.ne(self.ignore_id)
            seq_lens = (nonpad_positions).sum(1)
            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
            input_mask = torch.ones_like(nonpad_positions)
            bsz, seq_len = ys_pad.size()
View full source on GitHub →
method

EParaformer.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Sampler with grad.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
📄 Source code
    def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):

        """Sampler with grad.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
        )
        decoder_out, _ = decoder_outs[0], decoder_outs[1]
        pred_tokens = decoder_out.argmax(-1)
        nonpad_positions = ys_pad.ne(self.ignore_id)
        seq_lens = (nonpad_positions).sum(1)
        same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
        input_mask = torch.ones_like(nonpad_positions)
        bsz, seq_len = ys_pad.size()
        for li in range(bsz):
View full source on GitHub →
method

EParaformer.init_beam_search(**kwargs)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.paraformer.search import BeamSearchPara
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

        weights = dict(
View full source on GitHub →
method

EParaformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        # init beamsearch
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        pred_timestamp = kwargs.get("pred_timestamp", False)
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

View full source on GitHub →
method

EParaformer.export(**kwargs)

funasr.models.e_paraformer.EParaformer · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
class

Emotion2vec

funasr.models.emotion2vec · View on GitHub ↗
  • Author — Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
  • emotion2vec — Self-Supervised Pre-Training for Speech Emotion Representation

https://arxiv.org/abs/2312.15185

📄 Source code
class Emotion2vec(torch.nn.Module):
    """
    Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
    emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
    https://arxiv.org/abs/2312.15185
    """

    def __init__(self, **kwargs):
        """Initialize Emotion2vec.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        # import pdb; pdb.set_trace()
        cfg = OmegaConf.create(kwargs["model_conf"])
        self.cfg = cfg

        make_layer_norm = partial(
            torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
        )

        def make_block(drop_path, dim=None, heads=None):
            """Make block.
            
                Args:
                    drop_path: TODO.
                    dim: TODO.
                    heads: TODO.
                """
View full source on GitHub →

Methods

.forward(source, target, id, mode, padding_mask, mask, features_only, force_remove_masked, remove_extra_tokens, precomputed_mask, **kwargs) L121

Forward pass for training.

Args:

  • source — TODO.
  • target — TODO.
  • id — TODO.
  • mode — TODO.
  • padding_mask — TODO.
  • mask — TODO.
  • features_only — TODO.
  • force_remove_masked — TODO.
  • remove_extra_tokens — TODO.
  • precomputed_mask — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
        **kwargs,
    ):

        """Forward pass for training.
        
            Args:
                source: TODO.
                target: TODO.
                id: TODO.
                mode: TODO.
                padding_mask: TODO.
                mask: TODO.
                features_only: TODO.
                force_remove_masked: TODO.
                remove_extra_tokens: TODO.
                precomputed_mask: TODO.
                **kwargs: Additional keyword arguments.
            """
View full source →
.extract_features(source, mode, padding_mask, mask, remove_extra_tokens) L212

Extract features.

Args:

  • source — TODO.
  • mode — TODO.
  • padding_mask — TODO.
  • mask — TODO.
  • remove_extra_tokens — TODO.
📄 Source
    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
    ):
        """Extract features.
        
            Args:
                source: TODO.
                mode: TODO.
                padding_mask: TODO.
                mask: TODO.
                remove_extra_tokens: TODO.
            """
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L234

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        # if source_file.endswith('.wav'):
        #     wav, sr = sf.read(source_file)
        #     channel = sf.info(source_file).channels
        #     assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
        #     assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        granularity = kwargs.get("granularity", "utterance")
        extract_embedding = kwargs.get("extract_embedding", True)
        if self.proj is None:
            extract_embedding = True
        meta_data = {}
View full source →
.export(**kwargs) L320

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
method

Emotion2vec.forward(source, target, id, mode, padding_mask, mask, features_only, force_remove_masked, remove_extra_tokens, precomputed_mask, **kwargs)

funasr.models.emotion2vec.Emotion2vec · View on GitHub ↗

Forward pass for training.

Args:

  • source — TODO.
  • target — TODO.
  • id — TODO.
  • mode — TODO.
  • padding_mask — TODO.
  • mask — TODO.
  • features_only — TODO.
  • force_remove_masked — TODO.
  • remove_extra_tokens — TODO.
  • precomputed_mask — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
        **kwargs,
    ):

        """Forward pass for training.
        
            Args:
                source: TODO.
                target: TODO.
                id: TODO.
                mode: TODO.
                padding_mask: TODO.
                mask: TODO.
                features_only: TODO.
                force_remove_masked: TODO.
                remove_extra_tokens: TODO.
                precomputed_mask: TODO.
                **kwargs: Additional keyword arguments.
            """
View full source on GitHub →
method

Emotion2vec.extract_features(source, mode, padding_mask, mask, remove_extra_tokens)

funasr.models.emotion2vec.Emotion2vec · View on GitHub ↗

Extract features.

Args:

  • source — TODO.
  • mode — TODO.
  • padding_mask — TODO.
  • mask — TODO.
  • remove_extra_tokens — TODO.
📄 Source code
    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
    ):
        """Extract features.
        
            Args:
                source: TODO.
                mode: TODO.
                padding_mask: TODO.
                mask: TODO.
                remove_extra_tokens: TODO.
            """
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res
method

Emotion2vec.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.emotion2vec.Emotion2vec · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        # if source_file.endswith('.wav'):
        #     wav, sr = sf.read(source_file)
        #     channel = sf.info(source_file).channels
        #     assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
        #     assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        granularity = kwargs.get("granularity", "utterance")
        extract_embedding = kwargs.get("extract_embedding", True)
        if self.proj is None:
            extract_embedding = True
        meta_data = {}
View full source on GitHub →
method

Emotion2vec.export(**kwargs)

funasr.models.emotion2vec.Emotion2vec · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
class

ERes2NetV2SV

funasr.models.eres2net · View on GitHub ↗
  • ERes2NetV2 — Enhanced Res2Net v2 for Speaker Verification.

Improved speaker embedding model based on Res2Net architecture with

  • multi — scale feature aggregation. Provides 192-dim speaker embeddings

for speaker verification and diarization.

Better than CAM++ for short-duration audio (< 3s) speaker feature extraction.

  • Output — {"spk_embedding": Tensor of shape (1, 192)}
📄 Source code
class ERes2NetV2SV(torch.nn.Module):
    """ERes2NetV2: Enhanced Res2Net v2 for Speaker Verification.

    Improved speaker embedding model based on Res2Net architecture with
    multi-scale feature aggregation. Provides 192-dim speaker embeddings
    for speaker verification and diarization.

    Better than CAM++ for short-duration audio (< 3s) speaker feature extraction.

    Output: {"spk_embedding": Tensor of shape (1, 192)}
    """

    def __init__(
        self,
        feat_dim=80,
        embedding_size=192,
        m_channels=64,
        baseWidth=26,
        scale=2,
        expansion=2,
        num_blocks=[3, 4, 6, 3],
        pooling_func="TSTP",
        two_emb_layer=False,
        **kwargs,
    ):
        """Initialize ERes2NetV2SV.
        
            Args:
                feat_dim: Size/dimension parameter.
                embedding_size: Size/dimension parameter.
View full source on GitHub →

Methods

.forward(x) L97

Forward pass for training.

Args:

  • x — TODO.
📄 Source
    def forward(self, x):
        """Forward pass for training.
        
            Args:
                x: TODO.
            """
        return self.model(x)
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L105

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(
            data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
        )
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
        speech = speech.to(device=kwargs["device"])
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
View full source →
method

ERes2NetV2SV.forward(x)

funasr.models.eres2net.ERes2NetV2SV · View on GitHub ↗

Forward pass for training.

Args:

  • x — TODO.
📄 Source code
    def forward(self, x):
        """Forward pass for training.
        
            Args:
                x: TODO.
            """
        return self.model(x)
method

ERes2NetV2SV.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.eres2net.ERes2NetV2SV · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(
            data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
        )
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
        speech = speech.to(device=kwargs["device"])
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
View full source on GitHub →
class

FsmnKWS

funasr.models.fsmn_kws · View on GitHub ↗
  • FSMN — KWS: Keyword Spotting model using FSMN architecture.

Detects predefined keywords/wake words in audio streams.

Supports both offline and streaming operation.

  • Output — {"key": str, "value": detected_keyword_info}
📄 Source code
class FsmnKWS(torch.nn.Module):
    """FSMN-KWS: Keyword Spotting model using FSMN architecture.

    Detects predefined keywords/wake words in audio streams.
    Supports both offline and streaming operation.

    Output: {"key": str, "value": detected_keyword_info}
    """

    def __init__(
        self,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        ctc_weight: float = 1.0,
        input_size: int = 360,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        **kwargs,
    ):
        """Initialize FsmnKWS.
        
            Args:
                specaug: TODO.
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L104

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        loss_ctc, cer_ctc = self._calc_ctc_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )

        # Collect CTC branch stats
        stats = dict()
        stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
View full source →
.encode(speech, speech_lengths, **kwargs) L146

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out = self.encoder(speech)
        encoder_out_lens = speech_lengths

        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L199

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list=None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        keywords = kwargs.get("keywords")
        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer.token_list,
            seg_dict=tokenizer.seg_dict,
        )

        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
View full source →
method

FsmnKWS.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.fsmn_kws.FsmnKWS · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        loss_ctc, cer_ctc = self._calc_ctc_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )

        # Collect CTC branch stats
        stats = dict()
        stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
View full source on GitHub →
method

FsmnKWS.encode(speech, speech_lengths, **kwargs)

funasr.models.fsmn_kws.FsmnKWS · View on GitHub ↗

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out = self.encoder(speech)
        encoder_out_lens = speech_lengths

        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
method

FsmnKWS.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.fsmn_kws.FsmnKWS · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list=None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        keywords = kwargs.get("keywords")
        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer.token_list,
            seg_dict=tokenizer.seg_dict,
        )

        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
View full source on GitHub →
class

FsmnKWSConvert

funasr.models.fsmn_kws · View on GitHub ↗
  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • Deep — FSMN for Large Vocabulary Continuous Speech Recognition

https://arxiv.org/abs/1803.05030

📄 Source code
class FsmnKWSConvert(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """

    def __init__(
        self,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        ctc_weight: float = 1.0,
        input_size: int = 360,
        vocab_size: int = -1,
        blank_id: int = 0,
        **kwargs,
    ):
        """Initialize FsmnKWSConvert.
        
            Args:
                encoder: TODO.
                encoder_conf: Configuration dict for encoder.
                ctc: TODO.
                ctc_conf: Configuration dict for ctc.
                ctc_weight: TODO.
                input_size: Size/dimension parameter.
                vocab_size: Size/dimension parameter.
                blank_id: TODO.
View full source on GitHub →

Methods

.to_kaldi_net() L331

To kaldi net.

📄 Source
    def to_kaldi_net(self):
		    """To kaldi net."""
		    return self.encoder.to_kaldi_net()
.to_pytorch_net(kaldi_file) L336

To pytorch net.

Args:

  • kaldi_file — TODO.
📄 Source
    def to_pytorch_net(self, kaldi_file):
		    """To pytorch net.
		    
		        Args:
		            kaldi_file: TODO.
		        """
		    return self.encoder.to_pytorch_net(kaldi_file)
method

FsmnKWSConvert.to_kaldi_net()

funasr.models.fsmn_kws.FsmnKWSConvert · View on GitHub ↗

To kaldi net.

📄 Source code
    def to_kaldi_net(self):
		    """To kaldi net."""
		    return self.encoder.to_kaldi_net()
method

FsmnKWSConvert.to_pytorch_net(kaldi_file)

funasr.models.fsmn_kws.FsmnKWSConvert · View on GitHub ↗

To pytorch net.

Args:

  • kaldi_file — TODO.
📄 Source code
    def to_pytorch_net(self, kaldi_file):
		    """To pytorch net.
		    
		        Args:
		            kaldi_file: TODO.
		        """
		    return self.encoder.to_pytorch_net(kaldi_file)
class

FsmnKWSMT

funasr.models.fsmn_kws_mt · View on GitHub ↗
  • FSMN — KWS-MT: Multi-Task FSMN Keyword Spotting.

Keyword spotting with multi-task learning: simultaneously

detects keywords and performs filler token classification.

Improves keyword detection robustness through auxiliary tasks.

  • Output — {"key": str, "value": keyword_detection_result}
  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • Deep — FSMN for Large Vocabulary Continuous Speech Recognition

https://arxiv.org/abs/1803.05030

📄 Source code
class FsmnKWSMT(torch.nn.Module):
    """FSMN-KWS-MT: Multi-Task FSMN Keyword Spotting.

    Keyword spotting with multi-task learning: simultaneously
    detects keywords and performs filler token classification.
    Improves keyword detection robustness through auxiliary tasks.

    Output: {"key": str, "value": keyword_detection_result}

    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """

    def __init__(
        self,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        ctc_conf: Optional[Dict] = None,
        input_size: int = 360,
        vocab_size: list = [],
        ignore_id: int = -1,
        blank_id: int = 0,
        **kwargs,
    ):
        """Initialize FsmnKWSMT.
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, text2, text2_lengths, **kwargs) L106

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
  • text2 — (Batch, Length)
  • text2_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        text2: torch.Tensor,
        text2_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
                text2: (Batch, Length)
                text2_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out2, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = self._calc_ctc_loss(
            encoder_out, encoder_out_lens, text, text_lengths
View full source →
.encode(speech, speech_lengths, **kwargs) L158

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out2 = self.encoder(speech)
        encoder_out_lens = speech_lengths

        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        if isinstance(encoder_out2, tuple):
            encoder_out2 = encoder_out2[0]
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L242

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list=None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        keywords = kwargs.get("keywords")
        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer[0].token_list,
            seg_dict=tokenizer[0].seg_dict,
        )
        self.kws_decoder2 = KwsCtcPrefixDecoder(
            ctc=self.ctc2,
            keywords=keywords,
View full source →
method

FsmnKWSMT.forward(speech, speech_lengths, text, text_lengths, text2, text2_lengths, **kwargs)

funasr.models.fsmn_kws_mt.FsmnKWSMT · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
  • text2 — (Batch, Length)
  • text2_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        text2: torch.Tensor,
        text2_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
                text2: (Batch, Length)
                text2_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out2, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = self._calc_ctc_loss(
            encoder_out, encoder_out_lens, text, text_lengths
View full source on GitHub →
method

FsmnKWSMT.encode(speech, speech_lengths, **kwargs)

funasr.models.fsmn_kws_mt.FsmnKWSMT · View on GitHub ↗

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out2 = self.encoder(speech)
        encoder_out_lens = speech_lengths

        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        if isinstance(encoder_out2, tuple):
            encoder_out2 = encoder_out2[0]
View full source on GitHub →
method

FsmnKWSMT.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.fsmn_kws_mt.FsmnKWSMT · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list=None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        keywords = kwargs.get("keywords")
        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer[0].token_list,
            seg_dict=tokenizer[0].seg_dict,
        )
        self.kws_decoder2 = KwsCtcPrefixDecoder(
            ctc=self.ctc2,
            keywords=keywords,
View full source on GitHub →
class

FsmnKWSMTConvert

funasr.models.fsmn_kws_mt · View on GitHub ↗
  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • Deep — FSMN for Large Vocabulary Continuous Speech Recognition

https://arxiv.org/abs/1803.05030

📄 Source code
class FsmnKWSMTConvert(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """

    def __init__(
        self,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        ctc_conf: Optional[Dict] = None,
        ctc_weight: float = 1.0,
        input_size: int = 360,
        blank_id: int = 0,
        **kwargs,
    ):
        """Initialize FsmnKWSMTConvert.
        
            Args:
                encoder: TODO.
                encoder_conf: Configuration dict for encoder.
                ctc_conf: Configuration dict for ctc.
                ctc_weight: TODO.
                input_size: Size/dimension parameter.
                blank_id: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()

View full source on GitHub →

Methods

.to_kaldi_net() L390

To kaldi net.

📄 Source
    def to_kaldi_net(self):
        """To kaldi net."""
        return self.encoder.to_kaldi_net()
.to_kaldi_net2() L394

To kaldi net2.

📄 Source
    def to_kaldi_net2(self):
        """To kaldi net2."""
        return self.encoder.to_kaldi_net2()
.to_pytorch_net(kaldi_file) L398

To pytorch net.

Args:

  • kaldi_file — TODO.
📄 Source
    def to_pytorch_net(self, kaldi_file):
        """To pytorch net.
        
            Args:
                kaldi_file: TODO.
            """
        return self.encoder.to_pytorch_net(kaldi_file)
method

FsmnKWSMTConvert.to_kaldi_net()

funasr.models.fsmn_kws_mt.FsmnKWSMTConvert · View on GitHub ↗

To kaldi net.

📄 Source code
    def to_kaldi_net(self):
        """To kaldi net."""
        return self.encoder.to_kaldi_net()
method

FsmnKWSMTConvert.to_kaldi_net2()

funasr.models.fsmn_kws_mt.FsmnKWSMTConvert · View on GitHub ↗

To kaldi net2.

📄 Source code
    def to_kaldi_net2(self):
        """To kaldi net2."""
        return self.encoder.to_kaldi_net2()
method

FsmnKWSMTConvert.to_pytorch_net(kaldi_file)

funasr.models.fsmn_kws_mt.FsmnKWSMTConvert · View on GitHub ↗

To pytorch net.

Args:

  • kaldi_file — TODO.
📄 Source code
    def to_pytorch_net(self, kaldi_file):
        """To pytorch net.
        
            Args:
                kaldi_file: TODO.
            """
        return self.encoder.to_pytorch_net(kaldi_file)
class

FsmnVADStreaming

funasr.models.fsmn_vad_streaming · View on GitHub ↗
  • FSMN — based Voice Activity Detection (streaming/offline).

Detects speech segments in audio, returning start/end timestamps (milliseconds).

Supports both offline (full audio) and streaming (chunk-by-chunk) modes.

Offline output: [{"key": "...", "value": [[start_ms, end_ms], ...]}]

Streaming output: [[beg, -1]] (start), [[-1, end]] (end), [[beg, end]] (complete), [] (no event)

  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • Deep — FSMN for Large Vocabulary Continuous Speech Recognition

https://arxiv.org/abs/1803.05030

📄 Source code
class FsmnVADStreaming(nn.Module):
    """FSMN-based Voice Activity Detection (streaming/offline).

    Detects speech segments in audio, returning start/end timestamps (milliseconds).
    Supports both offline (full audio) and streaming (chunk-by-chunk) modes.

    Offline output: [{"key": "...", "value": [[start_ms, end_ms], ...]}]
    Streaming output: [[beg, -1]] (start), [[-1, end]] (end), [[beg, end]] (complete), [] (no event)

    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
    https://arxiv.org/abs/1803.05030
    """

    def __init__(
        self,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        vad_post_args: Dict[str, Any] = None,
        **kwargs,
    ):
        """Initialize FsmnVADStreaming.
        
            Args:
                encoder: TODO.
                encoder_conf: Configuration dict for encoder.
                vad_post_args: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
View full source on GitHub →

Methods

.ResetDetection(cache) L382

Resetdetection.

Args:

  • cache — State cache dict for streaming inference.
📄 Source
    def ResetDetection(self, cache: dict = None):
        """Resetdetection.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        cache["stats"].continous_silence_frame_count = 0
        cache["stats"].latest_confirmed_speech_frame = 0
        cache["stats"].lastest_confirmed_silence_frame = -1
        cache["stats"].confirmed_start_frame = -1
        cache["stats"].confirmed_end_frame = -1
        cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        cache["windows_detector"].Reset()
        cache["stats"].sil_frame = 0
        cache["stats"].frame_probs = []

        if cache["stats"].output_data_buf:
            assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
            drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
            real_drop_frames = drop_frames - cache["stats"].last_drop_frames
            cache["stats"].last_drop_frames = drop_frames
            cache["stats"].data_buf_all = cache["stats"].data_buf_all[
                real_drop_frames
                * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
            ]
            cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
            cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
.ComputeDecibel(cache) L412

Computedecibel.

Args:

  • cache — State cache dict for streaming inference.
📄 Source
    def ComputeDecibel(self, cache: dict = None) -> None:
        """Computedecibel.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if cache["stats"].data_buf_all is None:
            cache["stats"].data_buf_all = cache["stats"].waveform[
                0
            ]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
            cache["stats"].data_buf = cache["stats"].data_buf_all
        else:
            cache["stats"].data_buf_all = torch.cat(
                (cache["stats"].data_buf_all, cache["stats"].waveform[0])
            )
            
        waveform_numpy = cache["stats"].waveform.numpy()

        offsets = np.arange(0, waveform_numpy.shape[1] - frame_sample_length + 1, frame_shift_length)
        frames = waveform_numpy[0, offsets[:, np.newaxis] + np.arange(frame_sample_length)]

        decibel_numpy = 10 * np.log10(np.sum(np.square(frames), axis=1) + 0.000001)
        decibel_numpy = decibel_numpy.tolist()

        cache["stats"].decibel.extend(decibel_numpy)
.ComputeScores(feats, cache) L443

Computescores.

Args:

  • feats — Feature tensor (e.g., fbank), shape (batch, frames, dim).
  • cache — State cache dict for streaming inference.
📄 Source
    def ComputeScores(self, feats: torch.Tensor, cache: dict = None) -> None:
        """Computescores.
        
            Args:
                feats: Feature tensor (e.g., fbank), shape (batch, frames, dim).
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        scores = self.encoder(feats, cache=cache["encoder"])  # return B * T * D
        assert (
            scores.shape[1] == feats.shape[1]
        ), "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        cache["stats"].frm_cnt += scores.shape[1]  # count total frames
        if cache["stats"].scores is None:
            cache["stats"].scores = scores  # the first calculation
        else:
            cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
.PopDataBufTillFrame(frame_idx, cache) L463

Popdatabuftillframe.

Args:

  • frame_idx — TODO.
  • cache — State cache dict for streaming inference.
📄 Source
    def PopDataBufTillFrame(self, frame_idx: int, cache: dict = None) -> None:  # need check again
        """Popdatabuftillframe.
        
            Args:
                frame_idx: TODO.
                cache: State cache dict for streaming inference.
            """
        while cache["stats"].data_buf_start_frame < frame_idx:
            if len(cache["stats"].data_buf) >= int(
                self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
            ):
                if cache is None:
                    cache = {}
                cache["stats"].data_buf_start_frame += 1
                cache["stats"].data_buf = cache["stats"].data_buf_all[
                    (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames)
                    * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
                ]
.PopDataToOutputBuf(start_frm, frm_cnt, first_frm_is_start_point, last_frm_is_end_point, end_point_is_sent_end, cache) L482

Popdatatooutputbuf.

Args:

  • start_frm — TODO.
  • frm_cnt — TODO.
  • first_frm_is_start_point — TODO.
  • last_frm_is_end_point — TODO.
  • end_point_is_sent_end — TODO.
  • cache — State cache dict for streaming inference.
📄 Source
    def PopDataToOutputBuf(
        self,
        start_frm: int,
        frm_cnt: int,
        first_frm_is_start_point: bool,
        last_frm_is_end_point: bool,
        end_point_is_sent_end: bool,
        cache: dict = None,
    ) -> None:
        """Popdatatooutputbuf.
        
            Args:
                start_frm: TODO.
                frm_cnt: TODO.
                first_frm_is_start_point: TODO.
                last_frm_is_end_point: TODO.
                end_point_is_sent_end: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        self.PopDataBufTillFrame(start_frm, cache=cache)
        expected_sample_number = int(
            frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
        )
        if last_frm_is_end_point:
            extra_sample = max(
                0,
                int(
                    self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
View full source →
.OnSilenceDetected(valid_frame, cache) L552

Onsilencedetected.

Args:

  • valid_frame — TODO.
  • cache — State cache dict for streaming inference.
📄 Source
    def OnSilenceDetected(self, valid_frame: int, cache: dict = None):
        """Onsilencedetected.
        
            Args:
                valid_frame: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        cache["stats"].lastest_confirmed_silence_frame = valid_frame
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataBufTillFrame(valid_frame, cache=cache)
.OnVoiceDetected(valid_frame, cache) L568

Onvoicedetected.

Args:

  • valid_frame — TODO.
  • cache — State cache dict for streaming inference.
📄 Source
    def OnVoiceDetected(self, valid_frame: int, cache: dict = None) -> None:
        """Onvoicedetected.
        
            Args:
                valid_frame: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        cache["stats"].latest_confirmed_speech_frame = valid_frame
        self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
.OnVoiceStart(start_frame, fake_result, cache) L580

Onvoicestart.

Args:

  • start_frame — TODO.
  • fake_result — TODO.
  • cache — State cache dict for streaming inference.
📄 Source
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = None) -> None:
        """Onvoicestart.
        
            Args:
                start_frame: TODO.
                fake_result: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        if self.vad_opts.do_start_point_detection:
            pass
        if cache["stats"].confirmed_start_frame != -1:
            print("not reset vad properly\n")
        else:
            cache["stats"].confirmed_start_frame = start_frame

        if (
            not fake_result
            and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
        ):
            self.PopDataToOutputBuf(
                cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache
            )
.OnVoiceEnd(end_frame, fake_result, is_last_frame, cache) L605

Onvoiceend.

Args:

  • end_frame — TODO.
  • fake_result — TODO.
  • is_last_frame — Boolean flag for last frame.
  • cache — State cache dict for streaming inference.
📄 Source
    def OnVoiceEnd(
        self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = None
    ) -> None:
        """Onvoiceend.
        
            Args:
                end_frame: TODO.
                fake_result: TODO.
                is_last_frame: Boolean flag for last frame.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
            self.OnVoiceDetected(t, cache=cache)
        if self.vad_opts.do_end_point_detection:
            pass
        if cache["stats"].confirmed_end_frame != -1:
            print("not reset vad properly\n")
        else:
            cache["stats"].confirmed_end_frame = end_frame
        if not fake_result:
            cache["stats"].sil_frame = 0
            self.PopDataToOutputBuf(
                cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache
            )
        cache["stats"].number_end_time_detected += 1
.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache) L633

Maybeonvoiceendiflastframe.

Args:

  • is_final_frame — Boolean flag for final frame.
  • cur_frm_idx — TODO.
  • cache — State cache dict for streaming inference.
📄 Source
    def MaybeOnVoiceEndIfLastFrame(
        self, is_final_frame: bool, cur_frm_idx: int, cache: dict = None
    ) -> None:
        """Maybeonvoiceendiflastframe.
        
            Args:
                is_final_frame: Boolean flag for final frame.
                cur_frm_idx: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        if is_final_frame:
            self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
.GetLatency(cache) L649

Getlatency.

Args:

  • cache — State cache dict for streaming inference.
📄 Source
    def GetLatency(self, cache: dict = None) -> int:
        """Getlatency.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
.LatencyFrmNumAtStartPoint(cache) L659

Latencyfrmnumatstartpoint.

Args:

  • cache — State cache dict for streaming inference.
📄 Source
    def LatencyFrmNumAtStartPoint(self, cache: dict = None) -> int:
        """Latencyfrmnumatstartpoint.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        vad_latency = cache["windows_detector"].GetWinSize()
        if self.vad_opts.do_extend:
            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
        return vad_latency
.GetFrameState(t, cache) L672

Getframestate.

Args:

  • t — TODO.
  • cache — State cache dict for streaming inference.
📄 Source
    def GetFrameState(self, t: int, cache: dict = None):
        """Getframestate.
        
            Args:
                t: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        frame_state = FrameState.kFrameStateInvalid
        if t >= len(cache["stats"].decibel):
            return FrameState.kFrameStateSil
        cur_decibel = cache["stats"].decibel[t]
        cur_snr = cur_decibel - cache["stats"].noise_average_decibel
        # for each frame, calc log posterior probability of each state
        if cur_decibel < self.vad_opts.decibel_thres:
            frame_state = FrameState.kFrameStateSil
            self.DetectOneFrame(frame_state, t, False, cache=cache)
            return frame_state

        sum_score = 0.0
        noise_prob = 0.0
        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
        if len(cache["stats"].sil_pdf_ids) > 0:
            assert len(cache["stats"].scores) == 1  # 只支持batch_size = 1的测试
            """
            - Change type of `sum_score` to float. The reason is that `sum_score` is a tensor with single element.
              and `torch.Tensor` is slower `float` when tensor has only one element.
            - Put the iteration of `sil_pdf_ids` inside `sum()` to reduce the overhead of creating a new list.
            - The default `sil_pdf_ids` is [0], the `if` statement is used to reduce the overhead of expression
View full source →
.forward(feats, waveform, cache, is_final, **kwargs) L737

Forward pass for training.

Args:

  • feats — Feature tensor (e.g., fbank), shape (batch, frames, dim).
  • waveform — TODO.
  • cache — State cache dict for streaming inference.
  • is_final — Whether this is the final chunk in streaming.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(
        self,
        feats: torch.Tensor,
        waveform: torch.tensor,
        cache: dict = None,
        is_final: bool = False,
        **kwargs,
    ):
        """Forward pass for training.
        
            Args:
                feats: Feature tensor (e.g., fbank), shape (batch, frames, dim).
                waveform: TODO.
                cache: State cache dict for streaming inference.
                is_final: Whether this is the final chunk in streaming.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        # if len(cache) == 0:
        #     self.AllResetDetection()
        # self.waveform = waveform  # compute decibel for each frame
        cache["stats"].waveform = waveform
        is_streaming_input = kwargs.get("is_streaming_input", True)
        self.ComputeDecibel(cache=cache)
        self.ComputeScores(feats, cache=cache)
        if not is_final:
            self.DetectCommonFrames(cache=cache)
        else:
            self.DetectLastFrames(cache=cache)
View full source →
.init_cache(cache, **kwargs) L820

Init cache.

Args:

  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def init_cache(self, cache: dict = None, **kwargs):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}

        cache["frontend"] = {}
        cache["prev_samples"] = torch.empty(0)
        cache["encoder"] = {}

        if kwargs.get("max_end_silence_time") is not None:
            # update the max_end_silence_time
            self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")

        windows_detector = WindowDetector(
            self.vad_opts.window_size_ms,
            self.vad_opts.sil_to_speech_time_thres,
            self.vad_opts.speech_to_sil_time_thres,
            self.vad_opts.frame_in_ms,
        )
        windows_detector.Reset()

        stats = Stats(
            sil_pdf_ids=self.vad_opts.sil_pdf_ids,
            max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time
            - self.vad_opts.speech_to_sil_time_thres,
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L856

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        if len(cache) == 0:
            self.init_cache(cache, **kwargs)

        meta_data = {}
        chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)

View full source →
.export(**kwargs) L970

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):

        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
.DetectCommonFrames(cache) L982

Detectcommonframes.

Args:

  • cache — State cache dict for streaming inference.
📄 Source
    def DetectCommonFrames(self, cache: dict = None) -> int:
        """Detectcommonframes.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(
                cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
            )
            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)

        return 0
.DetectLastFrames(cache) L1001

Detectlastframes.

Args:

  • cache — State cache dict for streaming inference.
📄 Source
    def DetectLastFrames(self, cache: dict = None) -> int:
        """Detectlastframes.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(
                cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
            )
            if i != 0:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
            else:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)

        return 0
.DetectOneFrame(cur_frm_state, cur_frm_idx, is_final_frame, cache) L1023

Detectoneframe.

Args:

  • cur_frm_state — TODO.
  • cur_frm_idx — TODO.
  • is_final_frame — Boolean flag for final frame.
  • cache — State cache dict for streaming inference.
📄 Source
    def DetectOneFrame(
        self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = None
    ) -> None:
        """Detectoneframe.
        
            Args:
                cur_frm_state: TODO.
                cur_frm_idx: TODO.
                is_final_frame: Boolean flag for final frame.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        tmp_cur_frm_state = FrameState.kFrameStateInvalid
        if cur_frm_state == FrameState.kFrameStateSpeech:
            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
                tmp_cur_frm_state = FrameState.kFrameStateSpeech
            else:
                tmp_cur_frm_state = FrameState.kFrameStateSil
        elif cur_frm_state == FrameState.kFrameStateSil:
            tmp_cur_frm_state = FrameState.kFrameStateSil
        state_change = cache["windows_detector"].DetectOneFrame(
            tmp_cur_frm_state, cur_frm_idx, cache=cache
        )
        frm_shift_in_ms = self.vad_opts.frame_in_ms
        if AudioChangeState.kChangeStateSil2Speech == state_change:
            silence_frame_count = cache["stats"].continous_silence_frame_count
            cache["stats"].continous_silence_frame_count = 0
            cache["stats"].pre_end_silence_detected = False
            start_frame = 0
View full source →
method

FsmnVADStreaming.ResetDetection(cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Resetdetection.

Args:

  • cache — State cache dict for streaming inference.
📄 Source code
    def ResetDetection(self, cache: dict = None):
        """Resetdetection.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        cache["stats"].continous_silence_frame_count = 0
        cache["stats"].latest_confirmed_speech_frame = 0
        cache["stats"].lastest_confirmed_silence_frame = -1
        cache["stats"].confirmed_start_frame = -1
        cache["stats"].confirmed_end_frame = -1
        cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
        cache["windows_detector"].Reset()
        cache["stats"].sil_frame = 0
        cache["stats"].frame_probs = []

        if cache["stats"].output_data_buf:
            assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
            drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
            real_drop_frames = drop_frames - cache["stats"].last_drop_frames
            cache["stats"].last_drop_frames = drop_frames
            cache["stats"].data_buf_all = cache["stats"].data_buf_all[
                real_drop_frames
                * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
            ]
            cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
            cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
method

FsmnVADStreaming.ComputeDecibel(cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Computedecibel.

Args:

  • cache — State cache dict for streaming inference.
📄 Source code
    def ComputeDecibel(self, cache: dict = None) -> None:
        """Computedecibel.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
        if cache["stats"].data_buf_all is None:
            cache["stats"].data_buf_all = cache["stats"].waveform[
                0
            ]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
            cache["stats"].data_buf = cache["stats"].data_buf_all
        else:
            cache["stats"].data_buf_all = torch.cat(
                (cache["stats"].data_buf_all, cache["stats"].waveform[0])
            )
            
        waveform_numpy = cache["stats"].waveform.numpy()

        offsets = np.arange(0, waveform_numpy.shape[1] - frame_sample_length + 1, frame_shift_length)
        frames = waveform_numpy[0, offsets[:, np.newaxis] + np.arange(frame_sample_length)]

        decibel_numpy = 10 * np.log10(np.sum(np.square(frames), axis=1) + 0.000001)
        decibel_numpy = decibel_numpy.tolist()

        cache["stats"].decibel.extend(decibel_numpy)
method

FsmnVADStreaming.ComputeScores(feats, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Computescores.

Args:

  • feats — Feature tensor (e.g., fbank), shape (batch, frames, dim).
  • cache — State cache dict for streaming inference.
📄 Source code
    def ComputeScores(self, feats: torch.Tensor, cache: dict = None) -> None:
        """Computescores.
        
            Args:
                feats: Feature tensor (e.g., fbank), shape (batch, frames, dim).
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        scores = self.encoder(feats, cache=cache["encoder"])  # return B * T * D
        assert (
            scores.shape[1] == feats.shape[1]
        ), "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        cache["stats"].frm_cnt += scores.shape[1]  # count total frames
        if cache["stats"].scores is None:
            cache["stats"].scores = scores  # the first calculation
        else:
            cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
method

FsmnVADStreaming.PopDataBufTillFrame(frame_idx, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Popdatabuftillframe.

Args:

  • frame_idx — TODO.
  • cache — State cache dict for streaming inference.
📄 Source code
    def PopDataBufTillFrame(self, frame_idx: int, cache: dict = None) -> None:  # need check again
        """Popdatabuftillframe.
        
            Args:
                frame_idx: TODO.
                cache: State cache dict for streaming inference.
            """
        while cache["stats"].data_buf_start_frame < frame_idx:
            if len(cache["stats"].data_buf) >= int(
                self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
            ):
                if cache is None:
                    cache = {}
                cache["stats"].data_buf_start_frame += 1
                cache["stats"].data_buf = cache["stats"].data_buf_all[
                    (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames)
                    * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
                ]
method

FsmnVADStreaming.PopDataToOutputBuf(start_frm, frm_cnt, first_frm_is_start_point, last_frm_is_end_point, end_point_is_sent_end, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Popdatatooutputbuf.

Args:

  • start_frm — TODO.
  • frm_cnt — TODO.
  • first_frm_is_start_point — TODO.
  • last_frm_is_end_point — TODO.
  • end_point_is_sent_end — TODO.
  • cache — State cache dict for streaming inference.
📄 Source code
    def PopDataToOutputBuf(
        self,
        start_frm: int,
        frm_cnt: int,
        first_frm_is_start_point: bool,
        last_frm_is_end_point: bool,
        end_point_is_sent_end: bool,
        cache: dict = None,
    ) -> None:
        """Popdatatooutputbuf.
        
            Args:
                start_frm: TODO.
                frm_cnt: TODO.
                first_frm_is_start_point: TODO.
                last_frm_is_end_point: TODO.
                end_point_is_sent_end: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        self.PopDataBufTillFrame(start_frm, cache=cache)
        expected_sample_number = int(
            frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
        )
        if last_frm_is_end_point:
            extra_sample = max(
                0,
                int(
                    self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
View full source on GitHub →
method

FsmnVADStreaming.OnSilenceDetected(valid_frame, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Onsilencedetected.

Args:

  • valid_frame — TODO.
  • cache — State cache dict for streaming inference.
📄 Source code
    def OnSilenceDetected(self, valid_frame: int, cache: dict = None):
        """Onsilencedetected.
        
            Args:
                valid_frame: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        cache["stats"].lastest_confirmed_silence_frame = valid_frame
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
            self.PopDataBufTillFrame(valid_frame, cache=cache)
method

FsmnVADStreaming.OnVoiceDetected(valid_frame, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Onvoicedetected.

Args:

  • valid_frame — TODO.
  • cache — State cache dict for streaming inference.
📄 Source code
    def OnVoiceDetected(self, valid_frame: int, cache: dict = None) -> None:
        """Onvoicedetected.
        
            Args:
                valid_frame: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        cache["stats"].latest_confirmed_speech_frame = valid_frame
        self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
method

FsmnVADStreaming.OnVoiceStart(start_frame, fake_result, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Onvoicestart.

Args:

  • start_frame — TODO.
  • fake_result — TODO.
  • cache — State cache dict for streaming inference.
📄 Source code
    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = None) -> None:
        """Onvoicestart.
        
            Args:
                start_frame: TODO.
                fake_result: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        if self.vad_opts.do_start_point_detection:
            pass
        if cache["stats"].confirmed_start_frame != -1:
            print("not reset vad properly\n")
        else:
            cache["stats"].confirmed_start_frame = start_frame

        if (
            not fake_result
            and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
        ):
            self.PopDataToOutputBuf(
                cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache
            )
method

FsmnVADStreaming.OnVoiceEnd(end_frame, fake_result, is_last_frame, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Onvoiceend.

Args:

  • end_frame — TODO.
  • fake_result — TODO.
  • is_last_frame — Boolean flag for last frame.
  • cache — State cache dict for streaming inference.
📄 Source code
    def OnVoiceEnd(
        self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = None
    ) -> None:
        """Onvoiceend.
        
            Args:
                end_frame: TODO.
                fake_result: TODO.
                is_last_frame: Boolean flag for last frame.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
            self.OnVoiceDetected(t, cache=cache)
        if self.vad_opts.do_end_point_detection:
            pass
        if cache["stats"].confirmed_end_frame != -1:
            print("not reset vad properly\n")
        else:
            cache["stats"].confirmed_end_frame = end_frame
        if not fake_result:
            cache["stats"].sil_frame = 0
            self.PopDataToOutputBuf(
                cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache
            )
        cache["stats"].number_end_time_detected += 1
method

FsmnVADStreaming.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Maybeonvoiceendiflastframe.

Args:

  • is_final_frame — Boolean flag for final frame.
  • cur_frm_idx — TODO.
  • cache — State cache dict for streaming inference.
📄 Source code
    def MaybeOnVoiceEndIfLastFrame(
        self, is_final_frame: bool, cur_frm_idx: int, cache: dict = None
    ) -> None:
        """Maybeonvoiceendiflastframe.
        
            Args:
                is_final_frame: Boolean flag for final frame.
                cur_frm_idx: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        if is_final_frame:
            self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
method

FsmnVADStreaming.GetLatency(cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Getlatency.

Args:

  • cache — State cache dict for streaming inference.
📄 Source code
    def GetLatency(self, cache: dict = None) -> int:
        """Getlatency.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
method

FsmnVADStreaming.LatencyFrmNumAtStartPoint(cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Latencyfrmnumatstartpoint.

Args:

  • cache — State cache dict for streaming inference.
📄 Source code
    def LatencyFrmNumAtStartPoint(self, cache: dict = None) -> int:
        """Latencyfrmnumatstartpoint.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        vad_latency = cache["windows_detector"].GetWinSize()
        if self.vad_opts.do_extend:
            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
        return vad_latency
method

FsmnVADStreaming.GetFrameState(t, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Getframestate.

Args:

  • t — TODO.
  • cache — State cache dict for streaming inference.
📄 Source code
    def GetFrameState(self, t: int, cache: dict = None):
        """Getframestate.
        
            Args:
                t: TODO.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        frame_state = FrameState.kFrameStateInvalid
        if t >= len(cache["stats"].decibel):
            return FrameState.kFrameStateSil
        cur_decibel = cache["stats"].decibel[t]
        cur_snr = cur_decibel - cache["stats"].noise_average_decibel
        # for each frame, calc log posterior probability of each state
        if cur_decibel < self.vad_opts.decibel_thres:
            frame_state = FrameState.kFrameStateSil
            self.DetectOneFrame(frame_state, t, False, cache=cache)
            return frame_state

        sum_score = 0.0
        noise_prob = 0.0
        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
        if len(cache["stats"].sil_pdf_ids) > 0:
            assert len(cache["stats"].scores) == 1  # 只支持batch_size = 1的测试
            """
            - Change type of `sum_score` to float. The reason is that `sum_score` is a tensor with single element.
              and `torch.Tensor` is slower `float` when tensor has only one element.
            - Put the iteration of `sil_pdf_ids` inside `sum()` to reduce the overhead of creating a new list.
            - The default `sil_pdf_ids` is [0], the `if` statement is used to reduce the overhead of expression
View full source on GitHub →
method

FsmnVADStreaming.forward(feats, waveform, cache, is_final, **kwargs)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Forward pass for training.

Args:

  • feats — Feature tensor (e.g., fbank), shape (batch, frames, dim).
  • waveform — TODO.
  • cache — State cache dict for streaming inference.
  • is_final — Whether this is the final chunk in streaming.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(
        self,
        feats: torch.Tensor,
        waveform: torch.tensor,
        cache: dict = None,
        is_final: bool = False,
        **kwargs,
    ):
        """Forward pass for training.
        
            Args:
                feats: Feature tensor (e.g., fbank), shape (batch, frames, dim).
                waveform: TODO.
                cache: State cache dict for streaming inference.
                is_final: Whether this is the final chunk in streaming.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        # if len(cache) == 0:
        #     self.AllResetDetection()
        # self.waveform = waveform  # compute decibel for each frame
        cache["stats"].waveform = waveform
        is_streaming_input = kwargs.get("is_streaming_input", True)
        self.ComputeDecibel(cache=cache)
        self.ComputeScores(feats, cache=cache)
        if not is_final:
            self.DetectCommonFrames(cache=cache)
        else:
            self.DetectLastFrames(cache=cache)
View full source on GitHub →
method

FsmnVADStreaming.init_cache(cache, **kwargs)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Init cache.

Args:

  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_cache(self, cache: dict = None, **kwargs):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}

        cache["frontend"] = {}
        cache["prev_samples"] = torch.empty(0)
        cache["encoder"] = {}

        if kwargs.get("max_end_silence_time") is not None:
            # update the max_end_silence_time
            self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")

        windows_detector = WindowDetector(
            self.vad_opts.window_size_ms,
            self.vad_opts.sil_to_speech_time_thres,
            self.vad_opts.speech_to_sil_time_thres,
            self.vad_opts.frame_in_ms,
        )
        windows_detector.Reset()

        stats = Stats(
            sil_pdf_ids=self.vad_opts.sil_pdf_ids,
            max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time
            - self.vad_opts.speech_to_sil_time_thres,
View full source on GitHub →
method

FsmnVADStreaming.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        if len(cache) == 0:
            self.init_cache(cache, **kwargs)

        meta_data = {}
        chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)

View full source on GitHub →
method

FsmnVADStreaming.export(**kwargs)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):

        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
method

FsmnVADStreaming.DetectCommonFrames(cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Detectcommonframes.

Args:

  • cache — State cache dict for streaming inference.
📄 Source code
    def DetectCommonFrames(self, cache: dict = None) -> int:
        """Detectcommonframes.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(
                cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
            )
            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)

        return 0
method

FsmnVADStreaming.DetectLastFrames(cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Detectlastframes.

Args:

  • cache — State cache dict for streaming inference.
📄 Source code
    def DetectLastFrames(self, cache: dict = None) -> int:
        """Detectlastframes.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(
                cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
            )
            if i != 0:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
            else:
                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)

        return 0
method

FsmnVADStreaming.DetectOneFrame(cur_frm_state, cur_frm_idx, is_final_frame, cache)

funasr.models.fsmn_vad_streaming.FsmnVADStreaming · View on GitHub ↗

Detectoneframe.

Args:

  • cur_frm_state — TODO.
  • cur_frm_idx — TODO.
  • is_final_frame — Boolean flag for final frame.
  • cache — State cache dict for streaming inference.
📄 Source code
    def DetectOneFrame(
        self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = None
    ) -> None:
        """Detectoneframe.
        
            Args:
                cur_frm_state: TODO.
                cur_frm_idx: TODO.
                is_final_frame: Boolean flag for final frame.
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        tmp_cur_frm_state = FrameState.kFrameStateInvalid
        if cur_frm_state == FrameState.kFrameStateSpeech:
            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
                tmp_cur_frm_state = FrameState.kFrameStateSpeech
            else:
                tmp_cur_frm_state = FrameState.kFrameStateSil
        elif cur_frm_state == FrameState.kFrameStateSil:
            tmp_cur_frm_state = FrameState.kFrameStateSil
        state_change = cache["windows_detector"].DetectOneFrame(
            tmp_cur_frm_state, cur_frm_idx, cache=cache
        )
        frm_shift_in_ms = self.vad_opts.frame_in_ms
        if AudioChangeState.kChangeStateSil2Speech == state_change:
            silence_frame_count = cache["stats"].continous_silence_frame_count
            cache["stats"].continous_silence_frame_count = 0
            cache["stats"].pre_end_silence_detected = False
            start_frame = 0
View full source on GitHub →
class

FunASRNano

funasr.models.fun_asr_nano · View on GitHub ↗
  • Fun — ASR-Nano: End-to-End ASR Large Model.

Trained on tens of millions of hours of real speech data.

Supports 31 languages including Chinese dialects and regional accents.

Features:

  • Character-level timestamps (via CTC forced alignment)
  • Hotword customization
  • Speaker diarization (when combined with spk_model)
  • Lyrics and rap recognition
  • Streaming chunk-by-chunk inference (demo2.py)
  • Output — {"key": ..., "text": ..., "timestamps": [{"token", "start_time", "end_time"}, ...],
  • "ctc_timestamps": [...]}
  • Note — Outputs punctuation natively — punc_model is NOT needed.
  • Requirements — pip install tiktoken huggingface_hub
📄 Source code
class FunASRNano(nn.Module):
    """Fun-ASR-Nano: End-to-End ASR Large Model.

    Trained on tens of millions of hours of real speech data.
    Supports 31 languages including Chinese dialects and regional accents.

    Features:
    - Character-level timestamps (via CTC forced alignment)
    - Hotword customization
    - Speaker diarization (when combined with spk_model)
    - Lyrics and rap recognition
    - Streaming chunk-by-chunk inference (demo2.py)

    Output: {"key": ..., "text": ..., "timestamps": [{"token", "start_time", "end_time"}, ...],
             "ctc_timestamps": [...]}

    Note: Outputs punctuation natively — punc_model is NOT needed.

    Requirements: pip install tiktoken huggingface_hub
    """

    def __init__(
        self,
        audio_encoder: str = None,
        audio_encoder_conf: dict = None,
        audio_adaptor: str = None,
        audio_adaptor_conf: dict = None,
        llm: str = None,
        llm_conf: dict = None,
        input_size: int = 80,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs) L194

Forward pass for training.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • input_ids — TODO.
  • attention_mask — TODO.
  • labels_ids — TODO.
  • fbank_beg — TODO.
  • fbank_mask — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(
        self,
        speech: torch.Tensor = None,
        speech_lengths: torch.Tensor = None,
        input_ids: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        labels_ids: torch.Tensor = None,
        fbank_beg: torch.Tensor = None,
        fbank_mask: torch.Tensor = None,
        **kwargs,
    ):
        """Forward pass for training.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                input_ids: TODO.
                attention_mask: TODO.
                labels_ids: TODO.
                fbank_beg: TODO.
                fbank_mask: TODO.
                **kwargs: Additional keyword arguments.
            """
        batch_size, token_num = input_ids.shape
        stats = {}
        input_ids[input_ids < 0] = 0
        inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
        if speech is not None:
            if len(speech_lengths.size()) > 1:
                speech_lengths = speech_lengths[:, 0]
View full source →
.forward_export(speech, speech_lengths, **kwargs) L317

Forward export.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward_export(self, speech, speech_lengths, **kwargs):
        """Forward export.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                **kwargs: Additional keyword arguments.
            """
        x, olens = self.audio_encoder(speech, speech_lengths)
        encoder_out, encoder_out_lens = self.audio_adaptor(x, olens)
        return encoder_out, encoder_out_lens
.encode(speech, speech_lengths) L329

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
📄 Source
    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)

        return encoder_out, encoder_out_lens
.data_template(data) L341

Data template.

Args:

  • data — TODO.
📄 Source
    def data_template(self, data):
        """Data template.
        
            Args:
                data: TODO.
            """
        system, user, assistant = [], [], []
        for i, item in enumerate(data):
            role = item["role"]
            content = item["content"]
            if role == "system":
                system.append(content)
            elif role == "user":
                if "audio" in item:
                    audio = item["audio"]
                    content = [content, audio]
                user.append(content)
            elif role == "assistant":
                assistant.append(content)

        system = system * len(user)

        contents = {
            "system": system,
            "user": user,
            "assistant": assistant,
        }

        return contents
.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs) L371

Data load speech.

Args:

  • contents — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • meta_data — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
        """Data load speech.
        
            Args:
                contents: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                meta_data: TODO.
                **kwargs: Additional keyword arguments.
            """
        system = contents["system"]
        user = contents["user"]
        assistant = contents["assistant"]
        pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
        do_think = True
        sys_prompt = True
        if "dataset_conf" in kwargs:
            do_think = kwargs["dataset_conf"].get("do_think", True)
            sys_prompt = kwargs["dataset_conf"].get("sys_prompt", True)

        input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
            [],
            [],
            [],
            [],
            [],
            [],
            [],
        )
        input_source_ids = []
View full source →
.inference_prepare(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L533

Inference prepare.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference_prepare(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Inference prepare.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}

        if len(data_in) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        contents = self.data_template(data_in[0])
        output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
        batch = to_device(output, kwargs["device"])

        # audio encoder
        speech = batch["speech"]
View full source →
.get_prompt(hotwords, language, itn) L632

Get prompt.

Args:

  • hotwords — TODO.
  • language — Language identifier.
  • itn — TODO.
📄 Source
    def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
        """Get prompt.
        
            Args:
                hotwords: TODO.
                language: Language identifier.
                itn: TODO.
            """
        if len(hotwords) > 0:
            hotwords = ", ".join(hotwords)
            prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
            prompt += f"热词列表:[{hotwords}]\n"
        else:
            prompt = ""
        if language is None:
            prompt += "语音转写"
        else:
            prompt += f"语音转写成{language}"
        if not itn:
            prompt += ",不进行文本规整"
        return prompt + ":"
.generate_chatml(prompt, data) L654

Generate chatml.

Args:

  • prompt — TODO.
  • data — TODO.
📄 Source
    def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
        """Generate chatml.
        
            Args:
                prompt: TODO.
                data: TODO.
            """
        if isinstance(data, str):
            return [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
                {"role": "assistant", "content": "null"},
            ]
        elif isinstance(data, torch.Tensor):
            return [
                {"role": "system", "content": "You are a helpful assistant."},
                {
                    "role": "user",
                    "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
                    "audio": data,
                },
                {"role": "assistant", "content": "null"},
            ]
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L678

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        prompt = self.get_prompt(
            kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
        )
        data_in = [self.generate_chatml(prompt, data) for data in data_in]

        if key is None:
            key = []
            for _ in data_in:
                chars = string.ascii_letters + string.digits
                key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))

View full source →
.inference_llm(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L717

Inference llm.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference_llm(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Inference llm.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
            data_in, data_lengths, key, tokenizer, frontend, **kwargs
        )

        ctc_results = []
        if self.ctc_decoder is not None:
            encoder_out = meta_data["encoder_out"]
            encoder_out_lens = meta_data["encoder_out_lens"]
            decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
            ctc_logits = self.ctc.log_softmax(decoder_out)

View full source →
.from_pretrained(model, **kwargs) L856

From pretrained.

Args:

  • model — Model instance or model name.
  • **kwargs — Additional keyword arguments.
📄 Source
    def from_pretrained(model: str = None, **kwargs):
        """From pretrained.
        
            Args:
                model: Model instance or model name.
                **kwargs: Additional keyword arguments.
            """
        from funasr import AutoModel

        model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)

        return model, kwargs
method

FunASRNano.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Forward pass for training.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • input_ids — TODO.
  • attention_mask — TODO.
  • labels_ids — TODO.
  • fbank_beg — TODO.
  • fbank_mask — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(
        self,
        speech: torch.Tensor = None,
        speech_lengths: torch.Tensor = None,
        input_ids: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        labels_ids: torch.Tensor = None,
        fbank_beg: torch.Tensor = None,
        fbank_mask: torch.Tensor = None,
        **kwargs,
    ):
        """Forward pass for training.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                input_ids: TODO.
                attention_mask: TODO.
                labels_ids: TODO.
                fbank_beg: TODO.
                fbank_mask: TODO.
                **kwargs: Additional keyword arguments.
            """
        batch_size, token_num = input_ids.shape
        stats = {}
        input_ids[input_ids < 0] = 0
        inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
        if speech is not None:
            if len(speech_lengths.size()) > 1:
                speech_lengths = speech_lengths[:, 0]
View full source on GitHub →
method

FunASRNano.forward_export(speech, speech_lengths, **kwargs)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Forward export.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward_export(self, speech, speech_lengths, **kwargs):
        """Forward export.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                **kwargs: Additional keyword arguments.
            """
        x, olens = self.audio_encoder(speech, speech_lengths)
        encoder_out, encoder_out_lens = self.audio_adaptor(x, olens)
        return encoder_out, encoder_out_lens
method

FunASRNano.encode(speech, speech_lengths)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
📄 Source code
    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)

        return encoder_out, encoder_out_lens
method

FunASRNano.data_template(data)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Data template.

Args:

  • data — TODO.
📄 Source code
    def data_template(self, data):
        """Data template.
        
            Args:
                data: TODO.
            """
        system, user, assistant = [], [], []
        for i, item in enumerate(data):
            role = item["role"]
            content = item["content"]
            if role == "system":
                system.append(content)
            elif role == "user":
                if "audio" in item:
                    audio = item["audio"]
                    content = [content, audio]
                user.append(content)
            elif role == "assistant":
                assistant.append(content)

        system = system * len(user)

        contents = {
            "system": system,
            "user": user,
            "assistant": assistant,
        }

        return contents
method

FunASRNano.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Data load speech.

Args:

  • contents — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • meta_data — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
        """Data load speech.
        
            Args:
                contents: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                meta_data: TODO.
                **kwargs: Additional keyword arguments.
            """
        system = contents["system"]
        user = contents["user"]
        assistant = contents["assistant"]
        pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
        do_think = True
        sys_prompt = True
        if "dataset_conf" in kwargs:
            do_think = kwargs["dataset_conf"].get("do_think", True)
            sys_prompt = kwargs["dataset_conf"].get("sys_prompt", True)

        input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
            [],
            [],
            [],
            [],
            [],
            [],
            [],
        )
        input_source_ids = []
View full source on GitHub →
method

FunASRNano.inference_prepare(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Inference prepare.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference_prepare(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Inference prepare.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}

        if len(data_in) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        contents = self.data_template(data_in[0])
        output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
        batch = to_device(output, kwargs["device"])

        # audio encoder
        speech = batch["speech"]
View full source on GitHub →
method

FunASRNano.get_prompt(hotwords, language, itn)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Get prompt.

Args:

  • hotwords — TODO.
  • language — Language identifier.
  • itn — TODO.
📄 Source code
    def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
        """Get prompt.
        
            Args:
                hotwords: TODO.
                language: Language identifier.
                itn: TODO.
            """
        if len(hotwords) > 0:
            hotwords = ", ".join(hotwords)
            prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
            prompt += f"热词列表:[{hotwords}]\n"
        else:
            prompt = ""
        if language is None:
            prompt += "语音转写"
        else:
            prompt += f"语音转写成{language}"
        if not itn:
            prompt += ",不进行文本规整"
        return prompt + ":"
method

FunASRNano.generate_chatml(prompt, data)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Generate chatml.

Args:

  • prompt — TODO.
  • data — TODO.
📄 Source code
    def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
        """Generate chatml.
        
            Args:
                prompt: TODO.
                data: TODO.
            """
        if isinstance(data, str):
            return [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
                {"role": "assistant", "content": "null"},
            ]
        elif isinstance(data, torch.Tensor):
            return [
                {"role": "system", "content": "You are a helpful assistant."},
                {
                    "role": "user",
                    "content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
                    "audio": data,
                },
                {"role": "assistant", "content": "null"},
            ]
method

FunASRNano.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        prompt = self.get_prompt(
            kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
        )
        data_in = [self.generate_chatml(prompt, data) for data in data_in]

        if key is None:
            key = []
            for _ in data_in:
                chars = string.ascii_letters + string.digits
                key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))

View full source on GitHub →
method

FunASRNano.inference_llm(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

Inference llm.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference_llm(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Inference llm.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
            data_in, data_lengths, key, tokenizer, frontend, **kwargs
        )

        ctc_results = []
        if self.ctc_decoder is not None:
            encoder_out = meta_data["encoder_out"]
            encoder_out_lens = meta_data["encoder_out_lens"]
            decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
            ctc_logits = self.ctc.log_softmax(decoder_out)

View full source on GitHub →
method

FunASRNano.from_pretrained(model, **kwargs)

funasr.models.fun_asr_nano.FunASRNano · View on GitHub ↗

From pretrained.

Args:

  • model — Model instance or model name.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def from_pretrained(model: str = None, **kwargs):
        """From pretrained.
        
            Args:
                model: Model instance or model name.
                **kwargs: Additional keyword arguments.
            """
        from funasr import AutoModel

        model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)

        return model, kwargs
class

GLMASR

funasr.models.glm_asr · View on GitHub ↗

No documentation yet.

📄 Source code
class GLMASR(nn.Module):

    def __init__(self, **kwargs):
        """Initialize GLMASR.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        model_path = kwargs.get("model_path", kwargs.get("model", "zai-org/GLM-ASR-Nano-2512"))
        device = kwargs.get("device", "cuda:0")
        dtype = kwargs.get("dtype", "bf16")
        hub = kwargs.get("hub", "ms")
        self._max_new_tokens = kwargs.get("max_new_tokens", 512)

        self._dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
        self._device = device
        self._torch_dtype = self._dtype_map.get(dtype, torch.bfloat16)
        self._placeholder = nn.Parameter(torch.empty(0))

        model_path = self._resolve_model_path(model_path, hub, kwargs)
        self.model_path = model_path

        from transformers import AutoModel as HFAutoModel
        from transformers import AutoProcessor

        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.glm_model = HFAutoModel.from_pretrained(
            model_path,
            dtype=self._torch_dtype,
View full source on GitHub →

Methods

.forward(**kwargs) L73

Forward pass for training.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(self, **kwargs):
        """Forward pass for training.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        raise NotImplementedError("GLMASR only supports inference mode")
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L81

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        time1 = time.perf_counter()

        prompt = kwargs.get("prompt", "Please transcribe this audio into text")

        if isinstance(data_in, (list, tuple)):
            audio_list = list(data_in)
        elif isinstance(data_in, str):
            audio_list = [data_in]
        else:
            audio_list = [data_in]
View full source →
method

GLMASR.forward(**kwargs)

funasr.models.glm_asr.GLMASR · View on GitHub ↗

Forward pass for training.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(self, **kwargs):
        """Forward pass for training.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        raise NotImplementedError("GLMASR only supports inference mode")
method

GLMASR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.glm_asr.GLMASR · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        time1 = time.perf_counter()

        prompt = kwargs.get("prompt", "Please transcribe this audio into text")

        if isinstance(data_in, (list, tuple)):
            audio_list = list(data_in)
        elif isinstance(data_in, str):
            audio_list = [data_in]
        else:
            audio_list = [data_in]
View full source on GitHub →
class

LCBNet

funasr.models.lcbnet · View on GitHub ↗
  • LCBNet — Lightweight Convolutional Block Network for ASR.

Efficient model design using depthwise separable convolutions

for low-resource deployment scenarios.

Inherits Paraformer pipeline.

📄 Source code
class LCBNet(nn.Module):
    """LCBNet: Lightweight Convolutional Block Network for ASR.

    Efficient model design using depthwise separable convolutions
    for low-resource deployment scenarios.

    Inherits Paraformer pipeline.
    """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        text_encoder: str = None,
        text_encoder_conf: dict = None,
        bias_predictor: str = None,
        bias_predictor_conf: dict = None,
        fusion_encoder: str = None,
        fusion_encoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        select_num: int = 2,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L208

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """

        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]

View full source →
.encode(speech, speech_lengths, **kwargs) L302

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
View full source →
.init_beam_search(**kwargs) L400

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L450

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

        meta_data = {}
View full source →
method

LCBNet.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.lcbnet.LCBNet · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """

        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]

View full source on GitHub →
method

LCBNet.encode(speech, speech_lengths, **kwargs)

funasr.models.lcbnet.LCBNet · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
View full source on GitHub →
method

LCBNet.init_beam_search(**kwargs)

funasr.models.lcbnet.LCBNet · View on GitHub ↗

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

View full source on GitHub →
method

LCBNet.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.lcbnet.LCBNet · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

        meta_data = {}
View full source on GitHub →
class

LLMASR

funasr.models.llm_asr · View on GitHub ↗
  • LLM — ASR: Large Language Model based Speech Recognition.

Combines audio encoder with LLM decoder for speech-to-text.

  • Output — {"key": str, "text": str}
📄 Source code
class LLMASR(nn.Module):
    """LLM-ASR: Large Language Model based Speech Recognition.

    Combines audio encoder with LLM decoder for speech-to-text.
    Output: {"key": str, "text": str}
    """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        audio_encoder: str = None,
        audio_encoder_conf: dict = None,
        audio_adaptor: str = None,
        audio_adaptor_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        llm: str = None,
        llm_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs) L188

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        label_mask: torch.Tensor,
        audio_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

View full source →
.encode(speech, speech_lengths, **kwargs) L258

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • **kwargs — Additional keyword arguments.
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                **kwargs: Additional keyword arguments.
            """
        speech = speech.permute(0, 2, 1)
        res = self.audio_encoder(speech)
        if isinstance(res, (list, tuple)):
            encoder_out, encoder_out_lens = res[0], res[1]
        else:
            encoder_out, encoder_out_lens = res, speech_lengths
        return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L279

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        prompt = kwargs.get("prompt", "Transcribe speech to text.")

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
View full source →
method

LLMASR.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs)

funasr.models.llm_asr.LLMASR · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        label_mask: torch.Tensor,
        audio_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

View full source on GitHub →
method

LLMASR.encode(speech, speech_lengths, **kwargs)

funasr.models.llm_asr.LLMASR · View on GitHub ↗

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                **kwargs: Additional keyword arguments.
            """
        speech = speech.permute(0, 2, 1)
        res = self.audio_encoder(speech)
        if isinstance(res, (list, tuple)):
            encoder_out, encoder_out_lens = res[0], res[1]
        else:
            encoder_out, encoder_out_lens = res, speech_lengths
        return encoder_out, encoder_out_lens
method

LLMASR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.llm_asr.LLMASR · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        prompt = kwargs.get("prompt", "Transcribe speech to text.")

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
View full source on GitHub →
class

LLMASR2

funasr.models.llm_asr · View on GitHub ↗

No documentation yet.

📄 Source code
class LLMASR2(nn.Module):
    """ """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        audio_encoder: str = None,
        audio_encoder_conf: dict = None,
        audio_adaptor: str = None,
        audio_adaptor_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        llm: str = None,
        llm_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs) L561

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        fbank_beg: torch.Tensor,
        fbank_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size, frames, _ = speech.shape

        with torch.cuda.amp.autocast(enabled=False):
            # audio encoder
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

            # audio_adaptor
View full source →
.encode(speech, speech_lengths) L654

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
📄 Source
    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)

        return encoder_out, encoder_out_lens
.data_template(data) L666

Data template.

Args:

  • data — TODO.
📄 Source
    def data_template(self, data):
        """Data template.
        
            Args:
                data: TODO.
            """
        system, user, assistant = [], [], []
        for i, item in enumerate(data):
            role = item["role"]
            content = item["content"]
            if role == "system":
                system.append(content)
            elif role == "user":
                user.append(content)
            elif role == "assistant":
                assistant.append(content)

        system = system * len(user)

        contents = {
            "system": system,
            "user": user,
            "assistant": assistant,
        }

        return contents
.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs) L693

Data load speech.

Args:

  • contents — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • meta_data — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):

        """Data load speech.
        
            Args:
                contents: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                meta_data: TODO.
                **kwargs: Additional keyword arguments.
            """
        system = contents["system"]
        user = contents["user"]
        assistant = contents["assistant"]
        pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
        input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = (
            [],
            [],
            [],
            [],
            [],
            [],
            [],
            [],
        )

        for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):

            source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"

View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L820

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        prompt = kwargs.get("prompt", None)

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        contents = self.data_template(data_in[0])
        output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
        batch = to_device(output, kwargs["device"])

View full source →
method

LLMASR2.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs)

funasr.models.llm_asr.LLMASR2 · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        fbank_beg: torch.Tensor,
        fbank_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size, frames, _ = speech.shape

        with torch.cuda.amp.autocast(enabled=False):
            # audio encoder
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

            # audio_adaptor
View full source on GitHub →
method

LLMASR2.encode(speech, speech_lengths)

funasr.models.llm_asr.LLMASR2 · View on GitHub ↗

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
📄 Source code
    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)

        return encoder_out, encoder_out_lens
method

LLMASR2.data_template(data)

funasr.models.llm_asr.LLMASR2 · View on GitHub ↗

Data template.

Args:

  • data — TODO.
📄 Source code
    def data_template(self, data):
        """Data template.
        
            Args:
                data: TODO.
            """
        system, user, assistant = [], [], []
        for i, item in enumerate(data):
            role = item["role"]
            content = item["content"]
            if role == "system":
                system.append(content)
            elif role == "user":
                user.append(content)
            elif role == "assistant":
                assistant.append(content)

        system = system * len(user)

        contents = {
            "system": system,
            "user": user,
            "assistant": assistant,
        }

        return contents
method

LLMASR2.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs)

funasr.models.llm_asr.LLMASR2 · View on GitHub ↗

Data load speech.

Args:

  • contents — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • meta_data — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):

        """Data load speech.
        
            Args:
                contents: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                meta_data: TODO.
                **kwargs: Additional keyword arguments.
            """
        system = contents["system"]
        user = contents["user"]
        assistant = contents["assistant"]
        pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
        input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = (
            [],
            [],
            [],
            [],
            [],
            [],
            [],
            [],
        )

        for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):

            source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"

View full source on GitHub →
method

LLMASR2.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.llm_asr.LLMASR2 · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        prompt = kwargs.get("prompt", None)

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        contents = self.data_template(data_in[0])
        output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
        batch = to_device(output, kwargs["device"])

View full source on GitHub →
class

LLMASR3

funasr.models.llm_asr · View on GitHub ↗

No documentation yet.

📄 Source code
class LLMASR3(LLMASR2):
    """ """

    def __init__(
        self,
        *args,
        **kwargs,
    ):

        """Initialize LLMASR3.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)

    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
        return encoder_out, encoder_out_lens

Methods

.encode(speech, speech_lengths) L964

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
📄 Source
    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
        return encoder_out, encoder_out_lens
method

LLMASR3.encode(speech, speech_lengths)

funasr.models.llm_asr.LLMASR3 · View on GitHub ↗

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
📄 Source code
    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
        return encoder_out, encoder_out_lens
class

LLMASR4

funasr.models.llm_asr · View on GitHub ↗

No documentation yet.

📄 Source code
class LLMASR4(nn.Module):
    """ """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        audio_encoder: str = None,
        audio_encoder_conf: dict = None,
        audio_adaptor: str = None,
        audio_adaptor_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        llm: str = None,
        llm_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs) L1135

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        fbank_beg: torch.Tensor,
        fbank_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb
        #
        # pdb.set_trace()
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size_speech, frames, _ = speech.shape
        batch_size, token_num = input_ids.shape

        with torch.cuda.amp.autocast(enabled=False):
            # audio encoder
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source →
.encode(speech, speech_lengths) L1246

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
📄 Source
    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)

        return encoder_out, encoder_out_lens
.data_template(data) L1258

Data template.

Args:

  • data — TODO.
📄 Source
    def data_template(self, data):
        """Data template.
        
            Args:
                data: TODO.
            """
        system, user, assistant = [], [], []
        for i, item in enumerate(data):
            role = item["role"]
            content = item["content"]
            if role == "system":
                system.append(content)
            elif role == "user":
                user.append(content)
            elif role == "assistant":
                assistant.append(content)

        system = system * len(user)

        contents = {
            "system": system,
            "user": user,
            "assistant": assistant,
        }

        return contents
.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs) L1285

Data load speech.

Args:

  • contents — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • meta_data — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):

        """Data load speech.
        
            Args:
                contents: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                meta_data: TODO.
                **kwargs: Additional keyword arguments.
            """
        system = contents["system"]
        user = contents["user"]
        assistant = contents["assistant"]
        pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")

        input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
            [],
            [],
            [],
            [],
            [],
            [],
            [],
        )
        input_source_ids = []
        for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
            if i >= kwargs.get("multiturn_num_max", 5):
                break
            if len(input_ids) > kwargs.get("max_token_length", 1500):
View full source →
.inference_prepare(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L1433

Inference prepare.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference_prepare(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Inference prepare.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        prompt = kwargs.get("prompt", None)

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        contents = self.data_template(data_in[0])
        output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
        batch = to_device(output, kwargs["device"])

View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L1525

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
            data_in, data_lengths, key, tokenizer, frontend, **kwargs
        )

        llm_dtype = kwargs.get("llm_dtype", "fp32")
        if llm_dtype == "fp32":
            llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
            llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype

        with torch.cuda.amp.autocast(
View full source →
method

LLMASR4.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs)

funasr.models.llm_asr.LLMASR4 · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        fbank_beg: torch.Tensor,
        fbank_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb
        #
        # pdb.set_trace()
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size_speech, frames, _ = speech.shape
        batch_size, token_num = input_ids.shape

        with torch.cuda.amp.autocast(enabled=False):
            # audio encoder
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source on GitHub →
method

LLMASR4.encode(speech, speech_lengths)

funasr.models.llm_asr.LLMASR4 · View on GitHub ↗

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
📄 Source code
    def encode(self, speech, speech_lengths):
        # audio encoder
        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
            """
        encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)

        return encoder_out, encoder_out_lens
method

LLMASR4.data_template(data)

funasr.models.llm_asr.LLMASR4 · View on GitHub ↗

Data template.

Args:

  • data — TODO.
📄 Source code
    def data_template(self, data):
        """Data template.
        
            Args:
                data: TODO.
            """
        system, user, assistant = [], [], []
        for i, item in enumerate(data):
            role = item["role"]
            content = item["content"]
            if role == "system":
                system.append(content)
            elif role == "user":
                user.append(content)
            elif role == "assistant":
                assistant.append(content)

        system = system * len(user)

        contents = {
            "system": system,
            "user": user,
            "assistant": assistant,
        }

        return contents
method

LLMASR4.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs)

funasr.models.llm_asr.LLMASR4 · View on GitHub ↗

Data load speech.

Args:

  • contents — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • meta_data — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):

        """Data load speech.
        
            Args:
                contents: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                meta_data: TODO.
                **kwargs: Additional keyword arguments.
            """
        system = contents["system"]
        user = contents["user"]
        assistant = contents["assistant"]
        pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")

        input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
            [],
            [],
            [],
            [],
            [],
            [],
            [],
        )
        input_source_ids = []
        for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
            if i >= kwargs.get("multiturn_num_max", 5):
                break
            if len(input_ids) > kwargs.get("max_token_length", 1500):
View full source on GitHub →
method

LLMASR4.inference_prepare(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.llm_asr.LLMASR4 · View on GitHub ↗

Inference prepare.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference_prepare(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Inference prepare.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        prompt = kwargs.get("prompt", None)

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        contents = self.data_template(data_in[0])
        output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
        batch = to_device(output, kwargs["device"])

View full source on GitHub →
method

LLMASR4.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.llm_asr.LLMASR4 · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
            data_in, data_lengths, key, tokenizer, frontend, **kwargs
        )

        llm_dtype = kwargs.get("llm_dtype", "fp32")
        if llm_dtype == "fp32":
            llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
            llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype

        with torch.cuda.amp.autocast(
View full source on GitHub →
class

LLMASRNAR

funasr.models.llm_asr_nar · View on GitHub ↗

No documentation yet.

📄 Source code
class LLMASRNAR(nn.Module):
    """ """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        llm: str = None,
        llm_conf: dict = None,
        adaptor: str = None,
        adaptor_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs) L182

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        label_mask: torch.Tensor,
        audio_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)

View full source →
.encode(speech, speech_lengths, **kwargs) L253

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • **kwargs — Additional keyword arguments.
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):

        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                **kwargs: Additional keyword arguments.
            """
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
        text_token_int = kwargs.get("text_token_int", None)
        if audio_token_lengths is None:
            audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)

        batch = {"speech": speech, "speech_lengths": speech_lengths}
        enc, enc_lens = self.audio_encoder.encode(**batch)
        with autocast(False):
            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(
                enc,
                mask=enc_mask,
                target_label_length=audio_token_lengths,
            )

View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L285

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        prompt = kwargs.get("prompt", "Transcribe speech to text.")

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
View full source →
method

LLMASRNAR.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs)

funasr.models.llm_asr_nar.LLMASRNAR · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        label_mask: torch.Tensor,
        audio_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)

View full source on GitHub →
method

LLMASRNAR.encode(speech, speech_lengths, **kwargs)

funasr.models.llm_asr_nar.LLMASRNAR · View on GitHub ↗

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):

        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                **kwargs: Additional keyword arguments.
            """
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
        text_token_int = kwargs.get("text_token_int", None)
        if audio_token_lengths is None:
            audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)

        batch = {"speech": speech, "speech_lengths": speech_lengths}
        enc, enc_lens = self.audio_encoder.encode(**batch)
        with autocast(False):
            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(
                enc,
                mask=enc_mask,
                target_label_length=audio_token_lengths,
            )

View full source on GitHub →
method

LLMASRNAR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.llm_asr_nar.LLMASRNAR · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        prompt = kwargs.get("prompt", "Transcribe speech to text.")

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
View full source on GitHub →
class

LLMASRNARPrompt

funasr.models.llm_asr_nar · View on GitHub ↗

No documentation yet.

📄 Source code
class LLMASRNARPrompt(nn.Module):
    """ """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.0,
        llm: str = None,
        llm_conf: dict = None,
        adaptor: str = None,
        adaptor_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        predictor_weight: int = 1.0,
        report_cer: bool = True,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs) L588

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        label_mask: torch.Tensor,
        audio_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        stats = {}
View full source →
.encode(speech, speech_lengths, **kwargs) L691

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • **kwargs — Additional keyword arguments.
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):

        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                **kwargs: Additional keyword arguments.
            """
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
        text_token_int = kwargs.get("text_token_int", None)
        if audio_token_lengths is None and text_token_int is not None:
            audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)

        batch = {"speech": speech, "speech_lengths": speech_lengths}
        enc, enc_lens = self.audio_encoder.encode(**batch)
        with autocast(False):
            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(
                enc,
                mask=enc_mask,
                target_label_length=audio_token_lengths,
            )
            loss_pre = 0.0
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L753

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        prompt = kwargs.get("prompt", "Transcribe speech to text.")

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
View full source →
method

LLMASRNARPrompt.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs)

funasr.models.llm_asr_nar.LLMASRNARPrompt · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        label_mask: torch.Tensor,
        audio_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        stats = {}
View full source on GitHub →
method

LLMASRNARPrompt.encode(speech, speech_lengths, **kwargs)

funasr.models.llm_asr_nar.LLMASRNARPrompt · View on GitHub ↗

Encode.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):

        """Encode.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                **kwargs: Additional keyword arguments.
            """
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
        text_token_int = kwargs.get("text_token_int", None)
        if audio_token_lengths is None and text_token_int is not None:
            audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)

        batch = {"speech": speech, "speech_lengths": speech_lengths}
        enc, enc_lens = self.audio_encoder.encode(**batch)
        with autocast(False):
            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(
                enc,
                mask=enc_mask,
                target_label_length=audio_token_lengths,
            )
            loss_pre = 0.0
View full source on GitHub →
method

LLMASRNARPrompt.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.llm_asr_nar.LLMASRNARPrompt · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        prompt = kwargs.get("prompt", "Transcribe speech to text.")

        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
View full source on GitHub →
class

MonotonicAligner

funasr.models.monotonic_aligner · View on GitHub ↗
  • MonotonicAligner — Forced alignment model for timestamp prediction.

Given audio and text, computes character-level time alignments

using monotonic attention constraint.

  • Usage — model.generate(input=(audio_path, text_path), data_type=("sound", "text"))
  • Output — {"key": str, "timestamp": [[start_ms, end_ms], ...], "text": str}
📄 Source code
class MonotonicAligner(torch.nn.Module):
    """MonotonicAligner: Forced alignment model for timestamp prediction.

    Given audio and text, computes character-level time alignments
    using monotonic attention constraint.

    Usage: model.generate(input=(audio_path, text_path), data_type=("sound", "text"))
    Output: {"key": str, "timestamp": [[start_ms, end_ms], ...], "text": str}
    """

    def __init__(
        self,
        input_size: int = 80,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        predictor: str = None,
        predictor_conf: Optional[Dict] = None,
        predictor_bias: int = 0,
        length_normalized_loss: bool = False,
        **kwargs,
    ):
        """Initialize MonotonicAligner.
        
            Args:
                input_size: Size/dimension parameter.
                specaug: TODO.
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths) L86

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
        speech = speech[:, : speech_lengths.max()]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
View full source →
.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num) L137

Calc predictor timestamp.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • token_num — TODO.
📄 Source
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        """Calc predictor timestamp.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                token_num: TODO.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
            encoder_out, encoder_out_mask, token_num
        )
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
.encode(speech, speech_lengths, **kwargs) L153

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L182

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        # extract fbank feats
        time1 = time.perf_counter()
        audio_list, text_token_int_list = load_audio_text_image_video(
            data_in,
            fs=frontend.fs,
            audio_fs=kwargs.get("fs", 16000),
            data_type=kwargs.get("data_type", "sound"),
            tokenizer=tokenizer,
        )
        time2 = time.perf_counter()
View full source →
method

MonotonicAligner.forward(speech, speech_lengths, text, text_lengths)

funasr.models.monotonic_aligner.MonotonicAligner · View on GitHub ↗

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
        speech = speech[:, : speech_lengths.max()]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
View full source on GitHub →
method

MonotonicAligner.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num)

funasr.models.monotonic_aligner.MonotonicAligner · View on GitHub ↗

Calc predictor timestamp.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • token_num — TODO.
📄 Source code
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        """Calc predictor timestamp.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                token_num: TODO.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
            encoder_out, encoder_out_mask, token_num
        )
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
method

MonotonicAligner.encode(speech, speech_lengths, **kwargs)

funasr.models.monotonic_aligner.MonotonicAligner · View on GitHub ↗

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
method

MonotonicAligner.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.monotonic_aligner.MonotonicAligner · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        # extract fbank feats
        time1 = time.perf_counter()
        audio_list, text_token_int_list = load_audio_text_image_video(
            data_in,
            fs=frontend.fs,
            audio_fs=kwargs.get("fs", 16000),
            data_type=kwargs.get("data_type", "sound"),
            tokenizer=tokenizer,
        )
        time2 = time.perf_counter()
View full source on GitHub →
class

Paraformer

funasr.models.paraformer · View on GitHub ↗
  • Paraformer — Non-autoregressive End-to-End ASR Model.
  • High — accuracy speech recognition for Chinese/English. The production workhorse.

Features:

  • Non-autoregressive (parallel decoding, fast inference)
  • Character-level timestamps via CIF predictor
  • Streaming and offline modes
  • Hotword customization
  • Speaker diarization (with spk_model)
  • ONNX export support
  • Output — {"key": "...", "text": "recognized text", "timestamp": [[start_ms, end_ms], ...]}
  • Note — Requires punc_model="ct-punc" for punctuation (unlike Fun-ASR-Nano/SenseVoice

which output punctuation natively).

  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • Paraformer — Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition

https://arxiv.org/abs/2206.08317

📄 Source code
class Paraformer(torch.nn.Module):
    """Paraformer: Non-autoregressive End-to-End ASR Model.

    High-accuracy speech recognition for Chinese/English. The production workhorse.

    Features:
    - Non-autoregressive (parallel decoding, fast inference)
    - Character-level timestamps via CIF predictor
    - Streaming and offline modes
    - Hotword customization
    - Speaker diarization (with spk_model)
    - ONNX export support

    Output: {"key": "...", "text": "recognized text", "timestamp": [[start_ms, end_ms], ...]}

    Note: Requires punc_model="ct-punc" for punctuation (unlike Fun-ASR-Nano/SenseVoice
    which output punctuation natively).

    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """

    def __init__(
        self,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L215

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()

        # decoder: CTC branch
View full source →
.encode(speech, speech_lengths, **kwargs) L286

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
.calc_predictor(encoder_out, encoder_out_lens) L315

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source
    def calc_predictor(self, encoder_out, encoder_out_lens):

        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
            encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
        )
        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) L331

Cal decoder with predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source
    def cal_decoder_with_predictor(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
    ):

        """Cal decoder with predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds) L408

Sampler.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
📄 Source
    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):

        """Sampler.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
            )
            decoder_out, _ = decoder_outs[0], decoder_outs[1]
            pred_tokens = decoder_out.argmax(-1)
            nonpad_positions = ys_pad.ne(self.ignore_id)
            seq_lens = (nonpad_positions).sum(1)
            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
            input_mask = torch.ones_like(nonpad_positions)
            bsz, seq_len = ys_pad.size()
View full source →
.init_beam_search(**kwargs) L482

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.paraformer.search import BeamSearchPara
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

        weights = dict(
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L534

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        # init beamsearch
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        pred_timestamp = kwargs.get("pred_timestamp", False)
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

View full source →
.export(**kwargs) L699

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
method

Paraformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.paraformer.Paraformer · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()

        # decoder: CTC branch
View full source on GitHub →
method

Paraformer.encode(speech, speech_lengths, **kwargs)

funasr.models.paraformer.Paraformer · View on GitHub ↗

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
method

Paraformer.calc_predictor(encoder_out, encoder_out_lens)

funasr.models.paraformer.Paraformer · View on GitHub ↗

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source code
    def calc_predictor(self, encoder_out, encoder_out_lens):

        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
            encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
        )
        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
method

Paraformer.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)

funasr.models.paraformer.Paraformer · View on GitHub ↗

Cal decoder with predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source code
    def cal_decoder_with_predictor(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
    ):

        """Cal decoder with predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
method

Paraformer.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds)

funasr.models.paraformer.Paraformer · View on GitHub ↗

Sampler.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
📄 Source code
    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):

        """Sampler.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
            )
            decoder_out, _ = decoder_outs[0], decoder_outs[1]
            pred_tokens = decoder_out.argmax(-1)
            nonpad_positions = ys_pad.ne(self.ignore_id)
            seq_lens = (nonpad_positions).sum(1)
            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
            input_mask = torch.ones_like(nonpad_positions)
            bsz, seq_len = ys_pad.size()
View full source on GitHub →
method

Paraformer.init_beam_search(**kwargs)

funasr.models.paraformer.Paraformer · View on GitHub ↗

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.paraformer.search import BeamSearchPara
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

        weights = dict(
View full source on GitHub →
method

Paraformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.paraformer.Paraformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        # init beamsearch
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        pred_timestamp = kwargs.get("pred_timestamp", False)
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

View full source on GitHub →
method

Paraformer.export(**kwargs)

funasr.models.paraformer.Paraformer · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
class

ParaformerStreaming

funasr.models.paraformer_streaming · View on GitHub ↗
  • ParaformerStreaming — Streaming (online) version of Paraformer.

Processes audio chunk-by-chunk with encoder lookback for real-time

transcription. Uses cache mechanism to maintain state across chunks.

  • Usage — generate(input=chunk, cache=cache, is_final=bool, chunk_size=[0,10,5])
  • Output — {"key": str, "text": str} (partial results per chunk)
📄 Source code
class ParaformerStreaming(Paraformer):
    """ParaformerStreaming: Streaming (online) version of Paraformer.

    Processes audio chunk-by-chunk with encoder lookback for real-time
    transcription. Uses cache mechanism to maintain state across chunks.

    Usage: generate(input=chunk, cache=cache, is_final=bool, chunk_size=[0,10,5])
    Output: {"key": str, "text": str} (partial results per chunk)
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):

        """Initialize ParaformerStreaming.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)

        self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)

        self.scama_mask = None
        if (
            hasattr(self.encoder, "overlap_chunk_cls")
            and self.encoder.overlap_chunk_cls is not None
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L83

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        decoding_ind = kwargs.get("decoding_ind")
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        if hasattr(self.encoder, "overlap_chunk_cls"):
            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

View full source →
.encode_chunk(speech, speech_lengths, cache, **kwargs) L165

Encode chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def encode_chunk(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
View full source →
.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask) L336

Sampler.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
  • chunk_mask — TODO.
📄 Source
    def sampler(
        self,
        encoder_out,
        encoder_out_lens,
        ys_pad,
        ys_pad_lens,
        pre_acoustic_embeds,
        chunk_mask=None,
    ):

        """Sampler.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
                chunk_mask: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        with torch.no_grad():
            decoder_outs = self.decoder(
View full source →
.calc_predictor(encoder_out, encoder_out_lens) L392

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source
    def calc_predictor(self, encoder_out, encoder_out_lens):

        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        mask_chunk_predictor = None
        if self.encoder.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
                None, device=encoder_out.device, batch_size=encoder_out.size(0)
            )
            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(
                None, device=encoder_out.device, batch_size=encoder_out.size(0)
            )
            encoder_out = encoder_out * mask_shfit_chunk
        pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(
            encoder_out,
            None,
            encoder_out_mask,
            ignore_id=self.ignore_id,
            mask_chunk_predictor=mask_chunk_predictor,
            target_label_length=None,
        )
        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(
            pre_alphas,
View full source →
.calc_predictor_chunk(encoder_out, encoder_out_lens, cache, **kwargs) L460

Calc predictor chunk.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
        """Calc predictor chunk.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        is_final = kwargs.get("is_final", False)

        return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) L473

Cal decoder with predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source
    def cal_decoder_with_predictor(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
    ):
        """Cal decoder with predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
        )
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
.cal_decoder_with_predictor_chunk(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache) L491

Cal decoder with predictor chunk.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • cache — State cache dict for streaming inference.
📄 Source
    def cal_decoder_with_predictor_chunk(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None
    ):
        """Cal decoder with predictor chunk.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
                cache: State cache dict for streaming inference.
            """
        decoder_outs = self.decoder.forward_chunk(encoder_out, sematic_embeds, cache["decoder"])
        decoder_out = decoder_outs
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
.init_cache(cache, **kwargs) L508

Init cache.

Args:

  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def init_cache(self, cache: dict = None, **kwargs):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
        batch_size = 1

        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
        cache_encoder = {
            "start_idx": 0,
            "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
            "cif_alphas": torch.zeros((batch_size, 1)),
            "chunk_size": chunk_size,
            "encoder_chunk_look_back": encoder_chunk_look_back,
            "last_chunk": False,
            "opt": None,
            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
            "tail_chunk": False,
        }
        cache["encoder"] = cache_encoder

        cache_decoder = {
View full source →
.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs) L549

Generate chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def generate_chunk(
        self,
        speech,
        speech_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Generate chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        cache = kwargs.get("cache", {})
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])

        # Encoder
        encoder_out, encoder_out_lens = self.encode_chunk(
            speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)
        )
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L647

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}

        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        if self.beam_search is None and (is_use_lm or is_use_ctc):
View full source →
.export(**kwargs) L762

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
method

ParaformerStreaming.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        decoding_ind = kwargs.get("decoding_ind")
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        if hasattr(self.encoder, "overlap_chunk_cls"):
            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

View full source on GitHub →
method

ParaformerStreaming.encode_chunk(speech, speech_lengths, cache, **kwargs)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Encode chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def encode_chunk(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
View full source on GitHub →
method

ParaformerStreaming.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Sampler.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • pre_acoustic_embeds — TODO.
  • chunk_mask — TODO.
📄 Source code
    def sampler(
        self,
        encoder_out,
        encoder_out_lens,
        ys_pad,
        ys_pad_lens,
        pre_acoustic_embeds,
        chunk_mask=None,
    ):

        """Sampler.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
                pre_acoustic_embeds: TODO.
                chunk_mask: TODO.
            """
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        with torch.no_grad():
            decoder_outs = self.decoder(
View full source on GitHub →
method

ParaformerStreaming.calc_predictor(encoder_out, encoder_out_lens)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source code
    def calc_predictor(self, encoder_out, encoder_out_lens):

        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        mask_chunk_predictor = None
        if self.encoder.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
                None, device=encoder_out.device, batch_size=encoder_out.size(0)
            )
            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(
                None, device=encoder_out.device, batch_size=encoder_out.size(0)
            )
            encoder_out = encoder_out * mask_shfit_chunk
        pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(
            encoder_out,
            None,
            encoder_out_mask,
            ignore_id=self.ignore_id,
            mask_chunk_predictor=mask_chunk_predictor,
            target_label_length=None,
        )
        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(
            pre_alphas,
View full source on GitHub →
method

ParaformerStreaming.calc_predictor_chunk(encoder_out, encoder_out_lens, cache, **kwargs)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Calc predictor chunk.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
        """Calc predictor chunk.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        is_final = kwargs.get("is_final", False)

        return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
method

ParaformerStreaming.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Cal decoder with predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source code
    def cal_decoder_with_predictor(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
    ):
        """Cal decoder with predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
        )
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
method

ParaformerStreaming.cal_decoder_with_predictor_chunk(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Cal decoder with predictor chunk.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • sematic_embeds — TODO.
  • ys_pad_lens — Lengths of ys_pad.
  • cache — State cache dict for streaming inference.
📄 Source code
    def cal_decoder_with_predictor_chunk(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None
    ):
        """Cal decoder with predictor chunk.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                sematic_embeds: TODO.
                ys_pad_lens: Lengths of ys_pad.
                cache: State cache dict for streaming inference.
            """
        decoder_outs = self.decoder.forward_chunk(encoder_out, sematic_embeds, cache["decoder"])
        decoder_out = decoder_outs
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
method

ParaformerStreaming.init_cache(cache, **kwargs)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Init cache.

Args:

  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_cache(self, cache: dict = None, **kwargs):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
        batch_size = 1

        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
        cache_encoder = {
            "start_idx": 0,
            "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
            "cif_alphas": torch.zeros((batch_size, 1)),
            "chunk_size": chunk_size,
            "encoder_chunk_look_back": encoder_chunk_look_back,
            "last_chunk": False,
            "opt": None,
            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
            "tail_chunk": False,
        }
        cache["encoder"] = cache_encoder

        cache_decoder = {
View full source on GitHub →
method

ParaformerStreaming.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Generate chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def generate_chunk(
        self,
        speech,
        speech_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Generate chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        cache = kwargs.get("cache", {})
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])

        # Encoder
        encoder_out, encoder_out_lens = self.encode_chunk(
            speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)
        )
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

View full source on GitHub →
method

ParaformerStreaming.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}

        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        if self.beam_search is None and (is_use_lm or is_use_ctc):
View full source on GitHub →
method

ParaformerStreaming.export(**kwargs)

funasr.models.paraformer_streaming.ParaformerStreaming · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
class

Paraformer

funasr.models.paraformer_v2_community · View on GitHub ↗
  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • Paraformer — Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition

https://arxiv.org/abs/2206.08317

📄 Source code
class Paraformer(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """

    def __init__(
        self,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        decoder: str = None,
        decoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        ctc_weight: float = 0.5,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        # report_cer: bool = True,
        # report_wer: bool = True,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L182

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()

        # decoder: CTC branch
View full source →
.encode(speech, speech_lengths, **kwargs) L250

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
.map_alignment_to_target_index(align_path, blank_id) L366

Robustly map CTC alignment path (Token IDs) to Target Indices.

Logic:

Detect boundaries where a new token segment begins.

A segment starts if the current frame is a Token AND it is different from the previous frame

(considering CTC topology where repeats are separated by blanks or are distinct tokens).

Example:

  • Text — [A, B]

Align Path: [A, A, _, B, B]

  • Output — [0, 0, -1, 1, 1]
📄 Source
    def map_alignment_to_target_index(self, align_path, blank_id):
        """
        Robustly map CTC alignment path (Token IDs) to Target Indices.
        
        Logic:
            Detect boundaries where a new token segment begins.
            A segment starts if the current frame is a Token AND it is different from the previous frame 
            (considering CTC topology where repeats are separated by blanks or are distinct tokens).
        
        Example:
            Text: [A, B]
            Align Path: [A, A, _, B, B]
            Output:     [0, 0, -1, 1, 1]
        """
        # 1. Identify where the path is NOT blank
        is_token = align_path != blank_id
        
        # 2. Identify transitions
        prev_path = torch.roll(align_path, 1)
        # Handle the very first frame: if it's a token, it must be the start of segment 0.
        prev_path[0] = blank_id # force mismatch for the first element
        
        # A new segment starts if: It's a token AND (it differs from prev OR prev was blank)
        # Note: If align_path[i] == align_path[i-1] (and not blank), it's the same segment.
        new_segment_start = is_token & (align_path != prev_path)
        
        # 3. Cumulative sum to assign indices (1..U)
        segment_ids = torch.cumsum(new_segment_start.long(), dim=0) - 1
        
        # 4. Mask out blank positions with -1
View full source →
.force_align(ctc_probs, y, blank_id) L399

ctc forced alignment.

Args:

torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)

torch.Tensor y: id sequence tensor 1d tensor (L)

int blank_id: blank symbol index

Returns:

  • torch.Tensor — alignment result
📄 Source
    def force_align(self, ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
        """ctc forced alignment.

        Args:
            torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
            torch.Tensor y: id sequence tensor 1d tensor (L)
            int blank_id: blank symbol index
        Returns:
            torch.Tensor: alignment result
        """
        ctc_probs = ctc_probs[None].cpu()
        y = y[None].cpu()
        alignments, _ = torchaudio.functional.forced_align(ctc_probs, y, blank=blank_id)
        return alignments[0]
.average_repeats_training(ctc_probs, target_idx_path, target_len) L414

Aggregates frames belonging to the same target index using scatter_add.

Args:

  • ctc_probs — [T, V]
  • target_idx_path — [T], values in [-1, 0, ... U-1]
  • target_len — U

Returns:

  • compressed — [U, V]
📄 Source
    def average_repeats_training(self, ctc_probs, target_idx_path, target_len):
        """
        Aggregates frames belonging to the same target index using scatter_add.
        
        Args:
            ctc_probs: [T, V]
            target_idx_path: [T], values in [-1, 0, ... U-1]
            target_len: U
        Returns:
            compressed: [U, V]
        """
        U = target_len
        V = ctc_probs.size(1)
        
        compressed = torch.zeros((U, V), device=ctc_probs.device, dtype=ctc_probs.dtype)
        counts = torch.zeros((U, 1), device=ctc_probs.device, dtype=ctc_probs.dtype)
        
        # Filter valid frames (non-blank)
        mask = target_idx_path != -1
        valid_indices = target_idx_path[mask] # [T_valid]
        valid_probs = ctc_probs[mask]         # [T_valid, V]
        
        if valid_indices.numel() == 0:
            return compressed
            
        # Scatter Add Probs
        index_expanded = valid_indices.unsqueeze(1).repeat(1, V)
        compressed.scatter_add_(0, index_expanded, valid_probs)
        
        # Scatter Add Counts
View full source →
.average_repeats_inference(ctc_probs, greedy_path) L451

Returns:

  • merged_probs — [U', V]
  • timestamps — List[Tuple[int, int]] -> [(start_frame, end_frame), ...]
📄 Source
    def average_repeats_inference(self, ctc_probs, greedy_path):
        """
        Returns:
            merged_probs: [U', V]
            timestamps: List[Tuple[int, int]] -> [(start_frame, end_frame), ...]
        """
        if greedy_path.numel() == 0:
            return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device)

        # Find consecutive segments in the greedy path
        unique_tokens, counts = torch.unique_consecutive(greedy_path, return_counts=True)
        
        # Compute start and end indices for each segment
        end_indices = torch.cumsum(counts, dim=0)
        start_indices = torch.cat([torch.tensor([0], device=counts.device), end_indices[:-1]])
        
        merged_probs = []
    
        for i, token in enumerate(unique_tokens):
            if token != self.blank_id:
                start = start_indices[i].item()
                end = end_indices[i].item()
                
                # Extract and average probabilities for the decoder
                avg_prob = ctc_probs[start:end].mean(dim=0)
                merged_probs.append(avg_prob)

        
        if not merged_probs:
            return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device)
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L484

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is not None:
                speech_lengths = speech_lengths.squeeze(-1)
            else:
View full source →
.export(**kwargs) L592

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
method

Paraformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.paraformer_v2_community.Paraformer · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()

        # decoder: CTC branch
View full source on GitHub →
method

Paraformer.encode(speech, speech_lengths, **kwargs)

funasr.models.paraformer_v2_community.Paraformer · View on GitHub ↗

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
method

Paraformer.map_alignment_to_target_index(align_path, blank_id)

funasr.models.paraformer_v2_community.Paraformer · View on GitHub ↗

Robustly map CTC alignment path (Token IDs) to Target Indices.

Logic:

Detect boundaries where a new token segment begins.

A segment starts if the current frame is a Token AND it is different from the previous frame

(considering CTC topology where repeats are separated by blanks or are distinct tokens).

Example:

  • Text — [A, B]

Align Path: [A, A, _, B, B]

  • Output — [0, 0, -1, 1, 1]
📄 Source code
    def map_alignment_to_target_index(self, align_path, blank_id):
        """
        Robustly map CTC alignment path (Token IDs) to Target Indices.
        
        Logic:
            Detect boundaries where a new token segment begins.
            A segment starts if the current frame is a Token AND it is different from the previous frame 
            (considering CTC topology where repeats are separated by blanks or are distinct tokens).
        
        Example:
            Text: [A, B]
            Align Path: [A, A, _, B, B]
            Output:     [0, 0, -1, 1, 1]
        """
        # 1. Identify where the path is NOT blank
        is_token = align_path != blank_id
        
        # 2. Identify transitions
        prev_path = torch.roll(align_path, 1)
        # Handle the very first frame: if it's a token, it must be the start of segment 0.
        prev_path[0] = blank_id # force mismatch for the first element
        
        # A new segment starts if: It's a token AND (it differs from prev OR prev was blank)
        # Note: If align_path[i] == align_path[i-1] (and not blank), it's the same segment.
        new_segment_start = is_token & (align_path != prev_path)
        
        # 3. Cumulative sum to assign indices (1..U)
        segment_ids = torch.cumsum(new_segment_start.long(), dim=0) - 1
        
        # 4. Mask out blank positions with -1
View full source on GitHub →
method

Paraformer.force_align(ctc_probs, y, blank_id)

funasr.models.paraformer_v2_community.Paraformer · View on GitHub ↗

ctc forced alignment.

Args:

torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)

torch.Tensor y: id sequence tensor 1d tensor (L)

int blank_id: blank symbol index

Returns:

  • torch.Tensor — alignment result
📄 Source code
    def force_align(self, ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
        """ctc forced alignment.

        Args:
            torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
            torch.Tensor y: id sequence tensor 1d tensor (L)
            int blank_id: blank symbol index
        Returns:
            torch.Tensor: alignment result
        """
        ctc_probs = ctc_probs[None].cpu()
        y = y[None].cpu()
        alignments, _ = torchaudio.functional.forced_align(ctc_probs, y, blank=blank_id)
        return alignments[0]
method

Paraformer.average_repeats_training(ctc_probs, target_idx_path, target_len)

funasr.models.paraformer_v2_community.Paraformer · View on GitHub ↗

Aggregates frames belonging to the same target index using scatter_add.

Args:

  • ctc_probs — [T, V]
  • target_idx_path — [T], values in [-1, 0, ... U-1]
  • target_len — U

Returns:

  • compressed — [U, V]
📄 Source code
    def average_repeats_training(self, ctc_probs, target_idx_path, target_len):
        """
        Aggregates frames belonging to the same target index using scatter_add.
        
        Args:
            ctc_probs: [T, V]
            target_idx_path: [T], values in [-1, 0, ... U-1]
            target_len: U
        Returns:
            compressed: [U, V]
        """
        U = target_len
        V = ctc_probs.size(1)
        
        compressed = torch.zeros((U, V), device=ctc_probs.device, dtype=ctc_probs.dtype)
        counts = torch.zeros((U, 1), device=ctc_probs.device, dtype=ctc_probs.dtype)
        
        # Filter valid frames (non-blank)
        mask = target_idx_path != -1
        valid_indices = target_idx_path[mask] # [T_valid]
        valid_probs = ctc_probs[mask]         # [T_valid, V]
        
        if valid_indices.numel() == 0:
            return compressed
            
        # Scatter Add Probs
        index_expanded = valid_indices.unsqueeze(1).repeat(1, V)
        compressed.scatter_add_(0, index_expanded, valid_probs)
        
        # Scatter Add Counts
View full source on GitHub →
method

Paraformer.average_repeats_inference(ctc_probs, greedy_path)

funasr.models.paraformer_v2_community.Paraformer · View on GitHub ↗

Returns:

  • merged_probs — [U', V]
  • timestamps — List[Tuple[int, int]] -> [(start_frame, end_frame), ...]
📄 Source code
    def average_repeats_inference(self, ctc_probs, greedy_path):
        """
        Returns:
            merged_probs: [U', V]
            timestamps: List[Tuple[int, int]] -> [(start_frame, end_frame), ...]
        """
        if greedy_path.numel() == 0:
            return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device)

        # Find consecutive segments in the greedy path
        unique_tokens, counts = torch.unique_consecutive(greedy_path, return_counts=True)
        
        # Compute start and end indices for each segment
        end_indices = torch.cumsum(counts, dim=0)
        start_indices = torch.cat([torch.tensor([0], device=counts.device), end_indices[:-1]])
        
        merged_probs = []
    
        for i, token in enumerate(unique_tokens):
            if token != self.blank_id:
                start = start_indices[i].item()
                end = end_indices[i].item()
                
                # Extract and average probabilities for the decoder
                avg_prob = ctc_probs[start:end].mean(dim=0)
                merged_probs.append(avg_prob)

        
        if not merged_probs:
            return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device)
View full source on GitHub →
method

Paraformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.paraformer_v2_community.Paraformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is not None:
                speech_lengths = speech_lengths.squeeze(-1)
            else:
View full source on GitHub →
method

Paraformer.export(**kwargs)

funasr.models.paraformer_v2_community.Paraformer · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
class

Qwen3ASR

funasr.models.qwen3_asr · View on GitHub ↗
  • Qwen3 — ASR: Large Language Model based ASR supporting 52 languages.

Wraps the qwen-asr package's Qwen3ASRModel for use within FunASR's AutoModel interface.

Supports auto language detection, contextual recognition, and optional forced alignment

for character-level timestamps.

Requirements:

pip install qwen-asr

Models:

  • Qwen/Qwen3-ASR-0.6B (lighter, ~4GB GPU memory)
  • Qwen/Qwen3-ASR-1.7B (more accurate, ~8GB GPU memory)
📄 Source code
class Qwen3ASR(nn.Module):
    """Qwen3-ASR: Large Language Model based ASR supporting 52 languages.

    Wraps the qwen-asr package's Qwen3ASRModel for use within FunASR's AutoModel interface.
    Supports auto language detection, contextual recognition, and optional forced alignment
    for character-level timestamps.

    Requirements:
        pip install qwen-asr

    Models:
        - Qwen/Qwen3-ASR-0.6B (lighter, ~4GB GPU memory)
        - Qwen/Qwen3-ASR-1.7B (more accurate, ~8GB GPU memory)
    """

    def __init__(self, **kwargs):
        """Initialize Qwen3ASR.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        model_path = kwargs.get("model_path", kwargs.get("model", "Qwen/Qwen3-ASR-1.7B"))
        device = kwargs.get("device", "cuda:0")
        dtype = kwargs.get("dtype", "bf16")
        hub = kwargs.get("hub", "ms")
        max_new_tokens = kwargs.get("max_new_tokens", 512)
        max_inference_batch_size = kwargs.get("max_inference_batch_size", 32)
        forced_aligner = kwargs.get("forced_aligner", None)
        forced_aligner_kwargs = kwargs.get("forced_aligner_kwargs", None)
View full source on GitHub →

Methods

.forward(**kwargs) L105

Forward pass for training.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(self, **kwargs):
        """Forward pass for training.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        raise NotImplementedError("Qwen3ASR only supports inference mode")
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L113

Run Qwen3-ASR speech recognition.

Args:

  • data_in — Audio input. Accepts:
  • list of file paths/URLs
  • list of (numpy_array, sample_rate) tuples
  • single numpy array or torch Tensor
  • data_lengths — Not used.
  • key (list) — Sample identifiers.
  • tokenizer — Not used (Qwen3-ASR has internal tokenizer).
  • frontend — Not used (Qwen3-ASR has internal audio processing).
  • **kwargs — Runtime parameters:
  • language (str): Language hint (e.g. "Chinese", "English") or None for auto-detect.
  • return_time_stamps (bool): Return character-level timestamps (requires forced_aligner).
  • output_timestamp (bool): Same as return_time_stamps (for pipeline compatibility).
  • context (str): Context prompt for contextual recognition.

Returns:

  • tuple — (results, meta_data) where results is list of dicts:
  • "key" (str): Sample ID
  • "text" (str): Recognized text (with punctuation)
  • "language" (str): Detected language (if available)
  • "timestamp" (list): [[start_ms, end_ms], ...] (if timestamps enabled)
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run Qwen3-ASR speech recognition.

        Args:
            data_in: Audio input. Accepts:
                - list of file paths/URLs
                - list of (numpy_array, sample_rate) tuples
                - single numpy array or torch Tensor
            data_lengths: Not used.
            key (list): Sample identifiers.
            tokenizer: Not used (Qwen3-ASR has internal tokenizer).
            frontend: Not used (Qwen3-ASR has internal audio processing).
            **kwargs: Runtime parameters:
                - language (str): Language hint (e.g. "Chinese", "English") or None for auto-detect.
                - return_time_stamps (bool): Return character-level timestamps (requires forced_aligner).
                - output_timestamp (bool): Same as return_time_stamps (for pipeline compatibility).
                - context (str): Context prompt for contextual recognition.

        Returns:
            tuple: (results, meta_data) where results is list of dicts:
                - "key" (str): Sample ID
                - "text" (str): Recognized text (with punctuation)
View full source →
method

Qwen3ASR.forward(**kwargs)

funasr.models.qwen3_asr.Qwen3ASR · View on GitHub ↗

Forward pass for training.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(self, **kwargs):
        """Forward pass for training.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        raise NotImplementedError("Qwen3ASR only supports inference mode")
method

Qwen3ASR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.qwen3_asr.Qwen3ASR · View on GitHub ↗

Run Qwen3-ASR speech recognition.

Args:

  • data_in — Audio input. Accepts:
  • list of file paths/URLs
  • list of (numpy_array, sample_rate) tuples
  • single numpy array or torch Tensor
  • data_lengths — Not used.
  • key (list) — Sample identifiers.
  • tokenizer — Not used (Qwen3-ASR has internal tokenizer).
  • frontend — Not used (Qwen3-ASR has internal audio processing).
  • **kwargs — Runtime parameters:
  • language (str): Language hint (e.g. "Chinese", "English") or None for auto-detect.
  • return_time_stamps (bool): Return character-level timestamps (requires forced_aligner).
  • output_timestamp (bool): Same as return_time_stamps (for pipeline compatibility).
  • context (str): Context prompt for contextual recognition.

Returns:

  • tuple — (results, meta_data) where results is list of dicts:
  • "key" (str): Sample ID
  • "text" (str): Recognized text (with punctuation)
  • "language" (str): Detected language (if available)
  • "timestamp" (list): [[start_ms, end_ms], ...] (if timestamps enabled)
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run Qwen3-ASR speech recognition.

        Args:
            data_in: Audio input. Accepts:
                - list of file paths/URLs
                - list of (numpy_array, sample_rate) tuples
                - single numpy array or torch Tensor
            data_lengths: Not used.
            key (list): Sample identifiers.
            tokenizer: Not used (Qwen3-ASR has internal tokenizer).
            frontend: Not used (Qwen3-ASR has internal audio processing).
            **kwargs: Runtime parameters:
                - language (str): Language hint (e.g. "Chinese", "English") or None for auto-detect.
                - return_time_stamps (bool): Return character-level timestamps (requires forced_aligner).
                - output_timestamp (bool): Same as return_time_stamps (for pipeline compatibility).
                - context (str): Context prompt for contextual recognition.

        Returns:
            tuple: (results, meta_data) where results is list of dicts:
                - "key" (str): Sample ID
                - "text" (str): Recognized text (with punctuation)
View full source on GitHub →
class

QwenAudioWarp

funasr.models.qwen_audio · View on GitHub ↗
  • Qwen — Audio: Advancing Universal Audio Understanding via Unified Large-Scale Audio-Language Models

https://arxiv.org/abs/2311.07919

Modified from https://github.com/QwenLM/Qwen-Audio

📄 Source code
class QwenAudioWarp(nn.Module):
    """
    Qwen-Audio: Advancing Universal Audio Understanding via Unified Large-Scale Audio-Language Models
    https://arxiv.org/abs/2311.07919
    Modified from https://github.com/QwenLM/Qwen-Audio
    """

    def __init__(self, *args, **kwargs):
        """Initialize QwenAudioWarp.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from transformers.generation import GenerationConfig

        model_or_path = kwargs.get("model_path", "QwenAudio")
        model = AutoModelForCausalLM.from_pretrained(
            model_or_path, device_map="cpu", trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True)

        self.model = model
        self.tokenizer = tokenizer

    def forward(
        self,
    ):
View full source on GitHub →

Methods

.forward() L49

Forward pass for training.

📄 Source
    def forward(
        self,
    ):
        """Forward pass for training."""
        pass
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L55

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        # meta_data["batch_data_time"] = -1
        prompt = kwargs.get(
            "prompt", "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
        )
        query = f"<audio>{data_in[0]}</audio>{prompt}"
        audio_info = self.tokenizer.process_audio(query)
        inputs = self.tokenizer(query, return_tensors="pt", audio_info=audio_info)
View full source →
method

QwenAudioWarp.forward()

funasr.models.qwen_audio.QwenAudioWarp · View on GitHub ↗

Forward pass for training.

📄 Source code
    def forward(
        self,
    ):
        """Forward pass for training."""
        pass
method

QwenAudioWarp.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.qwen_audio.QwenAudioWarp · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        # meta_data["batch_data_time"] = -1
        prompt = kwargs.get(
            "prompt", "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
        )
        query = f"<audio>{data_in[0]}</audio>{prompt}"
        audio_info = self.tokenizer.process_audio(query)
        inputs = self.tokenizer(query, return_tensors="pt", audio_info=audio_info)
View full source on GitHub →
class

QwenAudioChatWarp

funasr.models.qwen_audio · View on GitHub ↗
  • QwenAudioChat — Qwen Audio Chat model wrapper.

Interactive audio chat using the Qwen-Audio-Chat model.

Supports multi-turn conversation about audio content.

📄 Source code
class QwenAudioChatWarp(nn.Module):
    """QwenAudioChat: Qwen Audio Chat model wrapper.

    Interactive audio chat using the Qwen-Audio-Chat model.
    Supports multi-turn conversation about audio content.
    """

    def __init__(self, *args, **kwargs):
        """
        Qwen-Audio: Advancing Universal Audio Understanding via Unified Large-Scale Audio-Language Models
        https://arxiv.org/abs/2311.07919
        Modified from https://github.com/QwenLM/Qwen-Audio
        """
        super().__init__()
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from transformers.generation import GenerationConfig

        model_or_path = kwargs.get("model_path", "QwenAudio")
        bf16 = kwargs.get("bf16", False)
        fp16 = kwargs.get("fp16", False)
        model = AutoModelForCausalLM.from_pretrained(
            model_or_path, device_map="cpu", bf16=bf16, fp16=fp16, trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True)

        self.model = model
        self.tokenizer = tokenizer

    def forward(
        self,
View full source on GitHub →

Methods

.forward() L132

Forward pass for training.

📄 Source
    def forward(
        self,
    ):
        """Forward pass for training."""
        pass
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L138

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}

        prompt = kwargs.get("prompt", "what does the person say?")
        cache = kwargs.get("cache", {})
        history = cache.get("history", None)
        if data_in[0] is not None:
            # 1st dialogue turn
            query = self.tokenizer.from_list_format(
View full source →
method

QwenAudioChatWarp.forward()

funasr.models.qwen_audio.QwenAudioChatWarp · View on GitHub ↗

Forward pass for training.

📄 Source code
    def forward(
        self,
    ):
        """Forward pass for training."""
        pass
method

QwenAudioChatWarp.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.qwen_audio.QwenAudioChatWarp · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}

        prompt = kwargs.get("prompt", "what does the person say?")
        cache = kwargs.get("cache", {})
        history = cache.get("history", None)
        if data_in[0] is not None:
            # 1st dialogue turn
            query = self.tokenizer.from_list_format(
View full source on GitHub →
class

SANM

funasr.models.sanm · View on GitHub ↗
  • Author — Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
  • San — m: Memory equipped self-attention for end-to-end speech recognition

https://arxiv.org/abs/2006.01713

📄 Source code
class SANM(Transformer):
    """
    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):

        """Initialize SANM.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)
class

SanmKWS

funasr.models.sanm_kws · View on GitHub ↗
  • SANM — KWS: Self-Attention Neural Memory based Keyword Spotting.

Advanced keyword spotting using self-attention mechanism

for better context modeling of keyword patterns.

  • Output — {"key": str, "value": detected_keyword_info}
📄 Source code
class SanmKWS(torch.nn.Module):
    """SANM-KWS: Self-Attention Neural Memory based Keyword Spotting.

    Advanced keyword spotting using self-attention mechanism
    for better context modeling of keyword patterns.

    Output: {"key": str, "value": detected_keyword_info}
    """

    def __init__(
        self,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        ctc_weight: float = 1.0,
        input_size: int = 360,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        **kwargs,
    ):

        """Initialize SanmKWS.
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L112

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        # decoder: CTC branch
        loss_ctc, cer_ctc = self._calc_ctc_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )

        # Collect CTC branch stats
View full source →
.encode(speech, speech_lengths, **kwargs) L156

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L209

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        keywords = kwargs.get("keywords")
        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer.token_list,
            seg_dict=tokenizer.seg_dict,
        )

        meta_data = {}
        if (
View full source →
.export(**kwargs) L300

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
method

SanmKWS.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.sanm_kws.SanmKWS · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]

        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        # decoder: CTC branch
        loss_ctc, cer_ctc = self._calc_ctc_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )

        # Collect CTC branch stats
View full source on GitHub →
method

SanmKWS.encode(speech, speech_lengths, **kwargs)

funasr.models.sanm_kws.SanmKWS · View on GitHub ↗

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
method

SanmKWS.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.sanm_kws.SanmKWS · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        keywords = kwargs.get("keywords")
        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer.token_list,
            seg_dict=tokenizer.seg_dict,
        )

        meta_data = {}
        if (
View full source on GitHub →
method

SanmKWS.export(**kwargs)

funasr.models.sanm_kws.SanmKWS · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
class

SanmKWSStreaming

funasr.models.sanm_kws_streaming · View on GitHub ↗
  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • Paraformer — Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition

https://arxiv.org/abs/2206.08317

📄 Source code
class SanmKWSStreaming(SanmKWS):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):
        """Initialize SanmKWSStreaming.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)

    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L63

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        decoding_ind = kwargs.get("decoding_ind")
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        if hasattr(self.encoder, "overlap_chunk_cls"):
            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

View full source →
.encode_chunk(speech, speech_lengths, cache, **kwargs) L119

Encode chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def encode_chunk(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
View full source →
.init_cache(cache, **kwargs) L161

Init cache.

Args:

  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def init_cache(self, cache: dict = None, **kwargs):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
        batch_size = 1

        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
        cache_encoder = {
            "start_idx": 0,
            "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
            "cif_alphas": torch.zeros((batch_size, 1)),
            "encoder_out": None,
            "encoder_out_lens": None,
            "chunk_size": chunk_size,
            "encoder_chunk_look_back": encoder_chunk_look_back,
            "last_chunk": False,
            "opt": None,
            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
            "tail_chunk": False,
        }
        cache["encoder"] = cache_encoder
View full source →
.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs) L204

Generate chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def generate_chunk(
        self,
        speech,
        speech_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Generate chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        cache = kwargs.get("cache", {})
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])

        # Encoder
        is_final = kwargs.get("is_final", False)
        encoder_out, encoder_out_lens = self.encode_chunk(
            speech, speech_lengths, cache=cache, is_final=is_final
        )
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L287

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        keywords = kwargs.get("keywords")
        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer.token_list,
            seg_dict=tokenizer.seg_dict,
View full source →
.export(**kwargs) L490

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
method

SanmKWSStreaming.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.sanm_kws_streaming.SanmKWSStreaming · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        decoding_ind = kwargs.get("decoding_ind")
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        if hasattr(self.encoder, "overlap_chunk_cls"):
            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

View full source on GitHub →
method

SanmKWSStreaming.encode_chunk(speech, speech_lengths, cache, **kwargs)

funasr.models.sanm_kws_streaming.SanmKWSStreaming · View on GitHub ↗

Encode chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def encode_chunk(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
View full source on GitHub →
method

SanmKWSStreaming.init_cache(cache, **kwargs)

funasr.models.sanm_kws_streaming.SanmKWSStreaming · View on GitHub ↗

Init cache.

Args:

  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_cache(self, cache: dict = None, **kwargs):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
        batch_size = 1

        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
        cache_encoder = {
            "start_idx": 0,
            "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
            "cif_alphas": torch.zeros((batch_size, 1)),
            "encoder_out": None,
            "encoder_out_lens": None,
            "chunk_size": chunk_size,
            "encoder_chunk_look_back": encoder_chunk_look_back,
            "last_chunk": False,
            "opt": None,
            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
            "tail_chunk": False,
        }
        cache["encoder"] = cache_encoder
View full source on GitHub →
method

SanmKWSStreaming.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.sanm_kws_streaming.SanmKWSStreaming · View on GitHub ↗

Generate chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def generate_chunk(
        self,
        speech,
        speech_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Generate chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        cache = kwargs.get("cache", {})
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])

        # Encoder
        is_final = kwargs.get("is_final", False)
        encoder_out, encoder_out_lens = self.encode_chunk(
            speech, speech_lengths, cache=cache, is_final=is_final
        )
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
View full source on GitHub →
method

SanmKWSStreaming.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)

funasr.models.sanm_kws_streaming.SanmKWSStreaming · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        keywords = kwargs.get("keywords")
        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
        self.kws_decoder = KwsCtcPrefixDecoder(
            ctc=self.ctc,
            keywords=keywords,
            token_list=tokenizer.token_list,
            seg_dict=tokenizer.seg_dict,
View full source on GitHub →
method

SanmKWSStreaming.export(**kwargs)

funasr.models.sanm_kws_streaming.SanmKWSStreaming · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
class

SCAMA

funasr.models.scama · View on GitHub ↗
  • SCAMA — Streaming Chunk-Aware Multi-head Attention ASR.

Streaming ASR using chunk-based encoder with configurable latency.

Supports 2-pass decoding for accuracy refinement.

  • Output — {"key": str, "text": str}
  • Author — Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
  • SCAMA — Streaming chunk-aware multihead attention for online end-to-end speech recognition

https://arxiv.org/abs/2006.01712

📄 Source code
class SCAMA(nn.Module):
    """SCAMA: Streaming Chunk-Aware Multi-head Attention ASR.

    Streaming ASR using chunk-based encoder with configurable latency.
    Supports 2-pass decoding for accuracy refinement.

    Output: {"key": str, "text": str}

    Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
    https://arxiv.org/abs/2006.01712
    """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        predictor: str = None,
        predictor_conf: dict = None,
        predictor_bias: int = 0,
        predictor_weight: float = 0.0,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L200

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """

        decoding_ind = kwargs.get("decoding_ind")
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
View full source →
.encode(speech, speech_lengths, **kwargs) L277

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
.encode_chunk(speech, speech_lengths, cache, **kwargs) L306

Encode chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def encode_chunk(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
View full source →
.calc_predictor_chunk(encoder_out, encoder_out_lens, cache, **kwargs) L348

Calc predictor chunk.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
        """Calc predictor chunk.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        is_final = kwargs.get("is_final", False)

        return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
.calc_predictor_mask(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) L462

Calc predictor mask.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source
    def calc_predictor_mask(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor = None,
        ys_pad_lens: torch.Tensor = None,
    ):
        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # ys_in_lens = ys_pad_lens + 1
        """Calc predictor mask.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        ys_out_pad, ys_in_lens = None, None

        encoder_out_mask = sequence_mask(
            encoder_out_lens,
            maxlen=encoder_out.size(1),
            dtype=encoder_out.dtype,
            device=encoder_out.device,
        )[:, None, :]
        mask_chunk_predictor = None

        mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
            None, device=encoder_out.device, batch_size=encoder_out.size(0)
        )
View full source →
.init_beam_search(**kwargs) L543

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def init_beam_search(
        self,
        **kwargs,
    ):

        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.scama.beam_search import BeamSearchScamaStreaming

        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
View full source →
.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs) L599

Generate chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def generate_chunk(
        self,
        speech,
        speech_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Generate chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        cache = kwargs.get("cache", {})
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])

        # Encoder
        encoder_out, encoder_out_lens = self.encode_chunk(
            speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)
        )
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        if "running_hyps" not in cache:
View full source →
.init_cache(cache, **kwargs) L694

Init cache.

Args:

  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def init_cache(self, cache: dict = None, **kwargs):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        device = kwargs.get("device", "cuda")

        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
        batch_size = 1

        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]

        cache_encoder = {
            "start_idx": 0,
            "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
            "cif_alphas": torch.zeros((batch_size, 1)).to(device=device),
            "chunk_size": chunk_size,
            "encoder_chunk_look_back": encoder_chunk_look_back,
            "last_chunk": False,
            "opt": None,
            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(
                device=device
            ),
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L741

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}

        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )

View full source →
method

SCAMA.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.scama.SCAMA · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """

        decoding_ind = kwargs.get("decoding_ind")
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # Encoder
        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)

        loss_ctc, cer_ctc = None, None
        loss_pre = None
View full source on GitHub →
method

SCAMA.encode(speech, speech_lengths, **kwargs)

funasr.models.scama.SCAMA · View on GitHub ↗

Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return encoder_out, encoder_out_lens
method

SCAMA.encode_chunk(speech, speech_lengths, cache, **kwargs)

funasr.models.scama.SCAMA · View on GitHub ↗

Encode chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def encode_chunk(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
View full source on GitHub →
method

SCAMA.calc_predictor_chunk(encoder_out, encoder_out_lens, cache, **kwargs)

funasr.models.scama.SCAMA · View on GitHub ↗

Calc predictor chunk.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
        """Calc predictor chunk.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        is_final = kwargs.get("is_final", False)

        return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
method

SCAMA.calc_predictor_mask(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)

funasr.models.scama.SCAMA · View on GitHub ↗

Calc predictor mask.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source code
    def calc_predictor_mask(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor = None,
        ys_pad_lens: torch.Tensor = None,
    ):
        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # ys_in_lens = ys_pad_lens + 1
        """Calc predictor mask.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        ys_out_pad, ys_in_lens = None, None

        encoder_out_mask = sequence_mask(
            encoder_out_lens,
            maxlen=encoder_out.size(1),
            dtype=encoder_out.dtype,
            device=encoder_out.device,
        )[:, None, :]
        mask_chunk_predictor = None

        mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
            None, device=encoder_out.device, batch_size=encoder_out.size(0)
        )
View full source on GitHub →
method

SCAMA.init_beam_search(**kwargs)

funasr.models.scama.SCAMA · View on GitHub ↗

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_beam_search(
        self,
        **kwargs,
    ):

        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.scama.beam_search import BeamSearchScamaStreaming

        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
View full source on GitHub →
method

SCAMA.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.scama.SCAMA · View on GitHub ↗

Generate chunk.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def generate_chunk(
        self,
        speech,
        speech_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Generate chunk.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        cache = kwargs.get("cache", {})
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])

        # Encoder
        encoder_out, encoder_out_lens = self.encode_chunk(
            speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)
        )
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        if "running_hyps" not in cache:
View full source on GitHub →
method

SCAMA.init_cache(cache, **kwargs)

funasr.models.scama.SCAMA · View on GitHub ↗

Init cache.

Args:

  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_cache(self, cache: dict = None, **kwargs):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        device = kwargs.get("device", "cuda")

        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
        batch_size = 1

        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]

        cache_encoder = {
            "start_idx": 0,
            "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
            "cif_alphas": torch.zeros((batch_size, 1)).to(device=device),
            "chunk_size": chunk_size,
            "encoder_chunk_look_back": encoder_chunk_look_back,
            "last_chunk": False,
            "opt": None,
            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(
                device=device
            ),
View full source on GitHub →
method

SCAMA.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)

funasr.models.scama.SCAMA · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        cache: dict = None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}

        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )

View full source on GitHub →
class

SeacoParaformer

funasr.models.seaco_paraformer · View on GitHub ↗
  • SeACo — Paraformer: Semantic-Aware Contextual Paraformer.

The recommended Chinese ASR model. Combines Paraformer's non-autoregressive

architecture with semantic context biasing for hotword recognition.

Registered as 'paraformer-zh' alias.

  • Output — {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}
📄 Source code
class SeacoParaformer(BiCifParaformer, Paraformer):
    """SeACo-Paraformer: Semantic-Aware Contextual Paraformer.

    The recommended Chinese ASR model. Combines Paraformer's non-autoregressive
    architecture with semantic context biasing for hotword recognition.

    Registered as 'paraformer-zh' alias.
    Output: {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}
    """

    def __init__(
        self,
        *args,
        **kwargs,
    ):
        """Initialize SeacoParaformer.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)

        self.inner_dim = kwargs.get("inner_dim", 256)
        self.bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
        bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
        bias_encoder_bid = kwargs.get("bias_encoder_bid", False)
        seaco_lsm_weight = kwargs.get("seaco_lsm_weight", 0.0)
        seaco_length_normalized_loss = kwargs.get("seaco_length_normalized_loss", True)

View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L122

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        # Check that batch_size is unified
        assert (
            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)

        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        seaco_label_pad = kwargs.get("seaco_label_pad")
        if len(hotword_lengths.size()) > 1:
            hotword_lengths = hotword_lengths[:, 0]
View full source →
.calc_predictor(encoder_out, encoder_out_lens) L202

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source
    def calc_predictor(self, encoder_out, encoder_out_lens):
        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        predictor_outs = self.predictor(
            encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
        )
        return predictor_outs[:4]
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L422

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        # init beamsearch
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
View full source →
.generate_hotwords_list(hotword_list_or_file, tokenizer, frontend) L583

Generate hotwords list.

Args:

  • hotword_list_or_file — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
📄 Source
    def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):

        """Generate hotwords list.
        
            Args:
                hotword_list_or_file: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
            """
        def seg_tokenize(txt, seg_dict):
            """Seg tokenize.
            
                Args:
                    txt: TODO.
                    seg_dict: TODO.
                """
            pattern = re.compile(r"^[\u4E00-\u9FA50-9]+$")
            out_txt = ""
            for word in txt:
                word = word.lower()
                if word in seg_dict:
                    out_txt += seg_dict[word] + " "
                else:
                    if pattern.match(word):
                        for char in word:
                            if char in seg_dict:
                                out_txt += seg_dict[char] + " "
                            else:
                                out_txt += "<unk>" + " "
                    else:
View full source →
.export(**kwargs) L692

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(
        self,
        **kwargs,
    ):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
method

SeacoParaformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.seaco_paraformer.SeacoParaformer · View on GitHub ↗

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss

        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        # Check that batch_size is unified
        assert (
            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)

        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        seaco_label_pad = kwargs.get("seaco_label_pad")
        if len(hotword_lengths.size()) > 1:
            hotword_lengths = hotword_lengths[:, 0]
View full source on GitHub →
method

SeacoParaformer.calc_predictor(encoder_out, encoder_out_lens)

funasr.models.seaco_paraformer.SeacoParaformer · View on GitHub ↗

Calc predictor.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
📄 Source code
    def calc_predictor(self, encoder_out, encoder_out_lens):
        """Calc predictor.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
            """
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        predictor_outs = self.predictor(
            encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
        )
        return predictor_outs[:4]
method

SeacoParaformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.seaco_paraformer.SeacoParaformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        # init beamsearch
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
View full source on GitHub →
method

SeacoParaformer.generate_hotwords_list(hotword_list_or_file, tokenizer, frontend)

funasr.models.seaco_paraformer.SeacoParaformer · View on GitHub ↗

Generate hotwords list.

Args:

  • hotword_list_or_file — TODO.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
📄 Source code
    def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):

        """Generate hotwords list.
        
            Args:
                hotword_list_or_file: TODO.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
            """
        def seg_tokenize(txt, seg_dict):
            """Seg tokenize.
            
                Args:
                    txt: TODO.
                    seg_dict: TODO.
                """
            pattern = re.compile(r"^[\u4E00-\u9FA50-9]+$")
            out_txt = ""
            for word in txt:
                word = word.lower()
                if word in seg_dict:
                    out_txt += seg_dict[word] + " "
                else:
                    if pattern.match(word):
                        for char in word:
                            if char in seg_dict:
                                out_txt += seg_dict[char] + " "
                            else:
                                out_txt += "<unk>" + " "
                    else:
View full source on GitHub →
method

SeacoParaformer.export(**kwargs)

funasr.models.seaco_paraformer.SeacoParaformer · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(
        self,
        **kwargs,
    ):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        from .export_meta import export_rebuild_model

        models = export_rebuild_model(model=self, **kwargs)
        return models
function

load_seg_dict(seg_dict_file)

funasr.models.seaco_paraformer · View on GitHub ↗

Load seg dict.

Args:

  • seg_dict_file — TODO.
📄 Source code
def load_seg_dict(seg_dict_file):
    """Load seg dict.
    
        Args:
            seg_dict_file: TODO.
        """
    seg_dict = {}
    assert isinstance(seg_dict_file, str)
    with open(seg_dict_file, "r", encoding="utf8") as f:
        lines = f.readlines()
        for line in lines:
            s = line.strip().split()
            key = s[0]
            value = s[1:]
            seg_dict[key] = " ".join(value)
    return seg_dict
function

sequence_mask(lengths, maxlen, dtype, device)

funasr.models.sense_voice · View on GitHub ↗

Sequence mask.

Args:

  • lengths — TODO.
  • maxlen — TODO.
  • dtype — TODO.
  • device — Target device ("cuda:0", "cpu", etc.).
📄 Source code
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
    """Sequence mask.
    
        Args:
            lengths: TODO.
            maxlen: TODO.
            dtype: TODO.
            device: Target device ("cuda:0", "cpu", etc.).
        """
    if maxlen is None:
        maxlen = lengths.max()
    row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
    matrix = torch.unsqueeze(lengths, dim=-1)
    mask = row_vector < matrix
    mask = mask.detach()

    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
class

SenseVoiceEncoderSmall

funasr.models.sense_voice · View on GitHub ↗
  • Author — Speech Lab of DAMO Academy, Alibaba Group
  • SCAMA — Streaming chunk-aware multihead attention for online end-to-end speech recognition

https://arxiv.org/abs/2006.01713

📄 Source code
class SenseVoiceEncoderSmall(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    """

    def __init__(
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        tp_blocks: int = 0,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        stochastic_depth_rate: float = 0.0,
        input_layer: Optional[str] = "conv2d",
        pos_enc_class=SinusoidalPositionEncoder,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 1,
        padding_idx: int = -1,
        kernel_size: int = 11,
        sanm_shfit: int = 0,
        selfattention_layer_type: str = "sanm",
        **kwargs,
View full source on GitHub →

Methods

.output_size() L619

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return self._output_size
.forward(xs_pad, ilens) L623

Embed positions in tensor.

📄 Source
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
    ):
        """Embed positions in tensor."""
        maxlen = xs_pad.shape[1]
        masks = sequence_mask(ilens, maxlen=maxlen, device=ilens.device)[:, None, :]

        xs_pad *= self.output_size() ** 0.5

        xs_pad = self.embed(xs_pad)

        # forward encoder1
        for layer_idx, encoder_layer in enumerate(self.encoders0):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]

        for layer_idx, encoder_layer in enumerate(self.encoders):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]

        xs_pad = self.after_norm(xs_pad)

        # forward encoder2
        olens = masks.squeeze(1).sum(1).int()

        for layer_idx, encoder_layer in enumerate(self.tp_encoders):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
View full source →
method

SenseVoiceEncoderSmall.output_size()

funasr.models.sense_voice.SenseVoiceEncoderSmall · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return self._output_size
method

SenseVoiceEncoderSmall.forward(xs_pad, ilens)

funasr.models.sense_voice.SenseVoiceEncoderSmall · View on GitHub ↗

Embed positions in tensor.

📄 Source code
    def forward(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
    ):
        """Embed positions in tensor."""
        maxlen = xs_pad.shape[1]
        masks = sequence_mask(ilens, maxlen=maxlen, device=ilens.device)[:, None, :]

        xs_pad *= self.output_size() ** 0.5

        xs_pad = self.embed(xs_pad)

        # forward encoder1
        for layer_idx, encoder_layer in enumerate(self.encoders0):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]

        for layer_idx, encoder_layer in enumerate(self.encoders):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]

        xs_pad = self.after_norm(xs_pad)

        # forward encoder2
        olens = masks.squeeze(1).sum(1).int()

        for layer_idx, encoder_layer in enumerate(self.tp_encoders):
            encoder_outs = encoder_layer(xs_pad, masks)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
View full source on GitHub →
class

SenseVoiceSmall

funasr.models.sense_voice · View on GitHub ↗
  • CTC — attention hybrid Encoder-Decoder model
📄 Source code
class SenseVoiceSmall(nn.Module):
    """CTC-attention hybrid Encoder-Decoder model"""

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        ctc_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        length_normalized_loss: bool = False,
        **kwargs,
    ):

        """Initialize SenseVoiceSmall.
        
            Args:
                specaug: TODO.
                specaug_conf: Configuration dict for specaug.
                normalize: TODO.
                normalize_conf: Configuration dict for normalize.
                encoder: TODO.
View full source on GitHub →

Methods

.from_pretrained(model, **kwargs) L754

From pretrained.

Args:

  • model — Model instance or model name.
  • **kwargs — Additional keyword arguments.
📄 Source
    def from_pretrained(model: str = None, **kwargs):
        """From pretrained.
        
            Args:
                model: Model instance or model name.
                **kwargs: Additional keyword arguments.
            """
        from funasr import AutoModel

        model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)

        return model, kwargs
.forward(speech, speech_lengths, text, text_lengths, **kwargs) L767

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ):
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)

        loss_ctc, cer_ctc = None, None
        loss_rich, acc_rich = None, None
        stats = dict()
View full source →
.encode(speech, speech_lengths, text, **kwargs) L817

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        **kwargs,
    ):
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """

        # Data augmentation
        if self.specaug is not None and self.training:
            speech, speech_lengths = self.specaug(speech, speech_lengths)

        # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
        if self.normalize is not None:
            speech, speech_lengths = self.normalize(speech, speech_lengths)

        lids = torch.LongTensor(
            [
                [
                    (
                        self.lid_int_dict[int(lid)]
                        if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict
                        else 0
                    )
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L918

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = ["wav_file_tmp_name"],
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
View full source →
.post(timestamp) L1082

Post.

Args:

  • timestamp — TODO.
📄 Source
    def post(self, timestamp):
        """Post.
        
            Args:
                timestamp: TODO.
            """
        timestamp_new = []
        words_new = []
        prev_word = None
        for i, t in enumerate(timestamp):
            word, start, end = t
            start = int(start * 1000)
            end = int(end * 1000)
            if word == "▁":
                continue
            if i == 0:
                # timestamp_new.append([word, start, end])
                timestamp_new.append([start, end])
                words_new.append(word)
            elif word.startswith("▁"):
                word = word[1:]
                timestamp_new.append([start, end])
                words_new.append(word)
            elif prev_word is not None and prev_word.isalpha() and prev_word.isascii() and word.isalpha() and word.isascii():
                word = prev_word + word
                timestamp_new[-1][1] = end
                words_new[-1] = word
            else:
                # timestamp_new[-1][0] += word
                timestamp_new.append([start, end])
View full source →
.export(**kwargs) L1116

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models

        return results, meta_data
method

SenseVoiceSmall.from_pretrained(model, **kwargs)

funasr.models.sense_voice.SenseVoiceSmall · View on GitHub ↗

From pretrained.

Args:

  • model — Model instance or model name.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def from_pretrained(model: str = None, **kwargs):
        """From pretrained.
        
            Args:
                model: Model instance or model name.
                **kwargs: Additional keyword arguments.
            """
        from funasr import AutoModel

        model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)

        return model, kwargs
method

SenseVoiceSmall.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.sense_voice.SenseVoiceSmall · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ):
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)

        loss_ctc, cer_ctc = None, None
        loss_rich, acc_rich = None, None
        stats = dict()
View full source on GitHub →
method

SenseVoiceSmall.encode(speech, speech_lengths, text, **kwargs)

funasr.models.sense_voice.SenseVoiceSmall · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        **kwargs,
    ):
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """

        # Data augmentation
        if self.specaug is not None and self.training:
            speech, speech_lengths = self.specaug(speech, speech_lengths)

        # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
        if self.normalize is not None:
            speech, speech_lengths = self.normalize(speech, speech_lengths)

        lids = torch.LongTensor(
            [
                [
                    (
                        self.lid_int_dict[int(lid)]
                        if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict
                        else 0
                    )
View full source on GitHub →
method

SenseVoiceSmall.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.sense_voice.SenseVoiceSmall · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = ["wav_file_tmp_name"],
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
View full source on GitHub →
method

SenseVoiceSmall.post(timestamp)

funasr.models.sense_voice.SenseVoiceSmall · View on GitHub ↗

Post.

Args:

  • timestamp — TODO.
📄 Source code
    def post(self, timestamp):
        """Post.
        
            Args:
                timestamp: TODO.
            """
        timestamp_new = []
        words_new = []
        prev_word = None
        for i, t in enumerate(timestamp):
            word, start, end = t
            start = int(start * 1000)
            end = int(end * 1000)
            if word == "▁":
                continue
            if i == 0:
                # timestamp_new.append([word, start, end])
                timestamp_new.append([start, end])
                words_new.append(word)
            elif word.startswith("▁"):
                word = word[1:]
                timestamp_new.append([start, end])
                words_new.append(word)
            elif prev_word is not None and prev_word.isalpha() and prev_word.isascii() and word.isalpha() and word.isascii():
                word = prev_word + word
                timestamp_new[-1][1] = end
                words_new[-1] = word
            else:
                # timestamp_new[-1][0] += word
                timestamp_new.append([start, end])
View full source on GitHub →
method

SenseVoiceSmall.export(**kwargs)

funasr.models.sense_voice.SenseVoiceSmall · View on GitHub ↗

Export.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def export(self, **kwargs):
        """Export.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from .export_meta import export_rebuild_model

        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models

        return results, meta_data
class

Transducer

funasr.models.transducer · View on GitHub ↗
  • Transducer (RNN-T) — Streaming ASR using encoder-predictor-joint architecture.

Combines encoder (audio frames), predictor (text history), and joint network.

Supports beam search decoding. Suitable for low-latency streaming applications.

  • Output — {"key": str, "text": str}
📄 Source code
class Transducer(torch.nn.Module):
    """Transducer (RNN-T): Streaming ASR using encoder-predictor-joint architecture.

    Combines encoder (audio frames), predictor (text history), and joint network.
    Supports beam search decoding. Suitable for low-latency streaming applications.

    Output: {"key": str, "text": str}
    """

    def __init__(
        self,
        frontend: Optional[str] = None,
        frontend_conf: Optional[Dict] = None,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        decoder: str = None,
        decoder_conf: Optional[Dict] = None,
        joint_network: str = None,
        joint_network_conf: Optional[Dict] = None,
        transducer_weight: float = 1.0,
        fastemit_lambda: float = 0.0,
        auxiliary_ctc_weight: float = 0.0,
        auxiliary_ctc_dropout_rate: float = 0.0,
        auxiliary_lm_loss_weight: float = 0.0,
        auxiliary_lm_loss_smoothing: float = 0.0,
        input_size: int = 80,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L190

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if (
            hasattr(self.encoder, "overlap_chunk_cls")
            and self.encoder.overlap_chunk_cls is not None
        ):
            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(
                encoder_out, encoder_out_lens, chunk_outs=None
            )
View full source →
.encode(speech, speech_lengths, **kwargs) L276

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
View full source →
.init_beam_search(**kwargs) L455

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def init_beam_search(
        self,
        **kwargs,
    ):

        # 1. Build ASR model
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

        beam_search = BeamSearchTransducer(
            self.decoder,
            self.joint_network,
            kwargs.get("beam_size", 2),
View full source →
.inference(data_in, data_lengths, key, tokenizer, **kwargs) L493

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in: list,
        data_lengths: list = None,
        key: list = None,
        tokenizer=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        # if self.beam_search is None and (is_use_lm or is_use_ctc):
        logging.info("enable beam_search")
        self.init_beam_search(**kwargs)
        self.nbest = kwargs.get("nbest", 1)
View full source →
method

Transducer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.transducer.Transducer · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if (
            hasattr(self.encoder, "overlap_chunk_cls")
            and self.encoder.overlap_chunk_cls is not None
        ):
            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(
                encoder_out, encoder_out_lens, chunk_outs=None
            )
View full source on GitHub →
method

Transducer.encode(speech, speech_lengths, **kwargs)

funasr.models.transducer.Transducer · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
View full source on GitHub →
method

Transducer.init_beam_search(**kwargs)

funasr.models.transducer.Transducer · View on GitHub ↗

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_beam_search(
        self,
        **kwargs,
    ):

        # 1. Build ASR model
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

        beam_search = BeamSearchTransducer(
            self.decoder,
            self.joint_network,
            kwargs.get("beam_size", 2),
View full source on GitHub →
method

Transducer.inference(data_in, data_lengths, key, tokenizer, **kwargs)

funasr.models.transducer.Transducer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in: list,
        data_lengths: list = None,
        key: list = None,
        tokenizer=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        # if self.beam_search is None and (is_use_lm or is_use_ctc):
        logging.info("enable beam_search")
        self.init_beam_search(**kwargs)
        self.nbest = kwargs.get("nbest", 1)
View full source on GitHub →
class

Transformer

funasr.models.transformer · View on GitHub ↗
  • Transformer — Base encoder-decoder ASR model.

Standard CTC-attention hybrid architecture with:

  • Encoder (self-attention + position encoding)
  • CTC branch for auxiliary loss
  • Attention decoder for sequence generation
  • Beam search with LM fusion

Base class for Conformer, Branchformer, etc.

  • Output — {"key": str, "text": str}
📄 Source code
class Transformer(nn.Module):
    """Transformer: Base encoder-decoder ASR model.

    Standard CTC-attention hybrid architecture with:
    - Encoder (self-attention + position encoding)
    - CTC branch for auxiliary loss
    - Attention decoder for sequence generation
    - Beam search with LM fusion

    Base class for Conformer, Branchformer, etc.
    Output: {"key": str, "text": str}
    """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L173

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]

        loss_att, acc_att, cer_att, wer_att = None, None, None, None
View full source →
.encode(speech, speech_lengths, **kwargs) L266

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
View full source →
.init_beam_search(**kwargs) L368

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L418

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

        meta_data = {}
View full source →
method

Transformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.transformer.Transformer · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]

        loss_att, acc_att, cer_att, wer_att = None, None, None, None
View full source on GitHub →
method

Transformer.encode(speech, speech_lengths, **kwargs)

funasr.models.transformer.Transformer · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
View full source on GitHub →
method

Transformer.init_beam_search(**kwargs)

funasr.models.transformer.Transformer · View on GitHub ↗

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

View full source on GitHub →
method

Transformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.transformer.Transformer · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

        meta_data = {}
View full source on GitHub →
class

UniASR

funasr.models.uniasr · View on GitHub ↗
  • UniASR — Unified Streaming and Non-streaming ASR model.

Single model that supports both streaming (online) and non-streaming

(offline) decoding through dynamic masking in the encoder.

Inherits Paraformer pipeline.

  • Output — {"key": str, "text": str}
📄 Source code
class UniASR(torch.nn.Module):
    """UniASR: Unified Streaming and Non-streaming ASR model.

    Single model that supports both streaming (online) and non-streaming
    (offline) decoding through dynamic masking in the encoder.

    Inherits Paraformer pipeline.
    Output: {"key": str, "text": str}
    """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        encoder2: str = None,
        encoder2_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        decoder2: str = None,
        decoder2_conf: dict = None,
        predictor: str = None,
        predictor_conf: dict = None,
        predictor_bias: int = 0,
        predictor_weight: float = 0.0,
        predictor2: str = None,
        predictor2_conf: dict = None,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L235

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
                        text: (Batch, Length)
                        text_lengths: (Batch,)
        """
        decoding_ind = kwargs.get("decoding_ind", None)
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
        # 1. Encoder
        if self.enable_maas_finetune:
            with torch.no_grad():
                speech_raw, encoder_out, encoder_out_lens = self.encode(
                    speech, speech_lengths, ind=ind
                )
View full source →
.collect_feats(speech, speech_lengths, text, text_lengths) L354

Collect feats.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • text — Text tensor or string input.
  • text_lengths — Length of each text sample.
📄 Source
    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        """Collect feats.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                text: Text tensor or string input.
                text_lengths: Length of each text sample.
            """
        if self.extract_feats_in_collect_stats:
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        else:
            # Generate dummy stats if extract_feats_in_collect_stats is False
            logging.warning(
                "Generating dummy stats for feats and feats_lengths, "
                "because encoder_conf.extract_feats_in_collect_stats is "
                f"{self.extract_feats_in_collect_stats}"
            )
            feats, feats_lengths = speech, speech_lengths
        return {"feats": feats, "feats_lengths": feats_lengths}
.encode(speech, speech_lengths, **kwargs) L381

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
        """
        ind = kwargs.get("ind", 0)
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        speech_raw = speech.clone().to(speech.device)

        # 4. Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return speech_raw, encoder_out, encoder_out_lens
.encode2(encoder_out, encoder_out_lens, speech, speech_lengths, **kwargs) L411

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
📄 Source
    def encode2(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
        """

        ind = kwargs.get("ind", 0)
        encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
            encoder_out,
            encoder_out_lens,
            chunk_outs=None,
        )
        # residual_input
        encoder_out = torch.cat((speech, encoder_out_rm), dim=-1)
        encoder_out_lens = encoder_out_lens_rm
        if self.stride_conv is not None:
            speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens)
        if not self.encoder1_encoder2_joint_training:
            speech = speech.detach()
            speech_lengths = speech_lengths.detach()
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
View full source →
.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) L449

Compute negative log likelihood(nll) from transformer-decoder

Normally, this function is called in batchify_nll.

Args:

  • encoder_out — (Batch, Length, Dim)
  • encoder_out_lens — (Batch,)
  • ys_pad — (Batch, Length)
  • ys_pad_lens — (Batch,)
📄 Source
    def nll(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ) -> torch.Tensor:
        """Compute negative log likelihood(nll) from transformer-decoder
        Normally, this function is called in batchify_nll.
        Args:
                        encoder_out: (Batch, Length, Dim)
                        encoder_out_lens: (Batch,)
                        ys_pad: (Batch, Length)
                        ys_pad_lens: (Batch,)
        """
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1

        # 1. Forward decoder
        decoder_out, _ = self.decoder(
            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
        )  # [batch, seqlen, dim]
        batch_size = decoder_out.size(0)
        decoder_num_class = decoder_out.size(2)
        # nll: negative log-likelihood
        nll = torch.nn.functional.cross_entropy(
            decoder_out.view(-1, decoder_num_class),
            ys_out_pad.view(-1),
            ignore_index=self.ignore_id,
            reduction="none",
View full source →
.batchify_nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, batch_size) L485

Compute negative log likelihood(nll) from transformer-decoder

To avoid OOM, this fuction seperate the input into batches.

Then call nll for each batch and combine and return results.

Args:

  • encoder_out — (Batch, Length, Dim)
  • encoder_out_lens — (Batch,)
  • ys_pad — (Batch, Length)
  • ys_pad_lens — (Batch,)
  • batch_size — int, samples each batch contain when computing nll,
  • you may change this to avoid OOM or increase
  • GPU memory usage
📄 Source
    def batchify_nll(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        batch_size: int = 100,
    ):
        """Compute negative log likelihood(nll) from transformer-decoder
        To avoid OOM, this fuction seperate the input into batches.
        Then call nll for each batch and combine and return results.
        Args:
                        encoder_out: (Batch, Length, Dim)
                        encoder_out_lens: (Batch,)
                        ys_pad: (Batch, Length)
                        ys_pad_lens: (Batch,)
                        batch_size: int, samples each batch contain when computing nll,
                                                                        you may change this to avoid OOM or increase
                                                                        GPU memory usage
        """
        total_num = encoder_out.size(0)
        if total_num <= batch_size:
            nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
        else:
            nll = []
            start_idx = 0
            while True:
                end_idx = min(start_idx + batch_size, total_num)
                batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
                batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
View full source →
.calc_predictor_mask(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) L787

Calc predictor mask.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source
    def calc_predictor_mask(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor = None,
        ys_pad_lens: torch.Tensor = None,
    ):
        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # ys_in_lens = ys_pad_lens + 1
        """Calc predictor mask.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        ys_out_pad, ys_in_lens = None, None

        encoder_out_mask = sequence_mask(
            encoder_out_lens,
            maxlen=encoder_out.size(1),
            dtype=encoder_out.dtype,
            device=encoder_out.device,
        )[:, None, :]
        mask_chunk_predictor = None
        if self.encoder.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
                None, device=encoder_out.device, batch_size=encoder_out.size(0)
            )
View full source →
.calc_predictor_mask2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) L877

Calc predictor mask2.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source
    def calc_predictor_mask2(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor = None,
        ys_pad_lens: torch.Tensor = None,
    ):
        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # ys_in_lens = ys_pad_lens + 1
        """Calc predictor mask2.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        ys_out_pad, ys_in_lens = None, None

        encoder_out_mask = sequence_mask(
            encoder_out_lens,
            maxlen=encoder_out.size(1),
            dtype=encoder_out.dtype,
            device=encoder_out.device,
        )[:, None, :]
        mask_chunk_predictor = None
        if self.encoder2.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(
                None, device=encoder_out.device, batch_size=encoder_out.size(0)
            )
View full source →
.init_beam_search(**kwargs) L967

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.uniasr.beam_search import BeamSearchScama
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        decoding_mode = kwargs.get("decoding_mode", "model1")
        if decoding_mode == "model1":
            decoder = self.decoder
        else:
            decoder = self.decoder2
        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L1022

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        decoding_model = kwargs.get("decoding_model", "normal")
        token_num_relax = kwargs.get("token_num_relax", 5)
        if decoding_model == "fast":
            decoding_ind = 0
            decoding_mode = "model1"
        elif decoding_model == "offline":
            decoding_ind = 1
            decoding_mode = "model2"
        else:
            decoding_ind = 0
View full source →
method

UniASR.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.uniasr.UniASR · View on GitHub ↗

Frontend + Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
                        text: (Batch, Length)
                        text_lengths: (Batch,)
        """
        decoding_ind = kwargs.get("decoding_ind", None)
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
        # 1. Encoder
        if self.enable_maas_finetune:
            with torch.no_grad():
                speech_raw, encoder_out, encoder_out_lens = self.encode(
                    speech, speech_lengths, ind=ind
                )
View full source on GitHub →
method

UniASR.collect_feats(speech, speech_lengths, text, text_lengths)

funasr.models.uniasr.UniASR · View on GitHub ↗

Collect feats.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • text — Text tensor or string input.
  • text_lengths — Length of each text sample.
📄 Source code
    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        """Collect feats.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                text: Text tensor or string input.
                text_lengths: Length of each text sample.
            """
        if self.extract_feats_in_collect_stats:
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        else:
            # Generate dummy stats if extract_feats_in_collect_stats is False
            logging.warning(
                "Generating dummy stats for feats and feats_lengths, "
                "because encoder_conf.extract_feats_in_collect_stats is "
                f"{self.extract_feats_in_collect_stats}"
            )
            feats, feats_lengths = speech, speech_lengths
        return {"feats": feats, "feats_lengths": feats_lengths}
method

UniASR.encode(speech, speech_lengths, **kwargs)

funasr.models.uniasr.UniASR · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
        """
        ind = kwargs.get("ind", 0)
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        speech_raw = speech.clone().to(speech.device)

        # 4. Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]

        return speech_raw, encoder_out, encoder_out_lens
method

UniASR.encode2(encoder_out, encoder_out_lens, speech, speech_lengths, **kwargs)

funasr.models.uniasr.UniASR · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
📄 Source code
    def encode2(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
        """

        ind = kwargs.get("ind", 0)
        encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
            encoder_out,
            encoder_out_lens,
            chunk_outs=None,
        )
        # residual_input
        encoder_out = torch.cat((speech, encoder_out_rm), dim=-1)
        encoder_out_lens = encoder_out_lens_rm
        if self.stride_conv is not None:
            speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens)
        if not self.encoder1_encoder2_joint_training:
            speech = speech.detach()
            speech_lengths = speech_lengths.detach()
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
View full source on GitHub →
method

UniASR.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)

funasr.models.uniasr.UniASR · View on GitHub ↗

Compute negative log likelihood(nll) from transformer-decoder

Normally, this function is called in batchify_nll.

Args:

  • encoder_out — (Batch, Length, Dim)
  • encoder_out_lens — (Batch,)
  • ys_pad — (Batch, Length)
  • ys_pad_lens — (Batch,)
📄 Source code
    def nll(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ) -> torch.Tensor:
        """Compute negative log likelihood(nll) from transformer-decoder
        Normally, this function is called in batchify_nll.
        Args:
                        encoder_out: (Batch, Length, Dim)
                        encoder_out_lens: (Batch,)
                        ys_pad: (Batch, Length)
                        ys_pad_lens: (Batch,)
        """
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1

        # 1. Forward decoder
        decoder_out, _ = self.decoder(
            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
        )  # [batch, seqlen, dim]
        batch_size = decoder_out.size(0)
        decoder_num_class = decoder_out.size(2)
        # nll: negative log-likelihood
        nll = torch.nn.functional.cross_entropy(
            decoder_out.view(-1, decoder_num_class),
            ys_out_pad.view(-1),
            ignore_index=self.ignore_id,
            reduction="none",
View full source on GitHub →
method

UniASR.batchify_nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, batch_size)

funasr.models.uniasr.UniASR · View on GitHub ↗

Compute negative log likelihood(nll) from transformer-decoder

To avoid OOM, this fuction seperate the input into batches.

Then call nll for each batch and combine and return results.

Args:

  • encoder_out — (Batch, Length, Dim)
  • encoder_out_lens — (Batch,)
  • ys_pad — (Batch, Length)
  • ys_pad_lens — (Batch,)
  • batch_size — int, samples each batch contain when computing nll,
  • you may change this to avoid OOM or increase
  • GPU memory usage
📄 Source code
    def batchify_nll(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        batch_size: int = 100,
    ):
        """Compute negative log likelihood(nll) from transformer-decoder
        To avoid OOM, this fuction seperate the input into batches.
        Then call nll for each batch and combine and return results.
        Args:
                        encoder_out: (Batch, Length, Dim)
                        encoder_out_lens: (Batch,)
                        ys_pad: (Batch, Length)
                        ys_pad_lens: (Batch,)
                        batch_size: int, samples each batch contain when computing nll,
                                                                        you may change this to avoid OOM or increase
                                                                        GPU memory usage
        """
        total_num = encoder_out.size(0)
        if total_num <= batch_size:
            nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
        else:
            nll = []
            start_idx = 0
            while True:
                end_idx = min(start_idx + batch_size, total_num)
                batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
                batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
View full source on GitHub →
method

UniASR.calc_predictor_mask(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)

funasr.models.uniasr.UniASR · View on GitHub ↗

Calc predictor mask.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source code
    def calc_predictor_mask(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor = None,
        ys_pad_lens: torch.Tensor = None,
    ):
        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # ys_in_lens = ys_pad_lens + 1
        """Calc predictor mask.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        ys_out_pad, ys_in_lens = None, None

        encoder_out_mask = sequence_mask(
            encoder_out_lens,
            maxlen=encoder_out.size(1),
            dtype=encoder_out.dtype,
            device=encoder_out.device,
        )[:, None, :]
        mask_chunk_predictor = None
        if self.encoder.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
                None, device=encoder_out.device, batch_size=encoder_out.size(0)
            )
View full source on GitHub →
method

UniASR.calc_predictor_mask2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)

funasr.models.uniasr.UniASR · View on GitHub ↗

Calc predictor mask2.

Args:

  • encoder_out — Encoder output tensor.
  • encoder_out_lens — Encoder output lengths.
  • ys_pad — TODO.
  • ys_pad_lens — Lengths of ys_pad.
📄 Source code
    def calc_predictor_mask2(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor = None,
        ys_pad_lens: torch.Tensor = None,
    ):
        # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # ys_in_lens = ys_pad_lens + 1
        """Calc predictor mask2.
        
            Args:
                encoder_out: Encoder output tensor.
                encoder_out_lens: Encoder output lengths.
                ys_pad: TODO.
                ys_pad_lens: Lengths of ys_pad.
            """
        ys_out_pad, ys_in_lens = None, None

        encoder_out_mask = sequence_mask(
            encoder_out_lens,
            maxlen=encoder_out.size(1),
            dtype=encoder_out.dtype,
            device=encoder_out.device,
        )[:, None, :]
        mask_chunk_predictor = None
        if self.encoder2.overlap_chunk_cls is not None:
            mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(
                None, device=encoder_out.device, batch_size=encoder_out.size(0)
            )
View full source on GitHub →
method

UniASR.init_beam_search(**kwargs)

funasr.models.uniasr.UniASR · View on GitHub ↗

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.uniasr.beam_search import BeamSearchScama
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        decoding_mode = kwargs.get("decoding_mode", "model1")
        if decoding_mode == "model1":
            decoder = self.decoder
        else:
            decoder = self.decoder2
        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

View full source on GitHub →
method

UniASR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.uniasr.UniASR · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        decoding_model = kwargs.get("decoding_model", "normal")
        token_num_relax = kwargs.get("token_num_relax", 5)
        if decoding_model == "fast":
            decoding_ind = 0
            decoding_mode = "model1"
        elif decoding_model == "offline":
            decoding_ind = 1
            decoding_mode = "model2"
        else:
            decoding_ind = 0
View full source on GitHub →
class

WhisperWarp

funasr.models.whisper · View on GitHub ↗
  • Whisper — OpenAI Whisper model integration.

Wraps Whisper for multilingual speech recognition and translation

within FunASR's AutoModel interface.

  • Supports — whisper-tiny through whisper-large-v3-turbo.
  • Output — {"key": str, "text": str}
📄 Source code
class WhisperWarp(nn.Module):
    """Whisper: OpenAI Whisper model integration.

    Wraps Whisper for multilingual speech recognition and translation
    within FunASR's AutoModel interface.

    Supports: whisper-tiny through whisper-large-v3-turbo.
    Output: {"key": str, "text": str}
    """

    def __init__(self, *args, **kwargs):
        """Initialize WhisperWarp.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        hub = kwargs.get("hub", "funasr")
        if hub == "openai":
            model_or_path = kwargs.get("model_path", "Whisper-large-v3")
            if model_or_path.startswith("Whisper-"):
                model_or_path = model_or_path.replace("Whisper-", "")
            model = whisper.load_model(model_or_path)
        else:
            dims = kwargs.get("dims", {})
            dims = whisper.model.ModelDimensions(**dims)
            model = whisper.model.Whisper(dims=dims)

        self.model = model
View full source on GitHub →

Methods

.forward() L66

Forward pass for training.

📄 Source
    def forward(
        self,
    ):
        """Forward pass for training."""
        pass
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L72

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(
                n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
            )
            self.frontend = frontend
        else:
            frontend = frontend if frontend is not None else self.frontend
View full source →
method

WhisperWarp.forward()

funasr.models.whisper.WhisperWarp · View on GitHub ↗

Forward pass for training.

📄 Source code
    def forward(
        self,
    ):
        """Forward pass for training."""
        pass
method

WhisperWarp.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.whisper.WhisperWarp · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(
                n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
            )
            self.frontend = frontend
        else:
            frontend = frontend if frontend is not None else self.frontend
View full source on GitHub →
class

OpenAIWhisperModel

funasr.models.whisper_lid · View on GitHub ↗
  • Whisper — LID: Language Identification using OpenAI Whisper encoder.

Detects spoken language from audio input. Supports 99 languages.

Uses Whisper encoder features with a classification head.

  • Output — {"key": str, "text": str (language code)}
📄 Source code
class OpenAIWhisperModel(nn.Module):
    """Whisper-LID: Language Identification using OpenAI Whisper encoder.

    Detects spoken language from audio input. Supports 99 languages.
    Uses Whisper encoder features with a classification head.

    Output: {"key": str, "text": str (language code)}
    """

    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
View full source on GitHub →

Methods

.forward(speech, speech_lengths, text, text_lengths, **kwargs) L164

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]

        loss_att, acc_att, cer_att, wer_att = None, None, None, None
View full source →
.encode(speech, speech_lengths, **kwargs) L257

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
View full source →
.init_beam_search(**kwargs) L359

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L409

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

        meta_data = {}
View full source →
method

OpenAIWhisperModel.forward(speech, speech_lengths, text, text_lengths, **kwargs)

funasr.models.whisper_lid.OpenAIWhisperModel · View on GitHub ↗

Encoder + Decoder + Calc loss

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • text — (Batch, Length)
  • text_lengths — (Batch,)
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]

        batch_size = speech.shape[0]

        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]

        loss_att, acc_att, cer_att, wer_att = None, None, None, None
View full source on GitHub →
method

OpenAIWhisperModel.encode(speech, speech_lengths, **kwargs)

funasr.models.whisper_lid.OpenAIWhisperModel · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
  • ind — int
📄 Source code
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
View full source on GitHub →
method

OpenAIWhisperModel.init_beam_search(**kwargs)

funasr.models.whisper_lid.OpenAIWhisperModel · View on GitHub ↗

Init beam search.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
    def init_beam_search(
        self,
        **kwargs,
    ):
        """Init beam search.
        
            Args:
                **kwargs: Additional keyword arguments.
            """
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus

        # 1. Build ASR model
        scorers = {}

        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )

        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram

View full source on GitHub →
method

OpenAIWhisperModel.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.whisper_lid.OpenAIWhisperModel · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)

        meta_data = {}
View full source on GitHub →
class

OpenAIWhisperLIDModel

funasr.models.whisper_lid · View on GitHub ↗

WhisperEncoder and EResNet based LID Model

📄 Source code
class OpenAIWhisperLIDModel(nn.Module):
    """WhisperEncoder and EResNet based LID Model"""

    def __init__(
        self,
        vocab_size: int,
        specaug: str = None,
        specaug_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        lid_predictor: str = None,
        lid_predictor_conf: dict = None,
        proj_dim: int = None,
        clip_frames: int = None,
        random_clip: bool = False,
        **kwargs,
    ):
        """Initialize OpenAIWhisperLIDModel.
        
            Args:
                vocab_size: Size/dimension parameter.
                specaug: TODO.
                specaug_conf: Configuration dict for specaug.
                encoder: TODO.
                encoder_conf: Configuration dict for encoder.
                lid_predictor: TODO.
                lid_predictor_conf: Configuration dict for lid_predictor.
                proj_dim: Size/dimension parameter.
                clip_frames: TODO.
                random_clip: TODO.
View full source on GitHub →

Methods

.forward(speech, speech_lengths, lid, lid_lengths) L587

Forward pass for training.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • lid — TODO.
  • lid_lengths — Lengths of lid.
📄 Source
    def forward(
        self,
        speech: torch.Tensor,  # may be padding
        speech_lengths: torch.Tensor,  # actual length
        lid: torch.Tensor,  # lid label, (batch_size, 1)
        lid_lengths: torch.Tensor,
    ):
        """Forward pass for training.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                lid: TODO.
                lid_lengths: Lengths of lid.
            """
        assert lid.shape[1] == 1
        batch_size = speech.shape[0]
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        # re-generate encoder_out
        if self.clip_frames is None:
            reduced_encoder_out = (
                torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1])
                .to(encoder_out.dtype)
                .to(encoder_out.device)
            )
            for i, enc_length in enumerate(encoder_out_lens):
                reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
        else:
            reduced_encoder_out = (
View full source →
.encode(speech, speech_lengths) L655

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
📄 Source
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech = speech.permute(0, 2, 1)
                # suit for whisper padding
                padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
                speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
                speech = speech.permute(0, 2, 1)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
View full source →
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L694

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
View full source →
method

OpenAIWhisperLIDModel.forward(speech, speech_lengths, lid, lid_lengths)

funasr.models.whisper_lid.OpenAIWhisperLIDModel · View on GitHub ↗

Forward pass for training.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • lid — TODO.
  • lid_lengths — Lengths of lid.
📄 Source code
    def forward(
        self,
        speech: torch.Tensor,  # may be padding
        speech_lengths: torch.Tensor,  # actual length
        lid: torch.Tensor,  # lid label, (batch_size, 1)
        lid_lengths: torch.Tensor,
    ):
        """Forward pass for training.
        
            Args:
                speech: Speech audio tensor, shape (batch, time).
                speech_lengths: Length of each speech sample.
                lid: TODO.
                lid_lengths: Lengths of lid.
            """
        assert lid.shape[1] == 1
        batch_size = speech.shape[0]
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)

        # re-generate encoder_out
        if self.clip_frames is None:
            reduced_encoder_out = (
                torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1])
                .to(encoder_out.dtype)
                .to(encoder_out.device)
            )
            for i, enc_length in enumerate(encoder_out_lens):
                reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
        else:
            reduced_encoder_out = (
View full source on GitHub →
method

OpenAIWhisperLIDModel.encode(speech, speech_lengths)

funasr.models.whisper_lid.OpenAIWhisperLIDModel · View on GitHub ↗

Frontend + Encoder. Note that this method is used by asr_inference.py

Args:

  • speech — (Batch, Length, ...)
  • speech_lengths — (Batch, )
📄 Source code
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
        """
        with autocast(False):

            # Data augmentation
            if self.specaug is not None and self.training:
                speech = speech.permute(0, 2, 1)
                # suit for whisper padding
                padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
                speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
                speech = speech.permute(0, 2, 1)

            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)

        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
View full source on GitHub →
method

OpenAIWhisperLIDModel.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)

funasr.models.whisper_lid.OpenAIWhisperLIDModel · View on GitHub ↗

Run inference on input data.

Args:

  • data_in — Input data (audio samples, file paths, or text).
  • data_lengths — Lengths of each input sample in the batch.
  • key — Sample identifiers.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • frontend — Audio frontend for feature extraction.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):

        """Run inference on input data.
        
            Args:
                data_in: Input data (audio samples, file paths, or text).
                data_lengths: Lengths of each input sample in the batch.
                key: Sample identifiers.
                tokenizer: Tokenizer instance for text encoding/decoding.
                frontend: Audio frontend for feature extraction.
                **kwargs: Additional keyword arguments.
            """
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")

        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
View full source on GitHub →
class

DefaultFrontend

funasr.frontends.default · View on GitHub ↗

Conventional frontend structure for ASR.

  • Stft — > WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
📄 Source code
class DefaultFrontend(nn.Module):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
    """

    def __init__(
        self,
        fs: int = 16000,
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = 128,
        window: Optional[str] = "hann",
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
        n_mels: int = 80,
        fmin: int = None,
        fmax: int = None,
        htk: bool = False,
        frontend_conf: Optional[dict] = None,
        apply_stft: bool = True,
        use_channel: int = None,
        **kwargs,
    ):
        """Initialize DefaultFrontend.
        
            Args:
                fs: TODO.
                n_fft: TODO.
                win_length: TODO.
View full source on GitHub →

Methods

.output_size() L106

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels
.forward(input, input_lengths) L110

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source
    def forward(
        self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        if isinstance(input_lengths, list):
            input_lengths = torch.tensor(input_lengths)
        if input.dtype == torch.float64:
            input = input.float()
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        if self.stft is not None:
            input_stft, feats_lens = self._compute_stft(input, input_lengths)
        else:
            input_stft = ComplexTensor(input[..., 0], input[..., 1])
            feats_lens = input_lengths
        # 2. [Option] Speech enhancement
        if self.frontend is not None:
            assert isinstance(input_stft, ComplexTensor), type(input_stft)
            # input_stft: (Batch, Length, [Channel], Freq)
            input_stft, _, mask = self.frontend(input_stft, feats_lens)

        # 3. [Multi channel case]: Select a channel
        if input_stft.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                if self.use_channel is not None:
View full source →
method

DefaultFrontend.output_size()

funasr.frontends.default.DefaultFrontend · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels
method

DefaultFrontend.forward(input, input_lengths)

funasr.frontends.default.DefaultFrontend · View on GitHub ↗

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source code
    def forward(
        self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        if isinstance(input_lengths, list):
            input_lengths = torch.tensor(input_lengths)
        if input.dtype == torch.float64:
            input = input.float()
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        if self.stft is not None:
            input_stft, feats_lens = self._compute_stft(input, input_lengths)
        else:
            input_stft = ComplexTensor(input[..., 0], input[..., 1])
            feats_lens = input_lengths
        # 2. [Option] Speech enhancement
        if self.frontend is not None:
            assert isinstance(input_stft, ComplexTensor), type(input_stft)
            # input_stft: (Batch, Length, [Channel], Freq)
            input_stft, _, mask = self.frontend(input_stft, feats_lens)

        # 3. [Multi channel case]: Select a channel
        if input_stft.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                if self.use_channel is not None:
View full source on GitHub →
class

MultiChannelFrontend

funasr.frontends.default · View on GitHub ↗

Conventional frontend structure for ASR.

  • Stft — > WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
📄 Source code
class MultiChannelFrontend(nn.Module):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
    """

    def __init__(
        self,
        fs: int = 16000,
        n_fft: int = 512,
        win_length: int = None,
        hop_length: int = None,
        frame_length: int = None,
        frame_shift: int = None,
        window: Optional[str] = "hann",
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
        n_mels: int = 80,
        fmin: int = None,
        fmax: int = None,
        htk: bool = False,
        frontend_conf: Optional[dict] = None,
        apply_stft: bool = True,
        use_channel: int = None,
        lfr_m: int = 1,
        lfr_n: int = 1,
        cmvn_file: str = None,
        mc: bool = True,
    ):
        """Initialize MultiChannelFrontend.
View full source on GitHub →

Methods

.output_size() L290

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels
.forward(input, input_lengths) L294

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        if self.stft is not None:
            input_stft, feats_lens = self._compute_stft(input, input_lengths)
        else:
            input_stft = ComplexTensor(input[..., 0], input[..., 1])
            feats_lens = input_lengths
        # 2. [Option] Speech enhancement
        if self.frontend is not None:
            assert isinstance(input_stft, ComplexTensor), type(input_stft)
            # input_stft: (Batch, Length, [Channel], Freq)
            input_stft, _, mask = self.frontend(input_stft, feats_lens)

        # 3. [Multi channel case]: Select a channel(sa_asr)
        if input_stft.dim() == 4 and not self.mc:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                if self.use_channel is not None:
                    input_stft = input_stft[:, :, self.use_channel, :]

                else:
                    # Select 1ch randomly
View full source →
method

MultiChannelFrontend.output_size()

funasr.frontends.default.MultiChannelFrontend · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels
method

MultiChannelFrontend.forward(input, input_lengths)

funasr.frontends.default.MultiChannelFrontend · View on GitHub ↗

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source code
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        if self.stft is not None:
            input_stft, feats_lens = self._compute_stft(input, input_lengths)
        else:
            input_stft = ComplexTensor(input[..., 0], input[..., 1])
            feats_lens = input_lengths
        # 2. [Option] Speech enhancement
        if self.frontend is not None:
            assert isinstance(input_stft, ComplexTensor), type(input_stft)
            # input_stft: (Batch, Length, [Channel], Freq)
            input_stft, _, mask = self.frontend(input_stft, feats_lens)

        # 3. [Multi channel case]: Select a channel(sa_asr)
        if input_stft.dim() == 4 and not self.mc:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                if self.use_channel is not None:
                    input_stft = input_stft[:, :, self.use_channel, :]

                else:
                    # Select 1ch randomly
View full source on GitHub →
function

transform(Y, dtype)

funasr.frontends.eend_ola_feature · View on GitHub ↗

Transform.

Args:

  • Y — TODO.
  • dtype — TODO.
📄 Source code
def transform(Y, dtype=np.float32):
    """Transform.
    
        Args:
            Y: TODO.
            dtype: TODO.
        """
    Y = np.abs(Y)
    n_fft = 2 * (Y.shape[1] - 1)
    sr = 8000
    n_mels = 23
    mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
    Y = np.dot(Y**2, mel_basis.T)
    Y = np.log10(np.maximum(Y, 1e-10))
    mean = np.mean(Y, axis=0)
    Y = Y - mean
    return Y.astype(dtype)
function

subsample(Y, T, subsampling)

funasr.frontends.eend_ola_feature · View on GitHub ↗

Subsample.

Args:

  • Y — TODO.
  • T — TODO.
  • subsampling — TODO.
📄 Source code
def subsample(Y, T, subsampling=1):
    """Subsample.
    
        Args:
            Y: TODO.
            T: TODO.
            subsampling: TODO.
        """
    Y_ss = Y[::subsampling]
    T_ss = T[::subsampling]
    return Y_ss, T_ss
function

splice(Y, context_size)

funasr.frontends.eend_ola_feature · View on GitHub ↗

Splice.

Args:

  • Y — TODO.
  • context_size — Size/dimension parameter.
📄 Source code
def splice(Y, context_size=0):
    """Splice.
    
        Args:
            Y: TODO.
            context_size: Size/dimension parameter.
        """
    Y_pad = np.pad(Y, [(context_size, context_size), (0, 0)], "constant")
    Y_spliced = np.lib.stride_tricks.as_strided(
        np.ascontiguousarray(Y_pad),
        (Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
        (Y.itemsize * Y.shape[1], Y.itemsize),
        writeable=False,
    )
    return Y_spliced
function

stft(data, frame_size, frame_shift)

funasr.frontends.eend_ola_feature · View on GitHub ↗

Stft.

Args:

  • data — TODO.
  • frame_size — Size/dimension parameter.
  • frame_shift — TODO.
📄 Source code
def stft(data, frame_size=1024, frame_shift=256):
    """Stft.
    
        Args:
            data: TODO.
            frame_size: Size/dimension parameter.
            frame_shift: TODO.
        """
    fft_size = 1 << (frame_size - 1).bit_length()
    if len(data) % frame_shift == 0:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift).T[
            :-1
        ]
    else:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift).T
class

FusedFrontends

funasr.frontends.fused · View on GitHub ↗

No documentation yet.

📄 Source code
class FusedFrontends(nn.Module):
    def __init__(self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000):

        """Initialize FusedFrontends.
        
            Args:
                frontends: TODO.
                align_method: TODO.
                proj_dim: Size/dimension parameter.
                fs: TODO.
            """
        super().__init__()
        self.align_method = align_method  # fusing method : linear_projection only for now
        self.proj_dim = proj_dim  # dim of the projection done on each frontend
        self.frontends = []  # list of the frontends to combine

        for i, frontend in enumerate(frontends):
            frontend_type = frontend["frontend_type"]
            if frontend_type == "default":
                n_mels, fs, n_fft, win_length, hop_length = (
                    frontend.get("n_mels", 80),
                    fs,
                    frontend.get("n_fft", 512),
                    frontend.get("win_length"),
                    frontend.get("hop_length", 128),
                )
                window, center, normalized, onesided = (
                    frontend.get("window", "hann"),
                    frontend.get("center", True),
                    frontend.get("normalized", False),
View full source on GitHub →

Methods

.output_size() L106

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return len(self.frontends) * self.proj_dim
.forward(input, input_lengths) L110

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        # step 0 : get all frontends features
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        self.feats = []
        for frontend in self.frontends:
            with torch.no_grad():
                input_feats, feats_lens = frontend.forward(input, input_lengths)
            self.feats.append([input_feats, feats_lens])

        if self.align_method == "linear_projection":  # TODO(Dan): to add other align methods

            # first step : projections
            self.feats_proj = []
            for i, frontend in enumerate(self.frontends):
                input_feats = self.feats[i][0]
                self.feats_proj.append(self.projection_layers[i](input_feats))

            # 2nd step : reshape
            self.feats_reshaped = []
            for i, frontend in enumerate(self.frontends):
                input_feats_proj = self.feats_proj[i]
                bs, nf, dim = input_feats_proj.shape
View full source →
method

FusedFrontends.output_size()

funasr.frontends.fused.FusedFrontends · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return len(self.frontends) * self.proj_dim
method

FusedFrontends.forward(input, input_lengths)

funasr.frontends.fused.FusedFrontends · View on GitHub ↗

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source code
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        # step 0 : get all frontends features
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        self.feats = []
        for frontend in self.frontends:
            with torch.no_grad():
                input_feats, feats_lens = frontend.forward(input, input_lengths)
            self.feats.append([input_feats, feats_lens])

        if self.align_method == "linear_projection":  # TODO(Dan): to add other align methods

            # first step : projections
            self.feats_proj = []
            for i, frontend in enumerate(self.frontends):
                input_feats = self.feats[i][0]
                self.feats_proj.append(self.projection_layers[i](input_feats))

            # 2nd step : reshape
            self.feats_reshaped = []
            for i, frontend in enumerate(self.frontends):
                input_feats_proj = self.feats_proj[i]
                bs, nf, dim = input_feats_proj.shape
View full source on GitHub →
function

base_s3prl_setup(args)

funasr.frontends.s3prl · View on GitHub ↗

Base s3prl setup.

Args:

  • args — TODO.
📄 Source code
def base_s3prl_setup(args):
    """Base s3prl setup.
    
        Args:
            args: TODO.
        """
    args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
    args.upstream_model_config = getattr(args, "upstream_model_config", None)
    args.upstream_refresh = getattr(args, "upstream_refresh", False)
    args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
    args.init_ckpt = getattr(args, "init_ckpt", None)
    args.verbose = getattr(args, "verbose", False)
    args.tile_factor = getattr(args, "tile_factor", 1)
    return args
class

S3prlFrontend

funasr.frontends.s3prl · View on GitHub ↗

Speech Pretrained Representation frontend structure for ASR.

📄 Source code
class S3prlFrontend(nn.Module):
    """Speech Pretrained Representation frontend structure for ASR."""

    def __init__(
        self,
        fs: Union[int, str] = 16000,
        frontend_conf: Optional[dict] = None,
        download_dir: str = None,
        multilayer_feature: bool = False,
    ):
        """Initialize S3prlFrontend.
        
            Args:
                fs: TODO.
                frontend_conf: Configuration dict for frontend.
                download_dir: TODO.
                multilayer_feature: TODO.
            """
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)

        if download_dir is not None:
            torch.hub.set_dir(download_dir)

        self.multilayer_feature = multilayer_feature
        self.upstream, self.featurizer = self._get_upstream(frontend_conf)
        self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
        self.output_dim = self.featurizer.output_dim
        self.frontend_type = "s3prl"
View full source on GitHub →

Methods

.output_size() L127

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return self.output_dim
.forward(input, input_lengths) L131

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
        self.upstream.eval()
        with torch.no_grad():
            feats = self.upstream(wavs)
        feats = self.featurizer(wavs, feats)

        if self.args.tile_factor != 1:
            feats = self._tile_representations(feats)

        input_feats = pad_list(feats, 0.0)
        feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)

        # Saving CUDA Memory
        del feats

        return input_feats, feats_lens
.reload_pretrained_parameters() L157

Reload pretrained parameters.

📄 Source
    def reload_pretrained_parameters(self):
        """Reload pretrained parameters."""
        self.upstream.load_state_dict(self.pretrained_params)
        logging.info("Pretrained S3PRL frontend model parameters reloaded!")
method

S3prlFrontend.output_size()

funasr.frontends.s3prl.S3prlFrontend · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return self.output_dim
method

S3prlFrontend.forward(input, input_lengths)

funasr.frontends.s3prl.S3prlFrontend · View on GitHub ↗

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source code
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
        self.upstream.eval()
        with torch.no_grad():
            feats = self.upstream(wavs)
        feats = self.featurizer(wavs, feats)

        if self.args.tile_factor != 1:
            feats = self._tile_representations(feats)

        input_feats = pad_list(feats, 0.0)
        feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)

        # Saving CUDA Memory
        del feats

        return input_feats, feats_lens
method

S3prlFrontend.reload_pretrained_parameters()

funasr.frontends.s3prl.S3prlFrontend · View on GitHub ↗

Reload pretrained parameters.

📄 Source code
    def reload_pretrained_parameters(self):
        """Reload pretrained parameters."""
        self.upstream.load_state_dict(self.pretrained_params)
        logging.info("Pretrained S3PRL frontend model parameters reloaded!")
function

load_cmvn(cmvn_file)

funasr.frontends.wav_frontend · View on GitHub ↗

Load cmvn.

Args:

  • cmvn_file — TODO.
📄 Source code
def load_cmvn(cmvn_file):
    """Load cmvn.
    
        Args:
            cmvn_file: TODO.
        """
    with open(cmvn_file, "r", encoding="utf-8") as f:
        lines = f.readlines()
    means_list = []
    vars_list = []
    for i in range(len(lines)):
        line_item = lines[i].split()
        if line_item[0] == "<AddShift>":
            line_item = lines[i + 1].split()
            if line_item[0] == "<LearnRateCoef>":
                add_shift_line = line_item[3 : (len(line_item) - 1)]
                means_list = list(add_shift_line)
                continue
        elif line_item[0] == "<Rescale>":
            line_item = lines[i + 1].split()
            if line_item[0] == "<LearnRateCoef>":
                rescale_line = line_item[3 : (len(line_item) - 1)]
                vars_list = list(rescale_line)
                continue
    means = np.array(means_list).astype(np.float32)
    vars = np.array(vars_list).astype(np.float32)
    cmvn = np.array([means, vars])
    cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
    return cmvn
function

apply_cmvn(inputs, cmvn)

funasr.frontends.wav_frontend · View on GitHub ↗

Apply CMVN with mvn data

📄 Source code
def apply_cmvn(inputs, cmvn):  # noqa
    """
    Apply CMVN with mvn data
    """

    device = inputs.device
    dtype = inputs.dtype
    frame, dim = inputs.shape

    means = cmvn[0:1, :dim]
    vars = cmvn[1:2, :dim]
    inputs += means.to(device)
    inputs *= vars.to(device)

    return inputs.type(torch.float32)
function

apply_lfr(inputs, lfr_m, lfr_n)

funasr.frontends.wav_frontend · View on GitHub ↗

Apply lfr.

Args:

  • inputs — TODO.
  • lfr_m — TODO.
  • lfr_n — TODO.
📄 Source code
def apply_lfr(inputs, lfr_m, lfr_n):
    """Apply lfr.
    
        Args:
            inputs: TODO.
            lfr_m: TODO.
            lfr_n: TODO.
        """
    LFR_inputs = []
    T = inputs.shape[0]
    T_lfr = int(np.ceil(T / lfr_n))
    left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
    inputs = torch.vstack((left_padding, inputs))
    T = T + (lfr_m - 1) // 2
    feat_dim = inputs.shape[-1]
    strides = (lfr_n * feat_dim, 1)
    sizes = (T_lfr, lfr_m * feat_dim)
    last_idx = (T - lfr_m) // lfr_n + 1
    num_padding = lfr_m - (T - last_idx * lfr_n)
    if num_padding > 0:
        num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
        inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
    LFR_outputs = inputs.as_strided(sizes, strides)
    return LFR_outputs.clone().type(torch.float32)
class

WavFrontend

funasr.frontends.wav_frontend · View on GitHub ↗

Conventional frontend structure for ASR.

📄 Source code
class WavFrontend(nn.Module):
    """Conventional frontend structure for ASR."""

    def __init__(
        self,
        cmvn_file: str = None,
        fs: int = 16000,
        window: str = "hamming",
        n_mels: int = 80,
        frame_length: int = 25,
        frame_shift: int = 10,
        filter_length_min: int = -1,
        filter_length_max: int = -1,
        lfr_m: int = 1,
        lfr_n: int = 1,
        dither: float = 1.0,
        snip_edges: bool = True,
        upsacle_samples: bool = True,
        **kwargs,
    ):
        """Initialize WavFrontend.
        
            Args:
                cmvn_file: TODO.
                fs: TODO.
                window: TODO.
                n_mels: TODO.
                frame_length: TODO.
                frame_shift: TODO.
                filter_length_min: TODO.
View full source on GitHub →

Methods

.output_size() L145

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels * self.lfr_m
.forward(input, input_lengths, **kwargs) L149

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(
        self,
        input: torch.Tensor,
        input_lengths,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                **kwargs: Additional keyword arguments.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            waveform_length = input_lengths[i]
            waveform = input[i][:waveform_length]
            if self.upsacle_samples:
                waveform = waveform * (1 << 15)
            waveform = waveform.unsqueeze(0)
            mat = kaldi.fbank(
                waveform,
                num_mel_bins=self.n_mels,
                frame_length=min(self.frame_length,waveform_length/self.fs*1000),
                frame_shift=self.frame_shift,
                dither=self.dither,
                energy_floor=0.0,
                window_type=self.window,
View full source →
.forward_fbank(input, input_lengths) L198

Forward fbank.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source
    def forward_fbank(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward fbank.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            waveform_length = input_lengths[i]
            waveform = input[i][:waveform_length]
            waveform = waveform * (1 << 15)
            waveform = waveform.unsqueeze(0)
            mat = kaldi.fbank(
                waveform,
                num_mel_bins=self.n_mels,
                frame_length=self.frame_length,
                frame_shift=self.frame_shift,
                dither=self.dither,
                energy_floor=0.0,
                window_type=self.window,
                sample_frequency=self.fs,
            )

            feat_length = mat.size(0)
            feats.append(mat)
View full source →
.forward_lfr_cmvn(input, input_lengths) L234

Forward lfr cmvn.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source
    def forward_lfr_cmvn(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward lfr cmvn.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            mat = input[i, : input_lengths[i], :]
            if self.lfr_m != 1 or self.lfr_n != 1:
                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            if self.cmvn is not None:
                mat = apply_cmvn(mat, self.cmvn)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)

        feats_lens = torch.as_tensor(feats_lens)
        feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
        return feats_pad, feats_lens
method

WavFrontend.output_size()

funasr.frontends.wav_frontend.WavFrontend · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels * self.lfr_m
method

WavFrontend.forward(input, input_lengths, **kwargs)

funasr.frontends.wav_frontend.WavFrontend · View on GitHub ↗

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(
        self,
        input: torch.Tensor,
        input_lengths,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                **kwargs: Additional keyword arguments.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            waveform_length = input_lengths[i]
            waveform = input[i][:waveform_length]
            if self.upsacle_samples:
                waveform = waveform * (1 << 15)
            waveform = waveform.unsqueeze(0)
            mat = kaldi.fbank(
                waveform,
                num_mel_bins=self.n_mels,
                frame_length=min(self.frame_length,waveform_length/self.fs*1000),
                frame_shift=self.frame_shift,
                dither=self.dither,
                energy_floor=0.0,
                window_type=self.window,
View full source on GitHub →
method

WavFrontend.forward_fbank(input, input_lengths)

funasr.frontends.wav_frontend.WavFrontend · View on GitHub ↗

Forward fbank.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source code
    def forward_fbank(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward fbank.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            waveform_length = input_lengths[i]
            waveform = input[i][:waveform_length]
            waveform = waveform * (1 << 15)
            waveform = waveform.unsqueeze(0)
            mat = kaldi.fbank(
                waveform,
                num_mel_bins=self.n_mels,
                frame_length=self.frame_length,
                frame_shift=self.frame_shift,
                dither=self.dither,
                energy_floor=0.0,
                window_type=self.window,
                sample_frequency=self.fs,
            )

            feat_length = mat.size(0)
            feats.append(mat)
View full source on GitHub →
method

WavFrontend.forward_lfr_cmvn(input, input_lengths)

funasr.frontends.wav_frontend.WavFrontend · View on GitHub ↗

Forward lfr cmvn.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source code
    def forward_lfr_cmvn(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward lfr cmvn.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            mat = input[i, : input_lengths[i], :]
            if self.lfr_m != 1 or self.lfr_n != 1:
                mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
            if self.cmvn is not None:
                mat = apply_cmvn(mat, self.cmvn)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)

        feats_lens = torch.as_tensor(feats_lens)
        feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
        return feats_pad, feats_lens
class

WavFrontendOnline

funasr.frontends.wav_frontend · View on GitHub ↗

Conventional frontend structure for streaming ASR/VAD.

📄 Source code
class WavFrontendOnline(nn.Module):
    """Conventional frontend structure for streaming ASR/VAD."""

    def __init__(
        self,
        cmvn_file: str = None,
        fs: int = 16000,
        window: str = "hamming",
        n_mels: int = 80,
        frame_length: int = 25,
        frame_shift: int = 10,
        filter_length_min: int = -1,
        filter_length_max: int = -1,
        lfr_m: int = 1,
        lfr_n: int = 1,
        dither: float = 1.0,
        snip_edges: bool = True,
        upsacle_samples: bool = True,
        **kwargs,
    ):
        """Initialize WavFrontendOnline.
        
            Args:
                cmvn_file: TODO.
                fs: TODO.
                window: TODO.
                n_mels: TODO.
                frame_length: TODO.
                frame_shift: TODO.
                filter_length_min: TODO.
View full source on GitHub →

Methods

.output_size() L324

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels * self.lfr_m
.apply_cmvn(inputs, cmvn) L329

Apply CMVN with mvn data

📄 Source
    def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
        """
        Apply CMVN with mvn data
        """

        device = inputs.device
        dtype = inputs.dtype
        frame, dim = inputs.shape

        means = np.tile(cmvn[0:1, :dim], (frame, 1))
        vars = np.tile(cmvn[1:2, :dim], (frame, 1))
        inputs += torch.from_numpy(means).type(dtype).to(device)
        inputs *= torch.from_numpy(vars).type(dtype).to(device)

        return inputs.type(torch.float32)
.apply_lfr(inputs, lfr_m, lfr_n, is_final) L346

Apply lfr with data

📄 Source
    def apply_lfr(
        inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Apply lfr with data
        """

        LFR_inputs = []
        # inputs = torch.vstack((inputs_lfr_cache, inputs))
        T = inputs.shape[0]  # include the right context
        T_lfr = int(
            np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
        )  # minus the right context: (lfr_m - 1) // 2
        splice_idx = T_lfr
        feat_dim = inputs.shape[-1]
        ori_inputs = inputs
        strides = (lfr_n * feat_dim, 1)
        sizes = (T_lfr, lfr_m * feat_dim)
        last_idx = (T - lfr_m) // lfr_n + 1
        num_padding = lfr_m - (T - last_idx * lfr_n)
        if is_final:
            if num_padding > 0:
                num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
                inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
        else:
            if num_padding > 0:
                sizes = (last_idx, lfr_m * feat_dim)
                splice_idx = last_idx
        splice_idx = min(T - 1, splice_idx * lfr_n)
        LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
View full source →
.compute_frame_num(sample_length, frame_sample_length, frame_shift_sample_length) L380

Compute frame num.

Args:

  • sample_length — TODO.
  • frame_sample_length — TODO.
  • frame_shift_sample_length — TODO.
📄 Source
    def compute_frame_num(
        sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
    ) -> int:
        """Compute frame num.
        
            Args:
                sample_length: TODO.
                frame_sample_length: TODO.
                frame_shift_sample_length: TODO.
            """
        frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
        return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
.forward_fbank(input, input_lengths, cache, **kwargs) L393

Forward fbank.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward_fbank(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward fbank.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        batch_size = input.size(0)

        input = torch.cat((cache["input_cache"], input), dim=1)
        frame_num = self.compute_frame_num(
            input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
        )
        # update self.in_cache
        cache["input_cache"] = input[
            :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
        ]
        waveforms = torch.empty(0)
        feats_pad = torch.empty(0)
        feats_lens = torch.empty(0)
View full source →
.forward_lfr_cmvn(input, input_lengths, is_final, cache, **kwargs) L462

Forward lfr cmvn.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • is_final — Whether this is the final chunk in streaming.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward_lfr_cmvn(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor,
        is_final: bool = False,
        cache: dict = None,
        **kwargs,
    ):
        """Forward lfr cmvn.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                is_final: Whether this is the final chunk in streaming.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        lfr_splice_frame_idxs = []
        for i in range(batch_size):
            mat = input[i, : input_lengths[i], :]
            if self.lfr_m != 1 or self.lfr_n != 1:
                # update self.lfr_splice_cache in self.apply_lfr
                # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
                mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = self.apply_lfr(
                    mat, self.lfr_m, self.lfr_n, is_final
View full source →
.forward(input, input_lengths, **kwargs) L505

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs):
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                **kwargs: Additional keyword arguments.
            """
        is_final = kwargs.get("is_final", False)
        cache = kwargs.get("cache", {})
        if len(cache) == 0:
            self.init_cache(cache)

        batch_size = input.shape[0]
        assert (
            batch_size == 1
        ), "we support to extract feature online only when the batch size is equal to 1 now"

        waveforms, feats, feats_lengths = self.forward_fbank(
            input, input_lengths, cache=cache
        )  # input shape: B T D

        if feats.shape[0]:

            cache["waveforms"] = torch.cat((cache["reserve_waveforms"], waveforms), dim=1)

            if not cache["lfr_splice_cache"]:  # 初始化splice_cache
                for i in range(batch_size):
                    cache["lfr_splice_cache"].append(
                        feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1)
View full source →
.init_cache(cache) L594

Init cache.

Args:

  • cache — State cache dict for streaming inference.
📄 Source
    def init_cache(self, cache: dict = None):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        cache["reserve_waveforms"] = torch.empty(0)
        cache["input_cache"] = torch.empty(0)
        cache["lfr_splice_cache"] = []
        cache["waveforms"] = None
        cache["fbanks"] = None
        cache["fbanks_lens"] = None
        return cache
method

WavFrontendOnline.output_size()

funasr.frontends.wav_frontend.WavFrontendOnline · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels * self.lfr_m
method

WavFrontendOnline.apply_cmvn(inputs, cmvn)

funasr.frontends.wav_frontend.WavFrontendOnline · View on GitHub ↗

Apply CMVN with mvn data

📄 Source code
    def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
        """
        Apply CMVN with mvn data
        """

        device = inputs.device
        dtype = inputs.dtype
        frame, dim = inputs.shape

        means = np.tile(cmvn[0:1, :dim], (frame, 1))
        vars = np.tile(cmvn[1:2, :dim], (frame, 1))
        inputs += torch.from_numpy(means).type(dtype).to(device)
        inputs *= torch.from_numpy(vars).type(dtype).to(device)

        return inputs.type(torch.float32)
method

WavFrontendOnline.apply_lfr(inputs, lfr_m, lfr_n, is_final)

funasr.frontends.wav_frontend.WavFrontendOnline · View on GitHub ↗

Apply lfr with data

📄 Source code
    def apply_lfr(
        inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Apply lfr with data
        """

        LFR_inputs = []
        # inputs = torch.vstack((inputs_lfr_cache, inputs))
        T = inputs.shape[0]  # include the right context
        T_lfr = int(
            np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
        )  # minus the right context: (lfr_m - 1) // 2
        splice_idx = T_lfr
        feat_dim = inputs.shape[-1]
        ori_inputs = inputs
        strides = (lfr_n * feat_dim, 1)
        sizes = (T_lfr, lfr_m * feat_dim)
        last_idx = (T - lfr_m) // lfr_n + 1
        num_padding = lfr_m - (T - last_idx * lfr_n)
        if is_final:
            if num_padding > 0:
                num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
                inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
        else:
            if num_padding > 0:
                sizes = (last_idx, lfr_m * feat_dim)
                splice_idx = last_idx
        splice_idx = min(T - 1, splice_idx * lfr_n)
        LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
View full source on GitHub →
method

WavFrontendOnline.compute_frame_num(sample_length, frame_sample_length, frame_shift_sample_length)

funasr.frontends.wav_frontend.WavFrontendOnline · View on GitHub ↗

Compute frame num.

Args:

  • sample_length — TODO.
  • frame_sample_length — TODO.
  • frame_shift_sample_length — TODO.
📄 Source code
    def compute_frame_num(
        sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
    ) -> int:
        """Compute frame num.
        
            Args:
                sample_length: TODO.
                frame_sample_length: TODO.
                frame_shift_sample_length: TODO.
            """
        frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
        return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
method

WavFrontendOnline.forward_fbank(input, input_lengths, cache, **kwargs)

funasr.frontends.wav_frontend.WavFrontendOnline · View on GitHub ↗

Forward fbank.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward_fbank(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward fbank.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        batch_size = input.size(0)

        input = torch.cat((cache["input_cache"], input), dim=1)
        frame_num = self.compute_frame_num(
            input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
        )
        # update self.in_cache
        cache["input_cache"] = input[
            :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
        ]
        waveforms = torch.empty(0)
        feats_pad = torch.empty(0)
        feats_lens = torch.empty(0)
View full source on GitHub →
method

WavFrontendOnline.forward_lfr_cmvn(input, input_lengths, is_final, cache, **kwargs)

funasr.frontends.wav_frontend.WavFrontendOnline · View on GitHub ↗

Forward lfr cmvn.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • is_final — Whether this is the final chunk in streaming.
  • cache — State cache dict for streaming inference.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward_lfr_cmvn(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor,
        is_final: bool = False,
        cache: dict = None,
        **kwargs,
    ):
        """Forward lfr cmvn.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                is_final: Whether this is the final chunk in streaming.
                cache: State cache dict for streaming inference.
                **kwargs: Additional keyword arguments.
            """
        if cache is None:
            cache = {}
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        lfr_splice_frame_idxs = []
        for i in range(batch_size):
            mat = input[i, : input_lengths[i], :]
            if self.lfr_m != 1 or self.lfr_n != 1:
                # update self.lfr_splice_cache in self.apply_lfr
                # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
                mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = self.apply_lfr(
                    mat, self.lfr_m, self.lfr_n, is_final
View full source on GitHub →
method

WavFrontendOnline.forward(input, input_lengths, **kwargs)

funasr.frontends.wav_frontend.WavFrontendOnline · View on GitHub ↗

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs):
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                **kwargs: Additional keyword arguments.
            """
        is_final = kwargs.get("is_final", False)
        cache = kwargs.get("cache", {})
        if len(cache) == 0:
            self.init_cache(cache)

        batch_size = input.shape[0]
        assert (
            batch_size == 1
        ), "we support to extract feature online only when the batch size is equal to 1 now"

        waveforms, feats, feats_lengths = self.forward_fbank(
            input, input_lengths, cache=cache
        )  # input shape: B T D

        if feats.shape[0]:

            cache["waveforms"] = torch.cat((cache["reserve_waveforms"], waveforms), dim=1)

            if not cache["lfr_splice_cache"]:  # 初始化splice_cache
                for i in range(batch_size):
                    cache["lfr_splice_cache"].append(
                        feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1)
View full source on GitHub →
method

WavFrontendOnline.init_cache(cache)

funasr.frontends.wav_frontend.WavFrontendOnline · View on GitHub ↗

Init cache.

Args:

  • cache — State cache dict for streaming inference.
📄 Source code
    def init_cache(self, cache: dict = None):
        """Init cache.
        
            Args:
                cache: State cache dict for streaming inference.
            """
        if cache is None:
            cache = {}
        cache["reserve_waveforms"] = torch.empty(0)
        cache["input_cache"] = torch.empty(0)
        cache["lfr_splice_cache"] = []
        cache["waveforms"] = None
        cache["fbanks"] = None
        cache["fbanks_lens"] = None
        return cache
class

WavFrontendMel23

funasr.frontends.wav_frontend · View on GitHub ↗

Conventional frontend structure for ASR.

📄 Source code
class WavFrontendMel23(nn.Module):
    """Conventional frontend structure for ASR."""

    def __init__(
        self,
        fs: int = 16000,
        frame_length: int = 25,
        frame_shift: int = 10,
        lfr_m: int = 1,
        lfr_n: int = 1,
        **kwargs,
    ):
        """Initialize WavFrontendMel23.
        
            Args:
                fs: TODO.
                frame_length: TODO.
                frame_shift: TODO.
                lfr_m: TODO.
                lfr_n: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        self.fs = fs
        self.frame_length = frame_length
        self.frame_shift = frame_shift
        self.lfr_m = lfr_m
        self.lfr_n = lfr_n
        self.n_mels = 23

View full source on GitHub →

Methods

.output_size() L641

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels * (2 * self.lfr_m + 1)
.forward(input, input_lengths) L645

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            waveform_length = input_lengths[i]
            waveform = input[i][:waveform_length]
            waveform = waveform.numpy()
            mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
            mat = eend_ola_feature.transform(mat)
            mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
            mat = mat[:: self.lfr_n]
            mat = torch.from_numpy(mat)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)

        feats_lens = torch.as_tensor(feats_lens)
        feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
        return feats_pad, feats_lens
method

WavFrontendMel23.output_size()

funasr.frontends.wav_frontend.WavFrontendMel23 · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels * (2 * self.lfr_m + 1)
method

WavFrontendMel23.forward(input, input_lengths)

funasr.frontends.wav_frontend.WavFrontendMel23 · View on GitHub ↗

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
📄 Source code
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        for i in range(batch_size):
            waveform_length = input_lengths[i]
            waveform = input[i][:waveform_length]
            waveform = waveform.numpy()
            mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
            mat = eend_ola_feature.transform(mat)
            mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
            mat = mat[:: self.lfr_n]
            mat = torch.from_numpy(mat)
            feat_length = mat.size(0)
            feats.append(mat)
            feats_lens.append(feat_length)

        feats_lens = torch.as_tensor(feats_lens)
        feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
        return feats_pad, feats_lens
class

WhisperFrontend

funasr.frontends.whisper_frontend · View on GitHub ↗

Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:

  • URL — https://github.com/openai/whisper
📄 Source code
class WhisperFrontend(nn.Module):
    """Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:

    URL: https://github.com/openai/whisper
    """

    def __init__(
        self,
        fs: int = 16000,
        whisper_model: str = None,
        do_pad_trim: bool = True,
        n_mels: int = 80,
        permute: bool = False,
        **kwargs,
    ):
        """Initialize WhisperFrontend.
        
            Args:
                fs: TODO.
                whisper_model: Whisper Model instance.
                do_pad_trim: TODO.
                n_mels: TODO.
                permute: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        assert fs == 16000
        self.fs = fs
        import whisper
        from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
View full source on GitHub →

Methods

.output_size() L67

Output size.

📄 Source
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels
.log_mel_spectrogram(audio, ilens) L71

Log mel spectrogram.

Args:

  • audio — TODO.
  • ilens — TODO.
📄 Source
    def log_mel_spectrogram(
        self,
        audio: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> torch.Tensor:
        """Log mel spectrogram.
        
            Args:
                audio: TODO.
                ilens: TODO.
            """
        window = torch.hann_window(self.win_length).to(audio.device)
        stft = torch.stft(audio, self.n_fft, self.hop_length, window=window, return_complex=True)

        # whisper deletes the last frame by default (Shih-Lun)
        magnitudes = stft[..., :-1].abs() ** 2
        if self.filters_path is not None:
            filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
        else:
            filters = self.mel_filters(audio.device, self.n_mels)
        mel_spec = filters @ magnitudes

        log_spec = torch.clamp(mel_spec, min=1e-10).log10()

        if ilens is not None:
            olens = ilens // self.hop_length
        else:
            olens = None

        log_spec = torch.maximum(
View full source →
.forward(input, input_lengths, **kwargs) L108

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                **kwargs: Additional keyword arguments.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        input = input.to(torch.float32)
        for i in range(batch_size):
            if self.do_pad_trim:
                feat = self.pad_or_trim(input[i], self.pad_samples)
            else:
                feat = input[i]
            feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
            feats.append(feat[0])
            feats_lens.append(feat_len)
        feats_lens = torch.as_tensor(feats_lens)

        if batch_size == 1:
            feats_pad = feats[0][None, :, :]
        else:
View full source →
method

WhisperFrontend.output_size()

funasr.frontends.whisper_frontend.WhisperFrontend · View on GitHub ↗

Output size.

📄 Source code
    def output_size(self) -> int:
        """Output size."""
        return self.n_mels
method

WhisperFrontend.log_mel_spectrogram(audio, ilens)

funasr.frontends.whisper_frontend.WhisperFrontend · View on GitHub ↗

Log mel spectrogram.

Args:

  • audio — TODO.
  • ilens — TODO.
📄 Source code
    def log_mel_spectrogram(
        self,
        audio: torch.Tensor,
        ilens: torch.Tensor = None,
    ) -> torch.Tensor:
        """Log mel spectrogram.
        
            Args:
                audio: TODO.
                ilens: TODO.
            """
        window = torch.hann_window(self.win_length).to(audio.device)
        stft = torch.stft(audio, self.n_fft, self.hop_length, window=window, return_complex=True)

        # whisper deletes the last frame by default (Shih-Lun)
        magnitudes = stft[..., :-1].abs() ** 2
        if self.filters_path is not None:
            filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
        else:
            filters = self.mel_filters(audio.device, self.n_mels)
        mel_spec = filters @ magnitudes

        log_spec = torch.clamp(mel_spec, min=1e-10).log10()

        if ilens is not None:
            olens = ilens // self.hop_length
        else:
            olens = None

        log_spec = torch.maximum(
View full source on GitHub →
method

WhisperFrontend.forward(input, input_lengths, **kwargs)

funasr.frontends.whisper_frontend.WhisperFrontend · View on GitHub ↗

Forward pass for training.

Args:

  • input — Input audio/text data.
  • input_lengths — Lengths of input.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass for training.
        
            Args:
                input: Input audio/text data.
                input_lengths: Lengths of input.
                **kwargs: Additional keyword arguments.
            """
        batch_size = input.size(0)
        feats = []
        feats_lens = []
        input = input.to(torch.float32)
        for i in range(batch_size):
            if self.do_pad_trim:
                feat = self.pad_or_trim(input[i], self.pad_samples)
            else:
                feat = input[i]
            feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
            feats.append(feat[0])
            feats_lens.append(feat_len)
        feats_lens = torch.as_tensor(feats_lens)

        if batch_size == 1:
            feats_pad = feats[0][None, :, :]
        else:
View full source on GitHub →
class

SlidingWindow

funasr.frontends.windowing · View on GitHub ↗

Sliding Window.

Provides a sliding window over a batched continuous raw audio tensor.

Optionally, provides padding (Currently not implemented).

Combine this module with a pre-encoder compatible with raw audio data,

for example Sinc convolutions.

Known issues:

Output length is calculated incorrectly if audio shorter than win_length.

  • WARNING — trailing values are discarded - padding not implemented yet.

There is currently no additional window function applied to input values.

📄 Source code
class SlidingWindow(nn.Module):
    """Sliding Window.
    Provides a sliding window over a batched continuous raw audio tensor.
    Optionally, provides padding (Currently not implemented).
    Combine this module with a pre-encoder compatible with raw audio data,
    for example Sinc convolutions.
    Known issues:
    Output length is calculated incorrectly if audio shorter than win_length.
    WARNING: trailing values are discarded - padding not implemented yet.
    There is currently no additional window function applied to input values.
    """

    def __init__(
        self,
        win_length: int = 400,
        hop_length: int = 160,
        channels: int = 1,
        padding: int = None,
        fs=None,
    ):
        """Initialize.
        Args:
            win_length: Length of frame.
            hop_length: Relative starting point of next frame.
            channels: Number of input channels.
            padding: Padding (placeholder, currently not implemented).
            fs:  Sampling rate (placeholder for compatibility, not used).
        """
        super().__init__()
        self.fs = fs
View full source on GitHub →

Methods

.forward(input, input_lengths) L47

Apply a sliding window on the input.

Args:

  • input — Input (B, T, C*D) or (B, T*C*D), with D=C=1.
  • input_lengths — Input lengths within batch.

Returns:

  • Tensor — Output with dimensions (B, T, C, D), with D=win_length.
  • Tensor — Output lengths within batch.
📄 Source
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply a sliding window on the input.
        Args:
            input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
            input_lengths: Input lengths within batch.
        Returns:
            Tensor: Output with dimensions (B, T, C, D), with D=win_length.
            Tensor: Output lengths within batch.
        """
        input_size = input.size()
        B = input_size[0]
        T = input_size[1]
        C = self.channels
        D = self.win_length
        # (B, T, C) --> (T, B, C)
        continuous = input.view(B, T, C).permute(1, 0, 2)
        windowed = continuous.unfold(0, D, self.hop_length)
        # (T, B, C, D) --> (B, T, C, D)
        output = windowed.permute(1, 0, 2, 3).contiguous()
        # After unfold(), windowed lengths change:
        output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
        return output, output_lengths
.output_size() L72

Return output length of feature dimension D, i.e. the window length.

📄 Source
    def output_size(self) -> int:
        """Return output length of feature dimension D, i.e. the window length."""
        return self.win_length
method

SlidingWindow.forward(input, input_lengths)

funasr.frontends.windowing.SlidingWindow · View on GitHub ↗

Apply a sliding window on the input.

Args:

  • input — Input (B, T, C*D) or (B, T*C*D), with D=C=1.
  • input_lengths — Input lengths within batch.

Returns:

  • Tensor — Output with dimensions (B, T, C, D), with D=win_length.
  • Tensor — Output lengths within batch.
📄 Source code
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply a sliding window on the input.
        Args:
            input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
            input_lengths: Input lengths within batch.
        Returns:
            Tensor: Output with dimensions (B, T, C, D), with D=win_length.
            Tensor: Output lengths within batch.
        """
        input_size = input.size()
        B = input_size[0]
        T = input_size[1]
        C = self.channels
        D = self.win_length
        # (B, T, C) --> (T, B, C)
        continuous = input.view(B, T, C).permute(1, 0, 2)
        windowed = continuous.unfold(0, D, self.hop_length)
        # (T, B, C, D) --> (B, T, C, D)
        output = windowed.permute(1, 0, 2, 3).contiguous()
        # After unfold(), windowed lengths change:
        output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
        return output, output_lengths
method

SlidingWindow.output_size()

funasr.frontends.windowing.SlidingWindow · View on GitHub ↗

Return output length of feature dimension D, i.e. the window length.

📄 Source code
    def output_size(self) -> int:
        """Return output length of feature dimension D, i.e. the window length."""
        return self.win_length
function

build_tokenizer(token_type, bpemodel, non_linguistic_symbols, remove_non_linguistic_symbols, space_symbol, delimiter, g2p_type)

funasr.tokenizer.build_tokenizer · View on GitHub ↗

A helper function to instantiate Tokenizer

📄 Source code
def build_tokenizer(
    token_type: str,
    bpemodel: Union[Path, str, Iterable[str]] = None,
    non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
    remove_non_linguistic_symbols: bool = False,
    space_symbol: str = "<space>",
    delimiter: str = None,
    g2p_type: str = None,
) -> AbsTokenizer:
    """A helper function to instantiate Tokenizer"""
    if token_type == "bpe":
        if bpemodel is None:
            raise ValueError('bpemodel is required if token_type = "bpe"')

        if remove_non_linguistic_symbols:
            raise RuntimeError(
                "remove_non_linguistic_symbols is not implemented for token_type=bpe"
            )
        return SentencepiecesTokenizer(bpemodel)

    elif token_type == "word":
        if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
            return WordTokenizer(
                delimiter=delimiter,
                non_linguistic_symbols=non_linguistic_symbols,
                remove_non_linguistic_symbols=True,
            )
        else:
            return WordTokenizer(delimiter=delimiter)

View full source on GitHub →
class

CharTokenizer

funasr.tokenizer.char_tokenizer · View on GitHub ↗

No documentation yet.

📄 Source code
class CharTokenizer(BaseTokenizer):
    def __init__(
        self,
        non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
        space_symbol: str = "<space>",
        remove_non_linguistic_symbols: bool = False,
        split_with_space: bool = False,
        seg_dict: str = None,
        **kwargs,
    ):
        """Initialize CharTokenizer.
        
            Args:
                non_linguistic_symbols: TODO.
                space_symbol: TODO.
                remove_non_linguistic_symbols: TODO.
                split_with_space: TODO.
                seg_dict: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(**kwargs)
        self.space_symbol = space_symbol
        if non_linguistic_symbols is None:
            self.non_linguistic_symbols = set()
        elif isinstance(non_linguistic_symbols, (Path, str)):
            non_linguistic_symbols = Path(non_linguistic_symbols)
            try:
                with non_linguistic_symbols.open("r", encoding="utf-8") as f:
                    self.non_linguistic_symbols = set(line.rstrip() for line in f)
            except FileNotFoundError:
View full source on GitHub →

Methods

.text2tokens(line) L63

Text2tokens.

Args:

  • line — TODO.
📄 Source
    def text2tokens(self, line: Union[str, list]) -> List[str]:
        # if self.split_with_space:
        """Text2tokens.
        
            Args:
                line: TODO.
            """
        if self.seg_dict is not None:
            tokens = line.strip().split(" ")
            tokens = seg_tokenize(tokens, self.seg_dict)
        else:
            tokens = []
            while len(line) != 0:
                for w in self.non_linguistic_symbols:
                    if line.startswith(w):
                        if not self.remove_non_linguistic_symbols:
                            tokens.append(line[: len(w)])
                        line = line[len(w) :]
                        break
                else:
                    t = line[0]
                    if t == " ":
                        # t = "<space>"
                        line = line[1:]
                        continue
                    tokens.append(t)
                    line = line[1:]
        return tokens
.tokens2text(tokens) L92

Tokens2text.

Args:

  • tokens — TODO.
📄 Source
    def tokens2text(self, tokens: Iterable[str]) -> str:
        """Tokens2text.
        
            Args:
                tokens: TODO.
            """
        tokens = [t if t != self.space_symbol else " " for t in tokens]
        return "".join(tokens)
method

CharTokenizer.text2tokens(line)

funasr.tokenizer.char_tokenizer.CharTokenizer · View on GitHub ↗

Text2tokens.

Args:

  • line — TODO.
📄 Source code
    def text2tokens(self, line: Union[str, list]) -> List[str]:
        # if self.split_with_space:
        """Text2tokens.
        
            Args:
                line: TODO.
            """
        if self.seg_dict is not None:
            tokens = line.strip().split(" ")
            tokens = seg_tokenize(tokens, self.seg_dict)
        else:
            tokens = []
            while len(line) != 0:
                for w in self.non_linguistic_symbols:
                    if line.startswith(w):
                        if not self.remove_non_linguistic_symbols:
                            tokens.append(line[: len(w)])
                        line = line[len(w) :]
                        break
                else:
                    t = line[0]
                    if t == " ":
                        # t = "<space>"
                        line = line[1:]
                        continue
                    tokens.append(t)
                    line = line[1:]
        return tokens
method

CharTokenizer.tokens2text(tokens)

funasr.tokenizer.char_tokenizer.CharTokenizer · View on GitHub ↗

Tokens2text.

Args:

  • tokens — TODO.
📄 Source code
    def tokens2text(self, tokens: Iterable[str]) -> str:
        """Tokens2text.
        
            Args:
                tokens: TODO.
            """
        tokens = [t if t != self.space_symbol else " " for t in tokens]
        return "".join(tokens)
function

load_seg_dict(seg_dict_file)

funasr.tokenizer.char_tokenizer · View on GitHub ↗

Load seg dict.

Args:

  • seg_dict_file — TODO.
📄 Source code
def load_seg_dict(seg_dict_file):
    """Load seg dict.
    
        Args:
            seg_dict_file: TODO.
        """
    seg_dict = {}
    assert isinstance(seg_dict_file, str)
    with open(seg_dict_file, "r", encoding="utf8") as f:
        lines = f.readlines()
        for line in lines:
            s = line.strip().split()
            key = s[0]
            value = s[1:]
            seg_dict[key] = " ".join(value)
    return seg_dict
function

seg_tokenize(txt, seg_dict)

funasr.tokenizer.char_tokenizer · View on GitHub ↗

Seg tokenize.

Args:

  • txt — TODO.
  • seg_dict — TODO.
📄 Source code
def seg_tokenize(txt, seg_dict):
    # pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
    """Seg tokenize.
    
        Args:
            txt: TODO.
            seg_dict: TODO.
        """
    pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
    out_txt = ""
    for word in txt:
        word = word.lower()
        if word in seg_dict:
            out_txt += seg_dict[word] + " "
        else:
            if pattern.match(word):
                for char in word:
                    if char in seg_dict:
                        out_txt += seg_dict[char] + " "
                    else:
                        out_txt += "<unk>" + " "
            else:
                out_txt += "<unk>" + " "
    return out_txt.strip().split()
class

TextCleaner

funasr.tokenizer.cleaner · View on GitHub ↗

Text cleaner.

Examples:

>>> cleaner = TextCleaner("tacotron")

>>> cleaner("(Hello-World); & jr. & dr.")

'HELLO WORLD, AND JUNIOR AND DOCTOR'

📄 Source code
class TextCleaner:
    """Text cleaner.

    Examples:
        >>> cleaner = TextCleaner("tacotron")
        >>> cleaner("(Hello-World);   &  jr. & dr.")
        'HELLO WORLD, AND JUNIOR AND DOCTOR'

    """

    def __init__(self, cleaner_types: Collection[str] = None):

        """Initialize TextCleaner.
        
            Args:
                cleaner_types: TODO.
            """
        if cleaner_types is None:
            self.cleaner_types = []
        elif isinstance(cleaner_types, str):
            self.cleaner_types = [cleaner_types]
        else:
            self.cleaner_types = list(cleaner_types)

    def __call__(self, text: str) -> str:
        """Internal: call  .
        
            Args:
                text: Text tensor or string input.
            """
View full source on GitHub →
function

HuggingfaceTokenizer(init_param_path, **kwargs)

funasr.tokenizer.hf_tokenizer · View on GitHub ↗

Huggingfacetokenizer.

Args:

  • init_param_path — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
def HuggingfaceTokenizer(init_param_path, **kwargs):
    """Huggingfacetokenizer.
    
        Args:
            init_param_path: TODO.
            **kwargs: Additional keyword arguments.
        """
    try:
        from transformers import AutoTokenizer
    except Exception as e:
        raise ImportError(
            "HuggingfaceTokenizer requires 'transformers'. "
            "Please install it with: pip install -U transformers"
        ) from e

    tokenizer = AutoTokenizer.from_pretrained(init_param_path)

    return tokenizer
class

KoreanCleaner

funasr.tokenizer.korean_cleaner · View on GitHub ↗

No documentation yet.

📄 Source code
class KoreanCleaner:
    @classmethod
    def _normalize_numbers(cls, text):
        """Internal: normalize numbers.
        
            Args:
                text: Text tensor or string input.
            """
        number_to_kor = {
            "0": "영",
            "1": "일",
            "2": "이",
            "3": "삼",
            "4": "사",
            "5": "오",
            "6": "육",
            "7": "칠",
            "8": "팔",
            "9": "구",
        }
        new_text = "".join(
            number_to_kor[char] if char in number_to_kor.keys() else char for char in text
        )
        return new_text

    @classmethod
    def _normalize_english_text(cls, text):
        """Internal: normalize english text.
        
            Args:
View full source on GitHub →

Methods

.normalize_text(cls, text) L75

Normalize text.

Args:

  • text — Text tensor or string input.
📄 Source
    def normalize_text(cls, text):
        # stage 0 : text strip
        """Normalize text.
        
            Args:
                text: Text tensor or string input.
            """
        text = text.strip()

        # stage 1 : normalize numbers
        text = cls._normalize_numbers(text)

        # stage 2 : normalize english text
        text = cls._normalize_english_text(text)
        return text
method

KoreanCleaner.normalize_text(cls, text)

funasr.tokenizer.korean_cleaner.KoreanCleaner · View on GitHub ↗

Normalize text.

Args:

  • text — Text tensor or string input.
📄 Source code
    def normalize_text(cls, text):
        # stage 0 : text strip
        """Normalize text.
        
            Args:
                text: Text tensor or string input.
            """
        text = text.strip()

        # stage 1 : normalize numbers
        text = cls._normalize_numbers(text)

        # stage 2 : normalize english text
        text = cls._normalize_english_text(text)
        return text
function

split_by_space(text)

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Split by space.

Args:

  • text — Text tensor or string input.
📄 Source code
def split_by_space(text) -> List[str]:
    """Split by space.
    
        Args:
            text: Text tensor or string input.
        """
    if "   " in text:
        text = text.replace("   ", " <space> ")
        return [c.replace("<space>", " ") for c in text.split(" ")]
    else:
        return text.split(" ")
function

pyopenjtalk_g2p(text)

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Pyopenjtalk g2p.

Args:

  • text — Text tensor or string input.
📄 Source code
def pyopenjtalk_g2p(text) -> List[str]:
    """Pyopenjtalk g2p.
    
        Args:
            text: Text tensor or string input.
        """
    import pyopenjtalk

    # phones is a str object separated by space
    phones = pyopenjtalk.g2p(text, kana=False)
    phones = phones.split(" ")
    return phones
function

pyopenjtalk_g2p_accent(text)

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Pyopenjtalk g2p accent.

Args:

  • text — Text tensor or string input.
📄 Source code
def pyopenjtalk_g2p_accent(text) -> List[str]:
    """Pyopenjtalk g2p accent.
    
        Args:
            text: Text tensor or string input.
        """
    import pyopenjtalk
    import re

    phones = []
    for labels in pyopenjtalk.run_frontend(text)[1]:
        p = re.findall(r"\-(.*?)\+.*?\/A:([0-9\-]+).*?\/F:.*?_([0-9]+)", labels)
        if len(p) == 1:
            phones += [p[0][0], p[0][2], p[0][1]]
    return phones
function

pyopenjtalk_g2p_accent_with_pause(text)

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Pyopenjtalk g2p accent with pause.

Args:

  • text — Text tensor or string input.
📄 Source code
def pyopenjtalk_g2p_accent_with_pause(text) -> List[str]:
    """Pyopenjtalk g2p accent with pause.
    
        Args:
            text: Text tensor or string input.
        """
    import pyopenjtalk
    import re

    phones = []
    for labels in pyopenjtalk.run_frontend(text)[1]:
        if labels.split("-")[1].split("+")[0] == "pau":
            phones += ["pau"]
            continue
        p = re.findall(r"\-(.*?)\+.*?\/A:([0-9\-]+).*?\/F:.*?_([0-9]+)", labels)
        if len(p) == 1:
            phones += [p[0][0], p[0][2], p[0][1]]
    return phones
function

pyopenjtalk_g2p_kana(text)

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Pyopenjtalk g2p kana.

Args:

  • text — Text tensor or string input.
📄 Source code
def pyopenjtalk_g2p_kana(text) -> List[str]:
    """Pyopenjtalk g2p kana.
    
        Args:
            text: Text tensor or string input.
        """
    import pyopenjtalk

    kanas = pyopenjtalk.g2p(text, kana=True)
    return list(kanas)
function

pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels)

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Extract phoneme + prosoody symbol sequence from input full-context labels.

The algorithm is based on `Prosodic features control by symbols as input of

  • sequence — to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks.

Args:

  • text (str) — Input text.
  • drop_unvoiced_vowels (bool) — whether to drop unvoiced vowels.

Returns:

List[str]: List of phoneme + prosody symbols.

Examples:

>>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody

>>> pyopenjtalk_g2p_prosody("こんにちは。")

['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']

.. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic

modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104

📄 Source code
def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> List[str]:
    """Extract phoneme + prosoody symbol sequence from input full-context labels.

    The algorithm is based on `Prosodic features control by symbols as input of
    sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks.

    Args:
        text (str): Input text.
        drop_unvoiced_vowels (bool): whether to drop unvoiced vowels.

    Returns:
        List[str]: List of phoneme + prosody symbols.

    Examples:
        >>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
        >>> pyopenjtalk_g2p_prosody("こんにちは。")
        ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']

    .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic
        modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104

    """
    import pyopenjtalk

    labels = pyopenjtalk.run_frontend(text)[1]
    N = len(labels)

    phones = []
    for n in range(N):
        lab_curr = labels[n]
View full source on GitHub →
function

pypinyin_g2p(text)

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Pypinyin g2p.

Args:

  • text — Text tensor or string input.
📄 Source code
def pypinyin_g2p(text) -> List[str]:
    """Pypinyin g2p.
    
        Args:
            text: Text tensor or string input.
        """
    from pypinyin import pinyin
    from pypinyin import Style

    phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)]
    return phones
function

pypinyin_g2p_phone(text)

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Pypinyin g2p phone.

Args:

  • text — Text tensor or string input.
📄 Source code
def pypinyin_g2p_phone(text) -> List[str]:
    """Pypinyin g2p phone.
    
        Args:
            text: Text tensor or string input.
        """
    from pypinyin import pinyin
    from pypinyin import Style
    from pypinyin.style._utils import get_finals
    from pypinyin.style._utils import get_initials

    phones = [
        p
        for phone in pinyin(text, style=Style.TONE3)
        for p in [
            get_initials(phone[0], strict=True),
            get_finals(phone[0], strict=True),
        ]
        if len(p) != 0
    ]
    return phones
class

G2p_en

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

On behalf of g2p_en.G2p.

g2p_en.G2p isn't pickalable and it can't be copied to the other processes

via multiprocessing module.

As a workaround, g2p_en.G2p is instantiated upon calling this class.

📄 Source code
class G2p_en:
    """On behalf of g2p_en.G2p.

    g2p_en.G2p isn't pickalable and it can't be copied to the other processes
    via multiprocessing module.
    As a workaround, g2p_en.G2p is instantiated upon calling this class.

    """

    def __init__(self, no_space: bool = False):
        """Initialize G2p_en.
        
            Args:
                no_space: TODO.
            """
        self.no_space = no_space
        self.g2p = None

    def __call__(self, text) -> List[str]:
        """Internal: call  .
        
            Args:
                text: Text tensor or string input.
            """
        if self.g2p is None:
            self.g2p = g2p_en.G2p()

        phones = self.g2p(text)
        if self.no_space:
            # remove space which represents word serapater
View full source on GitHub →
class

G2pk

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

On behalf of g2pk.G2p.

g2pk.G2p isn't pickalable and it can't be copied to the other processes

via multiprocessing module.

As a workaround, g2pk.G2p is instantiated upon calling this class.

📄 Source code
class G2pk:
    """On behalf of g2pk.G2p.

    g2pk.G2p isn't pickalable and it can't be copied to the other processes
    via multiprocessing module.
    As a workaround, g2pk.G2p is instantiated upon calling this class.

    """

    def __init__(self, descritive=False, group_vowels=False, to_syl=False, no_space=False):
        """Initialize G2pk.
        
            Args:
                descritive: TODO.
                group_vowels: TODO.
                to_syl: TODO.
                no_space: TODO.
            """
        self.descritive = descritive
        self.group_vowels = group_vowels
        self.to_syl = to_syl
        self.no_space = no_space
        self.g2p = None

    def __call__(self, text) -> List[str]:
        """Internal: call  .
        
            Args:
                text: Text tensor or string input.
            """
View full source on GitHub →
class

Jaso

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

No documentation yet.

📄 Source code
class Jaso:
    PUNC = "!'(),-.:;?"
    SPACE = " "

    JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
    JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
    JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])

    VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + PUNC + SPACE

    def __init__(self, space_symbol=" ", no_space=False):
        """Initialize Jaso.
        
            Args:
                space_symbol: TODO.
                no_space: TODO.
            """
        self.space_symbol = space_symbol
        self.no_space = no_space

    def _text_to_jaso(self, line: str) -> List[str]:
        """Internal: text to jaso.
        
            Args:
                line: TODO.
            """
        jasos = list(jamo.hangul_to_jamo(line))
        return jasos

    def _remove_non_korean_characters(self, tokens):
View full source on GitHub →
class

Phonemizer

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

Phonemizer module for various languages.

This is wrapper module of https://github.com/bootphon/phonemizer.

You can define various g2p modules by specifying options for phonemizer.

See available options:

https://github.com/bootphon/phonemizer/blob/master/phonemizer/phonemize.py#L32

📄 Source code
class Phonemizer:
    """Phonemizer module for various languages.

    This is wrapper module of https://github.com/bootphon/phonemizer.
    You can define various g2p modules by specifying options for phonemizer.

    See available options:
        https://github.com/bootphon/phonemizer/blob/master/phonemizer/phonemize.py#L32

    """

    def __init__(
        self,
        backend,
        word_separator: Optional[str] = None,
        syllable_separator: Optional[str] = None,
        phone_separator: Optional[str] = " ",
        strip=False,
        split_by_single_token: bool = False,
        **phonemizer_kwargs,
    ):
        # delayed import
        """Initialize Phonemizer.
        
            Args:
                backend: TODO.
                word_separator: TODO.
                syllable_separator: TODO.
                phone_separator: TODO.
                strip: TODO.
View full source on GitHub →
class

PhonemeTokenizer

funasr.tokenizer.phoneme_tokenizer · View on GitHub ↗

No documentation yet.

📄 Source code
class PhonemeTokenizer(AbsTokenizer):
    def __init__(
        self,
        g2p_type: Union[None, str],
        non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
        space_symbol: str = "<space>",
        remove_non_linguistic_symbols: bool = False,
    ):
        """Initialize PhonemeTokenizer.
        
            Args:
                g2p_type: TODO.
                non_linguistic_symbols: TODO.
                space_symbol: TODO.
                remove_non_linguistic_symbols: TODO.
            """
        if g2p_type is None:
            self.g2p = split_by_space
        elif g2p_type == "g2p_en":
            self.g2p = G2p_en(no_space=False)
        elif g2p_type == "g2p_en_no_space":
            self.g2p = G2p_en(no_space=True)
        elif g2p_type == "pyopenjtalk":
            self.g2p = pyopenjtalk_g2p
        elif g2p_type == "pyopenjtalk_kana":
            self.g2p = pyopenjtalk_g2p_kana
        elif g2p_type == "pyopenjtalk_accent":
            self.g2p = pyopenjtalk_g2p_accent
        elif g2p_type == "pyopenjtalk_accent_with_pause":
            self.g2p = pyopenjtalk_g2p_accent_with_pause
View full source on GitHub →

Methods

.text2tokens(line) L614

Text2tokens.

Args:

  • line — TODO.
📄 Source
    def text2tokens(self, line: str) -> List[str]:
        """Text2tokens.
        
            Args:
                line: TODO.
            """
        tokens = []
        while len(line) != 0:
            for w in self.non_linguistic_symbols:
                if line.startswith(w):
                    if not self.remove_non_linguistic_symbols:
                        tokens.append(line[: len(w)])
                    line = line[len(w) :]
                    break
            else:
                t = line[0]
                tokens.append(t)
                line = line[1:]

        line = "".join(tokens)
        tokens = self.g2p(line)
        return tokens
.tokens2text(tokens) L637

Tokens2text.

Args:

  • tokens — TODO.
📄 Source
    def tokens2text(self, tokens: Iterable[str]) -> str:
        # phoneme type is not invertible
        """Tokens2text.
        
            Args:
                tokens: TODO.
            """
        return "".join(tokens)
method

PhonemeTokenizer.text2tokens(line)

funasr.tokenizer.phoneme_tokenizer.PhonemeTokenizer · View on GitHub ↗

Text2tokens.

Args:

  • line — TODO.
📄 Source code
    def text2tokens(self, line: str) -> List[str]:
        """Text2tokens.
        
            Args:
                line: TODO.
            """
        tokens = []
        while len(line) != 0:
            for w in self.non_linguistic_symbols:
                if line.startswith(w):
                    if not self.remove_non_linguistic_symbols:
                        tokens.append(line[: len(w)])
                    line = line[len(w) :]
                    break
            else:
                t = line[0]
                tokens.append(t)
                line = line[1:]

        line = "".join(tokens)
        tokens = self.g2p(line)
        return tokens
method

PhonemeTokenizer.tokens2text(tokens)

funasr.tokenizer.phoneme_tokenizer.PhonemeTokenizer · View on GitHub ↗

Tokens2text.

Args:

  • tokens — TODO.
📄 Source code
    def tokens2text(self, tokens: Iterable[str]) -> str:
        # phoneme type is not invertible
        """Tokens2text.
        
            Args:
                tokens: TODO.
            """
        return "".join(tokens)
class

SentencepiecesTokenizer

funasr.tokenizer.sentencepiece_tokenizer · View on GitHub ↗

No documentation yet.

📄 Source code
class SentencepiecesTokenizer(BaseTokenizer):
    def __init__(self, bpemodel: Union[Path, str], **kwargs):
        """Initialize SentencepiecesTokenizer.
        
            Args:
                bpemodel: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(**kwargs)
        self.bpemodel = str(bpemodel)
        # NOTE(kamo):
        # Don't build SentencePieceProcessor in __init__()
        # because it's not picklable and it may cause following error,
        # "TypeError: can't pickle SwigPyObject objects",
        # when giving it as argument of "multiprocessing.Process()".
        self.sp = None
        self._build_sentence_piece_processor()

    def __repr__(self):
        """Internal: repr  ."""
        return f'{self.__class__.__name__}(model="{self.bpemodel}")'

    def _build_sentence_piece_processor(self):
        # Build SentencePieceProcessor lazily.
        """Internal: build sentence piece processor."""
        if self.sp is None:
            self.sp = spm.SentencePieceProcessor()
            self.sp.load(self.bpemodel)

    def text2tokens(self, line: str) -> List[str]:
View full source on GitHub →

Methods

.text2tokens(line) L42

Text2tokens.

Args:

  • line — TODO.
📄 Source
    def text2tokens(self, line: str) -> List[str]:
        """Text2tokens.
        
            Args:
                line: TODO.
            """
        self._build_sentence_piece_processor()
        return self.sp.EncodeAsPieces(line)
.tokens2text(tokens) L51

Tokens2text.

Args:

  • tokens — TODO.
📄 Source
    def tokens2text(self, tokens: Iterable[str]) -> str:
        """Tokens2text.
        
            Args:
                tokens: TODO.
            """
        self._build_sentence_piece_processor()
        return self.sp.DecodePieces(list(tokens))
.encode(line, **kwargs) L60

Encode.

Args:

  • line — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def encode(self, line: str, **kwargs) -> List[int]:
        """Encode.
        
            Args:
                line: TODO.
                **kwargs: Additional keyword arguments.
            """
        self._build_sentence_piece_processor()
        return self.sp.EncodeAsIds(line)
.decode(line, **kwargs) L70

Decode.

Args:

  • line — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def decode(self, line: List[int], **kwargs):
        """Decode.
        
            Args:
                line: TODO.
                **kwargs: Additional keyword arguments.
            """
        self._build_sentence_piece_processor()
        return self.sp.DecodeIds(line)
.get_vocab_size() L80

Get vocab size.

📄 Source
    def get_vocab_size(self):
        """Get vocab size."""
        return self.sp.GetPieceSize()
.ids2tokens(*args, **kwargs) L84

Ids2tokens.

Args:

  • *args — Variable positional arguments.
  • **kwargs — Additional keyword arguments.
📄 Source
    def ids2tokens(self, *args, **kwargs):
        """Ids2tokens.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        return self.decode(*args, **kwargs)
.tokens2ids(*args, **kwargs) L93

Tokens2ids.

Args:

  • *args — Variable positional arguments.
  • **kwargs — Additional keyword arguments.
📄 Source
    def tokens2ids(self, *args, **kwargs):
        """Tokens2ids.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        return self.encode(*args, **kwargs)
method

SentencepiecesTokenizer.text2tokens(line)

funasr.tokenizer.sentencepiece_tokenizer.SentencepiecesTokenizer · View on GitHub ↗

Text2tokens.

Args:

  • line — TODO.
📄 Source code
    def text2tokens(self, line: str) -> List[str]:
        """Text2tokens.
        
            Args:
                line: TODO.
            """
        self._build_sentence_piece_processor()
        return self.sp.EncodeAsPieces(line)
method

SentencepiecesTokenizer.tokens2text(tokens)

funasr.tokenizer.sentencepiece_tokenizer.SentencepiecesTokenizer · View on GitHub ↗

Tokens2text.

Args:

  • tokens — TODO.
📄 Source code
    def tokens2text(self, tokens: Iterable[str]) -> str:
        """Tokens2text.
        
            Args:
                tokens: TODO.
            """
        self._build_sentence_piece_processor()
        return self.sp.DecodePieces(list(tokens))
method

SentencepiecesTokenizer.encode(line, **kwargs)

funasr.tokenizer.sentencepiece_tokenizer.SentencepiecesTokenizer · View on GitHub ↗

Encode.

Args:

  • line — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def encode(self, line: str, **kwargs) -> List[int]:
        """Encode.
        
            Args:
                line: TODO.
                **kwargs: Additional keyword arguments.
            """
        self._build_sentence_piece_processor()
        return self.sp.EncodeAsIds(line)
method

SentencepiecesTokenizer.decode(line, **kwargs)

funasr.tokenizer.sentencepiece_tokenizer.SentencepiecesTokenizer · View on GitHub ↗

Decode.

Args:

  • line — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def decode(self, line: List[int], **kwargs):
        """Decode.
        
            Args:
                line: TODO.
                **kwargs: Additional keyword arguments.
            """
        self._build_sentence_piece_processor()
        return self.sp.DecodeIds(line)
method

SentencepiecesTokenizer.get_vocab_size()

funasr.tokenizer.sentencepiece_tokenizer.SentencepiecesTokenizer · View on GitHub ↗

Get vocab size.

📄 Source code
    def get_vocab_size(self):
        """Get vocab size."""
        return self.sp.GetPieceSize()
method

SentencepiecesTokenizer.ids2tokens(*args, **kwargs)

funasr.tokenizer.sentencepiece_tokenizer.SentencepiecesTokenizer · View on GitHub ↗

Ids2tokens.

Args:

  • *args — Variable positional arguments.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def ids2tokens(self, *args, **kwargs):
        """Ids2tokens.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        return self.decode(*args, **kwargs)
method

SentencepiecesTokenizer.tokens2ids(*args, **kwargs)

funasr.tokenizer.sentencepiece_tokenizer.SentencepiecesTokenizer · View on GitHub ↗

Tokens2ids.

Args:

  • *args — Variable positional arguments.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def tokens2ids(self, *args, **kwargs):
        """Tokens2ids.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        return self.encode(*args, **kwargs)
class

TokenIDConverter

funasr.tokenizer.token_id_converter · View on GitHub ↗

No documentation yet.

📄 Source code
class TokenIDConverter:
    def __init__(
        self,
        token_list: Union[Path, str, Iterable[str]],
        unk_symbol: str = "<unk>",
    ):

        """Initialize TokenIDConverter.
        
            Args:
                token_list: TODO.
                unk_symbol: TODO.
            """
        if isinstance(token_list, (Path, str)):
            token_list = Path(token_list)
            self.token_list_repr = str(token_list)
            self.token_list: List[str] = []

            with token_list.open("r", encoding="utf-8") as f:
                for idx, line in enumerate(f):
                    line = line.rstrip()
                    self.token_list.append(line)

        else:
            self.token_list: List[str] = list(token_list)
            self.token_list_repr = ""
            for i, t in enumerate(self.token_list):
                if i == 3:
                    break
                self.token_list_repr += f"{t}, "
View full source on GitHub →

Methods

.get_num_vocabulary_size() L53

Get num vocabulary size.

📄 Source
    def get_num_vocabulary_size(self) -> int:
        """Get num vocabulary size."""
        return len(self.token_list)
.ids2tokens(integers) L57

Ids2tokens.

Args:

  • integers — TODO.
📄 Source
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        """Ids2tokens.
        
            Args:
                integers: TODO.
            """
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
.tokens2ids(tokens) L67

Tokens2ids.

Args:

  • tokens — TODO.
📄 Source
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        """Tokens2ids.
        
            Args:
                tokens: TODO.
            """
        return [self.token2id.get(i, self.unk_id) for i in tokens]
method

TokenIDConverter.get_num_vocabulary_size()

funasr.tokenizer.token_id_converter.TokenIDConverter · View on GitHub ↗

Get num vocabulary size.

📄 Source code
    def get_num_vocabulary_size(self) -> int:
        """Get num vocabulary size."""
        return len(self.token_list)
method

TokenIDConverter.ids2tokens(integers)

funasr.tokenizer.token_id_converter.TokenIDConverter · View on GitHub ↗

Ids2tokens.

Args:

  • integers — TODO.
📄 Source code
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        """Ids2tokens.
        
            Args:
                integers: TODO.
            """
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
method

TokenIDConverter.tokens2ids(tokens)

funasr.tokenizer.token_id_converter.TokenIDConverter · View on GitHub ↗

Tokens2ids.

Args:

  • tokens — TODO.
📄 Source code
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        """Tokens2ids.
        
            Args:
                tokens: TODO.
            """
        return [self.token2id.get(i, self.unk_id) for i in tokens]
function

WhisperTokenizer(**kwargs)

funasr.tokenizer.whisper_tokenizer · View on GitHub ↗

Whispertokenizer.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def WhisperTokenizer(**kwargs):
    """Whispertokenizer.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    try:
        from whisper.tokenizer import get_tokenizer
    except:
        print("Notice: If you want to use whisper, please `pip install -U openai-whisper`")

    language = kwargs.get("language", None)
    task = kwargs.get("task", "transcribe")
    is_multilingual = kwargs.get("is_multilingual", True)
    num_languages = kwargs.get("num_languages", 99)
    tokenizer = get_tokenizer(
        multilingual=is_multilingual,
        num_languages=num_languages,
        language=language,
        task=task,
    )

    return tokenizer
function

SenseVoiceTokenizer(**kwargs)

funasr.tokenizer.whisper_tokenizer · View on GitHub ↗

Sensevoicetokenizer.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def SenseVoiceTokenizer(**kwargs):
    """Sensevoicetokenizer.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    from funasr.models.sense_voice.whisper_lib.tokenizer import get_tokenizer

    language = kwargs.get("language", None)
    task = kwargs.get("task", None)
    is_multilingual = kwargs.get("is_multilingual", True)
    num_languages = kwargs.get("num_languages", 8749)
    vocab_path = kwargs.get("vocab_path", None)
    tokenizer = get_tokenizer(
        multilingual=is_multilingual,
        num_languages=num_languages,
        language=language,
        task=task,
        vocab_path=vocab_path,
    )

    return tokenizer
class

WordTokenizer

funasr.tokenizer.word_tokenizer · View on GitHub ↗

No documentation yet.

📄 Source code
class WordTokenizer(AbsTokenizer):
    def __init__(
        self,
        delimiter: str = None,
        non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
        remove_non_linguistic_symbols: bool = False,
    ):
        """Initialize WordTokenizer.
        
            Args:
                delimiter: TODO.
                non_linguistic_symbols: TODO.
                remove_non_linguistic_symbols: TODO.
            """
        self.delimiter = delimiter

        if not remove_non_linguistic_symbols and non_linguistic_symbols is not None:
            warnings.warn(
                "non_linguistic_symbols is only used " "when remove_non_linguistic_symbols = True"
            )

        if non_linguistic_symbols is None:
            self.non_linguistic_symbols = set()
        elif isinstance(non_linguistic_symbols, (Path, str)):
            non_linguistic_symbols = Path(non_linguistic_symbols)
            try:
                with non_linguistic_symbols.open("r", encoding="utf-8") as f:
                    self.non_linguistic_symbols = set(line.rstrip() for line in f)
            except FileNotFoundError:
                warnings.warn(f"{non_linguistic_symbols} doesn't exist.")
View full source on GitHub →

Methods

.text2tokens(line) L50

Text2tokens.

Args:

  • line — TODO.
📄 Source
    def text2tokens(self, line: str) -> List[str]:
        """Text2tokens.
        
            Args:
                line: TODO.
            """
        tokens = []
        for t in line.split(self.delimiter):
            if self.remove_non_linguistic_symbols and t in self.non_linguistic_symbols:
                continue
            tokens.append(t)
        return tokens
.tokens2text(tokens) L63

Tokens2text.

Args:

  • tokens — TODO.
📄 Source
    def tokens2text(self, tokens: Iterable[str]) -> str:
        """Tokens2text.
        
            Args:
                tokens: TODO.
            """
        if self.delimiter is None:
            delimiter = " "
        else:
            delimiter = self.delimiter
        return delimiter.join(tokens)
method

WordTokenizer.text2tokens(line)

funasr.tokenizer.word_tokenizer.WordTokenizer · View on GitHub ↗

Text2tokens.

Args:

  • line — TODO.
📄 Source code
    def text2tokens(self, line: str) -> List[str]:
        """Text2tokens.
        
            Args:
                line: TODO.
            """
        tokens = []
        for t in line.split(self.delimiter):
            if self.remove_non_linguistic_symbols and t in self.non_linguistic_symbols:
                continue
            tokens.append(t)
        return tokens
method

WordTokenizer.tokens2text(tokens)

funasr.tokenizer.word_tokenizer.WordTokenizer · View on GitHub ↗

Tokens2text.

Args:

  • tokens — TODO.
📄 Source code
    def tokens2text(self, tokens: Iterable[str]) -> str:
        """Tokens2text.
        
            Args:
                tokens: TODO.
            """
        if self.delimiter is None:
            delimiter = " "
        else:
            delimiter = self.delimiter
        return delimiter.join(tokens)
class

thread_wrapper

funasr.utils.compute_det_ctc · View on GitHub ↗

No documentation yet.

📄 Source code
class thread_wrapper(threading.Thread):
    def __init__(self, func, args=()):
        """Initialize thread_wrapper.
        
            Args:
                func: TODO.
                args: TODO.
            """
        super(thread_wrapper, self).__init__()
        self.func = func
        self.args = args
        self.result = []

    def run(self):
        """Run."""
        self.result = self.func(*self.args)

    def get_result(self):
        """Get result."""
        try:
            return self.result
        except Exception:
            return None

Methods

.run() L27

Run.

📄 Source
    def run(self):
        """Run."""
        self.result = self.func(*self.args)
.get_result() L31

Get result.

📄 Source
    def get_result(self):
        """Get result."""
        try:
            return self.result
        except Exception:
            return None
method

thread_wrapper.run()

funasr.utils.compute_det_ctc.thread_wrapper · View on GitHub ↗

Run.

📄 Source code
    def run(self):
        """Run."""
        self.result = self.func(*self.args)
method

thread_wrapper.get_result()

funasr.utils.compute_det_ctc.thread_wrapper · View on GitHub ↗

Get result.

📄 Source code
    def get_result(self):
        """Get result."""
        try:
            return self.result
        except Exception:
            return None
function

space_mixed_label(input_str)

funasr.utils.compute_det_ctc · View on GitHub ↗

Space mixed label.

Args:

  • input_str — TODO.
📄 Source code
def space_mixed_label(input_str):
    """Space mixed label.
    
        Args:
            input_str: TODO.
        """
    splits = split_mixed_label(input_str)
    space_str = ''.join(f'{sub} ' for sub in splits)
    return space_str.strip()
function

read_lists(list_file)

funasr.utils.compute_det_ctc · View on GitHub ↗

Read lists.

Args:

  • list_file — TODO.
📄 Source code
def read_lists(list_file):
    """Read lists.
    
        Args:
            list_file: TODO.
        """
    lists = []
    with open(list_file, 'r', encoding='utf8') as fin:
        for line in fin:
            if line.strip() != '':
                lists.append(line.strip())
    return lists
function

make_pair(wav_lists, trans_lists)

funasr.utils.compute_det_ctc · View on GitHub ↗

Make pair.

Args:

  • wav_lists — TODO.
  • trans_lists — TODO.
📄 Source code
def make_pair(wav_lists, trans_lists):
    """Make pair.
    
        Args:
            wav_lists: TODO.
            trans_lists: TODO.
        """
    logging.info('make pair for wav-trans list')

    trans_table = {}
    for line in trans_lists:
        arr = line.strip().replace('\t', ' ').split()
        if len(arr) < 2:
            logging.debug('invalid line in trans file: {}'.format(
                line.strip()))
            continue

        trans_table[arr[0]] = line.replace(arr[0],'').strip()

    lists = []
    for line in wav_lists:
        arr = line.strip().replace('\t', ' ').split()
        if len(arr) == 2 and arr[0] in trans_table:
            lists.append(
                dict(key=arr[0],
                     txt=trans_table[arr[0]],
                     wav=arr[1],
                     sample_rate=16000))
        else:
            logging.debug("can't find corresponding trans for key: {}".format(
View full source on GitHub →
function

count_duration(tid, data_lists)

funasr.utils.compute_det_ctc · View on GitHub ↗

Count duration.

Args:

  • tid — TODO.
  • data_lists — TODO.
📄 Source code
def count_duration(tid, data_lists):
    """Count duration.
    
        Args:
            tid: TODO.
            data_lists: TODO.
        """
    results = []

    for obj in data_lists:
        assert 'key' in obj
        assert 'wav' in obj
        assert 'txt' in obj
        key = obj['key']
        wav_file = obj['wav']
        txt = obj['txt']

        try:
            rate, waveform = kaldiio.load_mat(wav_file)
            waveform = torch.tensor(waveform, dtype=torch.float32)
            waveform = waveform.unsqueeze(0)
            frames = len(waveform[0])
            duration = frames / float(rate)
        except:
            logging.info(f'load file failed: {wav_file}')
            duration = 0.0

        obj['duration'] = duration
        results.append(obj)

View full source on GitHub →
function

load_data_and_score(keywords_list, data_file, trans_file, score_file)

funasr.utils.compute_det_ctc · View on GitHub ↗

Load data and score.

Args:

  • keywords_list — TODO.
  • data_file — TODO.
  • trans_file — TODO.
  • score_file — TODO.
📄 Source code
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
    # score_table: {uttid: [keywordlist]}
    """Load data and score.
    
        Args:
            keywords_list: TODO.
            data_file: TODO.
            trans_file: TODO.
            score_file: TODO.
        """
    score_table = {}
    with open(score_file, 'r', encoding='utf8') as fin:
        # read score file and store in table
        for line in fin:
            arr = line.strip().split()
            key = arr[0]
            is_detected = arr[1]
            if is_detected == 'detected':
                if key not in score_table:
                    score_table.update(
                        {key: {
                            'kw': space_mixed_label(arr[2]),
                            'confi': float(arr[3])
                        }})
            else:
                if key not in score_table:
                    score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})

    wav_lists = read_lists(data_file)
    trans_lists = read_lists(trans_file)
View full source on GitHub →
class

DatadirWriter

funasr.utils.datadir_writer · View on GitHub ↗

Writer class to create kaldi like data directory.

Examples:

>>> with DatadirWriter("output") as writer:

... # output/sub.txt is created here

... subwriter = writer["sub.txt"]

... # Write "uttidA some/where/a.wav"

... subwriter["uttidA"] = "some/where/a.wav"

... subwriter["uttidB"] = "some/where/b.wav"

📄 Source code
class DatadirWriter:
    """Writer class to create kaldi like data directory.

    Examples:
        >>> with DatadirWriter("output") as writer:
        ...     # output/sub.txt is created here
        ...     subwriter = writer["sub.txt"]
        ...     # Write "uttidA some/where/a.wav"
        ...     subwriter["uttidA"] = "some/where/a.wav"
        ...     subwriter["uttidB"] = "some/where/b.wav"

    """

    def __init__(self, p: Union[Path, str]):
        """Initialize DatadirWriter.
        
            Args:
                p: TODO.
            """
        self.path = Path(p)
        self.chilidren = {}
        self.fd = None
        self.has_children = False
        self.keys = set()

    def __enter__(self):
        """Internal: enter  ."""
        return self

    def __getitem__(self, key: str) -> "DatadirWriter":
View full source on GitHub →

Methods

.close() L82

Close.

📄 Source
    def close(self):
        """Close."""
        if self.has_children:
            prev_child = None
            for child in self.chilidren.values():
                child.close()
                if prev_child is not None and prev_child.keys != child.keys:
                    warnings.warn(
                        f"Ids are mismatching between " f"{prev_child.path} and {child.path}"
                    )
                prev_child = child

        elif self.fd is not None:
            self.fd.close()
method

DatadirWriter.close()

funasr.utils.datadir_writer.DatadirWriter · View on GitHub ↗

Close.

📄 Source code
    def close(self):
        """Close."""
        if self.has_children:
            prev_child = None
            for child in self.chilidren.values():
                child.close()
                if prev_child is not None and prev_child.keys != child.keys:
                    warnings.warn(
                        f"Ids are mismatching between " f"{prev_child.path} and {child.path}"
                    )
                prev_child = child

        elif self.fd is not None:
            self.fd.close()
function

load_module_from_path(file_path)

funasr.utils.dynamic_import · View on GitHub ↗

从给定的文件路径动态加载模块。

:param file_path: 模块文件的绝对路径。

:return: 加载的模块

📄 Source code
def load_module_from_path(file_path):
    """
    从给定的文件路径动态加载模块。

    :param file_path: 模块文件的绝对路径。
    :return: 加载的模块
    """
    module_name = file_path.split("/")[-1].replace(".py", "")
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module
function

import_module_from_path(file_path)

funasr.utils.dynamic_import · View on GitHub ↗

Import module from path.

Args:

  • file_path — TODO.
📄 Source code
def import_module_from_path(file_path: str):

    """Import module from path.
    
        Args:
            file_path: TODO.
        """
    if file_path.startswith("http"):
        from funasr.download.file import download_from_url

        file_path = download_from_url(file_path)

    file_dir = os.path.dirname(file_path)
    # file_name = os.path.basename(file_path)
    module_name = file_path.split("/")[-1].replace(".py", "")
    if len(file_dir) < 1:
        file_dir = "./"
    sys.path.append(file_dir)
    try:
        importlib.import_module(module_name)
        print(f"Loading remote code successfully: {file_path}")
    except Exception as e:
        print(f"Loading remote code failed: {file_path}, {e}")
function

export(model, data_in, quantize, opset_version, type, **kwargs)

funasr.utils.export_utils · View on GitHub ↗

Export.

Args:

  • model — Model instance or model name.
  • data_in — Input data (audio samples, file paths, or text).
  • quantize — TODO.
  • opset_version — TODO.
  • type — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
def export(
    model, data_in=None, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
):
    """Export.
    
        Args:
            model: Model instance or model name.
            data_in: Input data (audio samples, file paths, or text).
            quantize: TODO.
            opset_version: TODO.
            type: TODO.
            **kwargs: Additional keyword arguments.
        """
    model_scripts = model.export(**kwargs)
    export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
    os.makedirs(export_dir, exist_ok=True)

    if not isinstance(model_scripts, (list, tuple)):
        model_scripts = (model_scripts,)
    for m in model_scripts:
        m.eval()
        if type == "onnx":
            _onnx(
                m,
                data_in=data_in,
                quantize=quantize,
                opset_version=opset_version,
                export_dir=export_dir,
                **kwargs,
            )
View full source on GitHub →
function

install_requirements(requirements_path)

funasr.utils.install_model_requirements · View on GitHub ↗

Install requirements.

Args:

  • requirements_path — TODO.
📄 Source code
def install_requirements(requirements_path):
    """Install requirements.
    
        Args:
            requirements_path: TODO.
        """
    try:
        result = pip_install_r(requirements_path)
        # check status
        if result.returncode == 0:
            print("install model requirements successfully")
            return True
        else:
            print("fail to install model requirements! ")
            print("error", result.stderr)
            return False
    except Exception as e:
        result = pip_install_r(requirements_path)
        # check status
        if result.returncode == 0:
            print("install model requirements successfully")
            return True
        else:
            print("fail to install model requirements! ")
            print("error", result.stderr)
            return False
function

pip_install_r(requirements_path)

funasr.utils.install_model_requirements · View on GitHub ↗

Pip install r.

Args:

  • requirements_path — TODO.
📄 Source code
def pip_install_r(requirements_path):
    """Pip install r.
    
        Args:
            requirements_path: TODO.
        """
    cmd = []
    if shutil.which("pip") is not None:
        cmd = ["pip"]
    elif shutil.which("uv") is not None:
        cmd = ["uv", "pip"]
    else:
        raise RuntimeError("pip not found, failed to install model requirements")
    cmd += ["install", "-r", requirements_path]
    return subprocess.run(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
    )
function

split_mixed_label(input_str)

funasr.utils.kws_utils · View on GitHub ↗

Split mixed label.

Args:

  • input_str — TODO.
📄 Source code
def split_mixed_label(input_str):
    """Split mixed label.
    
        Args:
            input_str: TODO.
        """
    tokens = []
    s = input_str.lower()
    while len(s) > 0:
        match = re.match(r'[A-Za-z!?,<>()\']+', s)
        if match is not None:
            word = match.group(0)
        else:
            word = s[0:1]
        tokens.append(word)
        s = s.replace(word, '', 1).strip(' ')
    return tokens
function

query_token_set(txt, symbol_table, lexicon_table)

funasr.utils.kws_utils · View on GitHub ↗

Query token set.

Args:

  • txt — TODO.
  • symbol_table — TODO.
  • lexicon_table — TODO.
📄 Source code
def query_token_set(txt, symbol_table, lexicon_table):
    """Query token set.
    
        Args:
            txt: TODO.
            symbol_table: TODO.
            lexicon_table: TODO.
        """
    tokens_str = tuple()
    tokens_idx = tuple()

    if txt in symbol_table:
        tokens_str = tokens_str + (txt, )
        tokens_idx = tokens_idx + (symbol_table[txt], )
        return tokens_str, tokens_idx

    parts = split_mixed_label(txt)
    for part in parts:
        if part == '!sil' or part == '(sil)' or part == '<sil>':
            tokens_str = tokens_str + ('!sil', )
        elif part == '<blank>' or part == '<blank>':
            tokens_str = tokens_str + ('<blank>', )
        elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
            tokens_str = tokens_str + ('<unk>', )
        elif part in symbol_table:
            tokens_str = tokens_str + (part, )
        elif part in lexicon_table:
            for ch in lexicon_table[part]:
                tokens_str = tokens_str + (ch, )
        else:
View full source on GitHub →
class

KwsCtcPrefixDecoder

funasr.utils.kws_utils · View on GitHub ↗

Decoder interface wrapper for CTCPrefixDecode.

📄 Source code
class KwsCtcPrefixDecoder():
    """Decoder interface wrapper for CTCPrefixDecode."""

    def __init__(
        self,
        ctc: torch.nn.Module,
        keywords: str,
        token_list: list,
        seg_dict: dict,
    ):
        """Initialize class.

        Args:
            ctc (torch.nn.Module): The CTC implementation.
                For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`

        """
        self.ctc = ctc
        self.token_list = token_list

        token_table = {}
        for token in token_list:
            token_table[token] = token_list.index(token)

        self.keywords_idxset = {0}
        self.keywords_token = {}
        self.keywords_str = keywords
        keywords_list = self.keywords_str.strip().replace(' ', '').split(',')
        for keyword in keywords_list:
            strs, indexs = query_token_set(keyword, token_table, seg_dict)
View full source on GitHub →

Methods

.beam_search(logits, logits_lengths, keywords_tokenset, score_beam_size, path_beam_size) L125

CTC prefix beam search inner implementation

Args:

  • logits (torch.Tensor) — (1, max_len, vocab_size)
  • logits_lengths (torch.Tensor) — (1, )
  • keywords_tokenset (set) — token set for filtering score
  • score_beam_size (int) — beam size for score
  • path_beam_size (int) — beam size for path

Returns:

List[List[int]]: nbest results

📄 Source
    def beam_search(
        self,
        logits: torch.Tensor,
        logits_lengths: torch.Tensor,
        keywords_tokenset: set = None,
        score_beam_size: int = 3,
        path_beam_size: int = 20,
    ) -> Tuple[List[List[int]], torch.Tensor]:
        """ CTC prefix beam search inner implementation

        Args:
            logits (torch.Tensor): (1, max_len, vocab_size)
            logits_lengths (torch.Tensor): (1, )
            keywords_tokenset (set): token set for filtering score
            score_beam_size (int): beam size for score
            path_beam_size (int): beam size for path

        Returns:
            List[List[int]]: nbest results
        """

        maxlen = logits.size(0)
        ctc_probs = logits
        cur_hyps = [(tuple(), (1.0, 0.0, []))]

        # CTC beam search step by step
        for t in range(0, maxlen):
            probs = ctc_probs[t]  # (vocab_size,)
            # key: prefix, value (pb, pnb), default value(-inf, -inf)
            next_hyps = defaultdict(lambda: (0.0, 0.0, []))
View full source →
.is_sublist(main_list, check_list) L232

Is sublist.

Args:

  • main_list — TODO.
  • check_list — TODO.
📄 Source
    def is_sublist(self, main_list, check_list):
        """Is sublist.
        
            Args:
                main_list: TODO.
                check_list: TODO.
            """
        if len(main_list) < len(check_list):
            return -1

        if len(main_list) == len(check_list):
            return 0 if main_list == check_list else -1

        for i in range(len(main_list) - len(check_list) + 1):
            if main_list[i] == check_list[0]:
                for j in range(len(check_list)):
                    if main_list[i + j] != check_list[j]:
                        break
                else:
                    return i
        else:
            return -1
.decode(x) L295

Get an initial state for decoding.

Args:

  • x (torch.Tensor) — The encoded feature tensor
  • Returns — decode result
📄 Source
    def decode(self, x: torch.Tensor):
        """Get an initial state for decoding.

        Args:
            x (torch.Tensor): The encoded feature tensor

        Returns: decode result

        """

        raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
        xlen = torch.tensor([raw_logp.size(1)])

        return self._decode_inside(raw_logp, xlen)
method

KwsCtcPrefixDecoder.beam_search(logits, logits_lengths, keywords_tokenset, score_beam_size, path_beam_size)

funasr.utils.kws_utils.KwsCtcPrefixDecoder · View on GitHub ↗

CTC prefix beam search inner implementation

Args:

  • logits (torch.Tensor) — (1, max_len, vocab_size)
  • logits_lengths (torch.Tensor) — (1, )
  • keywords_tokenset (set) — token set for filtering score
  • score_beam_size (int) — beam size for score
  • path_beam_size (int) — beam size for path

Returns:

List[List[int]]: nbest results

📄 Source code
    def beam_search(
        self,
        logits: torch.Tensor,
        logits_lengths: torch.Tensor,
        keywords_tokenset: set = None,
        score_beam_size: int = 3,
        path_beam_size: int = 20,
    ) -> Tuple[List[List[int]], torch.Tensor]:
        """ CTC prefix beam search inner implementation

        Args:
            logits (torch.Tensor): (1, max_len, vocab_size)
            logits_lengths (torch.Tensor): (1, )
            keywords_tokenset (set): token set for filtering score
            score_beam_size (int): beam size for score
            path_beam_size (int): beam size for path

        Returns:
            List[List[int]]: nbest results
        """

        maxlen = logits.size(0)
        ctc_probs = logits
        cur_hyps = [(tuple(), (1.0, 0.0, []))]

        # CTC beam search step by step
        for t in range(0, maxlen):
            probs = ctc_probs[t]  # (vocab_size,)
            # key: prefix, value (pb, pnb), default value(-inf, -inf)
            next_hyps = defaultdict(lambda: (0.0, 0.0, []))
View full source on GitHub →
method

KwsCtcPrefixDecoder.is_sublist(main_list, check_list)

funasr.utils.kws_utils.KwsCtcPrefixDecoder · View on GitHub ↗

Is sublist.

Args:

  • main_list — TODO.
  • check_list — TODO.
📄 Source code
    def is_sublist(self, main_list, check_list):
        """Is sublist.
        
            Args:
                main_list: TODO.
                check_list: TODO.
            """
        if len(main_list) < len(check_list):
            return -1

        if len(main_list) == len(check_list):
            return 0 if main_list == check_list else -1

        for i in range(len(main_list) - len(check_list) + 1):
            if main_list[i] == check_list[0]:
                for j in range(len(check_list)):
                    if main_list[i + j] != check_list[j]:
                        break
                else:
                    return i
        else:
            return -1
method

KwsCtcPrefixDecoder.decode(x)

funasr.utils.kws_utils.KwsCtcPrefixDecoder · View on GitHub ↗

Get an initial state for decoding.

Args:

  • x (torch.Tensor) — The encoded feature tensor
  • Returns — decode result
📄 Source code
    def decode(self, x: torch.Tensor):
        """Get an initial state for decoding.

        Args:
            x (torch.Tensor): The encoded feature tensor

        Returns: decode result

        """

        raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
        xlen = torch.tensor([raw_logp.size(1)])

        return self._decode_inside(raw_logp, xlen)
function

is_ffmpeg_installed()

funasr.utils.load_utils · View on GitHub ↗

Is ffmpeg installed.

📄 Source code
def is_ffmpeg_installed():
    """Is ffmpeg installed."""
    try:
        output = subprocess.check_output(["ffmpeg", "-version"], stderr=subprocess.STDOUT)
        return "ffmpeg version" in output.decode("utf-8")
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False
function

load_audio_text_image_video(data_or_path_or_list, fs, audio_fs, data_type, tokenizer, **kwargs)

funasr.utils.load_utils · View on GitHub ↗

Load audio/text/image/video data from various input formats.

Args:

  • data_or_path_or_list — File path, URL, numpy array, torch Tensor, bytes, or list.
  • fs (int) — Target sample rate (default 16000).
  • audio_fs (int) — Source audio sample rate.
  • data_type (str) — Input type ("sound", "text", "fbank").

Returns:

torch.Tensor or list: Loaded and resampled audio tensor(s).

📄 Source code
def load_audio_text_image_video(
    data_or_path_or_list,
    fs: int = 16000,
    audio_fs: int = 16000,
    data_type="sound",
    tokenizer=None,
    **kwargs,
):
    """Load audio/text/image/video data from various input formats.

    Args:
        data_or_path_or_list: File path, URL, numpy array, torch Tensor, bytes, or list.
        fs (int): Target sample rate (default 16000).
        audio_fs (int): Source audio sample rate.
        data_type (str): Input type ("sound", "text", "fbank").

    Returns:
        torch.Tensor or list: Loaded and resampled audio tensor(s).
    """
    if isinstance(data_or_path_or_list, (list, tuple)):
        if data_type is not None and isinstance(data_type, (list, tuple)):
            data_types = [data_type] * len(data_or_path_or_list)
            data_or_path_or_list_ret = [[] for d in data_type]
            for i, (data_type_i, data_or_path_or_list_i) in enumerate(
                zip(data_types, data_or_path_or_list)
            ):
                for j, (data_type_j, data_or_path_or_list_j) in enumerate(
                    zip(data_type_i, data_or_path_or_list_i)
                ):
                    data_or_path_or_list_j = load_audio_text_image_video(
View full source on GitHub →
function

load_bytes(input)

funasr.utils.load_utils · View on GitHub ↗

Convert audio bytes to numpy array.

Args:

  • input (bytes) — Raw audio bytes.

Returns:

  • numpy.ndarray — Decoded audio samples.
📄 Source code
def load_bytes(input):
    """Convert audio bytes to numpy array.

    Args:
        input (bytes): Raw audio bytes.

    Returns:
        numpy.ndarray: Decoded audio samples.
    """
    # Only run the (expensive) frame-rate validation when the payload is an
    # actual audio container (WAV, MP3, OGG, …).  Raw PCM buffers have no
    # recognisable header and would cause pydub to spend ~200 ms before
    # raising an exception that is then silently swallowed anyway.
    if _is_audio_container(input):
        try:
            input = validate_frame_rate(input)
        except Exception:
            pass
    middle_data = np.frombuffer(input, dtype=np.int16)
    middle_data = np.asarray(middle_data)
    if middle_data.dtype.kind not in "iu":
        raise TypeError("'middle_data' must be an array of integers")
    dtype = np.dtype("float32")
    if dtype.kind != "f":
        raise TypeError("'dtype' must be a floating point type")

    i = np.iinfo(middle_data.dtype)
    abs_max = 2 ** (i.bits - 1)
    offset = i.min + abs_max
    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
View full source on GitHub →
function

validate_frame_rate(input, fs)

funasr.utils.load_utils · View on GitHub ↗

Validate frame rate.

Args:

  • input — Input audio/text data.
  • fs — TODO.
📄 Source code
def validate_frame_rate(
    input,
    fs: int = 16000,
):

    # 将文件读取为字节流
    """Validate frame rate.
    
        Args:
            input: Input audio/text data.
            fs: TODO.
        """
    byte_data = BytesIO(input)

    # 使用 pydub 加载音频
    try:
        audio = AudioSegment.from_file(byte_data)
    except:
        raise RuntimeError(
            "You are decoding the pcm data, please install pydub first. via `pip install pydub`."
        )

    # 确保采样率为 16000 Hz
    if audio.frame_rate != fs:
        audio = audio.set_frame_rate(fs)

        # 将重新采样后的音频导出为字节流
        output = BytesIO()
        audio.export(output, format="wav")
        output.seek(0)
View full source on GitHub →
function

extract_fbank(data, data_len, data_type, frontend, **kwargs)

funasr.utils.load_utils · View on GitHub ↗

Extract filter-bank features from audio data.

Args:

  • data — Audio samples (list of numpy arrays or tensors).
  • data_len — Lengths of each sample.
  • data_type (str) — Input type ("sound", "fbank").
  • frontend — Frontend instance for feature extraction.

Returns:

  • tuple — (features_tensor, feature_lengths, feature_times)
📄 Source code
def extract_fbank(data, data_len=None, data_type: str = "sound", frontend=None, **kwargs):
    """Extract filter-bank features from audio data.

    Args:
        data: Audio samples (list of numpy arrays or tensors).
        data_len: Lengths of each sample.
        data_type (str): Input type ("sound", "fbank").
        frontend: Frontend instance for feature extraction.

    Returns:
        tuple: (features_tensor, feature_lengths, feature_times)
    """
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data)
        if len(data.shape) < 2:
            data = data[None, :]  # data: [batch, N]
        elif data.shape[0] > 1:
            data = data.mean(dim=0, keepdim=True)  # convert stereo/multi-channel to mono
        data_len = [data.shape[1]] if data_len is None else data_len
    elif isinstance(data, torch.Tensor):
        if len(data.shape) < 2:
            data = data[None, :]  # data: [batch, N]
        elif data.shape[0] > 1:
            data = data.mean(dim=0, keepdim=True)  # convert stereo/multi-channel to mono
        data_len = [data.shape[1]] if data_len is None else data_len
    elif isinstance(data, (list, tuple)):
        data_list, data_len = [], []
        for data_i in data:
            if isinstance(data_i, np.ndarray):
                data_i = torch.from_numpy(data_i)
View full source on GitHub →
function

statistic_model_parameters(model, prefix)

funasr.utils.misc · View on GitHub ↗

Statistic model parameters.

Args:

  • model — Model instance or model name.
  • prefix — TODO.
📄 Source code
def statistic_model_parameters(model, prefix=None):
    """Statistic model parameters.
    
        Args:
            model: Model instance or model name.
            prefix: TODO.
        """
    var_dict = model.state_dict()
    numel = 0
    for i, key in enumerate(
        sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))
    ):
        if prefix is None or key.startswith(prefix):
            numel += var_dict[key].numel()
    return numel
function

int2vec(x, vec_dim, dtype)

funasr.utils.misc · View on GitHub ↗

Int2vec.

Args:

  • x — TODO.
  • vec_dim — Size/dimension parameter.
  • dtype — TODO.
📄 Source code
def int2vec(x, vec_dim=8, dtype=np.int32):
    """Int2vec.
    
        Args:
            x: TODO.
            vec_dim: Size/dimension parameter.
            dtype: TODO.
        """
    b = ("{:0" + str(vec_dim) + "b}").format(x)
    # little-endian order: lower bit first
    return (np.array(list(b)[::-1]) == "1").astype(dtype)
function

seq2arr(seq, vec_dim)

funasr.utils.misc · View on GitHub ↗

Seq2arr.

Args:

  • seq — TODO.
  • vec_dim — Size/dimension parameter.
📄 Source code
def seq2arr(seq, vec_dim=8):
    """Seq2arr.
    
        Args:
            seq: TODO.
            vec_dim: Size/dimension parameter.
        """
    return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
function

load_scp_as_dict(scp_path, value_type, kv_sep)

funasr.utils.misc · View on GitHub ↗

Load scp as dict.

Args:

  • scp_path — TODO.
  • value_type — TODO.
  • kv_sep — TODO.
📄 Source code
def load_scp_as_dict(scp_path, value_type="str", kv_sep=" "):
    """Load scp as dict.
    
        Args:
            scp_path: TODO.
            value_type: TODO.
            kv_sep: TODO.
        """
    with io.open(scp_path, "r", encoding="utf-8") as f:
        ret_dict = OrderedDict()
        for one_line in f.readlines():
            one_line = one_line.strip()
            pos = one_line.find(kv_sep)
            key, value = one_line[:pos], one_line[pos + 1 :]
            if value_type == "list":
                value = value.split(" ")
            ret_dict[key] = value
        return ret_dict
function

load_scp_as_list(scp_path, value_type, kv_sep)

funasr.utils.misc · View on GitHub ↗

Load scp as list.

Args:

  • scp_path — TODO.
  • value_type — TODO.
  • kv_sep — TODO.
📄 Source code
def load_scp_as_list(scp_path, value_type="str", kv_sep=" "):
    """Load scp as list.
    
        Args:
            scp_path: TODO.
            value_type: TODO.
            kv_sep: TODO.
        """
    with io.open(scp_path, "r", encoding="utf8") as f:
        ret_dict = []
        for one_line in f.readlines():
            one_line = one_line.strip()
            pos = one_line.find(kv_sep)
            key, value = one_line[:pos], one_line[pos + 1 :]
            if value_type == "list":
                value = value.split(" ")
            ret_dict.append((key, value))
        return ret_dict
function

deep_update(original, update)

funasr.utils.misc · View on GitHub ↗

Recursively merge update dict into original dict (in-place).

For nested dicts, merges recursively. For other types, overwrites.

Args:

  • original (dict) — Target dict to be updated in-place.
  • update (dict) — Source dict with new values.
📄 Source code
def deep_update(original, update):

    """Recursively merge update dict into original dict (in-place).

    For nested dicts, merges recursively. For other types, overwrites.

    Args:
        original (dict): Target dict to be updated in-place.
        update (dict): Source dict with new values.
    """
    for key, value in update.items():
        if isinstance(value, dict) and key in original:
            if len(value) == 0:
                original[key] = value
            deep_update(original[key], value)
        else:
            original[key] = value
function

prepare_model_dir(**kwargs)

funasr.utils.misc · View on GitHub ↗

Prepare model dir.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def prepare_model_dir(**kwargs):

    """Prepare model dir.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)

    yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
    OmegaConf.save(config=kwargs, f=yaml_file)
    logging.info(f"kwargs: {kwargs}")
    logging.info("config.yaml is saved to: %s", yaml_file)

    model_path = kwargs.get("model_path", None)
    if model_path is not None:
        config_json = os.path.join(model_path, "configuration.json")
        if os.path.exists(config_json):
            shutil.copy(
                config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")
            )
function

extract_filename_without_extension(file_path)

funasr.utils.misc · View on GitHub ↗

从给定的文件路径中提取文件名(不包含路径和扩展名)

:param file_path: 完整的文件路径

:return: 文件名(不含路径和扩展名)

📄 Source code
def extract_filename_without_extension(file_path):
    """
    从给定的文件路径中提取文件名(不包含路径和扩展名)
    :param file_path: 完整的文件路径
    :return: 文件名(不含路径和扩展名)
    """
    # 首先,使用os.path.basename获取路径中的文件名部分(含扩展名)
    filename_with_extension = os.path.basename(file_path)
    # 然后,使用os.path.splitext分离文件名和扩展名
    filename, extension = os.path.splitext(filename_with_extension)
    # 返回不包含扩展名的文件名
    return filename
function

smart_remove(path)

funasr.utils.misc · View on GitHub ↗

Intelligently removes files, empty directories, and non-empty directories recursively.

📄 Source code
def smart_remove(path):
    """Intelligently removes files, empty directories, and non-empty directories recursively."""
    # Check if the provided path exists
    if not os.path.exists(path):
        print(f"{path} does not exist.")
        return

    # If the path is a file, delete it
    if os.path.isfile(path):
        os.remove(path)
        print(f"File {path} has been deleted.")
    # If the path is a directory
    elif os.path.isdir(path):
        try:
            # Attempt to remove an empty directory
            os.rmdir(path)
            print(f"Empty directory {path} has been deleted.")
        except OSError:
            # If the directory is not empty, remove it along with all its contents
            shutil.rmtree(path)
            print(f"Non-empty directory {path} has been recursively deleted.")
function

isChinese(ch)

funasr.utils.postprocess_utils · View on GitHub ↗

Ischinese.

Args:

  • ch — TODO.
📄 Source code
def isChinese(ch: str):
    """Ischinese.
    
        Args:
            ch: TODO.
        """
    if "\u4e00" <= ch <= "\u9fff" or "\u0030" <= ch <= "\u0039" or ch == "@":
        return True
    return False
function

isAllChinese(word)

funasr.utils.postprocess_utils · View on GitHub ↗

Isallchinese.

Args:

  • word — TODO.
📄 Source code
def isAllChinese(word: Union[List[Any], str]):
    """Isallchinese.
    
        Args:
            word: TODO.
        """
    word_lists = []
    for i in word:
        cur = i.replace(" ", "")
        cur = cur.replace("</s>", "")
        cur = cur.replace("<s>", "")
        cur = cur.replace("<unk>", "")
        cur = cur.replace("<OOV>", "")
        word_lists.append(cur)

    if len(word_lists) == 0:
        return False

    for ch in word_lists:
        if isChinese(ch) is False:
            return False
    return True
function

isAllAlpha(word)

funasr.utils.postprocess_utils · View on GitHub ↗

Isallalpha.

Args:

  • word — TODO.
📄 Source code
def isAllAlpha(word: Union[List[Any], str]):
    """Isallalpha.
    
        Args:
            word: TODO.
        """
    word_lists = []
    for i in word:
        cur = i.replace(" ", "")
        cur = cur.replace("</s>", "")
        cur = cur.replace("<s>", "")
        cur = cur.replace("<unk>", "")
        cur = cur.replace("<OOV>", "")
        word_lists.append(cur)

    if len(word_lists) == 0:
        return False

    for ch in word_lists:
        if ch.isalpha() is False and ch != "'":
            return False
        elif ch.isalpha() is True and isChinese(ch) is True:
            return False

    return True
function

abbr_dispose(words, time_stamp)

funasr.utils.postprocess_utils · View on GitHub ↗

Abbr dispose.

Args:

  • words — TODO.
  • time_stamp — TODO.
📄 Source code
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
    """Abbr dispose.
    
        Args:
            words: TODO.
            time_stamp: TODO.
        """
    words_size = len(words)
    word_lists = []
    abbr_begin = []
    abbr_end = []
    last_num = -1
    ts_lists = []
    ts_nums = []
    ts_index = 0
    for num in range(words_size):
        if num <= last_num:
            continue

        if len(words[num]) == 1 and words[num].encode("utf-8").isalpha():
            if (
                num + 1 < words_size
                and words[num + 1] == " "
                and num + 2 < words_size
                and len(words[num + 2]) == 1
                and words[num + 2].encode("utf-8").isalpha()
            ):
                # found the begin of abbr
                abbr_begin.append(num)
                num += 2
View full source on GitHub →
function

sentence_postprocess(words, time_stamp)

funasr.utils.postprocess_utils · View on GitHub ↗

Sentence postprocess.

Args:

  • words — TODO.
  • time_stamp — TODO.
📄 Source code
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
    """Sentence postprocess.
    
        Args:
            words: TODO.
            time_stamp: TODO.
        """
    middle_lists = []
    word_lists = []
    word_item = ""
    ts_lists = []

    # wash words lists
    for i in words:
        word = ""
        if isinstance(i, str):
            word = i
        else:
            word = i.decode("utf-8")

        if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
            continue
        else:
            middle_lists.append(word)

    # all chinese characters
    if isAllChinese(middle_lists):
        for i, ch in enumerate(middle_lists):
            word_lists.append(ch.replace(" ", ""))
        if time_stamp is not None:
View full source on GitHub →
function

sentence_postprocess_sentencepiece(words)

funasr.utils.postprocess_utils · View on GitHub ↗

Sentence postprocess sentencepiece.

Args:

  • words — TODO.
📄 Source code
def sentence_postprocess_sentencepiece(words):
    """Sentence postprocess sentencepiece.
    
        Args:
            words: TODO.
        """
    middle_lists = []
    word_lists = []
    word_item = ""

    # wash words lists
    for i in words:
        word = ""
        if isinstance(i, str):
            word = i
        else:
            word = i.decode("utf-8")

        if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
            continue
        else:
            middle_lists.append(word)

    # all alpha characters
    for i, ch in enumerate(middle_lists):
        word = ""
        if "\u2581" in ch and i == 0:
            word_item = ""
            word = ch.replace("\u2581", "")
            word_item += word
View full source on GitHub →
function

format_str_v2(s)

funasr.utils.postprocess_utils · View on GitHub ↗

Format str v2.

Args:

  • s — TODO.
📄 Source code
def format_str_v2(s):
    """Format str v2.
    
        Args:
            s: TODO.
        """
    sptk_dict = {}
    for sptk in emoji_dict:
        sptk_dict[sptk] = s.count(sptk)
        s = s.replace(sptk, "")
    emo = "<|NEUTRAL|>"
    for e in emo_dict:
        if sptk_dict[e] > sptk_dict[emo]:
            emo = e
    for e in event_dict:
        if sptk_dict[e] > 0:
            s = event_dict[e] + s
    s = s + emo_dict[emo]

    for emoji in emo_set.union(event_set):
        s = s.replace(" " + emoji, emoji)
        s = s.replace(emoji + " ", emoji)
    return s.strip()
function

rich_transcription_postprocess(s)

funasr.utils.postprocess_utils · View on GitHub ↗

Rich transcription postprocess.

Args:

  • s — TODO.
📄 Source code
def rich_transcription_postprocess(s):
    """Rich transcription postprocess.
    
        Args:
            s: TODO.
        """
    def get_emo(s):
        """Get emo.
        
            Args:
                s: TODO.
            """
        return s[-1] if s[-1] in emo_set else None

    def get_event(s):
        """Get event.
        
            Args:
                s: TODO.
            """
        return s[0] if s[0] in event_set else None

    s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
    for lang in lang_dict:
        s = s.replace(lang, "<|lang|>")
    s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
    new_s = " " + s_list[0]
    cur_ent_event = get_event(new_s)
    for i in range(1, len(s_list)):
        if len(s_list[i]) == 0:
View full source on GitHub →
function

check_audio_list(audio)

funasr.utils.speaker_utils · View on GitHub ↗

Check audio list.

Args:

  • audio — TODO.
📄 Source code
def check_audio_list(audio: list):
    """Check audio list.
    
        Args:
            audio: TODO.
        """
    audio_dur = 0
    for i in range(len(audio)):
        seg = audio[i]
        assert seg[1] >= seg[0], "modelscope error: Wrong time stamps."
        assert isinstance(seg[2], np.ndarray), "modelscope error: Wrong data type."
        assert (
            int(seg[1] * 16000) - int(seg[0] * 16000) == seg[2].shape[0]
        ), "modelscope error: audio data in list is inconsistent with time length."
        if i > 0:
            assert seg[0] >= audio[i - 1][1], "modelscope error: Wrong time stamps."
        audio_dur += seg[1] - seg[0]
    return audio_dur
function

sv_preprocess(inputs)

funasr.utils.speaker_utils · View on GitHub ↗

Sv preprocess.

Args:

  • inputs — TODO.
📄 Source code
def sv_preprocess(inputs: Union[np.ndarray, list]):
    """Sv preprocess.
    
        Args:
            inputs: TODO.
        """
    output = []
    for i in range(len(inputs)):
        if isinstance(inputs[i], str):
            file_bytes = File.read(inputs[i])
            data, fs = sf.load(io.BytesIO(file_bytes), dtype="float32")
            if len(data.shape) == 2:
                data = data[:, 0]
            data = torch.from_numpy(data).unsqueeze(0)
            data = data.squeeze(0)
        elif isinstance(inputs[i], np.ndarray):
            assert len(inputs[i].shape) == 1, "modelscope error: Input array should be [N, T]"
            data = inputs[i]
            if data.dtype in ["int16", "int32", "int64"]:
                data = (data / (1 << 15)).astype("float32")
            else:
                data = data.astype("float32")
            data = torch.from_numpy(data)
        else:
            raise ValueError(
                "modelscope error: The input type is restricted to audio address and nump array."
            )
        output.append(data)
    return output
function

sv_chunk(vad_segments, fs)

funasr.utils.speaker_utils · View on GitHub ↗

Sv chunk.

Args:

  • vad_segments — TODO.
  • fs — TODO.
📄 Source code
def sv_chunk(vad_segments: list, fs=16000) -> list:
    """Sv chunk.
    
        Args:
            vad_segments: TODO.
            fs: TODO.
        """
    config = {
        "seg_dur": 1.5,
        "seg_shift": 0.75,
    }

    def seg_chunk(seg_data):
        """Seg chunk.
        
            Args:
                seg_data: TODO.
            """
        seg_st = seg_data[0]
        data = seg_data[2]
        chunk_len = int(config["seg_dur"] * fs)
        chunk_shift = int(config["seg_shift"] * fs)
        last_chunk_ed = 0
        seg_res = []
        for chunk_st in range(0, data.shape[0], chunk_shift):
            chunk_ed = min(chunk_st + chunk_len, data.shape[0])
            if chunk_ed <= last_chunk_ed:
                break
            last_chunk_ed = chunk_ed
            chunk_st = max(0, chunk_ed - chunk_len)
View full source on GitHub →
function

extract_feature(audio)

funasr.utils.speaker_utils · View on GitHub ↗

Extract feature.

Args:

  • audio — TODO.
📄 Source code
def extract_feature(audio):
    """Extract feature.
    
        Args:
            audio: TODO.
        """
    features = []
    for au in audio:
        feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
        feature = feature - feature.mean(dim=0, keepdim=True)
        features.append(feature.unsqueeze(0))
    features = torch.cat(features)
    return features
function

postprocess(segments, vad_segments, labels, embeddings)

funasr.utils.speaker_utils · View on GitHub ↗

Postprocess.

Args:

  • segments — TODO.
  • vad_segments — TODO.
  • labels — TODO.
  • embeddings — TODO.
📄 Source code
def postprocess(
    segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray
) -> list:
    """Postprocess.
    
        Args:
            segments: TODO.
            vad_segments: TODO.
            labels: TODO.
            embeddings: TODO.
        """
    assert len(segments) == len(labels)
    labels = correct_labels(labels)
    distribute_res = []
    for i in range(len(segments)):
        distribute_res.append([segments[i][0], segments[i][1], labels[i]])
    # merge the same speakers chronologically
    distribute_res = merge_seque(distribute_res)

    # accquire speaker center
    spk_embs = []
    for i in range(labels.max() + 1):
        spk_emb = embeddings[labels == i].mean(0)
        spk_embs.append(spk_emb)
    spk_embs = np.stack(spk_embs)

    def is_overlapped(t1, t2):
        """Is overlapped.
        
            Args:
View full source on GitHub →
function

correct_labels(labels)

funasr.utils.speaker_utils · View on GitHub ↗

Correct labels.

Args:

  • labels — TODO.
📄 Source code
def correct_labels(labels):
    """Correct labels.
    
        Args:
            labels: TODO.
        """
    labels_id = 0
    id2id = {}
    new_labels = []
    for i in labels:
        if i not in id2id:
            id2id[i] = labels_id
            labels_id += 1
        new_labels.append(id2id[i])
    return np.array(new_labels)
function

merge_seque(distribute_res)

funasr.utils.speaker_utils · View on GitHub ↗

Merge seque.

Args:

  • distribute_res — TODO.
📄 Source code
def merge_seque(distribute_res):
    """Merge seque.
    
        Args:
            distribute_res: TODO.
        """
    res = [distribute_res[0]]
    for i in range(1, len(distribute_res)):
        if distribute_res[i][2] != res[-1][2] or distribute_res[i][0] > res[-1][1]:
            res.append(distribute_res[i])
        else:
            res[-1][1] = distribute_res[i][1]
    return res
function

smooth(res, mindur)

funasr.utils.speaker_utils · View on GitHub ↗

Smooth.

Args:

  • res — TODO.
  • mindur — TODO.
📄 Source code
def smooth(res, mindur=1):
    # short segments are assigned to nearest speakers.
    """Smooth.
    
        Args:
            res: TODO.
            mindur: TODO.
        """
    for i in range(len(res)):
        res[i][0] = round(res[i][0], 2)
        res[i][1] = round(res[i][1], 2)
        if res[i][1] - res[i][0] < mindur:
            if i == 0:
                res[i][2] = res[i + 1][2]
            elif i == len(res) - 1:
                res[i][2] = res[i - 1][2]
            elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
                res[i][2] = res[i - 1][2]
            else:
                res[i][2] = res[i + 1][2]
    # merge the speakers
    res = merge_seque(res)

    return res
function

distribute_spk(sentence_list, sd_time_list)

funasr.utils.speaker_utils · View on GitHub ↗

Distribute spk.

Args:

  • sentence_list — TODO.
  • sd_time_list — TODO.
📄 Source code
def distribute_spk(sentence_list, sd_time_list):
    """Distribute spk.
    
        Args:
            sentence_list: TODO.
            sd_time_list: TODO.
        """
    sd_sentence_list = []
    for d in sentence_list:
        sentence_start = d["ts_list"][0][0]
        sentence_end = d["ts_list"][-1][1]
        sentence_spk = 0
        max_overlap = 0
        for sd_time in sd_time_list:
            spk_st, spk_ed, spk = sd_time
            spk_st = spk_st * 1000
            spk_ed = spk_ed * 1000
            overlap = max(min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
            if overlap > max_overlap:
                max_overlap = overlap
                sentence_spk = spk
        d["spk"] = sentence_spk
        sd_sentence_list.append(d)
    return sd_sentence_list
function

cif_wo_hidden(alphas, threshold)

funasr.utils.timestamp_tools · View on GitHub ↗

Cif wo hidden.

Args:

  • alphas — TODO.
  • threshold — TODO.
📄 Source code
def cif_wo_hidden(alphas, threshold):
    """Cif wo hidden.
    
        Args:
            alphas: TODO.
            threshold: TODO.
        """
    batch_size, len_time = alphas.size()
    # loop varss
    integrate = torch.zeros([batch_size], device=alphas.device)
    # intermediate vars along time
    list_fires = []
    for t in range(len_time):
        alpha = alphas[:, t]
        integrate += alpha
        list_fires.append(integrate)
        fire_place = integrate >= threshold
        integrate = torch.where(
            fire_place,
            integrate - torch.ones([batch_size], device=alphas.device) * threshold,
            integrate,
        )
    fires = torch.stack(list_fires, 1)
    return fires
function

ts_prediction_lfr6_standard(us_alphas, us_peaks, char_list, vad_offset, force_time_shift, sil_in_str, upsample_rate)

funasr.utils.timestamp_tools · View on GitHub ↗

Ts prediction lfr6 standard.

Args:

  • us_alphas — TODO.
  • us_peaks — TODO.
  • char_list — TODO.
  • vad_offset — TODO.
  • force_time_shift — TODO.
  • sil_in_str — TODO.
  • upsample_rate — TODO.
📄 Source code
def ts_prediction_lfr6_standard(
    us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
):
    """Ts prediction lfr6 standard.
    
        Args:
            us_alphas: TODO.
            us_peaks: TODO.
            char_list: TODO.
            vad_offset: TODO.
            force_time_shift: TODO.
            sil_in_str: TODO.
            upsample_rate: TODO.
        """
    if not len(char_list):
        return "", []
    START_END_THRESHOLD = 5
    MAX_TOKEN_DURATION = 12  #  3 times upsampled
    TIME_RATE=10.0 * 6 / 1000 / upsample_rate
    if len(us_alphas.shape) == 2:
        alphas, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
    else:
        alphas, peaks = us_alphas, us_peaks
    if char_list[-1] == "</s>":
        char_list = char_list[:-1]
    fire_place = (
        torch.where(peaks >= 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
    )  # total offset
    if len(fire_place) != len(char_list) + 1:
        alphas /= alphas.sum() / (len(char_list) + 1)
View full source on GitHub →
function

timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text)

funasr.utils.timestamp_tools · View on GitHub ↗

Split recognized text into sentences using punctuation, with timestamps.

Args:

  • punc_id_list (Tensor/list) — Punctuation IDs from CT-Transformer.
  • Values — 1=none, 2=comma, 3=period, 4=question.
  • timestamp_postprocessed (list) — Per-character timestamps [[start_ms, end_ms], ...].
  • text_postprocessed (str) — Space-separated recognized text.
  • return_raw_text (bool) — Include raw_text in output.

Returns:

list[dict]: Sentences with keys: text, start, end, timestamp, [raw_text].

📄 Source code
def timestamp_sentence(
    punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):

    """Split recognized text into sentences using punctuation, with timestamps.

    Args:
        punc_id_list (Tensor/list): Punctuation IDs from CT-Transformer.
            Values: 1=none, 2=comma, 3=period, 4=question.
        timestamp_postprocessed (list): Per-character timestamps [[start_ms, end_ms], ...].
        text_postprocessed (str): Space-separated recognized text.
        return_raw_text (bool): Include raw_text in output.

    Returns:
        list[dict]: Sentences with keys: text, start, end, timestamp, [raw_text].
    """
    punc_list = [",", "。", "?", "、"]
    res = []
    if text_postprocessed is None:
        return res
    if timestamp_postprocessed is None:
        return res
    if len(timestamp_postprocessed) == 0:
        return res
    if len(text_postprocessed) == 0:
        return res

    if punc_id_list is None or len(punc_id_list) == 0:
        res.append(
            {
View full source on GitHub →
function

timestamp_sentence_en(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text)

funasr.utils.timestamp_tools · View on GitHub ↗

Timestamp sentence en.

Args:

  • punc_id_list — TODO.
  • timestamp_postprocessed — TODO.
  • text_postprocessed — TODO.
  • return_raw_text — TODO.
📄 Source code
def timestamp_sentence_en(
    punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):
    """Timestamp sentence en.
    
        Args:
            punc_id_list: TODO.
            timestamp_postprocessed: TODO.
            text_postprocessed: TODO.
            return_raw_text: TODO.
        """
    punc_list = [",", ".", "?", ","]
    res = []
    if text_postprocessed is None:
        return res
    if timestamp_postprocessed is None:
        return res
    if len(timestamp_postprocessed) == 0:
        return res
    if len(text_postprocessed) == 0:
        return res

    if punc_id_list is None or len(punc_id_list) == 0:
        res.append(
            {
                "text": text_postprocessed.split(),
                "start": timestamp_postprocessed[0][0],
                "end": timestamp_postprocessed[-1][1],
                "timestamp": timestamp_postprocessed,
            }
View full source on GitHub →
class

MakePadMask

funasr.utils.torch_function · View on GitHub ↗

No documentation yet.

📄 Source code
class MakePadMask(nn.Module):
    def __init__(self, max_seq_len=512, flip=True):
        """Initialize MakePadMask.
        
            Args:
                max_seq_len: TODO.
                flip: TODO.
            """
        super().__init__()
        if flip:
            self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
        else:
            self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)

    def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
        """Make mask tensor containing indices of padded part.
        This implementation creates the same mask tensor with original make_pad_mask,
        which can be converted into onnx format.
        Dimension length of xs should be 2 or 3.
        """
        if length_dim == 0:
            raise ValueError("length_dim cannot be 0: {}".format(length_dim))

        if xs is not None and len(xs.shape) == 3:
            if length_dim == 1:
                lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
            else:
                lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])

        if maxlen is not None:
View full source on GitHub →

Methods

.forward(lengths, xs, length_dim, maxlen) L23

Make mask tensor containing indices of padded part.

This implementation creates the same mask tensor with original make_pad_mask,

which can be converted into onnx format.

Dimension length of xs should be 2 or 3.

📄 Source
    def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
        """Make mask tensor containing indices of padded part.
        This implementation creates the same mask tensor with original make_pad_mask,
        which can be converted into onnx format.
        Dimension length of xs should be 2 or 3.
        """
        if length_dim == 0:
            raise ValueError("length_dim cannot be 0: {}".format(length_dim))

        if xs is not None and len(xs.shape) == 3:
            if length_dim == 1:
                lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
            else:
                lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])

        if maxlen is not None:
            m = maxlen
        elif xs is not None:
            m = xs.shape[-1]
        else:
            m = torch.max(lengths)

        mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)

        if length_dim == 1:
            return mask.transpose(1, 2)
        else:
            return mask
method

MakePadMask.forward(lengths, xs, length_dim, maxlen)

funasr.utils.torch_function.MakePadMask · View on GitHub ↗

Make mask tensor containing indices of padded part.

This implementation creates the same mask tensor with original make_pad_mask,

which can be converted into onnx format.

Dimension length of xs should be 2 or 3.

📄 Source code
    def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
        """Make mask tensor containing indices of padded part.
        This implementation creates the same mask tensor with original make_pad_mask,
        which can be converted into onnx format.
        Dimension length of xs should be 2 or 3.
        """
        if length_dim == 0:
            raise ValueError("length_dim cannot be 0: {}".format(length_dim))

        if xs is not None and len(xs.shape) == 3:
            if length_dim == 1:
                lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
            else:
                lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])

        if maxlen is not None:
            m = maxlen
        elif xs is not None:
            m = xs.shape[-1]
        else:
            m = torch.max(lengths)

        mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)

        if length_dim == 1:
            return mask.transpose(1, 2)
        else:
            return mask
class

sequence_mask

funasr.utils.torch_function · View on GitHub ↗

No documentation yet.

📄 Source code
class sequence_mask(nn.Module):
    def __init__(self, max_seq_len=512, flip=True):
        """Initialize sequence_mask.
        
            Args:
                max_seq_len: TODO.
                flip: TODO.
            """
        super().__init__()

    def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
        """Forward pass for training.
        
            Args:
                lengths: TODO.
                max_seq_len: TODO.
                dtype: TODO.
                device: Target device ("cuda:0", "cpu", etc.).
            """
        if max_seq_len is None:
            max_seq_len = lengths.max()
        row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
        matrix = torch.unsqueeze(lengths, dim=-1)
        mask = row_vector < matrix

        return mask.type(dtype).to(device) if device is not None else mask.type(dtype)

Methods

.forward(lengths, max_seq_len, dtype, device) L63

Forward pass for training.

Args:

  • lengths — TODO.
  • max_seq_len — TODO.
  • dtype — TODO.
  • device — Target device ("cuda:0", "cpu", etc.).
📄 Source
    def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
        """Forward pass for training.
        
            Args:
                lengths: TODO.
                max_seq_len: TODO.
                dtype: TODO.
                device: Target device ("cuda:0", "cpu", etc.).
            """
        if max_seq_len is None:
            max_seq_len = lengths.max()
        row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
        matrix = torch.unsqueeze(lengths, dim=-1)
        mask = row_vector < matrix

        return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
method

sequence_mask.forward(lengths, max_seq_len, dtype, device)

funasr.utils.torch_function.sequence_mask · View on GitHub ↗

Forward pass for training.

Args:

  • lengths — TODO.
  • max_seq_len — TODO.
  • dtype — TODO.
  • device — Target device ("cuda:0", "cpu", etc.).
📄 Source code
    def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
        """Forward pass for training.
        
            Args:
                lengths: TODO.
                max_seq_len: TODO.
                dtype: TODO.
                device: Target device ("cuda:0", "cpu", etc.).
            """
        if max_seq_len is None:
            max_seq_len = lengths.max()
        row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
        matrix = torch.unsqueeze(lengths, dim=-1)
        mask = row_vector < matrix

        return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
function

normalize(input, p, dim, out)

funasr.utils.torch_function · View on GitHub ↗

Normalize.

Args:

  • input — Input audio/text data.
  • p — TODO.
  • dim — TODO.
  • out — TODO.
📄 Source code
def normalize(
    input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """Normalize.
    
        Args:
            input: Input audio/text data.
            p: TODO.
            dim: TODO.
            out: TODO.
        """
    if out is None:
        denom = input.norm(p, dim, keepdim=True).expand_as(input)
        return input / denom
    else:
        denom = input.norm(p, dim, keepdim=True).expand_as(input)
        return torch.div(input, denom, out=out)
function

subsequent_mask(size)

funasr.utils.torch_function · View on GitHub ↗

Subsequent mask.

Args:

  • size — TODO.
📄 Source code
def subsequent_mask(size: torch.Tensor):
    """Subsequent mask.
    
        Args:
            size: TODO.
        """
    return torch.ones(size, size).tril()
function

MakePadMask_test()

funasr.utils.torch_function · View on GitHub ↗

Makepadmask test.

📄 Source code
def MakePadMask_test():
    """Makepadmask test."""
    feats_length = torch.tensor([10]).type(torch.long)
    mask_fn = MakePadMask()
    mask = mask_fn(feats_length)
    print(mask)
function

str2bool(value)

funasr.utils.type_utils · View on GitHub ↗

Str2bool.

Args:

  • value — TODO.
📄 Source code
def str2bool(value: str) -> bool:
    """Str2bool.
    
        Args:
            value: TODO.
        """
    return bool(strtobool(value))
function

remove_parenthesis(value)

funasr.utils.type_utils · View on GitHub ↗

Remove parenthesis.

Args:

  • value — TODO.
📄 Source code
def remove_parenthesis(value: str):
    """Remove parenthesis.
    
        Args:
            value: TODO.
        """
    value = value.strip()
    if value.startswith("(") and value.endswith(")"):
        value = value[1:-1]
    elif value.startswith("[") and value.endswith("]"):
        value = value[1:-1]
    return value
function

remove_quotes(value)

funasr.utils.type_utils · View on GitHub ↗

Remove quotes.

Args:

  • value — TODO.
📄 Source code
def remove_quotes(value: str):
    """Remove quotes.
    
        Args:
            value: TODO.
        """
    value = value.strip()
    if value.startswith('"') and value.endswith('"'):
        value = value[1:-1]
    elif value.startswith("'") and value.endswith("'"):
        value = value[1:-1]
    return value
function

int_or_none(value)

funasr.utils.type_utils · View on GitHub ↗

int_or_none.

Examples:

>>> import argparse

>>> parser = argparse.ArgumentParser()

>>> _ = parser.add_argument('--foo', type=int_or_none)

>>> parser.parse_args(['--foo', '456'])

Namespace(foo=456)

>>> parser.parse_args(['--foo', 'none'])

Namespace(foo=None)

>>> parser.parse_args(['--foo', 'null'])

Namespace(foo=None)

>>> parser.parse_args(['--foo', 'nil'])

Namespace(foo=None)

📄 Source code
def int_or_none(value: str) -> Optional[int]:
    """int_or_none.

    Examples:
        >>> import argparse
        >>> parser = argparse.ArgumentParser()
        >>> _ = parser.add_argument('--foo', type=int_or_none)
        >>> parser.parse_args(['--foo', '456'])
        Namespace(foo=456)
        >>> parser.parse_args(['--foo', 'none'])
        Namespace(foo=None)
        >>> parser.parse_args(['--foo', 'null'])
        Namespace(foo=None)
        >>> parser.parse_args(['--foo', 'nil'])
        Namespace(foo=None)

    """
    if value.strip().lower() in ("none", "null", "nil"):
        return None
    return int(value)
function

float_or_none(value)

funasr.utils.type_utils · View on GitHub ↗

float_or_none.

Examples:

>>> import argparse

>>> parser = argparse.ArgumentParser()

>>> _ = parser.add_argument('--foo', type=float_or_none)

>>> parser.parse_args(['--foo', '4.5'])

Namespace(foo=4.5)

>>> parser.parse_args(['--foo', 'none'])

Namespace(foo=None)

>>> parser.parse_args(['--foo', 'null'])

Namespace(foo=None)

>>> parser.parse_args(['--foo', 'nil'])

Namespace(foo=None)

📄 Source code
def float_or_none(value: str) -> Optional[float]:
    """float_or_none.

    Examples:
        >>> import argparse
        >>> parser = argparse.ArgumentParser()
        >>> _ = parser.add_argument('--foo', type=float_or_none)
        >>> parser.parse_args(['--foo', '4.5'])
        Namespace(foo=4.5)
        >>> parser.parse_args(['--foo', 'none'])
        Namespace(foo=None)
        >>> parser.parse_args(['--foo', 'null'])
        Namespace(foo=None)
        >>> parser.parse_args(['--foo', 'nil'])
        Namespace(foo=None)

    """
    if value.strip().lower() in ("none", "null", "nil"):
        return None
    return float(value)
function

humanfriendly_parse_size_or_none(value)

funasr.utils.type_utils · View on GitHub ↗

Humanfriendly parse size or none.

Args:

  • value — TODO.
📄 Source code
def humanfriendly_parse_size_or_none(value) -> Optional[float]:
    """Humanfriendly parse size or none.
    
        Args:
            value: TODO.
        """
    if value.strip().lower() in ("none", "null", "nil"):
        return None
    return humanfriendly.parse_size(value)
function

str_or_int(value)

funasr.utils.type_utils · View on GitHub ↗

Str or int.

Args:

  • value — TODO.
📄 Source code
def str_or_int(value: str) -> Union[str, int]:
    """Str or int.
    
        Args:
            value: TODO.
        """
    try:
        return int(value)
    except ValueError:
        return value
function

str_or_none(value)

funasr.utils.type_utils · View on GitHub ↗

str_or_none.

Examples:

>>> import argparse

>>> parser = argparse.ArgumentParser()

>>> _ = parser.add_argument('--foo', type=str_or_none)

>>> parser.parse_args(['--foo', 'aaa'])

Namespace(foo='aaa')

>>> parser.parse_args(['--foo', 'none'])

Namespace(foo=None)

>>> parser.parse_args(['--foo', 'null'])

Namespace(foo=None)

>>> parser.parse_args(['--foo', 'nil'])

Namespace(foo=None)

📄 Source code
def str_or_none(value: str) -> Optional[str]:
    """str_or_none.

    Examples:
        >>> import argparse
        >>> parser = argparse.ArgumentParser()
        >>> _ = parser.add_argument('--foo', type=str_or_none)
        >>> parser.parse_args(['--foo', 'aaa'])
        Namespace(foo='aaa')
        >>> parser.parse_args(['--foo', 'none'])
        Namespace(foo=None)
        >>> parser.parse_args(['--foo', 'null'])
        Namespace(foo=None)
        >>> parser.parse_args(['--foo', 'nil'])
        Namespace(foo=None)

    """
    if value.strip().lower() in ("none", "null", "nil"):
        return None
    return value
function

str2pair_str(value)

funasr.utils.type_utils · View on GitHub ↗

str2pair_str.

Examples:

>>> import argparse

>>> str2pair_str('abc,def ')

('abc', 'def')

>>> parser = argparse.ArgumentParser()

>>> _ = parser.add_argument('--foo', type=str2pair_str)

>>> parser.parse_args(['--foo', 'abc,def'])

Namespace(foo=('abc', 'def'))

📄 Source code
def str2pair_str(value: str) -> Tuple[str, str]:
    """str2pair_str.

    Examples:
        >>> import argparse
        >>> str2pair_str('abc,def ')
        ('abc', 'def')
        >>> parser = argparse.ArgumentParser()
        >>> _ = parser.add_argument('--foo', type=str2pair_str)
        >>> parser.parse_args(['--foo', 'abc,def'])
        Namespace(foo=('abc', 'def'))

    """
    value = remove_parenthesis(value)
    a, b = value.split(",")

    # Workaround for configargparse issues:
    # If the list values are given from yaml file,
    # the value givent to type() is shaped as python-list,
    # e.g. ['a', 'b', 'c'],
    # so we need to remove double quotes from it.
    return remove_quotes(a), remove_quotes(b)
function

str2triple_str(value)

funasr.utils.type_utils · View on GitHub ↗

str2triple_str.

Examples:

>>> str2triple_str('abc,def ,ghi')

('abc', 'def', 'ghi')

📄 Source code
def str2triple_str(value: str) -> Tuple[str, str, str]:
    """str2triple_str.

    Examples:
        >>> str2triple_str('abc,def ,ghi')
        ('abc', 'def', 'ghi')
    """
    value = remove_parenthesis(value)
    a, b, c = value.split(",")

    # Workaround for configargparse issues:
    # If the list values are given from yaml file,
    # the value givent to type() is shaped as python-list,
    # e.g. ['a', 'b', 'c'],
    # so we need to remove quotes from it.
    return remove_quotes(a), remove_quotes(b), remove_quotes(c)
function

slice_padding_fbank(speech, speech_lengths, vad_segments)

funasr.utils.vad_utils · View on GitHub ↗

Slice padding fbank.

Args:

  • speech — Speech audio tensor, shape (batch, time).
  • speech_lengths — Length of each speech sample.
  • vad_segments — TODO.
📄 Source code
def slice_padding_fbank(speech, speech_lengths, vad_segments):
    """Slice padding fbank.
    
        Args:
            speech: Speech audio tensor, shape (batch, time).
            speech_lengths: Length of each speech sample.
            vad_segments: TODO.
        """
    speech_list = []
    speech_lengths_list = []
    for i, segment in enumerate(vad_segments):

        bed_idx = int(segment[0][0] * 16)
        end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
        speech_i = speech[0, bed_idx:end_idx]
        speech_lengths_i = end_idx - bed_idx
        speech_list.append(speech_i)
        speech_lengths_list.append(speech_lengths_i)
    feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
    speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
    return feats_pad, speech_lengths_pad
function

slice_padding_audio_samples(speech, speech_lengths, vad_segments)

funasr.utils.vad_utils · View on GitHub ↗

Slice audio into VAD segments with proper padding.

Args:

  • speech (Tensor) — Full audio tensor.
  • speech_lengths (int) — Total audio length.
  • vad_segments (list) — List of (segment_info, original_index) tuples,
  • where segment_info is [start_ms, end_ms].

Returns:

  • tuple — (speech_list, speech_lengths_list) - lists of numpy arrays and their lengths.
📄 Source code
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):

    """Slice audio into VAD segments with proper padding.

    Args:
        speech (Tensor): Full audio tensor.
        speech_lengths (int): Total audio length.
        vad_segments (list): List of (segment_info, original_index) tuples,
            where segment_info is [start_ms, end_ms].

    Returns:
        tuple: (speech_list, speech_lengths_list) - lists of numpy arrays and their lengths.
    """
    speech_list = []
    speech_lengths_list = []
    for i, segment in enumerate(vad_segments):
        bed_idx = int(segment[0][0] * 16)
        end_idx = min(int(segment[0][1] * 16), speech_lengths)
        speech_i = speech[bed_idx:end_idx]
        speech_lengths_i = end_idx - bed_idx
        speech_list.append(speech_i)
        speech_lengths_list.append(speech_lengths_i)

    return speech_list, speech_lengths_list
function

merge_vad(vad_result, max_length, min_length)

funasr.utils.vad_utils · View on GitHub ↗

Merge short VAD segments to reduce fragmentation.

Args:

  • vad_result (list) — VAD segments [[start_ms, end_ms], ...].
  • max_length (int) — Maximum merged segment length in ms (default 15000).
  • min_length (int) — Minimum segment length; shorter ones get merged (default 0).

Returns:

  • list — Merged VAD segments [[start_ms, end_ms], ...].
📄 Source code
def merge_vad(vad_result, max_length=15000, min_length=0):

    """Merge short VAD segments to reduce fragmentation.

    Args:
        vad_result (list): VAD segments [[start_ms, end_ms], ...].
        max_length (int): Maximum merged segment length in ms (default 15000).
        min_length (int): Minimum segment length; shorter ones get merged (default 0).

    Returns:
        list: Merged VAD segments [[start_ms, end_ms], ...].
    """
    new_result = []
    if len(vad_result) <= 1:
        return vad_result
    time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
    time_step = sorted(list(set(time_step)))
    if len(time_step) == 0:
        return []
    bg = 0
    for i in range(len(time_step) - 1):
        time = time_step[i]
        if time_step[i + 1] - bg < max_length:
            continue
        if time - bg > min_length:
            new_result.append([bg, time])
        # if time - bg < max_length * 1.5:
        #     new_result.append([bg, time])
        # else:
        #     split_num = int(time - bg) // max_length + 1
View full source on GitHub →
function

get_pypi_version(package_name)

funasr.utils.version_checker · View on GitHub ↗

Get pypi version.

Args:

  • package_name — TODO.
📄 Source code
def get_pypi_version(package_name):
    """Get pypi version.
    
        Args:
            package_name: TODO.
        """
    import requests

    url = f"https://pypi.org/pypi/{package_name}/json"
    response = requests.get(url)
    if response.status_code == 200:
        data = response.json()
        return version.parse(data["info"]["version"])
    else:
        raise Exception("Failed to retrieve version information from PyPI.")
function

check_for_update(disable)

funasr.utils.version_checker · View on GitHub ↗

Check for update.

Args:

  • disable — TODO.
📄 Source code
def check_for_update(disable=False):
    """Check for update.
    
        Args:
            disable: TODO.
        """
    current_version = version.parse(__version__)
    print(f"funasr version: {current_version}.")
    if disable:
        return

    print(
        "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
    )
    pypi_version = get_pypi_version("funasr")

    if current_version < pypi_version:
        print(f"New version is available: {pypi_version}.")
        print('Please use the command "pip install -U funasr" to upgrade.')
    else:
        print(f"You are using the latest version of funasr-{current_version}")
function

main_hydra(kwargs)

funasr.bin.compute_audio_cmvn · View on GitHub ↗

Main hydra.

Args:

  • kwargs — Additional keyword arguments.
📄 Source code
def main_hydra(kwargs: DictConfig):
    """Main hydra.
    
        Args:
            kwargs: Additional keyword arguments.
        """
    if kwargs.get("debug", False):
        import pdb

        pdb.set_trace()

    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)

    main(**kwargs)
function

main(**kwargs)

funasr.bin.compute_audio_cmvn · View on GitHub ↗

Main.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def main(**kwargs):
    """Main.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    print(kwargs)
    # set random seed
    # tables.print()
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)

    tokenizer = kwargs.get("tokenizer", None)

    # build frontend if frontend is none None
    frontend = kwargs.get("frontend", None)
    if frontend is not None:
        frontend_class = tables.frontend_classes.get(frontend)
        frontend = frontend_class(**kwargs["frontend_conf"])
        kwargs["frontend"] = frontend
        kwargs["input_size"] = frontend.output_size()

    # dataset
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_train = dataset_class(
        kwargs.get("train_data_set_list"),
        frontend=frontend,
        tokenizer=None,
View full source on GitHub →
function

main_hydra(cfg)

funasr.bin.export · View on GitHub ↗

Main hydra.

Args:

  • cfg — Configuration overrides.
📄 Source code
def main_hydra(cfg: DictConfig):
    """Main hydra.
    
        Args:
            cfg: Configuration overrides.
        """
    def to_plain_list(cfg_item):
        """To plain list.
        
            Args:
                cfg_item: TODO.
            """
        if isinstance(cfg_item, ListConfig):
            return OmegaConf.to_container(cfg_item, resolve=True)
        elif isinstance(cfg_item, DictConfig):
            return {k: to_plain_list(v) for k, v in cfg_item.items()}
        else:
            return cfg_item

    kwargs = to_plain_list(cfg)

    if kwargs.get("debug", False):
        import pdb

        pdb.set_trace()

    if "device" not in kwargs:
        kwargs["device"] = "cpu"
    model = AutoModel(**kwargs)

View full source on GitHub →
function

main_hydra(cfg)

funasr.bin.inference · View on GitHub ↗

Main hydra.

Args:

  • cfg — Configuration overrides.
📄 Source code
def main_hydra(cfg: DictConfig):
    """Main hydra.
    
        Args:
            cfg: Configuration overrides.
        """
    def to_plain_list(cfg_item):
        """To plain list.
        
            Args:
                cfg_item: TODO.
            """
        if isinstance(cfg_item, ListConfig):
            return OmegaConf.to_container(cfg_item, resolve=True)
        elif isinstance(cfg_item, DictConfig):
            return {k: to_plain_list(v) for k, v in cfg_item.items()}
        else:
            return cfg_item

    kwargs = to_plain_list(cfg)

    if kwargs.get("debug", False):
        import pdb

        pdb.set_trace()
    model = AutoModel(**kwargs)
    res = model.generate(input=kwargs["input"])
    print(res)
function

field2slice(field)

funasr.bin.tokenize_text · View on GitHub ↗

Convert field string to slice

Note that field string accepts 1-based integer.

Examples:

>>> field2slice("1-")

slice(0, None, None)

>>> field2slice("1-3")

slice(0, 3, None)

>>> field2slice("-3")

slice(None, 3, None)

📄 Source code
def field2slice(field: Optional[str]) -> slice:
    """Convert field string to slice

    Note that field string accepts 1-based integer.

    Examples:
        >>> field2slice("1-")
        slice(0, None, None)
        >>> field2slice("1-3")
        slice(0, 3, None)
        >>> field2slice("-3")
        slice(None, 3, None)
    """
    field = field.strip()
    try:
        if "-" in field:
            # e.g. "2-" or "2-5" or "-7"
            s1, s2 = field.split("-", maxsplit=1)
            if s1.strip() == "":
                s1 = None
            else:
                s1 = int(s1)
                if s1 == 0:
                    raise ValueError("1-based string")
            if s2.strip() == "":
                s2 = None
            else:
                s2 = int(s2)
        else:
            # e.g. "2"
View full source on GitHub →
function

tokenize(input, output, field, delimiter, token_type, space_symbol, non_linguistic_symbols, bpemodel, log_level, write_vocabulary, vocabulary_size, remove_non_linguistic_symbols, cutoff, add_symbol, cleaner, g2p)

funasr.bin.tokenize_text · View on GitHub ↗

Tokenize.

Args:

  • input — Input audio/text data.
  • output — TODO.
  • field — TODO.
  • delimiter — TODO.
  • token_type — TODO.
  • space_symbol — TODO.
  • non_linguistic_symbols — TODO.
  • bpemodel — TODO.
  • log_level — TODO.
  • write_vocabulary — TODO.
  • vocabulary_size — Size/dimension parameter.
  • remove_non_linguistic_symbols — TODO.
  • cutoff — TODO.
  • add_symbol — TODO.
  • cleaner — TODO.
  • g2p — TODO.
📄 Source code
def tokenize(
    input: str,
    output: str,
    field: Optional[str],
    delimiter: Optional[str],
    token_type: str,
    space_symbol: str,
    non_linguistic_symbols: Optional[str],
    bpemodel: Optional[str],
    log_level: str,
    write_vocabulary: bool,
    vocabulary_size: int,
    remove_non_linguistic_symbols: bool,
    cutoff: int,
    add_symbol: List[str],
    cleaner: Optional[str],
    g2p: Optional[str],
):

    """Tokenize.
    
        Args:
            input: Input audio/text data.
            output: TODO.
            field: TODO.
            delimiter: TODO.
            token_type: TODO.
            space_symbol: TODO.
            non_linguistic_symbols: TODO.
            bpemodel: TODO.
View full source on GitHub →
function

get_parser()

funasr.bin.tokenize_text · View on GitHub ↗

Get parser.

📄 Source code
def get_parser() -> argparse.ArgumentParser:
    """Get parser."""
    parser = argparse.ArgumentParser(
        description="Tokenize texts",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )

    parser.add_argument("--input", "-i", required=True, help="Input text. - indicates sys.stdin")
    parser.add_argument("--output", "-o", required=True, help="Output text. - indicates sys.stdout")
    parser.add_argument(
        "--field",
        "-f",
        help="The target columns of the input text as 1-based integer. e.g 2-",
    )
    parser.add_argument(
        "--token_type",
        "-t",
        default="char",
        choices=["char", "bpe", "word", "phn"],
        help="Token type",
    )
    parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
    parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
View full source on GitHub →
function

main(cmd)

funasr.bin.tokenize_text · View on GitHub ↗

Main.

Args:

  • cmd — TODO.
📄 Source code
def main(cmd=None):
    """Main.
    
        Args:
            cmd: TODO.
        """
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    tokenize(**kwargs)
function

main_hydra(kwargs)

funasr.bin.train · View on GitHub ↗

Main hydra.

Args:

  • kwargs — Additional keyword arguments.
📄 Source code
def main_hydra(kwargs: DictConfig):
    """Main hydra.
    
        Args:
            kwargs: Additional keyword arguments.
        """
    if kwargs.get("debug", False):
        import pdb

        pdb.set_trace()

    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)

    main(**kwargs)
function

main(**kwargs)

funasr.bin.train · View on GitHub ↗

Main.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def main(**kwargs):

    # set random seed
    """Main.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    # open tf32
    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if local_rank == 0:
        tables.print()
    # Check if we are using DDP or FSDP
    use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
    use_fsdp = kwargs.get("use_fsdp", False)
    # use_ddp = False if use_fsdp else use_fsdp
    if use_ddp or use_fsdp:
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
        torch.cuda.set_device(local_rank)

    logging.info("Build model, frontend, tokenizer")
    device = kwargs.get("device", "cuda")
    kwargs["device"] = "cpu"
    model = AutoModel(**kwargs)
View full source on GitHub →
function

main_hydra(kwargs)

funasr.bin.train_ds · View on GitHub ↗

Main hydra.

Args:

  • kwargs — Additional keyword arguments.
📄 Source code
def main_hydra(kwargs: DictConfig):
    """Main hydra.
    
        Args:
            kwargs: Additional keyword arguments.
        """
    if kwargs.get("debug", False):
        import pdb

        pdb.set_trace()

    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)

    main(**kwargs)
function

main(**kwargs)

funasr.bin.train_ds · View on GitHub ↗

Main.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def main(**kwargs):

    # set random seed
    """Main.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    # open tf32
    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)

    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    if local_rank == 0:
        tables.print()

    use_ddp = world_size > 1
    use_fsdp = kwargs.get("use_fsdp", False)
    use_deepspeed = kwargs.get("use_deepspeed", False)
    if use_deepspeed:
        logging.info(f"use_deepspeed: {use_deepspeed}")
        deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
    elif use_ddp or use_fsdp:
        logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
View full source on GitHub →
function

download_dataset()

funasr.download.download_dataset_from_hub · View on GitHub ↗

Download dataset.

📄 Source code
def download_dataset():
    """Download dataset."""
    pass
function

download_dataset_from_ms(**kwargs)

funasr.download.download_dataset_from_hub · View on GitHub ↗

Download dataset from ms.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def download_dataset_from_ms(**kwargs):
    """Download dataset from ms.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    from modelscope.msdatasets import MsDataset

    dataset_name = kwargs.get("dataset_name", "speech_asr/speech_asr_aishell1_trainsets")
    subset_name = kwargs.get("subset_name", "default")
    split = kwargs.get("split", "train")
    data_dump_dir = kwargs.get("data_dump_dir", None)
    ds = MsDataset.load(
        dataset_name=dataset_name, subset_name=subset_name, split=split, cache_dir=data_dump_dir
    )
function

download_model(**kwargs)

funasr.download.download_model_from_hub · View on GitHub ↗

Download model from hub and parse its configuration.

Resolves model name aliases, downloads from ModelScope or HuggingFace,

reads config.yaml and configuration.json, and returns complete kwargs

for model instantiation.

Args:

  • **kwargs — Must include 'model' (str). Optional: 'hub', 'model_revision',
  • 'is_training', etc.

Returns:

  • dict — Complete kwargs with resolved paths, model class name, and config.
📄 Source code
def download_model(**kwargs):

    """Download model from hub and parse its configuration.

    Resolves model name aliases, downloads from ModelScope or HuggingFace,
    reads config.yaml and configuration.json, and returns complete kwargs
    for model instantiation.

    Args:
        **kwargs: Must include 'model' (str). Optional: 'hub', 'model_revision',
            'is_training', etc.

    Returns:
        dict: Complete kwargs with resolved paths, model class name, and config.
    """
    hub = kwargs.get("hub", "ms")
    if hub == "ms" or hub == "modelscope":
        kwargs = download_from_ms(**kwargs)
    elif hub == "hf" or hub == "huggingface":
        kwargs = download_from_hf(**kwargs)
    elif hub == "openai":
        model_or_path = kwargs.get("model")
        if os.path.exists(model_or_path):
            # local path
            kwargs["model_path"] = model_or_path
            kwargs["model"] = "WhisperWarp"
        else:
            # model name
            if model_or_path in name_maps_openai:
                model_or_path = name_maps_openai[model_or_path]
View full source on GitHub →
function

download_from_ms(**kwargs)

funasr.download.download_model_from_hub · View on GitHub ↗

Download from ms.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def download_from_ms(**kwargs):
    """Download from ms.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    model_or_path = kwargs.get("model")
    if model_or_path in name_maps_ms:
        model_or_path = name_maps_ms[model_or_path]
    model_revision = kwargs.get("model_revision", "master")
    if not os.path.exists(model_or_path) and "model_path" not in kwargs:
        try:
            model_or_path = get_or_download_model_dir(
                model_or_path,
                model_revision,
                is_training=kwargs.get("is_training"),
                check_latest=kwargs.get("check_latest", True),
            )
        except Exception as e:
            print(f"Download: {model_or_path} failed!: {e}")

    kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
    model_or_path = kwargs["model_path"]
    if os.path.exists(os.path.join(model_or_path, "configuration.json")):
        with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
            conf_json = json.load(f)

            cfg = {}
            if "file_path_metas" in conf_json:
                add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
View full source on GitHub →
function

download_from_hf(**kwargs)

funasr.download.download_model_from_hub · View on GitHub ↗

Download from hf.

Args:

  • **kwargs — Additional keyword arguments.
📄 Source code
def download_from_hf(**kwargs):
    """Download from hf.
    
        Args:
            **kwargs: Additional keyword arguments.
        """
    model_or_path = kwargs.get("model")
    if model_or_path in name_maps_hf:
        model_or_path = name_maps_hf[model_or_path]
    model_revision = kwargs.get("model_revision", "master")
    if not os.path.exists(model_or_path) and "model_path" not in kwargs:
        try:
            model_or_path = get_or_download_model_dir_hf(
                model_or_path,
                model_revision,
                is_training=kwargs.get("is_training"),
                check_latest=kwargs.get("check_latest", True),
            )
        except Exception as e:
            print(f"Download: {model_or_path} failed!: {e}")

    kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]

    if os.path.exists(os.path.join(model_or_path, "configuration.json")):
        with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
            conf_json = json.load(f)

            cfg = {}
            if "file_path_metas" in conf_json:
                add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
View full source on GitHub →
function

add_file_root_path(model_or_path, file_path_metas, cfg)

funasr.download.download_model_from_hub · View on GitHub ↗

Add file root path.

Args:

  • model_or_path — TODO.
  • file_path_metas — TODO.
  • cfg — Configuration overrides.
📄 Source code
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):

    """Add file root path.
    
        Args:
            model_or_path: TODO.
            file_path_metas: TODO.
            cfg: Configuration overrides.
        """
    if isinstance(file_path_metas, dict):
        if isinstance(cfg, list):
            cfg.append({})

        for k, v in file_path_metas.items():
            if isinstance(v, str):
                p = os.path.join(model_or_path, v)
                if os.path.exists(p):
                    if isinstance(cfg, dict):
                        cfg[k] = p
                    elif isinstance(cfg, list):
                        # if len(cfg) == 0:
                        # cfg.append({})
                        cfg[-1][k] = p

            elif isinstance(v, dict):
                if isinstance(cfg, dict):
                    if k not in cfg:
                        cfg[k] = {}
                    add_file_root_path(model_or_path, v, cfg[k])
                # elif isinstance(cfg, list):
View full source on GitHub →
function

get_or_download_model_dir(model, model_revision, is_training, check_latest)

funasr.download.download_model_from_hub · View on GitHub ↗

Get local model directory or download model if necessary.

Args:

  • model (str) — model id or path to local model directory.
  • model_revision (str, optional) — model version number.

:param is_training:

📄 Source code
def get_or_download_model_dir(
    model,
    model_revision=None,
    is_training=False,
    check_latest=True,
):
    """Get local model directory or download model if necessary.

    Args:
        model (str): model id or path to local model directory.
        model_revision  (str, optional): model version number.
        :param is_training:
    """
    from modelscope.hub.check_model import check_local_model_is_latest
    from modelscope.hub.snapshot_download import snapshot_download

    from modelscope.utils.constant import Invoke, ThirdParty

    key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE

    if os.path.exists(model) and check_latest:
        model_cache_dir = model if os.path.isdir(model) else os.path.dirname(model)
        try:
            check_local_model_is_latest(
                model_cache_dir, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}
            )
        except:
            print("could not check the latest version")
    else:
        model_cache_dir = snapshot_download(
View full source on GitHub →
function

get_or_download_model_dir_hf(model, model_revision, is_training, check_latest)

funasr.download.download_model_from_hub · View on GitHub ↗

Get local model directory or download model if necessary.

Args:

  • model (str) — model id or path to local model directory.
  • model_revision (str, optional) — model version number.

:param is_training:

📄 Source code
def get_or_download_model_dir_hf(
    model,
    model_revision=None,
    is_training=False,
    check_latest=True,
):
    """Get local model directory or download model if necessary.

    Args:
        model (str): model id or path to local model directory.
        model_revision  (str, optional): model version number.
        :param is_training:
    """
    from huggingface_hub import snapshot_download

    model_cache_dir = snapshot_download(model)
    return model_cache_dir
function

download_from_url(url)

funasr.download.file · View on GitHub ↗

Download from url.

Args:

  • url — TODO.
📄 Source code
def download_from_url(url):
    """Download from url.
    
        Args:
            url: TODO.
        """
    result = urlparse(url)
    file_path = None
    if result.scheme is not None and len(result.scheme) > 0:
        storage = HTTPStorage()
        # bytes
        data = storage.read(url)
        work_dir = tempfile.TemporaryDirectory().name
        if not os.path.exists(work_dir):
            os.makedirs(work_dir)
        file_path = os.path.join(work_dir, os.path.basename(url))
        with open(file_path, "wb") as fb:
            fb.write(data)
    assert file_path is not None, f"failed to download: {url}"
    return file_path
class

Storage

funasr.download.file · View on GitHub ↗

Abstract class of storage.

All backends need to implement two apis: ``read()`` and ``read_text()``.

``read()`` reads the file as a byte stream and ``read_text()`` reads

the file as texts.

📄 Source code
class Storage(metaclass=ABCMeta):
    """Abstract class of storage.

    All backends need to implement two apis: ``read()`` and ``read_text()``.
    ``read()`` reads the file as a byte stream and ``read_text()`` reads
    the file as texts.
    """

    @abstractmethod
    def read(self, filepath: str):
        """Read.
        
            Args:
                filepath: TODO.
            """
        pass

    @abstractmethod
    def read_text(self, filepath: str):
        """Read text.
        
            Args:
                filepath: TODO.
            """
        pass

    @abstractmethod
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        """Write.
        
View full source on GitHub →

Methods

.read(filepath) L45

Read.

Args:

  • filepath — TODO.
📄 Source
    def read(self, filepath: str):
        """Read.
        
            Args:
                filepath: TODO.
            """
        pass
.read_text(filepath) L54

Read text.

Args:

  • filepath — TODO.
📄 Source
    def read_text(self, filepath: str):
        """Read text.
        
            Args:
                filepath: TODO.
            """
        pass
.write(obj, filepath) L63

Write.

Args:

  • obj — TODO.
  • filepath — TODO.
📄 Source
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        """Write.
        
            Args:
                obj: TODO.
                filepath: TODO.
            """
        pass
.write_text(obj, filepath, encoding) L73

Write text.

Args:

  • obj — TODO.
  • filepath — TODO.
  • encoding — TODO.
📄 Source
    def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
        """Write text.
        
            Args:
                obj: TODO.
                filepath: TODO.
                encoding: TODO.
            """
        pass
method

Storage.read(filepath)

funasr.download.file.Storage · View on GitHub ↗

Read.

Args:

  • filepath — TODO.
📄 Source code
    def read(self, filepath: str):
        """Read.
        
            Args:
                filepath: TODO.
            """
        pass
method

Storage.read_text(filepath)

funasr.download.file.Storage · View on GitHub ↗

Read text.

Args:

  • filepath — TODO.
📄 Source code
    def read_text(self, filepath: str):
        """Read text.
        
            Args:
                filepath: TODO.
            """
        pass
method

Storage.write(obj, filepath)

funasr.download.file.Storage · View on GitHub ↗

Write.

Args:

  • obj — TODO.
  • filepath — TODO.
📄 Source code
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        """Write.
        
            Args:
                obj: TODO.
                filepath: TODO.
            """
        pass
method

Storage.write_text(obj, filepath, encoding)

funasr.download.file.Storage · View on GitHub ↗

Write text.

Args:

  • obj — TODO.
  • filepath — TODO.
  • encoding — TODO.
📄 Source code
    def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
        """Write text.
        
            Args:
                obj: TODO.
                filepath: TODO.
                encoding: TODO.
            """
        pass
class

LocalStorage

funasr.download.file · View on GitHub ↗

Local hard disk storage

📄 Source code
class LocalStorage(Storage):
    """Local hard disk storage"""

    def read(self, filepath: Union[str, Path]) -> bytes:
        """Read data from a given ``filepath`` with 'rb' mode.

        Args:
            filepath (str or Path): Path to read data.

        Returns:
            bytes: Expected bytes object.
        """
        with open(filepath, "rb") as f:
            content = f.read()
        return content

    def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
        """Read data from a given ``filepath`` with 'r' mode.

        Args:
            filepath (str or Path): Path to read data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.

        Returns:
            str: Expected text reading from ``filepath``.
        """
        with open(filepath, "r", encoding=encoding) as f:
            value_buf = f.read()
        return value_buf
View full source on GitHub →

Methods

.read(filepath) L87

Read data from a given ``filepath`` with 'rb' mode.

Args:

  • filepath (str or Path) — Path to read data.

Returns:

  • bytes — Expected bytes object.
📄 Source
    def read(self, filepath: Union[str, Path]) -> bytes:
        """Read data from a given ``filepath`` with 'rb' mode.

        Args:
            filepath (str or Path): Path to read data.

        Returns:
            bytes: Expected bytes object.
        """
        with open(filepath, "rb") as f:
            content = f.read()
        return content
.read_text(filepath, encoding) L100

Read data from a given ``filepath`` with 'r' mode.

Args:

  • filepath (str or Path) — Path to read data.
  • encoding (str) — The encoding format used to open the ``filepath``.
  • Default — 'utf-8'.

Returns:

  • str — Expected text reading from ``filepath``.
📄 Source
    def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
        """Read data from a given ``filepath`` with 'r' mode.

        Args:
            filepath (str or Path): Path to read data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.

        Returns:
            str: Expected text reading from ``filepath``.
        """
        with open(filepath, "r", encoding=encoding) as f:
            value_buf = f.read()
        return value_buf
.write(obj, filepath) L115

Write data to a given ``filepath`` with 'wb' mode.

Note:

``write`` will create a directory if the directory of ``filepath``

does not exist.

Args:

  • obj (bytes) — Data to be written.
  • filepath (str or Path) — Path to write data.
📄 Source
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        """Write data to a given ``filepath`` with 'wb' mode.

        Note:
            ``write`` will create a directory if the directory of ``filepath``
            does not exist.

        Args:
            obj (bytes): Data to be written.
            filepath (str or Path): Path to write data.
        """
        dirname = os.path.dirname(filepath)
        if dirname and not os.path.exists(dirname):
            os.makedirs(dirname, exist_ok=True)

        with open(filepath, "wb") as f:
            f.write(obj)
.write_text(obj, filepath, encoding) L133

Write data to a given ``filepath`` with 'w' mode.

Note:

``write_text`` will create a directory if the directory of

``filepath`` does not exist.

Args:

  • obj (str) — Data to be written.
  • filepath (str or Path) — Path to write data.
  • encoding (str) — The encoding format used to open the ``filepath``.
  • Default — 'utf-8'.
📄 Source
    def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
        """Write data to a given ``filepath`` with 'w' mode.

        Note:
            ``write_text`` will create a directory if the directory of
            ``filepath`` does not exist.

        Args:
            obj (str): Data to be written.
            filepath (str or Path): Path to write data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.
        """
        dirname = os.path.dirname(filepath)
        if dirname and not os.path.exists(dirname):
            os.makedirs(dirname, exist_ok=True)

        with open(filepath, "w", encoding=encoding) as f:
            f.write(obj)
.as_local_path(filepath) L154

Only for unified API and do nothing.

📄 Source
    def as_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
        """Only for unified API and do nothing."""
        yield filepath
method

LocalStorage.read(filepath)

funasr.download.file.LocalStorage · View on GitHub ↗

Read data from a given ``filepath`` with 'rb' mode.

Args:

  • filepath (str or Path) — Path to read data.

Returns:

  • bytes — Expected bytes object.
📄 Source code
    def read(self, filepath: Union[str, Path]) -> bytes:
        """Read data from a given ``filepath`` with 'rb' mode.

        Args:
            filepath (str or Path): Path to read data.

        Returns:
            bytes: Expected bytes object.
        """
        with open(filepath, "rb") as f:
            content = f.read()
        return content
method

LocalStorage.read_text(filepath, encoding)

funasr.download.file.LocalStorage · View on GitHub ↗

Read data from a given ``filepath`` with 'r' mode.

Args:

  • filepath (str or Path) — Path to read data.
  • encoding (str) — The encoding format used to open the ``filepath``.
  • Default — 'utf-8'.

Returns:

  • str — Expected text reading from ``filepath``.
📄 Source code
    def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
        """Read data from a given ``filepath`` with 'r' mode.

        Args:
            filepath (str or Path): Path to read data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.

        Returns:
            str: Expected text reading from ``filepath``.
        """
        with open(filepath, "r", encoding=encoding) as f:
            value_buf = f.read()
        return value_buf
method

LocalStorage.write(obj, filepath)

funasr.download.file.LocalStorage · View on GitHub ↗

Write data to a given ``filepath`` with 'wb' mode.

Note:

``write`` will create a directory if the directory of ``filepath``

does not exist.

Args:

  • obj (bytes) — Data to be written.
  • filepath (str or Path) — Path to write data.
📄 Source code
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        """Write data to a given ``filepath`` with 'wb' mode.

        Note:
            ``write`` will create a directory if the directory of ``filepath``
            does not exist.

        Args:
            obj (bytes): Data to be written.
            filepath (str or Path): Path to write data.
        """
        dirname = os.path.dirname(filepath)
        if dirname and not os.path.exists(dirname):
            os.makedirs(dirname, exist_ok=True)

        with open(filepath, "wb") as f:
            f.write(obj)
method

LocalStorage.write_text(obj, filepath, encoding)

funasr.download.file.LocalStorage · View on GitHub ↗

Write data to a given ``filepath`` with 'w' mode.

Note:

``write_text`` will create a directory if the directory of

``filepath`` does not exist.

Args:

  • obj (str) — Data to be written.
  • filepath (str or Path) — Path to write data.
  • encoding (str) — The encoding format used to open the ``filepath``.
  • Default — 'utf-8'.
📄 Source code
    def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
        """Write data to a given ``filepath`` with 'w' mode.

        Note:
            ``write_text`` will create a directory if the directory of
            ``filepath`` does not exist.

        Args:
            obj (str): Data to be written.
            filepath (str or Path): Path to write data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.
        """
        dirname = os.path.dirname(filepath)
        if dirname and not os.path.exists(dirname):
            os.makedirs(dirname, exist_ok=True)

        with open(filepath, "w", encoding=encoding) as f:
            f.write(obj)
method

LocalStorage.as_local_path(filepath)

funasr.download.file.LocalStorage · View on GitHub ↗

Only for unified API and do nothing.

📄 Source code
    def as_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
        """Only for unified API and do nothing."""
        yield filepath
class

HTTPStorage

funasr.download.file · View on GitHub ↗

HTTP and HTTPS storage.

📄 Source code
class HTTPStorage(Storage):
    """HTTP and HTTPS storage."""

    def read(self, url):
        # TODO @wenmeng.zwm add progress bar if file is too large
        """Read.
        
            Args:
                url: TODO.
            """
        r = requests.get(url)
        r.raise_for_status()
        return r.content

    def read_text(self, url):
        """Read text.
        
            Args:
                url: TODO.
            """
        r = requests.get(url)
        r.raise_for_status()
        return r.text

    @contextlib.contextmanager
    def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
        """Download a file from ``filepath``.

        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
        can be called with ``with`` statement, and when exists from the
View full source on GitHub →

Methods

.read(url) L162

Read.

Args:

  • url — TODO.
📄 Source
    def read(self, url):
        # TODO @wenmeng.zwm add progress bar if file is too large
        """Read.
        
            Args:
                url: TODO.
            """
        r = requests.get(url)
        r.raise_for_status()
        return r.content
.read_text(url) L173

Read text.

Args:

  • url — TODO.
📄 Source
    def read_text(self, url):
        """Read text.
        
            Args:
                url: TODO.
            """
        r = requests.get(url)
        r.raise_for_status()
        return r.text
.as_local_path(filepath) L184

Download a file from ``filepath``.

``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It

can be called with ``with`` statement, and when exists from the

``with`` statement, the temporary path will be released.

Args:

  • filepath (str) — Download a file from ``filepath``.

Examples:

>>> storage = HTTPStorage()

>>> # After existing from the ``with`` clause,

>>> # the path will be removed

>>> with storage.get_local_path('http://path/to/file') as path:

... # do something here

📄 Source
    def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
        """Download a file from ``filepath``.

        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
        can be called with ``with`` statement, and when exists from the
        ``with`` statement, the temporary path will be released.

        Args:
            filepath (str): Download a file from ``filepath``.

        Examples:
            >>> storage = HTTPStorage()
            >>> # After existing from the ``with`` clause,
            >>> # the path will be removed
            >>> with storage.get_local_path('http://path/to/file') as path:
            ...     # do something here
        """
        try:
            f = tempfile.NamedTemporaryFile(delete=False)
            f.write(self.read(filepath))
            f.close()
            yield f.name
        finally:
            os.remove(f.name)
.write(obj, url) L209

Write.

Args:

  • obj — TODO.
  • url — TODO.
📄 Source
    def write(self, obj: bytes, url: Union[str, Path]) -> None:
        """Write.
        
            Args:
                obj: TODO.
                url: TODO.
            """
        raise NotImplementedError("write is not supported by HTTP Storage")
.write_text(obj, url, encoding) L218

Write text.

Args:

  • obj — TODO.
  • url — TODO.
  • encoding — TODO.
📄 Source
    def write_text(self, obj: str, url: Union[str, Path], encoding: str = "utf-8") -> None:
        """Write text.
        
            Args:
                obj: TODO.
                url: TODO.
                encoding: TODO.
            """
        raise NotImplementedError("write_text is not supported by HTTP Storage")
method

HTTPStorage.read(url)

funasr.download.file.HTTPStorage · View on GitHub ↗

Read.

Args:

  • url — TODO.
📄 Source code
    def read(self, url):
        # TODO @wenmeng.zwm add progress bar if file is too large
        """Read.
        
            Args:
                url: TODO.
            """
        r = requests.get(url)
        r.raise_for_status()
        return r.content
method

HTTPStorage.read_text(url)

funasr.download.file.HTTPStorage · View on GitHub ↗

Read text.

Args:

  • url — TODO.
📄 Source code
    def read_text(self, url):
        """Read text.
        
            Args:
                url: TODO.
            """
        r = requests.get(url)
        r.raise_for_status()
        return r.text
method

HTTPStorage.as_local_path(filepath)

funasr.download.file.HTTPStorage · View on GitHub ↗

Download a file from ``filepath``.

``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It

can be called with ``with`` statement, and when exists from the

``with`` statement, the temporary path will be released.

Args:

  • filepath (str) — Download a file from ``filepath``.

Examples:

>>> storage = HTTPStorage()

>>> # After existing from the ``with`` clause,

>>> # the path will be removed

>>> with storage.get_local_path('http://path/to/file') as path:

... # do something here

📄 Source code
    def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
        """Download a file from ``filepath``.

        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
        can be called with ``with`` statement, and when exists from the
        ``with`` statement, the temporary path will be released.

        Args:
            filepath (str): Download a file from ``filepath``.

        Examples:
            >>> storage = HTTPStorage()
            >>> # After existing from the ``with`` clause,
            >>> # the path will be removed
            >>> with storage.get_local_path('http://path/to/file') as path:
            ...     # do something here
        """
        try:
            f = tempfile.NamedTemporaryFile(delete=False)
            f.write(self.read(filepath))
            f.close()
            yield f.name
        finally:
            os.remove(f.name)
method

HTTPStorage.write(obj, url)

funasr.download.file.HTTPStorage · View on GitHub ↗

Write.

Args:

  • obj — TODO.
  • url — TODO.
📄 Source code
    def write(self, obj: bytes, url: Union[str, Path]) -> None:
        """Write.
        
            Args:
                obj: TODO.
                url: TODO.
            """
        raise NotImplementedError("write is not supported by HTTP Storage")
method

HTTPStorage.write_text(obj, url, encoding)

funasr.download.file.HTTPStorage · View on GitHub ↗

Write text.

Args:

  • obj — TODO.
  • url — TODO.
  • encoding — TODO.
📄 Source code
    def write_text(self, obj: str, url: Union[str, Path], encoding: str = "utf-8") -> None:
        """Write text.
        
            Args:
                obj: TODO.
                url: TODO.
                encoding: TODO.
            """
        raise NotImplementedError("write_text is not supported by HTTP Storage")
class

OSSStorage

funasr.download.file · View on GitHub ↗

OSS storage.

📄 Source code
class OSSStorage(Storage):
    """OSS storage."""

    def __init__(self, oss_config_file=None):
        # read from config file or env var
        """Initialize OSSStorage.
        
            Args:
                oss_config_file: TODO.
            """
        raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")

    def read(self, filepath):
        """Read.
        
            Args:
                filepath: TODO.
            """
        raise NotImplementedError("OSSStorage.read to be implemented in the future")

    def read_text(self, filepath, encoding="utf-8"):
        """Read text.
        
            Args:
                filepath: TODO.
                encoding: TODO.
            """
        raise NotImplementedError("OSSStorage.read_text to be implemented in the future")

    @contextlib.contextmanager
View full source on GitHub →

Methods

.read(filepath) L241

Read.

Args:

  • filepath — TODO.
📄 Source
    def read(self, filepath):
        """Read.
        
            Args:
                filepath: TODO.
            """
        raise NotImplementedError("OSSStorage.read to be implemented in the future")
.read_text(filepath, encoding) L249

Read text.

Args:

  • filepath — TODO.
  • encoding — TODO.
📄 Source
    def read_text(self, filepath, encoding="utf-8"):
        """Read text.
        
            Args:
                filepath: TODO.
                encoding: TODO.
            """
        raise NotImplementedError("OSSStorage.read_text to be implemented in the future")
.as_local_path(filepath) L259

Download a file from ``filepath``.

``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It

can be called with ``with`` statement, and when exists from the

``with`` statement, the temporary path will be released.

Args:

  • filepath (str) — Download a file from ``filepath``.

Examples:

>>> storage = OSSStorage()

>>> # After existing from the ``with`` clause,

>>> # the path will be removed

>>> with storage.get_local_path('http://path/to/file') as path:

... # do something here

📄 Source
    def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
        """Download a file from ``filepath``.

        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
        can be called with ``with`` statement, and when exists from the
        ``with`` statement, the temporary path will be released.

        Args:
            filepath (str): Download a file from ``filepath``.

        Examples:
            >>> storage = OSSStorage()
            >>> # After existing from the ``with`` clause,
            >>> # the path will be removed
            >>> with storage.get_local_path('http://path/to/file') as path:
            ...     # do something here
        """
        try:
            f = tempfile.NamedTemporaryFile(delete=False)
            f.write(self.read(filepath))
            f.close()
            yield f.name
        finally:
            os.remove(f.name)
.write(obj, filepath) L284

Write.

Args:

  • obj — TODO.
  • filepath — TODO.
📄 Source
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        """Write.
        
            Args:
                obj: TODO.
                filepath: TODO.
            """
        raise NotImplementedError("OSSStorage.write to be implemented in the future")
.write_text(obj, filepath, encoding) L293

Write text.

Args:

  • obj — TODO.
  • filepath — TODO.
  • encoding — TODO.
📄 Source
    def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
        """Write text.
        
            Args:
                obj: TODO.
                filepath: TODO.
                encoding: TODO.
            """
        raise NotImplementedError("OSSStorage.write_text to be implemented in the future")
method

OSSStorage.read(filepath)

funasr.download.file.OSSStorage · View on GitHub ↗

Read.

Args:

  • filepath — TODO.
📄 Source code
    def read(self, filepath):
        """Read.
        
            Args:
                filepath: TODO.
            """
        raise NotImplementedError("OSSStorage.read to be implemented in the future")
method

OSSStorage.read_text(filepath, encoding)

funasr.download.file.OSSStorage · View on GitHub ↗

Read text.

Args:

  • filepath — TODO.
  • encoding — TODO.
📄 Source code
    def read_text(self, filepath, encoding="utf-8"):
        """Read text.
        
            Args:
                filepath: TODO.
                encoding: TODO.
            """
        raise NotImplementedError("OSSStorage.read_text to be implemented in the future")
method

OSSStorage.as_local_path(filepath)

funasr.download.file.OSSStorage · View on GitHub ↗

Download a file from ``filepath``.

``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It

can be called with ``with`` statement, and when exists from the

``with`` statement, the temporary path will be released.

Args:

  • filepath (str) — Download a file from ``filepath``.

Examples:

>>> storage = OSSStorage()

>>> # After existing from the ``with`` clause,

>>> # the path will be removed

>>> with storage.get_local_path('http://path/to/file') as path:

... # do something here

📄 Source code
    def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
        """Download a file from ``filepath``.

        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
        can be called with ``with`` statement, and when exists from the
        ``with`` statement, the temporary path will be released.

        Args:
            filepath (str): Download a file from ``filepath``.

        Examples:
            >>> storage = OSSStorage()
            >>> # After existing from the ``with`` clause,
            >>> # the path will be removed
            >>> with storage.get_local_path('http://path/to/file') as path:
            ...     # do something here
        """
        try:
            f = tempfile.NamedTemporaryFile(delete=False)
            f.write(self.read(filepath))
            f.close()
            yield f.name
        finally:
            os.remove(f.name)
method

OSSStorage.write(obj, filepath)

funasr.download.file.OSSStorage · View on GitHub ↗

Write.

Args:

  • obj — TODO.
  • filepath — TODO.
📄 Source code
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        """Write.
        
            Args:
                obj: TODO.
                filepath: TODO.
            """
        raise NotImplementedError("OSSStorage.write to be implemented in the future")
method

OSSStorage.write_text(obj, filepath, encoding)

funasr.download.file.OSSStorage · View on GitHub ↗

Write text.

Args:

  • obj — TODO.
  • filepath — TODO.
  • encoding — TODO.
📄 Source code
    def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
        """Write text.
        
            Args:
                obj: TODO.
                filepath: TODO.
                encoding: TODO.
            """
        raise NotImplementedError("OSSStorage.write_text to be implemented in the future")
class

File

funasr.download.file · View on GitHub ↗

No documentation yet.

📄 Source code
class File(object):
    _prefix_to_storage: dict = {
        "oss": OSSStorage,
        "http": HTTPStorage,
        "https": HTTPStorage,
        "local": LocalStorage,
    }

    @staticmethod
    def _get_storage(uri):
        """Internal: get storage.
        
            Args:
                uri: TODO.
            """
        assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"

        if "://" not in uri:
            # local path
            storage_type = "local"
        else:
            prefix, _ = uri.split("://")
            storage_type = prefix

        assert storage_type in File._prefix_to_storage, (
            f"Unsupported uri {uri}, valid prefixs: " f"{list(File._prefix_to_storage.keys())}"
        )

        if storage_type not in G_STORAGES:
            G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
View full source on GitHub →

Methods

.read(uri) L341

Read data from a given ``filepath`` with 'rb' mode.

Args:

  • filepath (str or Path) — Path to read data.

Returns:

  • bytes — Expected bytes object.
📄 Source
    def read(uri: str) -> bytes:
        """Read data from a given ``filepath`` with 'rb' mode.

        Args:
            filepath (str or Path): Path to read data.

        Returns:
            bytes: Expected bytes object.
        """
        storage = File._get_storage(uri)
        return storage.read(uri)
.read_text(uri, encoding) L354

Read data from a given ``filepath`` with 'r' mode.

Args:

  • filepath (str or Path) — Path to read data.
  • encoding (str) — The encoding format used to open the ``filepath``.
  • Default — 'utf-8'.

Returns:

  • str — Expected text reading from ``filepath``.
📄 Source
    def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
        """Read data from a given ``filepath`` with 'r' mode.

        Args:
            filepath (str or Path): Path to read data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.

        Returns:
            str: Expected text reading from ``filepath``.
        """
        storage = File._get_storage(uri)
        return storage.read_text(uri)
.write(obj, uri) L369

Write data to a given ``filepath`` with 'wb' mode.

Note:

``write`` will create a directory if the directory of ``filepath``

does not exist.

Args:

  • obj (bytes) — Data to be written.
  • filepath (str or Path) — Path to write data.
📄 Source
    def write(obj: bytes, uri: Union[str, Path]) -> None:
        """Write data to a given ``filepath`` with 'wb' mode.

        Note:
            ``write`` will create a directory if the directory of ``filepath``
            does not exist.

        Args:
            obj (bytes): Data to be written.
            filepath (str or Path): Path to write data.
        """
        storage = File._get_storage(uri)
        return storage.write(obj, uri)
.write_text(obj, uri, encoding) L384

Write data to a given ``filepath`` with 'w' mode.

Note:

``write_text`` will create a directory if the directory of

``filepath`` does not exist.

Args:

  • obj (str) — Data to be written.
  • filepath (str or Path) — Path to write data.
  • encoding (str) — The encoding format used to open the ``filepath``.
  • Default — 'utf-8'.
📄 Source
    def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
        """Write data to a given ``filepath`` with 'w' mode.

        Note:
            ``write_text`` will create a directory if the directory of
            ``filepath`` does not exist.

        Args:
            obj (str): Data to be written.
            filepath (str or Path): Path to write data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.
        """
        storage = File._get_storage(uri)
        return storage.write_text(obj, uri)
.as_local_path(uri) L401

Only for unified API and do nothing.

📄 Source
    def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
        """Only for unified API and do nothing."""
        storage = File._get_storage(uri)
        with storage.as_local_path(uri) as local_path:
            yield local_path
method

File.read(uri)

funasr.download.file.File · View on GitHub ↗

Read data from a given ``filepath`` with 'rb' mode.

Args:

  • filepath (str or Path) — Path to read data.

Returns:

  • bytes — Expected bytes object.
📄 Source code
    def read(uri: str) -> bytes:
        """Read data from a given ``filepath`` with 'rb' mode.

        Args:
            filepath (str or Path): Path to read data.

        Returns:
            bytes: Expected bytes object.
        """
        storage = File._get_storage(uri)
        return storage.read(uri)
method

File.read_text(uri, encoding)

funasr.download.file.File · View on GitHub ↗

Read data from a given ``filepath`` with 'r' mode.

Args:

  • filepath (str or Path) — Path to read data.
  • encoding (str) — The encoding format used to open the ``filepath``.
  • Default — 'utf-8'.

Returns:

  • str — Expected text reading from ``filepath``.
📄 Source code
    def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
        """Read data from a given ``filepath`` with 'r' mode.

        Args:
            filepath (str or Path): Path to read data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.

        Returns:
            str: Expected text reading from ``filepath``.
        """
        storage = File._get_storage(uri)
        return storage.read_text(uri)
method

File.write(obj, uri)

funasr.download.file.File · View on GitHub ↗

Write data to a given ``filepath`` with 'wb' mode.

Note:

``write`` will create a directory if the directory of ``filepath``

does not exist.

Args:

  • obj (bytes) — Data to be written.
  • filepath (str or Path) — Path to write data.
📄 Source code
    def write(obj: bytes, uri: Union[str, Path]) -> None:
        """Write data to a given ``filepath`` with 'wb' mode.

        Note:
            ``write`` will create a directory if the directory of ``filepath``
            does not exist.

        Args:
            obj (bytes): Data to be written.
            filepath (str or Path): Path to write data.
        """
        storage = File._get_storage(uri)
        return storage.write(obj, uri)
method

File.write_text(obj, uri, encoding)

funasr.download.file.File · View on GitHub ↗

Write data to a given ``filepath`` with 'w' mode.

Note:

``write_text`` will create a directory if the directory of

``filepath`` does not exist.

Args:

  • obj (str) — Data to be written.
  • filepath (str or Path) — Path to write data.
  • encoding (str) — The encoding format used to open the ``filepath``.
  • Default — 'utf-8'.
📄 Source code
    def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
        """Write data to a given ``filepath`` with 'w' mode.

        Note:
            ``write_text`` will create a directory if the directory of
            ``filepath`` does not exist.

        Args:
            obj (str): Data to be written.
            filepath (str or Path): Path to write data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.
        """
        storage = File._get_storage(uri)
        return storage.write_text(obj, uri)
method

File.as_local_path(uri)

funasr.download.file.File · View on GitHub ↗

Only for unified API and do nothing.

📄 Source code
    def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
        """Only for unified API and do nothing."""
        storage = File._get_storage(uri)
        with storage.as_local_path(uri) as local_path:
            yield local_path
function

main()

funasr.download.runtime_sdk_download_tool · View on GitHub ↗

Main.

📄 Source code
def main():
    """Main."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--export-dir", type=str, required=True)
    parser.add_argument("--export", type=str2bool, default=True, help="whether to export model")
    parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torchscript", "bladedisc"]')
    parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]')
    parser.add_argument("--quantize", type=str2bool, default=False, help="export quantized model")
    parser.add_argument("--fallback-num", type=int, default=0, help="amp fallback number")
    parser.add_argument("--audio_in", type=str, default=None, help='["wav", "wav.scp"]')
    parser.add_argument("--model_revision", type=str, default=None, help="model_revision")
    parser.add_argument("--calib_num", type=int, default=200, help="calib max num")
    args = parser.parse_args()

    model_dir = args.model_name
    output_dir = args.model_name
    if not Path(args.model_name).exists():
        from modelscope.hub.snapshot_download import snapshot_download

        try:
            model_dir = snapshot_download(
                args.model_name, cache_dir=args.export_dir, revision=args.model_revision
            )
            output_dir = os.path.join(args.export_dir, args.model_name)
        except:
            raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                model_dir
            )
    if args.export:
View full source on GitHub →
function

add_gradient_noise(model, iteration, duration, eta, scale_factor)

funasr.train_utils.add_gradient_noise · View on GitHub ↗

Adds noise from a standard normal distribution to the gradients.

The standard deviation (`sigma`) is controlled

by the three hyper-parameters below.

`sigma` goes to zero (no noise) with more iterations.

Args:

  • model — Model.
  • iteration — Number of iterations.
  • duration — {100, 1000}: Number of durations to control
  • the interval of the `sigma` change.
  • eta — {0.01, 0.3, 1.0}: The magnitude of `sigma`.
  • scale_factor — {0.55}: The scale of `sigma`.
📄 Source code
def add_gradient_noise(
    model: torch.nn.Module,
    iteration: int,
    duration: float = 100,
    eta: float = 1.0,
    scale_factor: float = 0.55,
):
    """Adds noise from a standard normal distribution to the gradients.

    The standard deviation (`sigma`) is controlled
    by the three hyper-parameters below.
    `sigma` goes to zero (no noise) with more iterations.

    Args:
        model: Model.
        iteration: Number of iterations.
        duration: {100, 1000}: Number of durations to control
            the interval of the `sigma` change.
        eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`.
        scale_factor: {0.55}: The scale of `sigma`.
    """
    interval = (iteration // duration) + 1
    sigma = eta / interval**scale_factor
    for param in model.parameters():
        if param.grad is not None:
            _shape = param.grad.size()
            noise = sigma * torch.randn(_shape).to(param.device)
            param.grad += noise
function

average_checkpoints(output_dir, last_n, **kwargs)

funasr.train_utils.average_nbest_models · View on GitHub ↗

Average the last 'last_n' checkpoints' model state_dicts.

If a tensor is of type torch.int, perform sum instead of average.

📄 Source code
def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
    """
    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
    print(f"average_checkpoints: {checkpoint_paths}")
    state_dicts = []

    # Load state_dicts from checkpoints
    for path in checkpoint_paths:
        if os.path.isfile(path):
            state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
        else:
            print(f"Checkpoint file {path} not found.")

    # Check if we have any state_dicts to average
    if len(state_dicts) < 1:
        print("No checkpoints found for averaging.")
        return

    # Average or sum weights
    avg_state_dict = OrderedDict()
    for key in state_dicts[0].keys():
        tensors = [state_dict[key].cpu() for state_dict in state_dicts]
        # Check the type of the tensor
        if str(tensors[0].dtype).startswith("torch.int"):
            # Perform sum for integer tensors
            summed_tensor = sum(tensors)
            avg_state_dict[key] = summed_tensor
View full source on GitHub →
function

to_device(data, device, dtype, non_blocking, copy)

funasr.train_utils.device_funcs · View on GitHub ↗

Change the device of object recursively

📄 Source code
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
    """Change the device of object recursively"""
    if isinstance(data, dict):
        return {k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()}
    elif dataclasses.is_dataclass(data) and not isinstance(data, type):
        return type(data)(
            *[to_device(v, device, dtype, non_blocking, copy) for v in dataclasses.astuple(data)]
        )
    # maybe namedtuple. I don't know the correct way to judge namedtuple.
    elif isinstance(data, tuple) and type(data) is not tuple:
        return type(data)(*[to_device(o, device, dtype, non_blocking, copy) for o in data])
    elif isinstance(data, (list, tuple)):
        return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
    elif isinstance(data, np.ndarray):
        return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
    elif isinstance(data, torch.Tensor):
        return data.to(device, dtype, non_blocking, copy)
    else:
        return data
function

force_gatherable(data, device)

funasr.train_utils.device_funcs · View on GitHub ↗

Change object to gatherable in torch.nn.DataParallel recursively

The difference from to_device() is changing to torch.Tensor if float or int

value is found.

The restriction to the returned value in DataParallel:

The object must be

  • torch.cuda.Tensor
  • 1 or more dimension. 0-dimension-tensor sends warning.

or a list, tuple, dict.

📄 Source code
def force_gatherable(data, device):
    """Change object to gatherable in torch.nn.DataParallel recursively

    The difference from to_device() is changing to torch.Tensor if float or int
    value is found.

    The restriction to the returned value in DataParallel:
        The object must be
        - torch.cuda.Tensor
        - 1 or more dimension. 0-dimension-tensor sends warning.
        or a list, tuple, dict.

    """
    if isinstance(data, dict):
        return {k: force_gatherable(v, device) for k, v in data.items()}
    # DataParallel can't handle NamedTuple well
    elif isinstance(data, tuple) and type(data) is not tuple:
        return type(data)(*[force_gatherable(o, device) for o in data])
    elif isinstance(data, (list, tuple, set)):
        return type(data)(force_gatherable(v, device) for v in data)
    elif isinstance(data, np.ndarray):
        return force_gatherable(torch.from_numpy(data), device)
    elif isinstance(data, torch.Tensor):
        if data.dim() == 0:
            # To 1-dim array
            data = data[None]
        return data.to(device)
    elif isinstance(data, float):
        return torch.tensor([data], dtype=torch.float, device=device)
    elif isinstance(data, int):
View full source on GitHub →
class

ForwardAdaptor

funasr.train_utils.forward_adaptor · View on GitHub ↗

Wrapped module to parallelize specified method

torch.nn.DataParallel parallelizes only "forward()"

and, maybe, the method having the other name can't be applied

except for wrapping the module just like this class.

Examples:

>>> class A(torch.nn.Module):

... def foo(self, x):

... ...

>>> model = A()

>>> model = ForwardAdaptor(model, "foo")

>>> model = torch.nn.DataParallel(model, device_ids=[0, 1])

>>> x = torch.randn(2, 10)

>>> model(x)

📄 Source code
class ForwardAdaptor(torch.nn.Module):
    """Wrapped module to parallelize specified method

    torch.nn.DataParallel parallelizes only "forward()"
    and, maybe, the method having the other name can't be applied
    except for wrapping the module just like this class.

    Examples:
        >>> class A(torch.nn.Module):
        ...     def foo(self, x):
        ...         ...
        >>> model = A()
        >>> model = ForwardAdaptor(model, "foo")
        >>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
        >>> x = torch.randn(2, 10)
        >>> model(x)
    """

    def __init__(self, module: torch.nn.Module, name: str):
        """Initialize ForwardAdaptor.
        
            Args:
                module: TODO.
                name: TODO.
            """
        super().__init__()
        self.module = module
        self.name = name
        if not hasattr(module, name):
            raise ValueError(f"{module} doesn't have {name}")
View full source on GitHub →

Methods

.forward(*args, **kwargs) L35

Forward pass for training.

Args:

  • *args — Variable positional arguments.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(self, *args, **kwargs):
        """Forward pass for training.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        func = getattr(self.module, self.name)
        return func(*args, **kwargs)
method

ForwardAdaptor.forward(*args, **kwargs)

funasr.train_utils.forward_adaptor.ForwardAdaptor · View on GitHub ↗

Forward pass for training.

Args:

  • *args — Variable positional arguments.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(self, *args, **kwargs):
        """Forward pass for training.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        func = getattr(self.module, self.name)
        return func(*args, **kwargs)
function

initialize(model, init)

funasr.train_utils.initialize · View on GitHub ↗

Initialize weights of a neural network module.

Parameters are initialized using the given method or distribution.

Custom initialization routines can be implemented into submodules

as function `espnet_initialization_fn` within the custom module.

Args:

  • model — Target.
  • init — Method of initialization.
📄 Source code
def initialize(model: torch.nn.Module, init: str):
    """Initialize weights of a neural network module.

    Parameters are initialized using the given method or distribution.

    Custom initialization routines can be implemented into submodules
    as function `espnet_initialization_fn` within the custom module.

    Args:
        model: Target.
        init: Method of initialization.
    """

    # weight init
    for p in model.parameters():
        if p.dim() > 1:
            if init == "xavier_uniform":
                torch.nn.init.xavier_uniform_(p.data)
            elif init == "xavier_normal":
                torch.nn.init.xavier_normal_(p.data)
            elif init == "kaiming_uniform":
                torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
            elif init == "kaiming_normal":
                torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
            else:
                raise ValueError("Unknown initialization: " + init)
    # bias init
    for p in model.parameters():
        if p.dim() == 1:
            p.data.zero_()
View full source on GitHub →
function

load_pretrained_model(path, model, ignore_init_mismatch, map_location, oss_bucket, scope_map, excludes, **kwargs)

funasr.train_utils.load_pretrained_model · View on GitHub ↗

Load a model state and set it to the model.

Args:

  • init_param — <file_path>:<src_key>:<dst_key>:<exclude_Keys>

Examples:

📄 Source code
def load_pretrained_model(
    path: str,
    model: torch.nn.Module,
    ignore_init_mismatch: bool = True,
    map_location: str = "cpu",
    oss_bucket=None,
    scope_map=[],
    excludes=None,
    **kwargs,
):
    """Load a model state and set it to the model.

    Args:
            init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>

    Examples:

    """

    obj = model
    dst_state = obj.state_dict()

    logging.info(f"ckpt: {path}")

    if oss_bucket is None:
        ori_state = torch.load(path, map_location=map_location)
    else:
        buffer = BytesIO(oss_bucket.get_object(path).read())
        ori_state = torch.load(buffer, map_location=map_location)

View full source on GitHub →
function

get_human_readable_count(number)

funasr.train_utils.model_summary · View on GitHub ↗

Return human_readable_count

Originated from:

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py

Abbreviates an integer number with K, M, B, T for thousands, millions,

billions and trillions, respectively.

Examples:

>>> get_human_readable_count(123)

'123 '

>>> get_human_readable_count(1234) # (one thousand)

'1 K'

>>> get_human_readable_count(2e6) # (two million)

'2 M'

>>> get_human_readable_count(3e9) # (three billion)

'3 B'

>>> get_human_readable_count(4e12) # (four trillion)

'4 T'

>>> get_human_readable_count(5e15) # (more than trillion)

'5,000 T'

Args:

  • number — a positive integer number

Return:

A string formatted according to the pattern described above.

📄 Source code
def get_human_readable_count(number: int) -> str:
    """Return human_readable_count

    Originated from:
    https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py

    Abbreviates an integer number with K, M, B, T for thousands, millions,
    billions and trillions, respectively.
    Examples:
        >>> get_human_readable_count(123)
        '123  '
        >>> get_human_readable_count(1234)  # (one thousand)
        '1 K'
        >>> get_human_readable_count(2e6)   # (two million)
        '2 M'
        >>> get_human_readable_count(3e9)   # (three billion)
        '3 B'
        >>> get_human_readable_count(4e12)  # (four trillion)
        '4 T'
        >>> get_human_readable_count(5e15)  # (more than trillion)
        '5,000 T'
    Args:
        number: a positive integer number
    Return:
        A string formatted according to the pattern described above.
    """
    assert number >= 0
    labels = [" ", "K", "M", "B", "T"]
    num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
    num_groups = int(np.ceil(num_digits / 3))
View full source on GitHub →
function

to_bytes(dtype)

funasr.train_utils.model_summary · View on GitHub ↗

To bytes.

Args:

  • dtype — TODO.
📄 Source code
def to_bytes(dtype) -> int:
    # torch.float16 -> 16
    """To bytes.
    
        Args:
            dtype: TODO.
        """
    return int(str(dtype)[-2:]) // 8
function

model_summary(model)

funasr.train_utils.model_summary · View on GitHub ↗

Model summary.

Args:

  • model — Model instance or model name.
📄 Source code
def model_summary(model: torch.nn.Module) -> str:
    """Model summary.
    
        Args:
            model: Model instance or model name.
        """
    message = "Model structure:\n"
    message += str(model)

    tot_params, num_params = 0, 0
    for name, param in model.named_parameters():
        print(
            "name: {}, dtype: {}, device: {}, trainable: {}, shape: {}, numel: {}".format(
                name, param.dtype, param.device, param.requires_grad, param.shape, param.numel()
            )
        )
        tot_params += param.numel()
        if param.requires_grad:
            num_params += param.numel()

    percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
    tot_params = get_human_readable_count(tot_params)
    num_params = get_human_readable_count(num_params)
    message += "\n\nModel summary:\n"
    message += f"    Class Name: {model.__class__.__name__}\n"
    message += f"    Total Number of model parameters: {tot_params}\n"
    message += f"    Number of trainable parameters: {num_params} ({percent_trainable}%)\n"

    dtype = next(iter(model.parameters())).dtype
    message += f"    Type: {dtype}"
View full source on GitHub →
function

recursive_sum(obj, weight, distributed)

funasr.train_utils.recursive_op · View on GitHub ↗

Recursive sum.

Args:

  • obj — TODO.
  • weight — TODO.
  • distributed — TODO.
📄 Source code
def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
    """Recursive sum.
    
        Args:
            obj: TODO.
            weight: TODO.
            distributed: TODO.
        """
    assert weight.dim() == 1, weight.size()
    if isinstance(obj, (tuple, list)):
        return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
    elif isinstance(obj, dict):
        return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
    elif isinstance(obj, torch.Tensor):
        assert obj.size() == weight.size(), (obj.size(), weight.size())
        obj = (obj * weight.type(obj.dtype)).sum()
        if distributed:
            torch.distributed.all_reduce(obj, op=ReduceOp.SUM)
        return obj
    elif obj is None:
        return None
    else:
        raise ValueError(type(obj))
function

recursive_divide(a, b)

funasr.train_utils.recursive_op · View on GitHub ↗

Recursive divide.

Args:

  • a — TODO.
  • b — TODO.
📄 Source code
def recursive_divide(a, b: torch.Tensor):
    """Recursive divide.
    
        Args:
            a: TODO.
            b: TODO.
        """
    if isinstance(a, (tuple, list)):
        return type(a)(recursive_divide(v, b) for v in a)
    elif isinstance(a, dict):
        return {k: recursive_divide(v, b) for k, v in a.items()}
    elif isinstance(a, torch.Tensor):
        assert a.size() == b.size(), (a.size(), b.size())
        return a / b.type(a.dtype)
    elif a is None:
        return None
    else:
        raise ValueError(type(a))
function

recursive_average(obj, weight, distributed)

funasr.train_utils.recursive_op · View on GitHub ↗

Recursive average.

Args:

  • obj — TODO.
  • weight — TODO.
  • distributed — TODO.
📄 Source code
def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
    """Recursive average.
    
        Args:
            obj: TODO.
            weight: TODO.
            distributed: TODO.
        """
    obj = recursive_sum(obj, weight, distributed)
    weight = weight.sum()
    if distributed:
        torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
    # Normalize weight to be sum-to-1
    obj = recursive_divide(obj, weight)
    return obj, weight
function

set_all_random_seed(seed)

funasr.train_utils.set_all_random_seed · View on GitHub ↗

Set all random seed.

Args:

  • seed — TODO.
📄 Source code
def set_all_random_seed(seed: int):
    """Set all random seed.
    
        Args:
            seed: TODO.
        """
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
function

maybe_autocast(enabled)

funasr.train_utils.trainer · View on GitHub ↗

Maybe autocast.

Args:

  • enabled — TODO.
📄 Source code
def maybe_autocast(enabled):
    """Maybe autocast.
    
        Args:
            enabled: TODO.
        """
    if enabled:
        with autocast():
            yield
    else:
        yield
class

Trainer

funasr.train_utils.trainer · View on GitHub ↗

A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,

and optionally resuming from a saved checkpoint.

Attributes:

  • max_epoch (int) — Maximum number of epochs for training.
  • model (torch.nn.Module) — The model to be trained.
  • optim (torch.optim.Optimizer) — The optimizer to use for training.
  • scheduler (torch.optim.lr_scheduler._LRScheduler) — The learning rate scheduler.
  • dataloader_train (torch.utils.data.DataLoader) — DataLoader for the training dataset.
  • dataloader_val (torch.utils.data.DataLoader) — DataLoader for the validation dataset.
  • output_dir (str) — Directory where model checkpoints will be saved.
  • resume (str, optional) — Path to a checkpoint to resume training from.
📄 Source code
class Trainer:
    """
    A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
    and optionally resuming from a saved checkpoint.

    Attributes:
        max_epoch (int): Maximum number of epochs for training.
        model (torch.nn.Module): The model to be trained.
        optim (torch.optim.Optimizer): The optimizer to use for training.
        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
        dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
        dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        output_dir (str): Directory where model checkpoints will be saved.
        resume (str, optional): Path to a checkpoint to resume training from.
    """

    def __init__(
        self,
        local_rank,
        use_ddp: bool = False,
        use_fsdp: bool = False,
        use_fp16: bool = False,
        output_dir: str = "./",
        **kwargs,
    ):
        """
        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.

        Args:
            model (torch.nn.Module): The model to be trained.
View full source on GitHub →

Methods

.save_checkpoint(epoch, step, model, optim, scheduler, scaler, step_in_epoch, **kwargs) L143

Saves a checkpoint containing the model's state, the optimizer's state,

and the scheduler's state at the end of the given epoch. This method is

intended to be called at the end of each epoch to save the training progress.

Args:

  • epoch (int) — The epoch number at which the checkpoint is being saved.
📄 Source
    def save_checkpoint(
        self,
        epoch,
        step=None,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        step_in_epoch=None,
        **kwargs,
    ):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
        intended to be called at the end of each epoch to save the training progress.

        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """

        step_in_epoch = None if step is None else step_in_epoch
        if self.rank == 0:
            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
            # self.step_or_epoch += 1
            state = {
                "epoch": epoch,
                "step": step,
                "total_step": self.batch_total,
                "state_dict": model.state_dict(),
                "optimizer": optim.state_dict(),
View full source →
.resume_checkpoint(model, optim, scheduler, scaler) L260

Resumes training from a checkpoint at the given file path.

Loads the model's state, the optimizer's state, and the scheduler's state.

Args:

  • resume_path (str) — The file path to the checkpoint to resume from.
📄 Source
    def resume_checkpoint(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
    ):
        """
        Resumes training from a checkpoint at the given file path.
        Loads the model's state, the optimizer's state, and the scheduler's state.

        Args:
            resume_path (str): The file path to the checkpoint to resume from.
        """
        if self.resume:
            ckpt = os.path.join(self.output_dir, "model.pt")
            if os.path.isfile(ckpt):
                checkpoint = torch.load(ckpt, map_location="cpu")
                self.start_epoch = checkpoint["epoch"]
                # self.model.load_state_dict(checkpoint['state_dict'])
                src_state = checkpoint["state_dict"]
                dst_state = model.state_dict()
                for k in dst_state.keys():
                    if not k.startswith("module.") and "module." + k in src_state.keys():
                        k_ddp = "module." + k
                    elif k.startswith("module.") and "module." + k not in src_state.keys():
                        k_ddp = k.replace("module.", "", 1)
                    else:
                        k_ddp = k

View full source →
.train_epoch(model, optim, scheduler, scaler, dataloader_train, dataloader_val, epoch, writer, **kwargs) L340

Defines the training process for a single epoch with gradient accumulation.

Args:

  • epoch (int) — The current epoch number.
📄 Source
    def train_epoch(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        dataloader_train=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    ):
        """
        Defines the training process for a single epoch with gradient accumulation.
        Args:
            epoch (int): The current epoch number.
        """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
        model.train()

        # Set the number of steps for gradient accumulation
        accum_grad = self.accum_grad
        # Initialize the gradient accumulation
        optim.zero_grad()
        speed_stats = {}

        iterator_stop = torch.tensor(0).to(self.device)

View full source →
.validate_epoch(model, dataloader_val, epoch, writer, **kwargs) L537

Defines the validation process for a single epoch.

Should be implemented with the actual model validation steps.

Args:

  • epoch (int) — The current epoch number.
📄 Source
    def validate_epoch(
        self,
        model=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    ):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.

        Args:
            epoch (int): The current epoch number.
        """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
        model.eval()

        with torch.no_grad():

            speed_stats = {}
            time5 = time.perf_counter()
            iterator_stop = torch.tensor(0).to(self.device)
            dataloader_val.batch_sampler.set_epoch(epoch)
            for batch_idx, batch in enumerate(dataloader_val):
                if self.use_ddp or self.use_fsdp:
                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                    if iterator_stop > 0:
View full source →
.log(epoch, batch_idx, step_in_epoch, batch_num_epoch, lr, loss, speed_stats, stats, writer, tag, data_split_i, data_split_num, log_step, **kwargs) L651

Log.

Args:

  • epoch — TODO.
  • batch_idx — TODO.
  • step_in_epoch — TODO.
  • batch_num_epoch — TODO.
  • lr — TODO.
  • loss — TODO.
  • speed_stats — TODO.
  • stats — TODO.
  • writer — TODO.
  • tag — TODO.
  • data_split_i — TODO.
  • data_split_num — TODO.
  • log_step — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def log(
        self,
        epoch=0,
        batch_idx=0,
        step_in_epoch=0,
        batch_num_epoch=-1,
        lr=0.0,
        loss=0.0,
        speed_stats=None,
        stats=None,
        writer=None,
        tag="train",
        data_split_i=0,
        data_split_num=1,
        log_step=None,
        **kwargs,
    ):

        """Log.
        
            Args:
                epoch: TODO.
                batch_idx: TODO.
                step_in_epoch: TODO.
                batch_num_epoch: TODO.
                lr: TODO.
                loss: TODO.
                speed_stats: TODO.
                stats: TODO.
                writer: TODO.
View full source →
.close(writer) L744

Close.

Args:

  • writer — TODO.
📄 Source
    def close(self, writer=None):

        """Close.
        
            Args:
                writer: TODO.
            """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()

        if writer is not None:
            writer.close()

        if self.use_ddp or self.use_fsdp:
            torch.distributed.destroy_process_group()
method

Trainer.save_checkpoint(epoch, step, model, optim, scheduler, scaler, step_in_epoch, **kwargs)

funasr.train_utils.trainer.Trainer · View on GitHub ↗

Saves a checkpoint containing the model's state, the optimizer's state,

and the scheduler's state at the end of the given epoch. This method is

intended to be called at the end of each epoch to save the training progress.

Args:

  • epoch (int) — The epoch number at which the checkpoint is being saved.
📄 Source code
    def save_checkpoint(
        self,
        epoch,
        step=None,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        step_in_epoch=None,
        **kwargs,
    ):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
        intended to be called at the end of each epoch to save the training progress.

        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """

        step_in_epoch = None if step is None else step_in_epoch
        if self.rank == 0:
            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
            # self.step_or_epoch += 1
            state = {
                "epoch": epoch,
                "step": step,
                "total_step": self.batch_total,
                "state_dict": model.state_dict(),
                "optimizer": optim.state_dict(),
View full source on GitHub →
method

Trainer.resume_checkpoint(model, optim, scheduler, scaler)

funasr.train_utils.trainer.Trainer · View on GitHub ↗

Resumes training from a checkpoint at the given file path.

Loads the model's state, the optimizer's state, and the scheduler's state.

Args:

  • resume_path (str) — The file path to the checkpoint to resume from.
📄 Source code
    def resume_checkpoint(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
    ):
        """
        Resumes training from a checkpoint at the given file path.
        Loads the model's state, the optimizer's state, and the scheduler's state.

        Args:
            resume_path (str): The file path to the checkpoint to resume from.
        """
        if self.resume:
            ckpt = os.path.join(self.output_dir, "model.pt")
            if os.path.isfile(ckpt):
                checkpoint = torch.load(ckpt, map_location="cpu")
                self.start_epoch = checkpoint["epoch"]
                # self.model.load_state_dict(checkpoint['state_dict'])
                src_state = checkpoint["state_dict"]
                dst_state = model.state_dict()
                for k in dst_state.keys():
                    if not k.startswith("module.") and "module." + k in src_state.keys():
                        k_ddp = "module." + k
                    elif k.startswith("module.") and "module." + k not in src_state.keys():
                        k_ddp = k.replace("module.", "", 1)
                    else:
                        k_ddp = k

View full source on GitHub →
method

Trainer.train_epoch(model, optim, scheduler, scaler, dataloader_train, dataloader_val, epoch, writer, **kwargs)

funasr.train_utils.trainer.Trainer · View on GitHub ↗

Defines the training process for a single epoch with gradient accumulation.

Args:

  • epoch (int) — The current epoch number.
📄 Source code
    def train_epoch(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        dataloader_train=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    ):
        """
        Defines the training process for a single epoch with gradient accumulation.
        Args:
            epoch (int): The current epoch number.
        """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
        model.train()

        # Set the number of steps for gradient accumulation
        accum_grad = self.accum_grad
        # Initialize the gradient accumulation
        optim.zero_grad()
        speed_stats = {}

        iterator_stop = torch.tensor(0).to(self.device)

View full source on GitHub →
method

Trainer.validate_epoch(model, dataloader_val, epoch, writer, **kwargs)

funasr.train_utils.trainer.Trainer · View on GitHub ↗

Defines the validation process for a single epoch.

Should be implemented with the actual model validation steps.

Args:

  • epoch (int) — The current epoch number.
📄 Source code
    def validate_epoch(
        self,
        model=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    ):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.

        Args:
            epoch (int): The current epoch number.
        """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
        model.eval()

        with torch.no_grad():

            speed_stats = {}
            time5 = time.perf_counter()
            iterator_stop = torch.tensor(0).to(self.device)
            dataloader_val.batch_sampler.set_epoch(epoch)
            for batch_idx, batch in enumerate(dataloader_val):
                if self.use_ddp or self.use_fsdp:
                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                    if iterator_stop > 0:
View full source on GitHub →
method

Trainer.log(epoch, batch_idx, step_in_epoch, batch_num_epoch, lr, loss, speed_stats, stats, writer, tag, data_split_i, data_split_num, log_step, **kwargs)

funasr.train_utils.trainer.Trainer · View on GitHub ↗

Log.

Args:

  • epoch — TODO.
  • batch_idx — TODO.
  • step_in_epoch — TODO.
  • batch_num_epoch — TODO.
  • lr — TODO.
  • loss — TODO.
  • speed_stats — TODO.
  • stats — TODO.
  • writer — TODO.
  • tag — TODO.
  • data_split_i — TODO.
  • data_split_num — TODO.
  • log_step — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def log(
        self,
        epoch=0,
        batch_idx=0,
        step_in_epoch=0,
        batch_num_epoch=-1,
        lr=0.0,
        loss=0.0,
        speed_stats=None,
        stats=None,
        writer=None,
        tag="train",
        data_split_i=0,
        data_split_num=1,
        log_step=None,
        **kwargs,
    ):

        """Log.
        
            Args:
                epoch: TODO.
                batch_idx: TODO.
                step_in_epoch: TODO.
                batch_num_epoch: TODO.
                lr: TODO.
                loss: TODO.
                speed_stats: TODO.
                stats: TODO.
                writer: TODO.
View full source on GitHub →
method

Trainer.close(writer)

funasr.train_utils.trainer.Trainer · View on GitHub ↗

Close.

Args:

  • writer — TODO.
📄 Source code
    def close(self, writer=None):

        """Close.
        
            Args:
                writer: TODO.
            """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()

        if writer is not None:
            writer.close()

        if self.use_ddp or self.use_fsdp:
            torch.distributed.destroy_process_group()
function

maybe_autocast(dtype, use_deepspeed)

funasr.train_utils.trainer_ds · View on GitHub ↗

Maybe autocast.

Args:

  • dtype — TODO.
  • use_deepspeed — TODO.
📄 Source code
def maybe_autocast(dtype=None, use_deepspeed=False):
    """Maybe autocast.
    
        Args:
            dtype: TODO.
            use_deepspeed: TODO.
        """
    if use_deepspeed:
        with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
            yield
    else:
        if dtype == torch.float16 or dtype == torch.bfloat16:
            with autocast(enabled=True, dtype=dtype):
                yield
        else:
            yield
class

Trainer

funasr.train_utils.trainer_ds · View on GitHub ↗

A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,

and optionally resuming from a saved checkpoint.

Attributes:

  • max_epoch (int) — Maximum number of epochs for training.
  • model (torch.nn.Module) — The model to be trained.
  • optim (torch.optim.Optimizer) — The optimizer to use for training.
  • scheduler (torch.optim.lr_scheduler._LRScheduler) — The learning rate scheduler.
  • dataloader_train (torch.utils.data.DataLoader) — DataLoader for the training dataset.
  • dataloader_val (torch.utils.data.DataLoader) — DataLoader for the validation dataset.
  • output_dir (str) — Directory where model checkpoints will be saved.
  • resume (str, optional) — Path to a checkpoint to resume training from.
📄 Source code
class Trainer:
    """
    A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
    and optionally resuming from a saved checkpoint.

    Attributes:
        max_epoch (int): Maximum number of epochs for training.
        model (torch.nn.Module): The model to be trained.
        optim (torch.optim.Optimizer): The optimizer to use for training.
        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
        dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
        dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        output_dir (str): Directory where model checkpoints will be saved.
        resume (str, optional): Path to a checkpoint to resume training from.
    """

    def __init__(
        self,
        rank=0,
        local_rank=0,
        world_size=1,
        use_ddp: bool = False,
        use_fsdp: bool = False,
        use_fp16: bool = False,
        use_bf16: bool = False,
        use_deepspeed: bool = False,
        output_dir: str = "./",
        **kwargs,
    ):
        """
View full source on GitHub →

Methods

.save_checkpoint(epoch, step, model, optim, scheduler, scaler, step_in_epoch, **kwargs) L171

Saves a checkpoint containing the model's state, the optimizer's state,

and the scheduler's state at the end of the given epoch. This method is

intended to be called at the end of each epoch to save the training progress.

Args:

  • epoch (int) — The epoch number at which the checkpoint is being saved.
📄 Source
    def save_checkpoint(
        self,
        epoch,
        step=None,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        step_in_epoch=None,
        **kwargs,
    ):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
        intended to be called at the end of each epoch to save the training progress.

        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        step_in_epoch = None if step is None else step_in_epoch
        if self.use_deepspeed:

            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
            # self.step_or_epoch += 1
            state = {
                "epoch": epoch,
                # "state_dict": model.state_dict(),
                # "optimizer": optim.state_dict(),
View full source →
.resume_checkpoint(model, optim, scheduler, scaler) L414

Resumes training from a checkpoint at the given file path.

Loads the model's state, the optimizer's state, and the scheduler's state.

Args:

  • resume_path (str) — The file path to the checkpoint to resume from.
📄 Source
    def resume_checkpoint(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
    ):
        """
        Resumes training from a checkpoint at the given file path.
        Loads the model's state, the optimizer's state, and the scheduler's state.

        Args:
            resume_path (str): The file path to the checkpoint to resume from.
        """
        if self.resume:

            if self.use_deepspeed:
                ckpt = os.path.join(self.output_dir, "model.pt")
                if os.path.exists(ckpt):
                    _, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
                    self.start_epoch = checkpoint["epoch"]
                    self.saved_ckpts = checkpoint["saved_ckpts"]
                    self.val_acc_step_or_epoch = (
                        checkpoint["val_acc_step_or_epoch"]
                        if "val_acc_step_or_epoch" in checkpoint
                        else {}
                    )
                    self.val_loss_step_or_epoch = (
                        checkpoint["val_loss_step_or_epoch"]
                        if "val_loss_step_or_epoch" in checkpoint
View full source →
.train_epoch(model, optim, scheduler, scaler, dataloader_train, dataloader_val, epoch, **kwargs) L554

Defines the training process for a single epoch with gradient accumulation.

Args:

  • epoch (int) — The current epoch number.
📄 Source
    def train_epoch(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        dataloader_train=None,
        dataloader_val=None,
        epoch=None,
        **kwargs,
    ):
        """
        Defines the training process for a single epoch with gradient accumulation.
        Args:
            epoch (int): The current epoch number.
        """
        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
            dist.barrier()
        logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
        model.train()

        # Set the number of steps for gradient accumulation
        accum_grad = self.accum_grad
        # Initialize the gradient accumulation
        optim.zero_grad()
        speed_stats = {}

        iterator_stop = torch.tensor(0).to(self.device)

        dataloader_train.batch_sampler.set_epoch(epoch)
View full source →
.forward_step(model, batch, loss_dict) L676

Forward step.

Args:

  • model — Model instance or model name.
  • batch — TODO.
  • loss_dict — TODO.
📄 Source
    def forward_step(self, model, batch, loss_dict={}):
        """Forward step.
        
            Args:
                model: Model instance or model name.
                batch: TODO.
                loss_dict: TODO.
            """
        with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
            retval = model(**batch)

        loss, stats, weight = retval
        stats = {k: v for k, v in stats.items() if v is not None}

        loss_dict["loss"] = loss
        loss_dict["stats"] = stats
        loss_dict["weight"] = weight
.backward_step(model, scaler, loss_dict) L694

Backward step.

Args:

  • model — Model instance or model name.
  • scaler — TODO.
  • loss_dict — TODO.
📄 Source
    def backward_step(self, model, scaler, loss_dict={}):
        """Backward step.
        
            Args:
                model: Model instance or model name.
                scaler: TODO.
                loss_dict: TODO.
            """
        loss = loss_dict["loss"]

        if self.use_deepspeed:
            scaled_loss = model.backward(loss)
        else:
            loss = loss / self.accum_grad
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
.update_step(model, optim, scheduler, scaler, loss_dict) L713

Update step.

Args:

  • model — Model instance or model name.
  • optim — TODO.
  • scheduler — TODO.
  • scaler — TODO.
  • loss_dict — TODO.
📄 Source
    def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
        """Update step.
        
            Args:
                model: Model instance or model name.
                optim: TODO.
                scheduler: TODO.
                scaler: TODO.
                loss_dict: TODO.
            """
        batch_idx = loss_dict["batch_idx"]
        if self.use_deepspeed:
            model.step()
        else:
            if (batch_idx + 1) % self.accum_grad == 0:
                # Perform gradient clipping if it is set
                if self.grad_clip > 0:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        max_norm=self.grad_clip,
                        norm_type=self.grad_clip_type,
                    )
                    if not torch.isfinite(grad_norm):
                        logging.warning(
                            f"The grad norm is {grad_norm}. Skipping updating the model."
                        )
                        optim.zero_grad()  # Reset gradients
                        return

                # Execute an optimization step (update model parameters)
View full source →
.validate_epoch(model, dataloader_val, epoch, writer, **kwargs) L754

Defines the validation process for a single epoch.

Should be implemented with the actual model validation steps.

Args:

  • epoch (int) — The current epoch number.
📄 Source
    def validate_epoch(
        self,
        model=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    ):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.

        Args:
            epoch (int): The current epoch number.
        """
        self.val_loss_avg = 0.0
        self.val_acc_avg  = 0.0

        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
            dist.barrier()
        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
        model.eval()

        with torch.no_grad():

            speed_stats = {}
            time_beg = time.perf_counter()
            time5 = time_beg

            dataloader_val.batch_sampler.set_epoch(epoch)
View full source →
.log(loss_dict, tag, **kwargs) L848

Log.

Args:

  • loss_dict — TODO.
  • tag — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def log(
        self,
        loss_dict: dict = None,
        tag="train",
        **kwargs,
    ):
        """Log.
        
            Args:
                loss_dict: TODO.
                tag: TODO.
                **kwargs: Additional keyword arguments.
            """
        loss = loss_dict["loss"].detach().cpu().item()
        epoch = loss_dict["epoch"]
        batch_idx = loss_dict["batch_idx"]
        step_in_epoch = loss_dict["step_in_epoch"]
        batch_total = loss_dict["batch_total"]
        batch_num_epoch = loss_dict["batch_num_epoch"]
        lr = loss_dict["lr"]

        speed_stats = loss_dict["speed_stats"]
        stats = loss_dict["stats"]
        data_split_i = loss_dict["data_split_i"]
        data_split_num = loss_dict["data_split_num"]
        log_step = loss_dict.get("log_step", None)

        if (batch_idx + 1) % self.log_interval == 0:
            batch_idx = log_step if log_step is not None else batch_idx
            gpu_info = (
View full source →
.close(writer) L929

Close.

Args:

  • writer — TODO.
📄 Source
    def close(self, writer=None):

        """Close.
        
            Args:
                writer: TODO.
            """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()

        if writer is not None:
            writer.close()

        if self.use_ddp or self.use_fsdp:
            torch.distributed.destroy_process_group()
.warp_model(model, **kwargs) L945

Warp model.

Args:

  • model — Model instance or model name.
  • **kwargs — Additional keyword arguments.
📄 Source
    def warp_model(self, model, **kwargs):

        """Warp model.
        
            Args:
                model: Model instance or model name.
                **kwargs: Additional keyword arguments.
            """
        if self.use_deepspeed:
            from deepspeed.runtime.zero.stage_1_and_2 import (
                estimate_zero2_model_states_mem_needs_all_live,
            )
            from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
            from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict

            local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
            world_size = int(os.environ.get("WORLD_SIZE", 1))

            # NOTE(xcsong): look in detail how the memory estimator API works:
            #   https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
            if int(os.environ.get("RANK", 0)) == 0:
                logging.info("Estimating model states memory needs (zero2)...")
                estimate_zero2_model_states_mem_needs_all_live(
                    model,
                    num_gpus_per_node=local_world_size,
                    num_nodes=world_size // local_world_size,
                )
                logging.info("Estimating model states memory needs (zero3)...")
                estimate_zero3_model_states_mem_needs_all_live(
                    model,
View full source →
.warp_optim_scheduler(model, **kwargs) L997

Warp optim scheduler.

Args:

  • model — Model instance or model name.
  • **kwargs — Additional keyword arguments.
📄 Source
    def warp_optim_scheduler(self, model, **kwargs):
        """Warp optim scheduler.
        
            Args:
                model: Model instance or model name.
                **kwargs: Additional keyword arguments.
            """
        from funasr.optimizers import optim_classes
        from funasr.schedulers import scheduler_classes
        from omegaconf import OmegaConf, DictConfig
        import json

        # optim
        logging.info("Build optim")
        optim = kwargs.get("optim", "adam")
        assert optim in optim_classes
        optim_class = optim_classes.get(optim)
        optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))

        # scheduler
        logging.info("Build scheduler")
        scheduler = kwargs.get("scheduler", "warmuplr")
        assert scheduler in scheduler_classes
        scheduler_class = scheduler_classes.get(scheduler)
        scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))

        if self.use_deepspeed:
            import deepspeed

            args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
View full source →
method

Trainer.save_checkpoint(epoch, step, model, optim, scheduler, scaler, step_in_epoch, **kwargs)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Saves a checkpoint containing the model's state, the optimizer's state,

and the scheduler's state at the end of the given epoch. This method is

intended to be called at the end of each epoch to save the training progress.

Args:

  • epoch (int) — The epoch number at which the checkpoint is being saved.
📄 Source code
    def save_checkpoint(
        self,
        epoch,
        step=None,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        step_in_epoch=None,
        **kwargs,
    ):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
        intended to be called at the end of each epoch to save the training progress.

        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        step_in_epoch = None if step is None else step_in_epoch
        if self.use_deepspeed:

            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
            # self.step_or_epoch += 1
            state = {
                "epoch": epoch,
                # "state_dict": model.state_dict(),
                # "optimizer": optim.state_dict(),
View full source on GitHub →
method

Trainer.resume_checkpoint(model, optim, scheduler, scaler)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Resumes training from a checkpoint at the given file path.

Loads the model's state, the optimizer's state, and the scheduler's state.

Args:

  • resume_path (str) — The file path to the checkpoint to resume from.
📄 Source code
    def resume_checkpoint(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
    ):
        """
        Resumes training from a checkpoint at the given file path.
        Loads the model's state, the optimizer's state, and the scheduler's state.

        Args:
            resume_path (str): The file path to the checkpoint to resume from.
        """
        if self.resume:

            if self.use_deepspeed:
                ckpt = os.path.join(self.output_dir, "model.pt")
                if os.path.exists(ckpt):
                    _, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
                    self.start_epoch = checkpoint["epoch"]
                    self.saved_ckpts = checkpoint["saved_ckpts"]
                    self.val_acc_step_or_epoch = (
                        checkpoint["val_acc_step_or_epoch"]
                        if "val_acc_step_or_epoch" in checkpoint
                        else {}
                    )
                    self.val_loss_step_or_epoch = (
                        checkpoint["val_loss_step_or_epoch"]
                        if "val_loss_step_or_epoch" in checkpoint
View full source on GitHub →
method

Trainer.train_epoch(model, optim, scheduler, scaler, dataloader_train, dataloader_val, epoch, **kwargs)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Defines the training process for a single epoch with gradient accumulation.

Args:

  • epoch (int) — The current epoch number.
📄 Source code
    def train_epoch(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        dataloader_train=None,
        dataloader_val=None,
        epoch=None,
        **kwargs,
    ):
        """
        Defines the training process for a single epoch with gradient accumulation.
        Args:
            epoch (int): The current epoch number.
        """
        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
            dist.barrier()
        logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
        model.train()

        # Set the number of steps for gradient accumulation
        accum_grad = self.accum_grad
        # Initialize the gradient accumulation
        optim.zero_grad()
        speed_stats = {}

        iterator_stop = torch.tensor(0).to(self.device)

        dataloader_train.batch_sampler.set_epoch(epoch)
View full source on GitHub →
method

Trainer.forward_step(model, batch, loss_dict)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Forward step.

Args:

  • model — Model instance or model name.
  • batch — TODO.
  • loss_dict — TODO.
📄 Source code
    def forward_step(self, model, batch, loss_dict={}):
        """Forward step.
        
            Args:
                model: Model instance or model name.
                batch: TODO.
                loss_dict: TODO.
            """
        with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
            retval = model(**batch)

        loss, stats, weight = retval
        stats = {k: v for k, v in stats.items() if v is not None}

        loss_dict["loss"] = loss
        loss_dict["stats"] = stats
        loss_dict["weight"] = weight
method

Trainer.backward_step(model, scaler, loss_dict)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Backward step.

Args:

  • model — Model instance or model name.
  • scaler — TODO.
  • loss_dict — TODO.
📄 Source code
    def backward_step(self, model, scaler, loss_dict={}):
        """Backward step.
        
            Args:
                model: Model instance or model name.
                scaler: TODO.
                loss_dict: TODO.
            """
        loss = loss_dict["loss"]

        if self.use_deepspeed:
            scaled_loss = model.backward(loss)
        else:
            loss = loss / self.accum_grad
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
method

Trainer.update_step(model, optim, scheduler, scaler, loss_dict)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Update step.

Args:

  • model — Model instance or model name.
  • optim — TODO.
  • scheduler — TODO.
  • scaler — TODO.
  • loss_dict — TODO.
📄 Source code
    def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
        """Update step.
        
            Args:
                model: Model instance or model name.
                optim: TODO.
                scheduler: TODO.
                scaler: TODO.
                loss_dict: TODO.
            """
        batch_idx = loss_dict["batch_idx"]
        if self.use_deepspeed:
            model.step()
        else:
            if (batch_idx + 1) % self.accum_grad == 0:
                # Perform gradient clipping if it is set
                if self.grad_clip > 0:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        max_norm=self.grad_clip,
                        norm_type=self.grad_clip_type,
                    )
                    if not torch.isfinite(grad_norm):
                        logging.warning(
                            f"The grad norm is {grad_norm}. Skipping updating the model."
                        )
                        optim.zero_grad()  # Reset gradients
                        return

                # Execute an optimization step (update model parameters)
View full source on GitHub →
method

Trainer.validate_epoch(model, dataloader_val, epoch, writer, **kwargs)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Defines the validation process for a single epoch.

Should be implemented with the actual model validation steps.

Args:

  • epoch (int) — The current epoch number.
📄 Source code
    def validate_epoch(
        self,
        model=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    ):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.

        Args:
            epoch (int): The current epoch number.
        """
        self.val_loss_avg = 0.0
        self.val_acc_avg  = 0.0

        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
            dist.barrier()
        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
        model.eval()

        with torch.no_grad():

            speed_stats = {}
            time_beg = time.perf_counter()
            time5 = time_beg

            dataloader_val.batch_sampler.set_epoch(epoch)
View full source on GitHub →
method

Trainer.log(loss_dict, tag, **kwargs)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Log.

Args:

  • loss_dict — TODO.
  • tag — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def log(
        self,
        loss_dict: dict = None,
        tag="train",
        **kwargs,
    ):
        """Log.
        
            Args:
                loss_dict: TODO.
                tag: TODO.
                **kwargs: Additional keyword arguments.
            """
        loss = loss_dict["loss"].detach().cpu().item()
        epoch = loss_dict["epoch"]
        batch_idx = loss_dict["batch_idx"]
        step_in_epoch = loss_dict["step_in_epoch"]
        batch_total = loss_dict["batch_total"]
        batch_num_epoch = loss_dict["batch_num_epoch"]
        lr = loss_dict["lr"]

        speed_stats = loss_dict["speed_stats"]
        stats = loss_dict["stats"]
        data_split_i = loss_dict["data_split_i"]
        data_split_num = loss_dict["data_split_num"]
        log_step = loss_dict.get("log_step", None)

        if (batch_idx + 1) % self.log_interval == 0:
            batch_idx = log_step if log_step is not None else batch_idx
            gpu_info = (
View full source on GitHub →
method

Trainer.close(writer)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Close.

Args:

  • writer — TODO.
📄 Source code
    def close(self, writer=None):

        """Close.
        
            Args:
                writer: TODO.
            """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()

        if writer is not None:
            writer.close()

        if self.use_ddp or self.use_fsdp:
            torch.distributed.destroy_process_group()
method

Trainer.warp_model(model, **kwargs)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Warp model.

Args:

  • model — Model instance or model name.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def warp_model(self, model, **kwargs):

        """Warp model.
        
            Args:
                model: Model instance or model name.
                **kwargs: Additional keyword arguments.
            """
        if self.use_deepspeed:
            from deepspeed.runtime.zero.stage_1_and_2 import (
                estimate_zero2_model_states_mem_needs_all_live,
            )
            from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
            from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict

            local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
            world_size = int(os.environ.get("WORLD_SIZE", 1))

            # NOTE(xcsong): look in detail how the memory estimator API works:
            #   https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
            if int(os.environ.get("RANK", 0)) == 0:
                logging.info("Estimating model states memory needs (zero2)...")
                estimate_zero2_model_states_mem_needs_all_live(
                    model,
                    num_gpus_per_node=local_world_size,
                    num_nodes=world_size // local_world_size,
                )
                logging.info("Estimating model states memory needs (zero3)...")
                estimate_zero3_model_states_mem_needs_all_live(
                    model,
View full source on GitHub →
method

Trainer.warp_optim_scheduler(model, **kwargs)

funasr.train_utils.trainer_ds.Trainer · View on GitHub ↗

Warp optim scheduler.

Args:

  • model — Model instance or model name.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def warp_optim_scheduler(self, model, **kwargs):
        """Warp optim scheduler.
        
            Args:
                model: Model instance or model name.
                **kwargs: Additional keyword arguments.
            """
        from funasr.optimizers import optim_classes
        from funasr.schedulers import scheduler_classes
        from omegaconf import OmegaConf, DictConfig
        import json

        # optim
        logging.info("Build optim")
        optim = kwargs.get("optim", "adam")
        assert optim in optim_classes
        optim_class = optim_classes.get(optim)
        optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))

        # scheduler
        logging.info("Build scheduler")
        scheduler = kwargs.get("scheduler", "warmuplr")
        assert scheduler in scheduler_classes
        scheduler_class = scheduler_classes.get(scheduler)
        scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))

        if self.use_deepspeed:
            import deepspeed

            args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
View full source on GitHub →
class

AudioDataset

funasr.datasets.datasets · View on GitHub ↗

AudioDataset

📄 Source code
class AudioDataset(torch.utils.data.Dataset):
    """
    AudioDataset
    """

    def __init__(
        self,
        path,
        index_ds: str = None,
        frontend=None,
        tokenizer=None,
        is_training: bool = True,
        int_pad_value: int = -1,
        float_pad_value: float = 0.0,
        **kwargs,
    ):
        """Initialize AudioDataset.
        
            Args:
                path: TODO.
                index_ds: TODO.
                frontend: Audio frontend for feature extraction.
                tokenizer: Tokenizer instance for text encoding/decoding.
                is_training: Boolean flag for training.
                int_pad_value: TODO.
                float_pad_value: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
View full source on GitHub →

Methods

.get_source_len(index) L67

Get source len.

Args:

  • index — TODO.
📄 Source
    def get_source_len(self, index):
        """Get source len.
        
            Args:
                index: TODO.
            """
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
.get_target_len(index) L76

Get target len.

Args:

  • index — TODO.
📄 Source
    def get_target_len(self, index):
        """Get target len.
        
            Args:
                index: TODO.
            """
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
.collator(samples) L127

Collator.

Args:

  • samples — TODO.
📄 Source
    def collator(self, samples: list = None):
        """Collator.
        
            Args:
                samples: TODO.
            """
        outputs = {}
        for sample in samples:
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])

        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:

                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value

                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )
        return outputs
method

AudioDataset.get_source_len(index)

funasr.datasets.datasets.AudioDataset · View on GitHub ↗

Get source len.

Args:

  • index — TODO.
📄 Source code
    def get_source_len(self, index):
        """Get source len.
        
            Args:
                index: TODO.
            """
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
method

AudioDataset.get_target_len(index)

funasr.datasets.datasets.AudioDataset · View on GitHub ↗

Get target len.

Args:

  • index — TODO.
📄 Source code
    def get_target_len(self, index):
        """Get target len.
        
            Args:
                index: TODO.
            """
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
method

AudioDataset.collator(samples)

funasr.datasets.datasets.AudioDataset · View on GitHub ↗

Collator.

Args:

  • samples — TODO.
📄 Source code
    def collator(self, samples: list = None):
        """Collator.
        
            Args:
                samples: TODO.
            """
        outputs = {}
        for sample in samples:
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])

        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:

                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value

                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )
        return outputs
class

AudioDatasetHotword

funasr.datasets.datasets · View on GitHub ↗

No documentation yet.

📄 Source code
class AudioDatasetHotword(AudioDataset):
    # for finetuning contextual_paraformer and seaco_paraformer
    def __init__(
        self,
        *args,
        seaco_id: bool = 0,
        **kwargs,
    ):
        """Initialize AudioDatasetHotword.
        
            Args:
                *args: Variable positional arguments.
                **kwargs: Additional keyword arguments.
            """
        super().__init__(*args, **kwargs)
        self.seaco_id = seaco_id

    def __getitem__(self, index):
        """Internal: getitem  .
        
            Args:
                index: TODO.
            """
        item = self.index_ds[index]
        # import pdb;
        # pdb.set_trace()
        source = item["source"]
        data_src = load_audio_text_image_video(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src, fs=self.fs)
View full source on GitHub →

Methods

.collator(samples) L268

Collator.

Args:

  • samples — TODO.
📄 Source
    def collator(self, samples: list = None):
        """Collator.
        
            Args:
                samples: TODO.
            """
        outputs = {}
        hotword_indxs = []
        seaco_id = samples[0]["seaco_id"]
        for sample in samples:
            for key in sample.keys():
                if key == "seaco_id":
                    continue
                elif key == "hotword_indx":
                    hotword_indxs.append(sample[key])
                else:
                    if key not in outputs:
                        outputs[key] = []
                    outputs[key].append(sample[key])

        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )

View full source →
method

AudioDatasetHotword.collator(samples)

funasr.datasets.datasets.AudioDatasetHotword · View on GitHub ↗

Collator.

Args:

  • samples — TODO.
📄 Source code
    def collator(self, samples: list = None):
        """Collator.
        
            Args:
                samples: TODO.
            """
        outputs = {}
        hotword_indxs = []
        seaco_id = samples[0]["seaco_id"]
        for sample in samples:
            for key in sample.keys():
                if key == "seaco_id":
                    continue
                elif key == "hotword_indx":
                    hotword_indxs.append(sample[key])
                else:
                    if key not in outputs:
                        outputs[key] = []
                    outputs[key].append(sample[key])

        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )

View full source on GitHub →
function

EspnetStyleBatchSampler_fn(dataset, **kwargs)

funasr.datasets.espnet_samplers · View on GitHub ↗

Espnetstylebatchsampler fn.

Args:

  • dataset — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
def EspnetStyleBatchSampler_fn(dataset, **kwargs):
    """Espnetstylebatchsampler fn.
    
        Args:
            dataset: TODO.
            **kwargs: Additional keyword arguments.
        """
    dataloader_args = {}

    batch_sampler = EspnetStyleBatchSampler(dataset, **kwargs)
    dataloader_args["batch_sampler"] = batch_sampler
    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)

    return dataloader_args
class

EspnetStyleBatchSampler

funasr.datasets.espnet_samplers · View on GitHub ↗

No documentation yet.

📄 Source code
class EspnetStyleBatchSampler(DistributedSampler):
    def __init__(
        self,
        dataset,
        batch_size,
        batch_type="token",
        rank=None,
        num_replicas=None,
        rank_split=False,
        shuffle=True,
        drop_last=False,
        is_training: bool = True,
        sort_size: int = 1024,
        start_step: int = 0,
        **kwargs,
    ):

        """Initialize EspnetStyleBatchSampler.
        
            Args:
                dataset: TODO.
                batch_size: Number of samples per batch.
                batch_type: TODO.
                rank: TODO.
                num_replicas: TODO.
                rank_split: TODO.
                shuffle: TODO.
                drop_last: TODO.
                is_training: Boolean flag for training.
                sort_size: Size/dimension parameter.
View full source on GitHub →

Methods

.set_epoch(epoch) L187

Set epoch.

Args:

  • epoch — TODO.
📄 Source
    def set_epoch(self, epoch):
        # Set the epoch for shuffling
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
method

EspnetStyleBatchSampler.set_epoch(epoch)

funasr.datasets.espnet_samplers.EspnetStyleBatchSampler · View on GitHub ↗

Set epoch.

Args:

  • epoch — TODO.
📄 Source code
    def set_epoch(self, epoch):
        # Set the epoch for shuffling
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
class

IndexDSJsonlRankFull

funasr.datasets.index_ds · View on GitHub ↗

No documentation yet.

📄 Source code
class IndexDSJsonlRankFull(torch.utils.data.Dataset):

    def __init__(self, path: str, **kwargs):
        """Initialize IndexDSJsonlRankFull.
        
            Args:
                path: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        self.max_source_length = kwargs.get("max_source_length", 2048)
        self.min_source_length = kwargs.get("min_source_length", 0)
        self.max_target_length = kwargs.get("max_target_length", 2048)
        self.min_target_length = kwargs.get("min_target_length", 0)
        self.max_token_length = kwargs.get("max_token_length", 2200)

        is_training = kwargs.get("is_training", True)
        if not (path.endswith(".jsonl") or path.endswith(".json")):
            # jsonl list file
            data_split_num = kwargs.get("data_split_num", 1)
            data_split_i = kwargs.get("data_split_i", 0)

            if not is_training:
                data_split_num = 1
                data_split_i = 0
            with open(path, encoding="utf-8") as fin:
                file_list_all = fin.readlines()

                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1  # 16
                file_list = file_list_all[
View full source on GitHub →

Methods

.get_source_len(data_dict) L158

Get source len.

Args:

  • data_dict — TODO.
📄 Source
    def get_source_len(self, data_dict):
        """Get source len.
        
            Args:
                data_dict: TODO.
            """
        return data_dict.get("source_len", 1)
.get_target_len(data_dict) L166

Get target len.

Args:

  • data_dict — TODO.
📄 Source
    def get_target_len(self, data_dict):

        """Get target len.
        
            Args:
                data_dict: TODO.
            """
        return data_dict.get("target_len", 0)
method

IndexDSJsonlRankFull.get_source_len(data_dict)

funasr.datasets.index_ds.IndexDSJsonlRankFull · View on GitHub ↗

Get source len.

Args:

  • data_dict — TODO.
📄 Source code
    def get_source_len(self, data_dict):
        """Get source len.
        
            Args:
                data_dict: TODO.
            """
        return data_dict.get("source_len", 1)
method

IndexDSJsonlRankFull.get_target_len(data_dict)

funasr.datasets.index_ds.IndexDSJsonlRankFull · View on GitHub ↗

Get target len.

Args:

  • data_dict — TODO.
📄 Source code
    def get_target_len(self, data_dict):

        """Get target len.
        
            Args:
                data_dict: TODO.
            """
        return data_dict.get("target_len", 0)
function

gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file)

funasr.datasets.jsonl2scp · View on GitHub ↗

Gen scp from jsonl.

Args:

  • jsonl_file — TODO.
  • data_type_list — TODO.
  • wav_scp_file — TODO.
  • text_file — TODO.
📄 Source code
def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file):

    """Gen scp from jsonl.
    
        Args:
            jsonl_file: TODO.
            data_type_list: TODO.
            wav_scp_file: TODO.
            text_file: TODO.
        """
    wav_f = open(wav_scp_file, "w")
    text_f = open(text_file, "w")
    with open(jsonl_file, encoding="utf-8") as fin:
        for line in fin:
            data = json.loads(line.strip())

            prompt = data.get("prompt", "<ASR>")
            source = data[data_type_list[0]]
            target = data[data_type_list[1]]
            source_len = data.get("source_len", 1)
            target_len = data.get("target_len", 0)
            if "aishell" in source:
                target = target.replace(" ", "")
            key = data["key"]
            wav_f.write(f"{key}\t{source}\n")
            wav_f.flush()
            text_f.write(f"{key}\t{target}\n")
            text_f.flush()

    wav_f.close()
View full source on GitHub →
function

main_hydra(cfg)

funasr.datasets.jsonl2scp · View on GitHub ↗

Main hydra.

Args:

  • cfg — Configuration overrides.
📄 Source code
def main_hydra(cfg: DictConfig):

    """Main hydra.
    
        Args:
            cfg: Configuration overrides.
        """
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    print(kwargs)

    scp_file_list = kwargs.get(
        "scp_file_list",
        ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
    )
    if isinstance(scp_file_list, str):
        scp_file_list = eval(scp_file_list)
    data_type_list = kwargs.get("data_type_list", ("source", "target"))
    jsonl_file = kwargs.get(
        "jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
    )
    gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list)
class

SpeechPreprocessSpeedPerturb

funasr.datasets.preprocessor · View on GitHub ↗

No documentation yet.

📄 Source code
class SpeechPreprocessSpeedPerturb(nn.Module):
    def __init__(self, speed_perturb: list = None, **kwargs):
        """Initialize SpeechPreprocessSpeedPerturb.
        
            Args:
                speed_perturb: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()
        self.speed_perturb = speed_perturb

    def forward(self, waveform, fs, **kwargs):
        """Forward pass for training.
        
            Args:
                waveform: TODO.
                fs: TODO.
                **kwargs: Additional keyword arguments.
            """
        if self.speed_perturb is None:
            return waveform
        speed = random.choice(self.speed_perturb)
        if speed != 1.0:
            if not isinstance(waveform, torch.Tensor):
                waveform = torch.tensor(waveform)
            waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
                waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
            )
            waveform = waveform.view(-1)

View full source on GitHub →

Methods

.forward(waveform, fs, **kwargs) L30

Forward pass for training.

Args:

  • waveform — TODO.
  • fs — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(self, waveform, fs, **kwargs):
        """Forward pass for training.
        
            Args:
                waveform: TODO.
                fs: TODO.
                **kwargs: Additional keyword arguments.
            """
        if self.speed_perturb is None:
            return waveform
        speed = random.choice(self.speed_perturb)
        if speed != 1.0:
            if not isinstance(waveform, torch.Tensor):
                waveform = torch.tensor(waveform)
            waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
                waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
            )
            waveform = waveform.view(-1)

        return waveform
method

SpeechPreprocessSpeedPerturb.forward(waveform, fs, **kwargs)

funasr.datasets.preprocessor.SpeechPreprocessSpeedPerturb · View on GitHub ↗

Forward pass for training.

Args:

  • waveform — TODO.
  • fs — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(self, waveform, fs, **kwargs):
        """Forward pass for training.
        
            Args:
                waveform: TODO.
                fs: TODO.
                **kwargs: Additional keyword arguments.
            """
        if self.speed_perturb is None:
            return waveform
        speed = random.choice(self.speed_perturb)
        if speed != 1.0:
            if not isinstance(waveform, torch.Tensor):
                waveform = torch.tensor(waveform)
            waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
                waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
            )
            waveform = waveform.view(-1)

        return waveform
class

TextPreprocessSegDict

funasr.datasets.preprocessor · View on GitHub ↗

No documentation yet.

📄 Source code
class TextPreprocessSegDict(nn.Module):
    def __init__(
        self,
        seg_dict: str = None,
        text_cleaner: Collection[str] = None,
        split_with_space: bool = False,
        **kwargs
    ):
        """Initialize TextPreprocessSegDict.
        
            Args:
                seg_dict: TODO.
                text_cleaner: TODO.
                split_with_space: TODO.
                **kwargs: Additional keyword arguments.
            """
        super().__init__()

        self.text_cleaner = TextCleaner(text_cleaner)

    def forward(self, text, **kwargs):
        """Forward pass for training.
        
            Args:
                text: Text tensor or string input.
                **kwargs: Additional keyword arguments.
            """
        text = self.text_cleaner(text)

        return text

Methods

.forward(text, **kwargs) L73

Forward pass for training.

Args:

  • text — Text tensor or string input.
  • **kwargs — Additional keyword arguments.
📄 Source
    def forward(self, text, **kwargs):
        """Forward pass for training.
        
            Args:
                text: Text tensor or string input.
                **kwargs: Additional keyword arguments.
            """
        text = self.text_cleaner(text)

        return text
method

TextPreprocessSegDict.forward(text, **kwargs)

funasr.datasets.preprocessor.TextPreprocessSegDict · View on GitHub ↗

Forward pass for training.

Args:

  • text — Text tensor or string input.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def forward(self, text, **kwargs):
        """Forward pass for training.
        
            Args:
                text: Text tensor or string input.
                **kwargs: Additional keyword arguments.
            """
        text = self.text_cleaner(text)

        return text
function

CustomDistributedBatchSampler_fn(dataset, **kwargs)

funasr.datasets.samplers · View on GitHub ↗

Customdistributedbatchsampler fn.

Args:

  • dataset — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
def CustomDistributedBatchSampler_fn(dataset, **kwargs):
    """Customdistributedbatchsampler fn.
    
        Args:
            dataset: TODO.
            **kwargs: Additional keyword arguments.
        """
    dataloader_args = {}
    batch_type = kwargs.get("batch_type", "example")
    if batch_type == "example":
        batch_sampler = CustomDistributedBatchSampler(dataset, **kwargs)

    else:
        if kwargs.get("sort_size", -1) > 0:
            batch_sampler = CustomDistributedBufferDynamicBatchSampler(dataset, **kwargs)
        else:
            batch_sampler = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
        # batch_sampler = CustomDistributedDynamicBatchSampler(dataset, **kwargs)

    dataloader_args["batch_sampler"] = batch_sampler
    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)

    return dataloader_args
class

CustomDistributedBatchSampler

funasr.datasets.samplers · View on GitHub ↗

No documentation yet.

📄 Source code
class CustomDistributedBatchSampler(Sampler):
    def __init__(
        self,
        dataset,
        batch_size,
        num_replicas=None,
        rank=None,
        shuffle=True,
        drop_last=False,
        is_training: bool = True,
        **kwargs,
    ):

        """Initialize CustomDistributedBatchSampler.
        
            Args:
                dataset: TODO.
                batch_size: Number of samples per batch.
                num_replicas: TODO.
                rank: TODO.
                shuffle: TODO.
                drop_last: TODO.
                is_training: Boolean flag for training.
                **kwargs: Additional keyword arguments.
            """
        try:
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
        except:
            rank = 0
View full source on GitHub →

Methods

.set_epoch(epoch) L148

Set epoch.

Args:

  • epoch — TODO.
📄 Source
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
method

CustomDistributedBatchSampler.set_epoch(epoch)

funasr.datasets.samplers.CustomDistributedBatchSampler · View on GitHub ↗

Set epoch.

Args:

  • epoch — TODO.
📄 Source code
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
class

CustomDistributedBufferBatchSampler

funasr.datasets.samplers · View on GitHub ↗

No documentation yet.

📄 Source code
class CustomDistributedBufferBatchSampler(Sampler):
    def __init__(
        self,
        dataset,
        batch_size,
        num_replicas=None,
        rank=None,
        shuffle=True,
        drop_last=False,
        is_training: bool = True,
        sort_size: int = 1024,
        **kwargs,
    ):

        """Initialize CustomDistributedBufferBatchSampler.
        
            Args:
                dataset: TODO.
                batch_size: Number of samples per batch.
                num_replicas: TODO.
                rank: TODO.
                shuffle: TODO.
                drop_last: TODO.
                is_training: Boolean flag for training.
                sort_size: Size/dimension parameter.
                **kwargs: Additional keyword arguments.
            """
        try:
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
View full source on GitHub →

Methods

.set_epoch(epoch) L284

Set epoch.

Args:

  • epoch — TODO.
📄 Source
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
method

CustomDistributedBufferBatchSampler.set_epoch(epoch)

funasr.datasets.samplers.CustomDistributedBufferBatchSampler · View on GitHub ↗

Set epoch.

Args:

  • epoch — TODO.
📄 Source code
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
class

CustomDistributedDynamicBatchSampler

funasr.datasets.samplers · View on GitHub ↗

No documentation yet.

📄 Source code
class CustomDistributedDynamicBatchSampler(DistributedSampler):
    def __init__(
        self,
        dataset,
        batch_size,
        num_replicas=None,
        rank=None,
        shuffle=True,
        drop_last=False,
        is_training: bool = True,
        **kwargs,
    ):

        """Initialize CustomDistributedDynamicBatchSampler.
        
            Args:
                dataset: TODO.
                batch_size: Number of samples per batch.
                num_replicas: TODO.
                rank: TODO.
                shuffle: TODO.
                drop_last: TODO.
                is_training: Boolean flag for training.
                **kwargs: Additional keyword arguments.
            """
        try:
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
        except:
            rank = 0
View full source on GitHub →

Methods

.set_epoch(epoch) L384

Set epoch.

Args:

  • epoch — TODO.
📄 Source
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
method

CustomDistributedDynamicBatchSampler.set_epoch(epoch)

funasr.datasets.samplers.CustomDistributedDynamicBatchSampler · View on GitHub ↗

Set epoch.

Args:

  • epoch — TODO.
📄 Source code
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
class

CustomDistributedBufferDynamicBatchSampler

funasr.datasets.samplers · View on GitHub ↗

No documentation yet.

📄 Source code
class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
    def __init__(
        self,
        dataset,
        batch_size,
        batch_type="token",
        num_replicas=None,
        rank=None,
        rank_split=False,
        shuffle=True,
        drop_last=False,
        is_training: bool = True,
        sort_size: int = 1024,
        start_step: int = 0,
        **kwargs,
    ):

        """Initialize CustomDistributedBufferDynamicBatchSampler.
        
            Args:
                dataset: TODO.
                batch_size: Number of samples per batch.
                batch_type: TODO.
                num_replicas: TODO.
                rank: TODO.
                rank_split: TODO.
                shuffle: TODO.
                drop_last: TODO.
                is_training: Boolean flag for training.
                sort_size: Size/dimension parameter.
View full source on GitHub →

Methods

.set_epoch(epoch) L526

Set epoch.

Args:

  • epoch — TODO.
📄 Source
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
method

CustomDistributedBufferDynamicBatchSampler.set_epoch(epoch)

funasr.datasets.samplers.CustomDistributedBufferDynamicBatchSampler · View on GitHub ↗

Set epoch.

Args:

  • epoch — TODO.
📄 Source code
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
class

DistributedSamplerWarp

funasr.datasets.samplers · View on GitHub ↗

No documentation yet.

📄 Source code
class DistributedSamplerWarp(BatchSampler):
    def __init__(
        self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False
    ):
        """Initialize DistributedSamplerWarp.
        
            Args:
                dataset: TODO.
                batch_size: Number of samples per batch.
                num_replicas: TODO.
                rank: TODO.
                shuffle: TODO.
                drop_last: TODO.
            """
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()

        self.dataset = dataset
        self.batch_size = batch_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.shuffle = shuffle
        self.drop_last = drop_last

View full source on GitHub →

Methods

.set_epoch(epoch) L582

Set epoch.

Args:

  • epoch — TODO.
📄 Source
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
method

DistributedSamplerWarp.set_epoch(epoch)

funasr.datasets.samplers.DistributedSamplerWarp · View on GitHub ↗

Set epoch.

Args:

  • epoch — TODO.
📄 Source code
    def set_epoch(self, epoch):
        """Set epoch.
        
            Args:
                epoch: TODO.
            """
        self.epoch = epoch
function

gen_jsonl_from_wav_text_list(path, data_type_list, jsonl_file_out, **kwargs)

funasr.datasets.scp2jsonl · View on GitHub ↗

Gen jsonl from wav text list.

Args:

  • path — TODO.
  • data_type_list — TODO.
  • jsonl_file_out — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
def gen_jsonl_from_wav_text_list(
    path, data_type_list=("source", "target"), jsonl_file_out: str = None, **kwargs
):
    """Gen jsonl from wav text list.
    
        Args:
            path: TODO.
            data_type_list: TODO.
            jsonl_file_out: TODO.
            **kwargs: Additional keyword arguments.
        """
    try:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    except:
        rank = 0
        world_size = 1

    cpu_cores = os.cpu_count() or 1
    print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
    if rank == 0:
        json_dict = {}
        for data_type, data_file in zip(data_type_list, path):
            json_dict[data_type] = {}
            with open(data_file, "r") as f:

                data_file_lists = f.readlines()
                lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                # import pdb;pdb.set_trace()
View full source on GitHub →
function

parse_context_length(data_list, data_type, id)

funasr.datasets.scp2jsonl · View on GitHub ↗

Parse context length.

Args:

  • data_list — TODO.
  • data_type — TODO.
  • id — TODO.
📄 Source code
def parse_context_length(data_list: list, data_type: str, id=0):
    """Parse context length.
    
        Args:
            data_list: TODO.
            data_type: TODO.
            id: TODO.
        """
    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
    res = {}
    for i, line in enumerate(data_list):
        pbar.update(1)
        pbar.set_description(f"cpu: {id}")
        lines = line.strip().split(maxsplit=1)
        key = lines[0]
        line = lines[1] if len(lines) > 1 else ""
        line = line.strip()
        if data_type == "source":
            if os.path.exists(line):
                waveform, _ = librosa.load(line, sr=16000)
                sample_num = len(waveform)
                context_len = int(sample_num * 1000 / 16000 / 10)
            else:
                print("source file not found: {}".format(line))
                continue
        else:
            context_len = len(line.split()) if " " in line else len(line)
        res[key] = {data_type: line, f"{data_type}_len": context_len}
    return res
function

main_hydra(cfg)

funasr.datasets.scp2jsonl · View on GitHub ↗

Main hydra.

Args:

  • cfg — Configuration overrides.
📄 Source code
def main_hydra(cfg: DictConfig):

    """Main hydra.
    
        Args:
            cfg: Configuration overrides.
        """
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    print(kwargs)

    scp_file_list = kwargs.get(
        "scp_file_list",
        ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
    )
    if isinstance(scp_file_list, str):
        scp_file_list = eval(scp_file_list)
    data_type_list = kwargs.get("data_type_list", ("source", "target"))
    jsonl_file_out = kwargs.get(
        "jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
    )
    gen_jsonl_from_wav_text_list(
        scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
    )
function

gen_jsonl_from_wav_text_list(path, data_type_list, jsonl_file_out, **kwargs)

funasr.datasets.scp2len · View on GitHub ↗

Gen jsonl from wav text list.

Args:

  • path — TODO.
  • data_type_list — TODO.
  • jsonl_file_out — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
def gen_jsonl_from_wav_text_list(
    path, data_type_list=("source",), jsonl_file_out: str = None, **kwargs
):
    """Gen jsonl from wav text list.
    
        Args:
            path: TODO.
            data_type_list: TODO.
            jsonl_file_out: TODO.
            **kwargs: Additional keyword arguments.
        """
    try:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    except:
        rank = 0
        world_size = 1

    cpu_cores = os.cpu_count() or 1
    print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
    if rank == 0:
        json_dict = {}
        # for data_type, data_file in zip(data_type_list, path):
        data_type = data_type_list[0]
        data_file = path
        json_dict[data_type] = {}
        with open(data_file, "r") as f:

            data_file_lists = f.readlines()
            print("")
View full source on GitHub →
function

parse_context_length(data_list, data_type, id)

funasr.datasets.scp2len · View on GitHub ↗

Parse context length.

Args:

  • data_list — TODO.
  • data_type — TODO.
  • id — TODO.
📄 Source code
def parse_context_length(data_list: list, data_type: str, id=0):
    """Parse context length.
    
        Args:
            data_list: TODO.
            data_type: TODO.
            id: TODO.
        """
    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
    res = {}
    for i, line in enumerate(data_list):
        pbar.update(1)
        pbar.set_description(f"cpu: {id}")
        lines = line.strip().split(maxsplit=1)
        key = lines[0]
        line = lines[1] if len(lines) > 1 else ""
        line = line.strip()
        if os.path.exists(line):
            waveform, _ = librosa.load(line, sr=16000)
            sample_num = len(waveform)
            context_len = int(sample_num / 16000 * 1000 / 10)
        else:
            context_len = len(line.split()) if " " in line else len(line)
        res[key] = {data_type: line, f"{data_type}_len": context_len}
    return res
function

main_hydra(cfg)

funasr.datasets.scp2len · View on GitHub ↗

Main hydra.

Args:

  • cfg — Configuration overrides.
📄 Source code
def main_hydra(cfg: DictConfig):

    """Main hydra.
    
        Args:
            cfg: Configuration overrides.
        """
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    print(kwargs)

    scp_file_list = kwargs.get("scp_file_list", "/Users/zhifu/funasr1.0/data/list/train_wav.scp")
    # if isinstance(scp_file_list, str):
    #     scp_file_list = eval(scp_file_list)
    data_type_list = kwargs.get("data_type_list", ("source",))
    jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/data/list/wav_len.txt")
    gen_jsonl_from_wav_text_list(
        scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
    )
function

gen_jsonl_from_wav_text_list(path, data_type_list, jsonl_file_out, model_dir, **kwargs)

funasr.datasets.sensevoice2jsonl · View on GitHub ↗

Gen jsonl from wav text list.

Args:

  • path — TODO.
  • data_type_list — TODO.
  • jsonl_file_out — TODO.
  • model_dir — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
def gen_jsonl_from_wav_text_list(
    path, data_type_list=("source", "target"), jsonl_file_out: str = None, model_dir: str = "iic/SenseVoiceSmall", **kwargs
):
    """Gen jsonl from wav text list.
    
        Args:
            path: TODO.
            data_type_list: TODO.
            jsonl_file_out: TODO.
            model_dir: TODO.
            **kwargs: Additional keyword arguments.
        """
    try:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    except:
        rank = 0
        world_size = 1

    cpu_cores = os.cpu_count() or 1
    print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
    if rank == 0:
        json_dict = {}
        for data_type, data_file in zip(data_type_list, path):
            json_dict[data_type] = {}
            with open(data_file, "r") as f:

                data_file_lists = f.readlines()
                lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
View full source on GitHub →
function

contains_punctuation(s)

funasr.datasets.sensevoice2jsonl · View on GitHub ↗

Contains punctuation.

Args:

  • s — TODO.
📄 Source code
def contains_punctuation(s):
    """Contains punctuation.
    
        Args:
            s: TODO.
        """
    punctuations = (
        string.punctuation +  
        ',。、;:?!""''()【】《》〈〉「」『』〔〕[]{}~·…—–'  
    )
    return any(char in punctuations for char in s)
function

parse_context_length(data_list, data_type, id)

funasr.datasets.sensevoice2jsonl · View on GitHub ↗

Parse context length.

Args:

  • data_list — TODO.
  • data_type — TODO.
  • id — TODO.
📄 Source code
def parse_context_length(data_list: list, data_type: str, id=0):
    """Parse context length.
    
        Args:
            data_list: TODO.
            data_type: TODO.
            id: TODO.
        """
    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
    res = {}
    for i, line in enumerate(data_list):
        pbar.update(1)
        pbar.set_description(f"cpu: {id}")
        lines = line.strip().split(maxsplit=1)
        key = lines[0]
        line = lines[1] if len(lines) > 1 else ""
        line = line.strip()
        if os.path.exists(line):
            waveform, _ = librosa.load(line, sr=16000)
            sample_num = len(waveform)
            context_len = int(sample_num / 16000 * 1000 / 10)
        else:
            context_len = len(line.split()) if " " in line else len(line)
        if data_type == "source":
            res[key] = {data_type: line, f"{data_type}_len": context_len}
        elif data_type == "target":
            punc = contains_punctuation(line)
            if punc:
                with_or_wo_itn = "<|withitn|>"
            else:
View full source on GitHub →
function

main_hydra(cfg)

funasr.datasets.sensevoice2jsonl · View on GitHub ↗

Main hydra.

Args:

  • cfg — Configuration overrides.
📄 Source code
def main_hydra(cfg: DictConfig):

    """Main hydra.
    
        Args:
            cfg: Configuration overrides.
        """
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    print(kwargs)

    scp_file_list = kwargs.get(
        "scp_file_list",
        ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
    )
    if isinstance(scp_file_list, str):
        scp_file_list = eval(scp_file_list)
    data_type_list = kwargs.get("data_type_list", ("source", "target"))
    jsonl_file_out = kwargs.get(
        "jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
    )
    model_dir = kwargs.get("model_dir", "iic/SenseVoiceSmall")
    gen_jsonl_from_wav_text_list(
        scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out, model_dir=model_dir
    )
function

gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu)

funasr.datasets.update_jsonl · View on GitHub ↗

Gen scp from jsonl.

Args:

  • jsonl_file — TODO.
  • jsonl_file_out — TODO.
  • ncpu — TODO.
📄 Source code
def gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu):
    """Gen scp from jsonl.
    
        Args:
            jsonl_file: TODO.
            jsonl_file_out: TODO.
            ncpu: TODO.
        """
    jsonl_file_out_f = open(jsonl_file_out, "w")
    with open(jsonl_file, encoding="utf-8") as fin:
        lines = fin.readlines()

        num_total = len(lines)
        if ncpu > 1:
            # 使用ThreadPoolExecutor限制并发线程数
            with ThreadPoolExecutor(max_workers=ncpu) as executor:
                # 提交任务到线程池
                futures = {executor.submit(update_data, lines, i) for i in tqdm(range(num_total))}

                # 等待所有任务完成,这会阻塞直到所有提交的任务完成
                for future in concurrent.futures.as_completed(futures):
                    # 这里可以添加额外的逻辑来处理完成的任务,但在这个例子中我们只是等待
                    pass
        else:
            for i in range(num_total):
                update_data(lines, i)
        logging.info("All audio durations have been processed.")

        for line in lines:

View full source on GitHub →
function

update_data(lines, i)

funasr.datasets.update_jsonl · View on GitHub ↗

Update data.

Args:

  • lines — TODO.
  • i — TODO.
📄 Source code
def update_data(lines, i):
    """Update data.
    
        Args:
            lines: TODO.
            i: TODO.
        """
    line = lines[i]
    data = json.loads(line.strip())

    wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
    if os.path.exists(wav_path):
        waveform, _ = librosa.load(wav_path, sr=16000)
        sample_num = len(waveform)
        source_len = int(sample_num / 16000 * 1000 / 10)
        source_len_old = data["source_len"]
        # if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
        #     logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
        data["source_len"] = source_len
        data["source"] = wav_path
        jsonl_line = json.dumps(data, ensure_ascii=False)
        lines[i] = jsonl_line
function

update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)

funasr.datasets.update_jsonl · View on GitHub ↗

Update wav len.

Args:

  • jsonl_file_list_in — TODO.
  • jsonl_file_out_dir — TODO.
  • ncpu — TODO.
📄 Source code
def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):

    """Update wav len.
    
        Args:
            jsonl_file_list_in: TODO.
            jsonl_file_out_dir: TODO.
            ncpu: TODO.
        """
    os.makedirs(jsonl_file_out_dir, exist_ok=True)
    with open(jsonl_file_list_in, "r") as f:
        data_file_lists = f.readlines()

        for i, jsonl in enumerate(data_file_lists):
            filename_with_extension = os.path.basename(jsonl.strip())
            jsonl_file_out = os.path.join(jsonl_file_out_dir, filename_with_extension)
            logging.info(f"{i}/{len(data_file_lists)}, jsonl: {jsonl}, {jsonl_file_out}")

            gen_scp_from_jsonl(jsonl.strip(), jsonl_file_out, ncpu)
function

main_hydra(cfg)

funasr.datasets.update_jsonl · View on GitHub ↗

Main hydra.

Args:

  • cfg — Configuration overrides.
📄 Source code
def main_hydra(cfg: DictConfig):

    """Main hydra.
    
        Args:
            cfg: Configuration overrides.
        """
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    logging.info(kwargs)

    jsonl_file_list_in = kwargs.get(
        "jsonl_file_list_in", "/Users/zhifu/funasr1.0/data/list/data_jsonl.list"
    )
    jsonl_file_out_dir = kwargs.get("jsonl_file_out_dir", "/Users/zhifu/funasr1.0/data_tmp")
    ncpu = kwargs.get("ncpu", 1)
    update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)
function

DataloaderMapStyle(frontend, tokenizer, **kwargs)

funasr.datasets.dataloader_entry · View on GitHub ↗

Dataloadermapstyle.

Args:

  • frontend — Audio frontend for feature extraction.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • **kwargs — Additional keyword arguments.
📄 Source code
def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
    # dataset
    """Dataloadermapstyle.
    
        Args:
            frontend: Audio frontend for feature extraction.
            tokenizer: Tokenizer instance for text encoding/decoding.
            **kwargs: Additional keyword arguments.
        """
    logging.info("Build dataloader")
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_tr = dataset_class(
        kwargs.get("train_data_set_list"),
        frontend=frontend,
        tokenizer=tokenizer,
        is_training=True,
        **kwargs.get("dataset_conf"),
    )
    dataset_val = dataset_class(
        kwargs.get("valid_data_set_list"),
        frontend=frontend,
        tokenizer=tokenizer,
        is_training=False,
        **kwargs.get("dataset_conf"),
    )

    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
    batch_sampler_val = None
    if batch_sampler is not None:
View full source on GitHub →
class

DataloaderMapStyle

funasr.datasets.dataloader_entry · View on GitHub ↗

No documentation yet.

📄 Source code
class DataloaderMapStyle:
    def __init__(self, frontend=None, tokenizer=None, **kwargs):
        # dataset
        """Initialize DataloaderMapStyle.
        
            Args:
                frontend: Audio frontend for feature extraction.
                tokenizer: Tokenizer instance for text encoding/decoding.
                **kwargs: Additional keyword arguments.
            """
        logging.info("Build dataloader")

        dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
        dataset_tr = None
        # split dataset
        self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
        if self.data_split_num == 1:
            dataset_tr = dataset_class(
                kwargs.get("train_data_set_list"),
                frontend=frontend,
                tokenizer=tokenizer,
                is_training=True,
                **kwargs.get("dataset_conf"),
            )
        dataset_val = dataset_class(
            kwargs.get("valid_data_set_list"),
            frontend=frontend,
            tokenizer=tokenizer,
            is_training=False,
            **kwargs.get("dataset_conf"),
View full source on GitHub →

Methods

.build_iter(epoch, data_split_i, start_step, **kwargs) L96

Build iter.

Args:

  • epoch — TODO.
  • data_split_i — TODO.
  • start_step — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source
    def build_iter(self, epoch=0, data_split_i=0, start_step=0, **kwargs):

        # reload dataset slice
        """Build iter.
        
            Args:
                epoch: TODO.
                data_split_i: TODO.
                start_step: TODO.
                **kwargs: Additional keyword arguments.
            """
        if self.data_split_num > 1:
            del self.dataset_tr
            self.dataset_tr = self.dataset_class(
                self.kwargs.get("train_data_set_list"),
                frontend=self.frontend,
                tokenizer=self.tokenizer,
                is_training=True,
                **self.kwargs.get("dataset_conf"),
                data_split_i=data_split_i,
            )

        # dataloader
        batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
        batch_sampler_val = None
        if batch_sampler is not None:
            batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
            batch_sampler = batch_sampler_class(
                self.dataset_tr, start_step=start_step, **self.kwargs.get("dataset_conf")
            )
View full source →
method

DataloaderMapStyle.build_iter(epoch, data_split_i, start_step, **kwargs)

funasr.datasets.dataloader_entry.DataloaderMapStyle · View on GitHub ↗

Build iter.

Args:

  • epoch — TODO.
  • data_split_i — TODO.
  • start_step — TODO.
  • **kwargs — Additional keyword arguments.
📄 Source code
    def build_iter(self, epoch=0, data_split_i=0, start_step=0, **kwargs):

        # reload dataset slice
        """Build iter.
        
            Args:
                epoch: TODO.
                data_split_i: TODO.
                start_step: TODO.
                **kwargs: Additional keyword arguments.
            """
        if self.data_split_num > 1:
            del self.dataset_tr
            self.dataset_tr = self.dataset_class(
                self.kwargs.get("train_data_set_list"),
                frontend=self.frontend,
                tokenizer=self.tokenizer,
                is_training=True,
                **self.kwargs.get("dataset_conf"),
                data_split_i=data_split_i,
            )

        # dataloader
        batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
        batch_sampler_val = None
        if batch_sampler is not None:
            batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
            batch_sampler = batch_sampler_class(
                self.dataset_tr, start_step=start_step, **self.kwargs.get("dataset_conf")
            )
View full source on GitHub →
function

DataloaderIterable(frontend, tokenizer, **kwargs)

funasr.datasets.dataloader_entry · View on GitHub ↗

Dataloaderiterable.

Args:

  • frontend — Audio frontend for feature extraction.
  • tokenizer — Tokenizer instance for text encoding/decoding.
  • **kwargs — Additional keyword arguments.
📄 Source code
def DataloaderIterable(frontend=None, tokenizer=None, **kwargs):
    """Dataloaderiterable.
    
        Args:
            frontend: Audio frontend for feature extraction.
            tokenizer: Tokenizer instance for text encoding/decoding.
            **kwargs: Additional keyword arguments.
        """
    logging.info("Build dataloader")
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "LargeDataset"))
    dataset_tr = dataset_class(
        kwargs.get("train_data_set_list"),
        frontend=frontend,
        tokenizer=tokenizer,
        is_training=True,
        **kwargs.get("dataset_conf"),
    )
    dataset_val = dataset_class(
        kwargs.get("valid_data_set_list"),
        frontend=frontend,
        tokenizer=tokenizer,
        is_training=False,
        **kwargs.get("dataset_conf"),
    )

    return dataset_tr, dataset_val
class

LabelSmoothingLoss

funasr.losses.label_smoothing_loss · View on GitHub ↗
  • Label — smoothing loss.

:param int size: the number of class

:param int padding_idx: ignored class id

:param float smoothing: smoothing rate (0.0 means the conventional CE)

:param bool normalize_length: normalize loss by sequence length if True

:param torch.nn.Module criterion: loss function to be smoothed

📄 Source code
class LabelSmoothingLoss(nn.Module):
    """Label-smoothing loss.

    :param int size: the number of class
    :param int padding_idx: ignored class id
    :param float smoothing: smoothing rate (0.0 means the conventional CE)
    :param bool normalize_length: normalize loss by sequence length if True
    :param torch.nn.Module criterion: loss function to be smoothed
    """

    def __init__(
        self,
        size,
        padding_idx,
        smoothing,
        normalize_length=False,
        criterion=nn.KLDivLoss(reduction="none"),
    ):
        """Construct an LabelSmoothingLoss object."""
        super(LabelSmoothingLoss, self).__init__()
        self.criterion = criterion
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        self.normalize_length = normalize_length

    def forward(self, x, target):
        """Compute loss between x and target.
View full source on GitHub →

Methods

.forward(x, target) L42

Compute loss between x and target.

:param torch.Tensor x: prediction (batch, seqlen, class)

:param torch.Tensor target:

target signal masked with self.padding_id (batch, seqlen)

:return: scalar float value

:rtype torch.Tensor

📄 Source
    def forward(self, x, target):
        """Compute loss between x and target.

        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target:
            target signal masked with self.padding_id (batch, seqlen)
        :return: scalar float value
        :rtype torch.Tensor
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.contiguous().view(-1, self.size)
        target = target.contiguous().view(-1)
        with torch.no_grad():
            true_dist = x.clone()
            true_dist.fill_(self.smoothing / (self.size - 1))
            ignore = target == self.padding_idx  # (B,)
            total = len(target) - ignore.sum().item()
            target = target.masked_fill(ignore, 0)  # avoid -1 index
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
method

LabelSmoothingLoss.forward(x, target)

funasr.losses.label_smoothing_loss.LabelSmoothingLoss · View on GitHub ↗

Compute loss between x and target.

:param torch.Tensor x: prediction (batch, seqlen, class)

:param torch.Tensor target:

target signal masked with self.padding_id (batch, seqlen)

:return: scalar float value

:rtype torch.Tensor

📄 Source code
    def forward(self, x, target):
        """Compute loss between x and target.

        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target:
            target signal masked with self.padding_id (batch, seqlen)
        :return: scalar float value
        :rtype torch.Tensor
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.contiguous().view(-1, self.size)
        target = target.contiguous().view(-1)
        with torch.no_grad():
            true_dist = x.clone()
            true_dist.fill_(self.smoothing / (self.size - 1))
            ignore = target == self.padding_idx  # (B,)
            total = len(target) - ignore.sum().item()
            target = target.masked_fill(ignore, 0)  # avoid -1 index
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
class

SequenceBinaryCrossEntropy

funasr.losses.label_smoothing_loss · View on GitHub ↗

No documentation yet.

📄 Source code
class SequenceBinaryCrossEntropy(nn.Module):
    def __init__(self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none")):
        """Initialize SequenceBinaryCrossEntropy.
        
            Args:
                normalize_length: TODO.
                criterion: TODO.
            """
        super().__init__()
        self.normalize_length = normalize_length
        self.criterion = criterion

    def forward(self, pred, label, lengths):
        """Forward pass for training.
        
            Args:
                pred: TODO.
                label: TODO.
                lengths: TODO.
            """
        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
        loss = self.criterion(pred, label)
        denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
        return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom

Methods

.forward(pred, label, lengths) L79

Forward pass for training.

Args:

  • pred — TODO.
  • label — TODO.
  • lengths — TODO.
📄 Source
    def forward(self, pred, label, lengths):
        """Forward pass for training.
        
            Args:
                pred: TODO.
                label: TODO.
                lengths: TODO.
            """
        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
        loss = self.criterion(pred, label)
        denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
        return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
method

SequenceBinaryCrossEntropy.forward(pred, label, lengths)

funasr.losses.label_smoothing_loss.SequenceBinaryCrossEntropy · View on GitHub ↗

Forward pass for training.

Args:

  • pred — TODO.
  • label — TODO.
  • lengths — TODO.
📄 Source code
    def forward(self, pred, label, lengths):
        """Forward pass for training.
        
            Args:
                pred: TODO.
                label: TODO.
                lengths: TODO.
            """
        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
        loss = self.criterion(pred, label)
        denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
        return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
class

NllLoss

funasr.losses.label_smoothing_loss · View on GitHub ↗

Nll loss.

:param int size: the number of class

:param int padding_idx: ignored class id

:param bool normalize_length: normalize loss by sequence length if True

:param torch.nn.Module criterion: loss function

📄 Source code
class NllLoss(nn.Module):
    """Nll loss.

    :param int size: the number of class
    :param int padding_idx: ignored class id
    :param bool normalize_length: normalize loss by sequence length if True
    :param torch.nn.Module criterion: loss function
    """

    def __init__(
        self,
        size,
        padding_idx,
        normalize_length=False,
        criterion=nn.NLLLoss(reduction="none"),
    ):
        """Construct an NllLoss object."""
        super(NllLoss, self).__init__()
        self.criterion = criterion
        self.padding_idx = padding_idx
        self.size = size
        self.true_dist = None
        self.normalize_length = normalize_length

    def forward(self, x, target):
        """Compute loss between x and target.

        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target:
            target signal masked with self.padding_id (batch, seqlen)
View full source on GitHub →

Methods

.forward(x, target) L117

Compute loss between x and target.

:param torch.Tensor x: prediction (batch, seqlen, class)

:param torch.Tensor target:

target signal masked with self.padding_id (batch, seqlen)

:return: scalar float value

:rtype torch.Tensor

📄 Source
    def forward(self, x, target):
        """Compute loss between x and target.

        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target:
            target signal masked with self.padding_id (batch, seqlen)
        :return: scalar float value
        :rtype torch.Tensor
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.view(-1, self.size)
        target = target.view(-1)
        with torch.no_grad():
            ignore = target == self.padding_idx  # (B,)
            total = len(target) - ignore.sum().item()
            target = target.masked_fill(ignore, 0)  # avoid -1 index
        kl = self.criterion(x, target)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore, 0).sum() / denom
method

NllLoss.forward(x, target)

funasr.losses.label_smoothing_loss.NllLoss · View on GitHub ↗

Compute loss between x and target.

:param torch.Tensor x: prediction (batch, seqlen, class)

:param torch.Tensor target:

target signal masked with self.padding_id (batch, seqlen)

:return: scalar float value

:rtype torch.Tensor

📄 Source code
    def forward(self, x, target):
        """Compute loss between x and target.

        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target:
            target signal masked with self.padding_id (batch, seqlen)
        :return: scalar float value
        :rtype torch.Tensor
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.view(-1, self.size)
        target = target.view(-1)
        with torch.no_grad():
            ignore = target == self.padding_idx  # (B,)
            total = len(target) - ignore.sum().item()
            target = target.masked_fill(ignore, 0)  # avoid -1 index
        kl = self.criterion(x, target)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore, 0).sum() / denom
class

CustomLambdaLR

funasr.schedulers.lambdalr_cus · View on GitHub ↗

No documentation yet.

📄 Source code
class CustomLambdaLR(_LRScheduler):
    def __init__(
        self,
        optimizer,
        warmup_steps: int = 25000,
        total_steps: int = 500000,
        last_epoch=-1,
        verbose=False,
    ):
        """Initialize CustomLambdaLR.
        
            Args:
                optimizer: TODO.
                warmup_steps: TODO.
                total_steps: TODO.
                last_epoch: TODO.
                verbose: TODO.
            """
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):

        """Get lr."""
        step = self.last_epoch + 1
        if step < self.warmup_steps:
            lr_scale = step / self.warmup_steps
        else:
            lr_scale = max(
View full source on GitHub →

Methods

.get_lr() L41

Get lr.

📄 Source
    def get_lr(self):

        """Get lr."""
        step = self.last_epoch + 1
        if step < self.warmup_steps:
            lr_scale = step / self.warmup_steps
        else:
            lr_scale = max(
                0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            )
        return [base_lr * lr_scale for base_lr in self.base_lrs]
method

CustomLambdaLR.get_lr()

funasr.schedulers.lambdalr_cus.CustomLambdaLR · View on GitHub ↗

Get lr.

📄 Source code
    def get_lr(self):

        """Get lr."""
        step = self.last_epoch + 1
        if step < self.warmup_steps:
            lr_scale = step / self.warmup_steps
        else:
            lr_scale = max(
                0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            )
        return [base_lr * lr_scale for base_lr in self.base_lrs]
class

NoamLR

funasr.schedulers.noam_lr · View on GitHub ↗

The LR scheduler proposed by Noam

Ref:

"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf

  • FIXME(kamo) — PyTorch doesn't provide _LRScheduler as public class,

thus the behaviour isn't guaranteed at forward PyTorch version.

  • NOTE(kamo) — The "model_size" in original implementation is derived from

the model, but in this implementation, this parameter is a constant value.

You need to change it if the model is changed.

📄 Source code
class NoamLR(_LRScheduler, AbsBatchStepScheduler):
    """The LR scheduler proposed by Noam

    Ref:
        "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf

    FIXME(kamo): PyTorch doesn't provide _LRScheduler as public class,
     thus the behaviour isn't guaranteed at forward PyTorch version.

    NOTE(kamo): The "model_size" in original implementation is derived from
     the model, but in this implementation, this parameter is a constant value.
     You need to change it if the model is changed.

    """

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        model_size: Union[int, float] = 320,
        warmup_steps: Union[int, float] = 25000,
        last_epoch: int = -1,
    ):
        """Initialize NoamLR.
        
            Args:
                optimizer: TODO.
                model_size: Size/dimension parameter.
                warmup_steps: TODO.
                last_epoch: TODO.
            """
View full source on GitHub →

Methods

.lr_for_WarmupLR(lr) L56

Lr for warmuplr.

Args:

  • lr — TODO.
📄 Source
    def lr_for_WarmupLR(self, lr: float) -> float:
        """Lr for warmuplr.
        
            Args:
                lr: TODO.
            """
        return lr / self.model_size**0.5 / self.warmup_steps**0.5
.get_lr() L71

Get lr.

📄 Source
    def get_lr(self):
        """Get lr."""
        step_num = self.last_epoch + 1
        return [
            lr * self.model_size**-0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
            for lr in self.base_lrs
        ]
method

NoamLR.lr_for_WarmupLR(lr)

funasr.schedulers.noam_lr.NoamLR · View on GitHub ↗

Lr for warmuplr.

Args:

  • lr — TODO.
📄 Source code
    def lr_for_WarmupLR(self, lr: float) -> float:
        """Lr for warmuplr.
        
            Args:
                lr: TODO.
            """
        return lr / self.model_size**0.5 / self.warmup_steps**0.5
method

NoamLR.get_lr()

funasr.schedulers.noam_lr.NoamLR · View on GitHub ↗

Get lr.

📄 Source code
    def get_lr(self):
        """Get lr."""
        step_num = self.last_epoch + 1
        return [
            lr * self.model_size**-0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
            for lr in self.base_lrs
        ]
class

TriStageLR

funasr.schedulers.tri_stage_scheduler · View on GitHub ↗

No documentation yet.

📄 Source code
class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        last_epoch: int = -1,
        phase_ratio: Optional[List[float]] = None,
        init_lr_scale: float = 0.01,
        final_lr_scale: float = 0.01,
    ):
        """Initialize TriStageLR.
        
            Args:
                optimizer: TODO.
                last_epoch: TODO.
                phase_ratio: TODO.
                init_lr_scale: TODO.
                final_lr_scale: TODO.
            """
        self.optimizer = optimizer
        self.last_epoch = last_epoch
        self.phase_ratio = phase_ratio
        self.init_lr_scale = init_lr_scale
        self.final_lr_scale = final_lr_scale
        self.optimizer_lr = self.optimizer.defaults["lr"]

    def init_tri_stage_scheudler(self, max_update):
        """Init tri stage scheudler.
        
            Args:
                max_update: TODO.
View full source on GitHub →

Methods

.init_tri_stage_scheudler(max_update) L40

Init tri stage scheudler.

Args:

  • max_update — TODO.
📄 Source
    def init_tri_stage_scheudler(self, max_update):
        """Init tri stage scheudler.
        
            Args:
                max_update: TODO.
            """
        self.max_update = max_update
        self.peak_lr = self.optimizer_lr
        self.init_lr = self.init_lr_scale * self.optimizer_lr
        self.final_lr = self.final_lr_scale * self.optimizer_lr

        assert self.max_update > 0
        assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
        assert len(self.phase_ratio) == 3
        self.warmup_steps = int(self.max_update * self.phase_ratio[0])
        self.hold_steps = int(self.max_update * self.phase_ratio[1])
        self.decay_steps = int(self.max_update * self.phase_ratio[2])

        self.warmup_rate = (
            (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0
        )
        self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps

        # initial learning rate
        self.lr = self.init_lr

        # __init__() must be invoked before setting field
        # because step() is also invoked in __init__()
        self.set_optimizer_lr(self.lr)
        super().__init__(self.optimizer, self.last_epoch)
.step_update(num_updates) L96

Update the learning rate after each update.

📄 Source
    def step_update(self, num_updates):
        """Update the learning rate after each update."""
        stage, steps_in_stage = self._decide_stage(num_updates)
        if stage == 0:
            self.lr = self.init_lr + self.warmup_rate * steps_in_stage
        elif stage == 1:
            self.lr = self.peak_lr
        elif stage == 2:
            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
        elif stage == 3:
            self.lr = self.final_lr
        else:
            raise ValueError("Undefined stage")
        self.set_optimizer_lr(self.lr)
.set_optimizer_lr(lr) L111

Set optimizer lr.

Args:

  • lr — TODO.
📄 Source
    def set_optimizer_lr(self, lr):
        """Set optimizer lr.
        
            Args:
                lr: TODO.
            """
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr
.get_lr() L120

Get lr.

📄 Source
    def get_lr(self):
        """Get lr."""
        step_num = self.last_epoch + 1
        self.step_update(step_num)
        return [self.lr]
method

TriStageLR.init_tri_stage_scheudler(max_update)

funasr.schedulers.tri_stage_scheduler.TriStageLR · View on GitHub ↗

Init tri stage scheudler.

Args:

  • max_update — TODO.
📄 Source code
    def init_tri_stage_scheudler(self, max_update):
        """Init tri stage scheudler.
        
            Args:
                max_update: TODO.
            """
        self.max_update = max_update
        self.peak_lr = self.optimizer_lr
        self.init_lr = self.init_lr_scale * self.optimizer_lr
        self.final_lr = self.final_lr_scale * self.optimizer_lr

        assert self.max_update > 0
        assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
        assert len(self.phase_ratio) == 3
        self.warmup_steps = int(self.max_update * self.phase_ratio[0])
        self.hold_steps = int(self.max_update * self.phase_ratio[1])
        self.decay_steps = int(self.max_update * self.phase_ratio[2])

        self.warmup_rate = (
            (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0
        )
        self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps

        # initial learning rate
        self.lr = self.init_lr

        # __init__() must be invoked before setting field
        # because step() is also invoked in __init__()
        self.set_optimizer_lr(self.lr)
        super().__init__(self.optimizer, self.last_epoch)
method

TriStageLR.step_update(num_updates)

funasr.schedulers.tri_stage_scheduler.TriStageLR · View on GitHub ↗

Update the learning rate after each update.

📄 Source code
    def step_update(self, num_updates):
        """Update the learning rate after each update."""
        stage, steps_in_stage = self._decide_stage(num_updates)
        if stage == 0:
            self.lr = self.init_lr + self.warmup_rate * steps_in_stage
        elif stage == 1:
            self.lr = self.peak_lr
        elif stage == 2:
            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
        elif stage == 3:
            self.lr = self.final_lr
        else:
            raise ValueError("Undefined stage")
        self.set_optimizer_lr(self.lr)
method

TriStageLR.set_optimizer_lr(lr)

funasr.schedulers.tri_stage_scheduler.TriStageLR · View on GitHub ↗

Set optimizer lr.

Args:

  • lr — TODO.
📄 Source code
    def set_optimizer_lr(self, lr):
        """Set optimizer lr.
        
            Args:
                lr: TODO.
            """
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr
method

TriStageLR.get_lr()

funasr.schedulers.tri_stage_scheduler.TriStageLR · View on GitHub ↗

Get lr.

📄 Source code
    def get_lr(self):
        """Get lr."""
        step_num = self.last_epoch + 1
        self.step_update(step_num)
        return [self.lr]
class

WarmupLR

funasr.schedulers.warmup_lr · View on GitHub ↗

The WarmupLR scheduler

This scheduler is almost same as NoamLR Scheduler except for following difference:

NoamLR:

lr = optimizer.lr * model_size ** -0.5

* min(step ** -0.5, step * warmup_step ** -1.5)

WarmupLR:

lr = optimizer.lr * warmup_step ** 0.5

* min(step ** -0.5, step * warmup_step ** -1.5)

Note that the maximum lr equals to optimizer.lr in this scheduler.

📄 Source code
class WarmupLR(_LRScheduler, AbsBatchStepScheduler):
    """The WarmupLR scheduler

    This scheduler is almost same as NoamLR Scheduler except for following difference:

    NoamLR:
        lr = optimizer.lr * model_size ** -0.5
             * min(step ** -0.5, step * warmup_step ** -1.5)
    WarmupLR:
        lr = optimizer.lr * warmup_step ** 0.5
             * min(step ** -0.5, step * warmup_step ** -1.5)

    Note that the maximum lr equals to optimizer.lr in this scheduler.

    """

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_steps: Union[int, float] = 25000,
        last_epoch: int = -1,
    ):
        """Initialize WarmupLR.
        
            Args:
                optimizer: TODO.
                warmup_steps: TODO.
                last_epoch: TODO.
            """
        self.warmup_steps = warmup_steps
View full source on GitHub →

Methods

.get_lr() L50

Get lr.

📄 Source
    def get_lr(self):
        """Get lr."""
        step_num = self.last_epoch + 1
        return [
            lr * self.warmup_steps**0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
            for lr in self.base_lrs
        ]
method

WarmupLR.get_lr()

funasr.schedulers.warmup_lr.WarmupLR · View on GitHub ↗

Get lr.

📄 Source code
    def get_lr(self):
        """Get lr."""
        step_num = self.last_epoch + 1
        return [
            lr * self.warmup_steps**0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
            for lr in self.base_lrs
        ]