250 entries · auto-generated from source
Click source code to expand. Links point to latest GitHub code.
is_npu_available()检查NPU是否可用。
def is_npu_available():
"""检查NPU是否可用。"""
try:
import torch_npu
return torch_npu.npu.is_available()
except ImportError:
return False
prepare_data_iterator(data_in, input_len, data_type, key)No documentation yet.
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
""" """
data_list = []
key_list = []
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
chars = string.ascii_letters + string.digits
if isinstance(data_in, str):
if data_in.startswith("http://") or data_in.startswith("https://"): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(
data_in
): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
with open(data_in, encoding="utf-8") as fin:
for line in fin:
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
if data_in.endswith(".jsonl"): # file.jsonl: json.dumps({"source": data})
lines = json.loads(line.strip())
data = lines["source"]
key = lines.get("key", key)
else: # filelist, wav.scp, text.txt: id \t data or data
lines = line.strip().split(maxsplit=1)
data = lines[1] if len(lines) > 1 else lines[0]
key = lines[0] if len(lines) > 1 else key
data_list.append(data)
View full source on GitHub →No documentation yet.
class AutoModel:
def __init__(self, **kwargs):
"""Initialize AutoModel with ASR model and optional sub-models.
Args:
model (str): Model name (hub alias or full ID) or local path.
device (str): Device for inference. "cuda:0", "cpu", "mps", "npu:0".
Falls back to CPU if specified device is unavailable.
vad_model (str, optional): VAD model for long audio segmentation.
Enables processing of any-length audio.
vad_kwargs (dict, optional): VAD config, e.g. {"max_single_segment_time": 60000}.
punc_model (str, optional): Punctuation restoration model.
Not needed for Fun-ASR-Nano/SenseVoice/Qwen3-ASR (they output punctuation natively).
spk_model (str, optional): Speaker model for diarization ("cam++" or full model ID).
Requires vad_model. For Qwen3-ASR, also requires forced_aligner.
spk_mode (str, optional): Speaker diarization mode. "punc_segment" (default) or "vad_segment".
hub (str): Model hub. "ms" (ModelScope, default) or "hf" (HuggingFace).
ncpu (int): CPU threads (default: 4).
disable_update (bool): Skip version check on startup.
disable_pbar (bool): Disable tqdm progress bars.
**kwargs: Additional model-specific parameters (passed to config.yaml overrides).
Examples:
>>> model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc")
>>> model = AutoModel(model="FunAudioLLM/Fun-ASR-Nano-2512", trust_remote_code=True,
... remote_code="./model.py", vad_model="fsmn-vad", spk_model="cam++", hub="hf")
"""
try:
from funasr.utils.version_checker import check_for_update
View full source on GitHub →.build_model(**kwargs) L275Download model from hub, build all components, and load pretrained weights.
This method handles the full model construction pipeline:
1. Download model files from ModelScope/HuggingFace (if not local)
2. Parse config.yaml to determine model class, tokenizer, frontend
3. Instantiate tokenizer, frontend, and model via the registry
4. Load pretrained weights from model.pt
**kwargs — Must include 'model' (str). All other config.yaml fields can be overridden.tuple — (model, kwargs) where model is the instantiated nn.Module and def build_model(**kwargs):
"""Download model from hub, build all components, and load pretrained weights.
This method handles the full model construction pipeline:
1. Download model files from ModelScope/HuggingFace (if not local)
2. Parse config.yaml to determine model class, tokenizer, frontend
3. Instantiate tokenizer, frontend, and model via the registry
4. Load pretrained weights from model.pt
Args:
**kwargs: Must include 'model' (str). All other config.yaml fields can be overridden.
Returns:
tuple: (model, kwargs) where model is the instantiated nn.Module and
kwargs contains the resolved configuration.
"""
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if (
(device.startswith("cuda") and not torch.cuda.is_available())
or (device.startswith("xpu") and not torch.xpu.is_available())
or (device.startswith("mps") and not torch.backends.mps.is_available())
or (device.startswith("npu") and not is_npu_available())
or kwargs.get("ngpu", 1) == 0
View full source →.generate(input, input_len, progress_callback, **cfg) L436Run speech recognition on input audio.
This is the primary user-facing method. It automatically routes to:
input — Audio input. Accepts:input_len (tensor, optional) — Length of each input sample.progress_callback (callable, optional) — fn(current, total) called during processing.**cfg — Runtime parameters:list[dict]: Results for each input sample. Common fields:
def generate(self, input, input_len=None, progress_callback=None, **cfg):
"""Run speech recognition on input audio.
This is the primary user-facing method. It automatically routes to:
- inference() if no vad_model is configured (single utterance)
- inference_with_vad() if vad_model is configured (long audio with segmentation)
Args:
input: Audio input. Accepts:
- File path (str): "audio.wav", "audio.mp3"
- URL (str): "https://..."
- numpy array: raw audio samples (float32, 16kHz)
- list: batch of file paths or arrays
- bytes: raw audio bytes
input_len (tensor, optional): Length of each input sample.
progress_callback (callable, optional): fn(current, total) called during processing.
**cfg: Runtime parameters:
- cache (dict): State cache for streaming mode. Pass {} for first call.
- hotword (str/list): Keywords to boost recognition accuracy.
- language (str): Language hint ("auto", "zh", "en", "Chinese", etc.)
- batch_size_s (int): Dynamic batch total duration in seconds.
- is_final (bool): Last chunk flag for streaming mode.
- return_spk_res (bool): Return speaker diarization results.
- sentence_timestamp (bool): Return sentence-level timestamps.
- use_itn (bool): Apply inverse text normalization (SenseVoice).
Returns:
list[dict]: Results for each input sample. Common fields:
- "key" (str): Sample identifier
- "text" (str): Recognized text
View full source →.inference(input, input_len, model, kwargs, key, progress_callback, **cfg) L490Run model inference on input data (internal method).
Handles batching, timing, and progress reporting. Called by generate()
and inference_with_vad(). Typically not called directly by users.
input — Audio data, file path, or text (for punc model).input_len (tensor, optional) — Input lengths for batch.model (nn.Module, optional) — Override model (used for VAD/PUNC/SPK sub-models).kwargs (dict, optional) — Override kwargs (used for sub-model configs).key (list, optional) — Sample identifiers.progress_callback (callable, optional) — Progress reporting function.**cfg — Additional config merged into kwargs.list[dict]: Model inference results.
def inference(
self,
input,
input_len=None,
model=None,
kwargs=None,
key=None,
progress_callback=None,
**cfg,
):
"""Run model inference on input data (internal method).
Handles batching, timing, and progress reporting. Called by generate()
and inference_with_vad(). Typically not called directly by users.
Args:
input: Audio data, file path, or text (for punc model).
input_len (tensor, optional): Input lengths for batch.
model (nn.Module, optional): Override model (used for VAD/PUNC/SPK sub-models).
kwargs (dict, optional): Override kwargs (used for sub-model configs).
key (list, optional): Sample identifiers.
progress_callback (callable, optional): Progress reporting function.
**cfg: Additional config merged into kwargs.
Returns:
list[dict]: Model inference results.
"""
if kwargs is None:
self._reset_runtime_configs()
kwargs = self.kwargs if kwargs is None else kwargs
View full source →.inference_with_vad(input, input_len, **cfg) L592Run ASR with VAD segmentation, punctuation, and optional speaker diarization.
Pipeline:
1. VAD: Segment audio into speech regions
2. ASR: Recognize each segment (sorted by length for efficient batching)
3. Timestamp merge: Combine per-segment timestamps with VAD offsets
4. Punctuation: Add punctuation to combined text (if punc_model configured)
5. Speaker diarization: Cluster speaker embeddings and assign labels (if spk_model configured)
input — Audio file path, URL, or numpy array.input_len — Not used (kept for interface consistency).**cfg — Runtime parameters (same as generate()).list[dict]: Results with fields: key, text, timestamp, sentence_info, raw_text.
def inference_with_vad(self, input, input_len=None, **cfg):
"""Run ASR with VAD segmentation, punctuation, and optional speaker diarization.
Pipeline:
1. VAD: Segment audio into speech regions
2. ASR: Recognize each segment (sorted by length for efficient batching)
3. Timestamp merge: Combine per-segment timestamps with VAD offsets
4. Punctuation: Add punctuation to combined text (if punc_model configured)
5. Speaker diarization: Cluster speaker embeddings and assign labels (if spk_model configured)
Args:
input: Audio file path, URL, or numpy array.
input_len: Not used (kept for interface consistency).
**cfg: Runtime parameters (same as generate()).
Returns:
list[dict]: Results with fields: key, text, timestamp, sentence_info, raw_text.
"""
self._reset_runtime_configs()
if self.spk_model is not None and "output_timestamp" not in cfg:
cfg["output_timestamp"] = True
cfg["return_time_stamps"] = True
kwargs = self.kwargs
# step.1: compute the vad model
deep_update(self.vad_kwargs, cfg)
beg_vad = time.time()
res = self.inference(
input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
)
end_vad = time.time()
View full source →.export(input, **cfg) L911Export model to ONNX format.
Creates a deep copy of the model to isolate ONNX operator monkey-patching,
then runs torch.onnx.export. The original model remains usable after export.
input — Sample input for tracing (auto-generated if None).**cfg — Export parameters:str — Path to the exported model directory. def export(self, input=None, **cfg):
"""Export model to ONNX format.
Creates a deep copy of the model to isolate ONNX operator monkey-patching,
then runs torch.onnx.export. The original model remains usable after export.
Args:
input: Sample input for tracing (auto-generated if None).
**cfg: Export parameters:
- type (str): Export format, "onnx" (default).
- quantize (bool): Whether to quantize the model.
- device (str): Device for export.
Returns:
str: Path to the exported model directory.
"""
"""
:param input:
:param type:
:param quantize:
:param fallback_num:
:param calib_num:
:param opset_version:
:param cfg:
:return:
"""
device = cfg.get("device", "cpu")
View full source →AutoModel.build_model(**kwargs)Download model from hub, build all components, and load pretrained weights.
This method handles the full model construction pipeline:
1. Download model files from ModelScope/HuggingFace (if not local)
2. Parse config.yaml to determine model class, tokenizer, frontend
3. Instantiate tokenizer, frontend, and model via the registry
4. Load pretrained weights from model.pt
**kwargs — Must include 'model' (str). All other config.yaml fields can be overridden.tuple — (model, kwargs) where model is the instantiated nn.Module and def build_model(**kwargs):
"""Download model from hub, build all components, and load pretrained weights.
This method handles the full model construction pipeline:
1. Download model files from ModelScope/HuggingFace (if not local)
2. Parse config.yaml to determine model class, tokenizer, frontend
3. Instantiate tokenizer, frontend, and model via the registry
4. Load pretrained weights from model.pt
Args:
**kwargs: Must include 'model' (str). All other config.yaml fields can be overridden.
Returns:
tuple: (model, kwargs) where model is the instantiated nn.Module and
kwargs contains the resolved configuration.
"""
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
device = kwargs.get("device", "cuda")
if (
(device.startswith("cuda") and not torch.cuda.is_available())
or (device.startswith("xpu") and not torch.xpu.is_available())
or (device.startswith("mps") and not torch.backends.mps.is_available())
or (device.startswith("npu") and not is_npu_available())
or kwargs.get("ngpu", 1) == 0
View full source on GitHub →AutoModel.generate(input, input_len, progress_callback, **cfg)Run speech recognition on input audio.
This is the primary user-facing method. It automatically routes to:
input — Audio input. Accepts:input_len (tensor, optional) — Length of each input sample.progress_callback (callable, optional) — fn(current, total) called during processing.**cfg — Runtime parameters:list[dict]: Results for each input sample. Common fields:
def generate(self, input, input_len=None, progress_callback=None, **cfg):
"""Run speech recognition on input audio.
This is the primary user-facing method. It automatically routes to:
- inference() if no vad_model is configured (single utterance)
- inference_with_vad() if vad_model is configured (long audio with segmentation)
Args:
input: Audio input. Accepts:
- File path (str): "audio.wav", "audio.mp3"
- URL (str): "https://..."
- numpy array: raw audio samples (float32, 16kHz)
- list: batch of file paths or arrays
- bytes: raw audio bytes
input_len (tensor, optional): Length of each input sample.
progress_callback (callable, optional): fn(current, total) called during processing.
**cfg: Runtime parameters:
- cache (dict): State cache for streaming mode. Pass {} for first call.
- hotword (str/list): Keywords to boost recognition accuracy.
- language (str): Language hint ("auto", "zh", "en", "Chinese", etc.)
- batch_size_s (int): Dynamic batch total duration in seconds.
- is_final (bool): Last chunk flag for streaming mode.
- return_spk_res (bool): Return speaker diarization results.
- sentence_timestamp (bool): Return sentence-level timestamps.
- use_itn (bool): Apply inverse text normalization (SenseVoice).
Returns:
list[dict]: Results for each input sample. Common fields:
- "key" (str): Sample identifier
- "text" (str): Recognized text
View full source on GitHub →AutoModel.inference(input, input_len, model, kwargs, key, progress_callback, **cfg)Run model inference on input data (internal method).
Handles batching, timing, and progress reporting. Called by generate()
and inference_with_vad(). Typically not called directly by users.
input — Audio data, file path, or text (for punc model).input_len (tensor, optional) — Input lengths for batch.model (nn.Module, optional) — Override model (used for VAD/PUNC/SPK sub-models).kwargs (dict, optional) — Override kwargs (used for sub-model configs).key (list, optional) — Sample identifiers.progress_callback (callable, optional) — Progress reporting function.**cfg — Additional config merged into kwargs.list[dict]: Model inference results.
def inference(
self,
input,
input_len=None,
model=None,
kwargs=None,
key=None,
progress_callback=None,
**cfg,
):
"""Run model inference on input data (internal method).
Handles batching, timing, and progress reporting. Called by generate()
and inference_with_vad(). Typically not called directly by users.
Args:
input: Audio data, file path, or text (for punc model).
input_len (tensor, optional): Input lengths for batch.
model (nn.Module, optional): Override model (used for VAD/PUNC/SPK sub-models).
kwargs (dict, optional): Override kwargs (used for sub-model configs).
key (list, optional): Sample identifiers.
progress_callback (callable, optional): Progress reporting function.
**cfg: Additional config merged into kwargs.
Returns:
list[dict]: Model inference results.
"""
if kwargs is None:
self._reset_runtime_configs()
kwargs = self.kwargs if kwargs is None else kwargs
View full source on GitHub →AutoModel.inference_with_vad(input, input_len, **cfg)Run ASR with VAD segmentation, punctuation, and optional speaker diarization.
Pipeline:
1. VAD: Segment audio into speech regions
2. ASR: Recognize each segment (sorted by length for efficient batching)
3. Timestamp merge: Combine per-segment timestamps with VAD offsets
4. Punctuation: Add punctuation to combined text (if punc_model configured)
5. Speaker diarization: Cluster speaker embeddings and assign labels (if spk_model configured)
input — Audio file path, URL, or numpy array.input_len — Not used (kept for interface consistency).**cfg — Runtime parameters (same as generate()).list[dict]: Results with fields: key, text, timestamp, sentence_info, raw_text.
def inference_with_vad(self, input, input_len=None, **cfg):
"""Run ASR with VAD segmentation, punctuation, and optional speaker diarization.
Pipeline:
1. VAD: Segment audio into speech regions
2. ASR: Recognize each segment (sorted by length for efficient batching)
3. Timestamp merge: Combine per-segment timestamps with VAD offsets
4. Punctuation: Add punctuation to combined text (if punc_model configured)
5. Speaker diarization: Cluster speaker embeddings and assign labels (if spk_model configured)
Args:
input: Audio file path, URL, or numpy array.
input_len: Not used (kept for interface consistency).
**cfg: Runtime parameters (same as generate()).
Returns:
list[dict]: Results with fields: key, text, timestamp, sentence_info, raw_text.
"""
self._reset_runtime_configs()
if self.spk_model is not None and "output_timestamp" not in cfg:
cfg["output_timestamp"] = True
cfg["return_time_stamps"] = True
kwargs = self.kwargs
# step.1: compute the vad model
deep_update(self.vad_kwargs, cfg)
beg_vad = time.time()
res = self.inference(
input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
)
end_vad = time.time()
View full source on GitHub →AutoModel.export(input, **cfg)Export model to ONNX format.
Creates a deep copy of the model to isolate ONNX operator monkey-patching,
then runs torch.onnx.export. The original model remains usable after export.
input — Sample input for tracing (auto-generated if None).**cfg — Export parameters:str — Path to the exported model directory. def export(self, input=None, **cfg):
"""Export model to ONNX format.
Creates a deep copy of the model to isolate ONNX operator monkey-patching,
then runs torch.onnx.export. The original model remains usable after export.
Args:
input: Sample input for tracing (auto-generated if None).
**cfg: Export parameters:
- type (str): Export format, "onnx" (default).
- quantize (bool): Whether to quantize the model.
- device (str): Device for export.
Returns:
str: Path to the exported model directory.
"""
"""
:param input:
:param type:
:param quantize:
:param fallback_num:
:param calib_num:
:param opset_version:
:param cfg:
:return:
"""
device = cfg.get("device", "cpu")
View full source on GitHub →Registry system for classes.
class RegisterTables:
"""Registry system for classes."""
model_classes = {}
frontend_classes = {}
specaug_classes = {}
normalize_classes = {}
encoder_classes = {}
decoder_classes = {}
joint_network_classes = {}
predictor_classes = {}
stride_conv_classes = {}
tokenizer_classes = {}
dataloader_classes = {}
batch_sampler_classes = {}
dataset_classes = {}
index_ds_classes = {}
def print(self, key: str = None) -> None:
"""Print registered classes."""
print("\ntables: \n")
fields = vars(self)
headers = ["register name", "class name", "class location"]
for classes_key, classes_dict in fields.items():
if classes_key.endswith("_meta") and (key is None or key in classes_key):
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
metas = []
for register_key, meta in classes_dict.items():
metas.append(meta)
metas.sort(key=lambda x: x[0])
View full source on GitHub →.print(key) L26Print registered classes.
def print(self, key: str = None) -> None:
"""Print registered classes."""
print("\ntables: \n")
fields = vars(self)
headers = ["register name", "class name", "class location"]
for classes_key, classes_dict in fields.items():
if classes_key.endswith("_meta") and (key is None or key in classes_key):
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
metas = []
for register_key, meta in classes_dict.items():
metas.append(meta)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
for row in data:
print(
"| "
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
+ " |"
)
print("\n")
.register(register_tables_key, key) L49Decorator to register a class.
def register(self, register_tables_key: str, key: str = None) -> callable:
"""Decorator to register a class."""
def decorator(target_class):
"""Decorator.
Args:
target_class: TODO.
"""
if not hasattr(self, register_tables_key):
setattr(self, register_tables_key, {})
logging.debug(f"New registry table added: {register_tables_key}")
registry = getattr(self, register_tables_key)
registry_key = key if key is not None else target_class.__name__
if registry_key in registry:
logging.debug(
f"Key {registry_key} already exists in {register_tables_key}, re-register"
)
registry[registry_key] = target_class
register_tables_key_meta = register_tables_key + "_meta"
if not hasattr(self, register_tables_key_meta):
setattr(self, register_tables_key_meta, {})
registry_meta = getattr(self, register_tables_key_meta)
class_file = inspect.getfile(target_class)
class_line = inspect.getsourcelines(target_class)[1]
View full source →RegisterTables.print(key)Print registered classes.
def print(self, key: str = None) -> None:
"""Print registered classes."""
print("\ntables: \n")
fields = vars(self)
headers = ["register name", "class name", "class location"]
for classes_key, classes_dict in fields.items():
if classes_key.endswith("_meta") and (key is None or key in classes_key):
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
metas = []
for register_key, meta in classes_dict.items():
metas.append(meta)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
for row in data:
print(
"| "
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
+ " |"
)
print("\n")
RegisterTables.register(register_tables_key, key)Decorator to register a class.
def register(self, register_tables_key: str, key: str = None) -> callable:
"""Decorator to register a class."""
def decorator(target_class):
"""Decorator.
Args:
target_class: TODO.
"""
if not hasattr(self, register_tables_key):
setattr(self, register_tables_key, {})
logging.debug(f"New registry table added: {register_tables_key}")
registry = getattr(self, register_tables_key)
registry_key = key if key is not None else target_class.__name__
if registry_key in registry:
logging.debug(
f"Key {registry_key} already exists in {register_tables_key}, re-register"
)
registry[registry_key] = target_class
register_tables_key_meta = register_tables_key + "_meta"
if not hasattr(self, register_tables_key_meta):
setattr(self, register_tables_key_meta, {})
registry_meta = getattr(self, register_tables_key_meta)
class_file = inspect.getfile(target_class)
class_line = inspect.getsourcelines(target_class)[1]
View full source on GitHub →BAT (Boundary-Aware Transducer) — Low-latency RNN-T model with boundary detection.Inherits from Transducer. Designed for streaming ASR with reduced latency
by predicting token boundaries explicitly.
class BAT(Transducer):
"""BAT (Boundary-Aware Transducer): Low-latency RNN-T model with boundary detection.
Inherits from Transducer. Designed for streaming ASR with reduced latency
by predicting token boundaries explicitly.
"""
pass
BiCifParaformer — Paraformer with Bidirectional CIF for Timestamp Prediction.Extends Paraformer with a second CIF predictor that provides accurate
character — level timestamp prediction alongside ASR. Uses bidirectionalinformation flow for better alignment between audio frames and text tokens.
Reference:
(https://arxiv.org/abs/2301.12343)
{"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}
Author — Speech Lab of DAMO Academy, Alibaba Groupclass BiCifParaformer(Paraformer):
"""BiCifParaformer: Paraformer with Bidirectional CIF for Timestamp Prediction.
Extends Paraformer with a second CIF predictor that provides accurate
character-level timestamp prediction alongside ASR. Uses bidirectional
information flow for better alignment between audio frames and text tokens.
Reference:
- FunASR: A Fundamental End-to-End Speech Recognition Toolkit (https://arxiv.org/abs/2305.11013)
- Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
(https://arxiv.org/abs/2301.12343)
Output:
{"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize BiCifParaformer.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
View full source on GitHub →.calc_predictor(encoder_out, encoder_out_lens) L162Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = (
self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num) L177Calc predictor timestamp.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.token_num — TODO. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
"""Calc predictor timestamp.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
token_num: TODO.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
encoder_out, encoder_out_mask, token_num
)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
.forward(speech, speech_lengths, text, text_lengths, **kwargs) L193Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L271Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source →.export(**kwargs) L430Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
BiCifParaformer.calc_predictor(encoder_out, encoder_out_lens)Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = (
self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id)
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
BiCifParaformer.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num)Calc predictor timestamp.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.token_num — TODO. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
"""Calc predictor timestamp.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
token_num: TODO.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
encoder_out, encoder_out_mask, token_num
)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
BiCifParaformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
View full source on GitHub →BiCifParaformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source on GitHub →BiCifParaformer.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
Branchformer — Parallel branch encoder architecture.Uses parallel branches of self-attention and convolution that are
merged via concatenation. Alternative to Conformer with similar accuracy.
Inherits Transformer pipeline for training and inference.
class Branchformer(Transformer):
"""Branchformer: Parallel branch encoder architecture.
Uses parallel branches of self-attention and convolution that are
merged via concatenation. Alternative to Conformer with similar accuracy.
Inherits Transformer pipeline for training and inference.
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize Branchformer.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
CAM++ Speaker Verification Model.
Extracts fixed-dimensional speaker embeddings from variable-length audio.
Used for speaker verification and speaker diarization pipelines.
Output — 192-dimensional speaker embedding per utterance.class CAMPPlus(torch.nn.Module):
"""CAM++ Speaker Verification Model.
Extracts fixed-dimensional speaker embeddings from variable-length audio.
Used for speaker verification and speaker diarization pipelines.
Output: 192-dimensional speaker embedding per utterance.
"""
def __init__(
self,
feat_dim=80,
embedding_size=192,
growth_rate=32,
bn_size=4,
init_channels=128,
config_str="batchnorm-relu",
memory_efficient=True,
output_level="segment",
**kwargs,
):
"""Initialize CAMPPlus.
Args:
feat_dim: Size/dimension parameter.
embedding_size: Size/dimension parameter.
growth_rate: TODO.
bn_size: Size/dimension parameter.
init_channels: TODO.
config_str: TODO.
View full source on GitHub →.forward(x) L141Extract speaker embedding from fbank features.
x (Tensor) — Input fbank features, shape (batch, time, feat_dim).Tensor — Speaker embedding, shape (batch, embedding_size) for segment level, def forward(self, x):
"""Extract speaker embedding from fbank features.
Args:
x (Tensor): Input fbank features, shape (batch, time, feat_dim).
Returns:
Tensor: Speaker embedding, shape (batch, embedding_size) for segment level,
or (batch, time, channels) for frame level.
"""
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector(x)
if self.output_level == "frame":
x = x.transpose(1, 2)
return x
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L158Run speaker embedding extraction on audio input.
data_in — Audio input (file path, numpy array, or list).data_lengths — Not used.key (list) — Sample identifiers.tokenizer — Not used.frontend — Not used.**kwargs — Must include 'device' (str) and optional 'fs' (int, default 16000).tuple — (results, meta_data) where results is def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run speaker embedding extraction on audio input.
Args:
data_in: Audio input (file path, numpy array, or list).
data_lengths: Not used.
key (list): Sample identifiers.
tokenizer: Not used.
frontend: Not used.
**kwargs: Must include 'device' (str) and optional 'fs' (int, default 16000).
Returns:
tuple: (results, meta_data) where results is
[{"spk_embedding": Tensor of shape (1, 192)}]
"""
# extract fbank feats
meta_data = {}
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
)
time2 = time.perf_counter()
View full source →CAMPPlus.forward(x)Extract speaker embedding from fbank features.
x (Tensor) — Input fbank features, shape (batch, time, feat_dim).Tensor — Speaker embedding, shape (batch, embedding_size) for segment level, def forward(self, x):
"""Extract speaker embedding from fbank features.
Args:
x (Tensor): Input fbank features, shape (batch, time, feat_dim).
Returns:
Tensor: Speaker embedding, shape (batch, embedding_size) for segment level,
or (batch, time, channels) for frame level.
"""
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector(x)
if self.output_level == "frame":
x = x.transpose(1, 2)
return x
CAMPPlus.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run speaker embedding extraction on audio input.
data_in — Audio input (file path, numpy array, or list).data_lengths — Not used.key (list) — Sample identifiers.tokenizer — Not used.frontend — Not used.**kwargs — Must include 'device' (str) and optional 'fs' (int, default 16000).tuple — (results, meta_data) where results is def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run speaker embedding extraction on audio input.
Args:
data_in: Audio input (file path, numpy array, or list).
data_lengths: Not used.
key (list): Sample identifiers.
tokenizer: Not used.
frontend: Not used.
**kwargs: Must include 'device' (str) and optional 'fs' (int, default 16000).
Returns:
tuple: (results, meta_data) where results is
[{"spk_embedding": Tensor of shape (1, 192)}]
"""
# extract fbank feats
meta_data = {}
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
)
time2 = time.perf_counter()
View full source on GitHub →Conformer — CTC-attention hybrid encoder-decoder model.Combines convolution and self-attention in the encoder for better
local and global context modeling. Inherits full Transformer pipeline
(CTC + attention decoder + beam search).
Output — {"key": str, "text": str}class Conformer(Transformer):
"""Conformer: CTC-attention hybrid encoder-decoder model.
Combines convolution and self-attention in the encoder for better
local and global context modeling. Inherits full Transformer pipeline
(CTC + attention decoder + beam search).
Output: {"key": str, "text": str}
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize Conformer.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
CTC — attention hybrid Encoder-Decoder modelclass Conformer(Transformer):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize Conformer.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
ContextualParaformer — Paraformer with hotword/context biasing.Extends Paraformer with a context encoder that incorporates user-defined
hotwords/keywords to boost recognition of domain-specific terms.
Usage — Pass hotwords via generate(hotword='term1 term2').class ContextualParaformer(Paraformer):
"""ContextualParaformer: Paraformer with hotword/context biasing.
Extends Paraformer with a context encoder that incorporates user-defined
hotwords/keywords to boost recognition of domain-specific terms.
Usage: Pass hotwords via generate(hotword='term1 term2').
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize ContextualParaformer.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
self.target_buffer_length = kwargs.get("target_buffer_length", -1)
inner_dim = kwargs.get("inner_dim", 256)
bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L96Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
text_lengths = text_lengths.squeeze()
speech_lengths = speech_lengths.squeeze()
batch_size = speech.shape[0]
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
# dha_pad = kwargs.get("dha_pad")
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
View full source →.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info) L268Sampler.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO.contextual_info — TODO. def sampler(
self,
encoder_out,
encoder_out_lens,
ys_pad,
ys_pad_lens,
pre_acoustic_embeds,
contextual_info,
):
"""Sampler.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
contextual_info: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
else:
ys_pad_embed = self.decoder.embed(ys_pad)
with torch.no_grad():
decoder_outs = self.decoder(
encoder_out,
View full source →.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list, clas_scale) L331Cal decoder with predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad.hw_list — TODO.clas_scale — TODO. def cal_decoder_with_predictor(
self,
encoder_out,
encoder_out_lens,
sematic_embeds,
ys_pad_lens,
hw_list=None,
clas_scale=1.0,
):
"""Cal decoder with predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
hw_list: TODO.
clas_scale: TODO.
"""
if hw_list is None:
hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
hw_list_pad = pad_list(hw_list, 0)
if self.use_decoder_embedding:
hw_embed = self.decoder.embed(hw_list_pad)
else:
hw_embed = self.bias_embed(hw_list_pad)
hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
else:
hw_lengths = [len(i) for i in hw_list]
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L387Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source →.generate_hotwords_list(hotword_list_or_file, tokenizer, frontend) L534Generate hotwords list.
hotword_list_or_file — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction. def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
"""Generate hotwords list.
Args:
hotword_list_or_file: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
"""
def load_seg_dict(seg_dict_file):
"""Load seg dict.
Args:
seg_dict_file: TODO.
"""
seg_dict = {}
assert isinstance(seg_dict_file, str)
with open(seg_dict_file, "r", encoding="utf8") as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
def seg_tokenize(txt, seg_dict):
"""Seg tokenize.
Args:
txt: TODO.
View full source →.export(**kwargs) L659Export.
**kwargs — Additional keyword arguments. def export(
self,
**kwargs,
):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
ContextualParaformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
text_lengths = text_lengths.squeeze()
speech_lengths = speech_lengths.squeeze()
batch_size = speech.shape[0]
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
# dha_pad = kwargs.get("dha_pad")
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
View full source on GitHub →ContextualParaformer.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info)Sampler.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO.contextual_info — TODO. def sampler(
self,
encoder_out,
encoder_out_lens,
ys_pad,
ys_pad_lens,
pre_acoustic_embeds,
contextual_info,
):
"""Sampler.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
contextual_info: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
else:
ys_pad_embed = self.decoder.embed(ys_pad)
with torch.no_grad():
decoder_outs = self.decoder(
encoder_out,
View full source on GitHub →ContextualParaformer.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list, clas_scale)Cal decoder with predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad.hw_list — TODO.clas_scale — TODO. def cal_decoder_with_predictor(
self,
encoder_out,
encoder_out_lens,
sematic_embeds,
ys_pad_lens,
hw_list=None,
clas_scale=1.0,
):
"""Cal decoder with predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
hw_list: TODO.
clas_scale: TODO.
"""
if hw_list is None:
hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list
hw_list_pad = pad_list(hw_list, 0)
if self.use_decoder_embedding:
hw_embed = self.decoder.embed(hw_list_pad)
else:
hw_embed = self.bias_embed(hw_list_pad)
hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
else:
hw_lengths = [len(i) for i in hw_list]
View full source on GitHub →ContextualParaformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source on GitHub →ContextualParaformer.generate_hotwords_list(hotword_list_or_file, tokenizer, frontend)Generate hotwords list.
hotword_list_or_file — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction. def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
"""Generate hotwords list.
Args:
hotword_list_or_file: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
"""
def load_seg_dict(seg_dict_file):
"""Load seg dict.
Args:
seg_dict_file: TODO.
"""
seg_dict = {}
assert isinstance(seg_dict_file, str)
with open(seg_dict_file, "r", encoding="utf8") as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
def seg_tokenize(txt, seg_dict):
"""Seg tokenize.
Args:
txt: TODO.
View full source on GitHub →ContextualParaformer.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(
self,
**kwargs,
):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
CT — Transformer: Punctuation Restoration Model.Adds punctuation (comma, period, question mark) to unpunctuated text.
Supports Chinese and English. Used as punc_model in the ASR pipeline.
Output — {"key": "...", "text": "punctuated text", "punc_array": Tensor}punc_array encoding: 1=none, 2=comma(,), 3=period(。), 4=question(?)
Note — Not needed for Fun-ASR-Nano/SenseVoice/Qwen3-ASR (they output punctuation natively).Only required for Paraformer models.
Author — Speech Lab of DAMO Academy, Alibaba GroupCT — Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detectionhttps://arxiv.org/pdf/2003.01309.pdf
class CTTransformer(torch.nn.Module):
"""CT-Transformer: Punctuation Restoration Model.
Adds punctuation (comma, period, question mark) to unpunctuated text.
Supports Chinese and English. Used as punc_model in the ASR pipeline.
Output: {"key": "...", "text": "punctuated text", "punc_array": Tensor}
punc_array encoding: 1=none, 2=comma(,), 3=period(。), 4=question(?)
Note: Not needed for Fun-ASR-Nano/SenseVoice/Qwen3-ASR (they output punctuation natively).
Only required for Paraformer models.
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
encoder: str = None,
encoder_conf: dict = None,
vocab_size: int = -1,
punc_list: list = None,
punc_weight: list = None,
embed_unit: int = 128,
att_unit: int = 256,
dropout_rate: float = 0.5,
ignore_id: int = -1,
sos: int = 1,
eos: int = 2,
View full source on GitHub →.punc_forward(text, text_lengths, **kwargs) L113Compute loss value from buffer sequences.
input (torch.Tensor) — Input ids. (batch, len)hidden (torch.Tensor) — Target ids. (batch, len) def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(text)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y, None
.score(y, state, x) L131Score new token.
y (torch.Tensor) — 1D torch.int64 prefix tokens.state — Scorer state for prefix tokensx (torch.Tensor) — encoder feature that generates ys.tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(
self.embed(y), self._target_mask(y), cache=state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
.batch_score(ys, states, xs) L153Score new token batch.
ys (torch.Tensor) — torch.int64 prefix tokens (n_batch, ylen).states (List[Any]) — Scorer states for prefix tokens.xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
]
# batch decoding
h, _, states = self.encoder.forward_one_step(
View full source →.nll(text, punc, text_lengths, punc_lengths, max_length, vad_indexes, vad_indexes_lengths) L192Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
text — (Batch, Length)punc — (Batch, Length)text_lengths — (Batch,)max_lengths — int def nll(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
max_length: Optional[int] = None,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
punc: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
punc = punc[:, : text_lengths.max()]
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
if self.with_vad():
# Should be VadRealtimeTransformer
View full source →.forward(text, punc, text_lengths, punc_lengths, vad_indexes, vad_indexes_lengths) L262Forward pass for training.
text — Text tensor or string input.punc — TODO.text_lengths — Length of each text sample.punc_lengths — Lengths of punc.vad_indexes — TODO.vad_indexes_lengths — Lengths of vad_indexes. def forward(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
):
"""Forward pass for training.
Args:
text: Text tensor or string input.
punc: TODO.
text_lengths: Length of each text sample.
punc_lengths: Lengths of punc.
vad_indexes: TODO.
vad_indexes_lengths: Lengths of vad_indexes.
"""
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L290Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
assert len(data_in) == 1
if not data_in[0] or (isinstance(data_in[0], str) and not data_in[0].strip()):
meta_data = {"batch_data_time": -1}
return [{"key": key[0] if key else "", "text": "", "punc_array": None}], meta_data
text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
vad_indexes = kwargs.get("vad_indexes", None)
# text = data_in[0]
# text_lengths = data_lengths[0] if data_lengths is not None else None
split_size = kwargs.get("split_size", 20)
tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
View full source →.export(**kwargs) L471Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
CTTransformer.punc_forward(text, text_lengths, **kwargs)Compute loss value from buffer sequences.
input (torch.Tensor) — Input ids. (batch, len)hidden (torch.Tensor) — Target ids. (batch, len) def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(text)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y, None
CTTransformer.with_vad()With vad.
def with_vad(self):
"""With vad."""
return False
CTTransformer.score(y, state, x)Score new token.
y (torch.Tensor) — 1D torch.int64 prefix tokens.state — Scorer state for prefix tokensx (torch.Tensor) — encoder feature that generates ys.tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (vocab_size)
and next state for ys
"""
y = y.unsqueeze(0)
h, _, cache = self.encoder.forward_one_step(
self.embed(y), self._target_mask(y), cache=state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
CTTransformer.batch_score(ys, states, xs)Score new token batch.
ys (torch.Tensor) — torch.int64 prefix tokens (n_batch, ylen).states (List[Any]) — Scorer states for prefix tokens.xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, vocab_size)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
]
# batch decoding
h, _, states = self.encoder.forward_one_step(
View full source on GitHub →CTTransformer.nll(text, punc, text_lengths, punc_lengths, max_length, vad_indexes, vad_indexes_lengths)Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
text — (Batch, Length)punc — (Batch, Length)text_lengths — (Batch,)max_lengths — int def nll(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
max_length: Optional[int] = None,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
punc: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
punc = punc[:, : text_lengths.max()]
else:
text = text[:, :max_length]
punc = punc[:, :max_length]
if self.with_vad():
# Should be VadRealtimeTransformer
View full source on GitHub →CTTransformer.forward(text, punc, text_lengths, punc_lengths, vad_indexes, vad_indexes_lengths)Forward pass for training.
text — Text tensor or string input.punc — TODO.text_lengths — Length of each text sample.punc_lengths — Lengths of punc.vad_indexes — TODO.vad_indexes_lengths — Lengths of vad_indexes. def forward(
self,
text: torch.Tensor,
punc: torch.Tensor,
text_lengths: torch.Tensor,
punc_lengths: torch.Tensor,
vad_indexes: Optional[torch.Tensor] = None,
vad_indexes_lengths: Optional[torch.Tensor] = None,
):
"""Forward pass for training.
Args:
text: Text tensor or string input.
punc: TODO.
text_lengths: Length of each text sample.
punc_lengths: Lengths of punc.
vad_indexes: TODO.
vad_indexes_lengths: Lengths of vad_indexes.
"""
nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
CTTransformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
assert len(data_in) == 1
if not data_in[0] or (isinstance(data_in[0], str) and not data_in[0].strip()):
meta_data = {"batch_data_time": -1}
return [{"key": key[0] if key else "", "text": "", "punc_array": None}], meta_data
text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
vad_indexes = kwargs.get("vad_indexes", None)
# text = data_in[0]
# text_lengths = data_lengths[0] if data_lengths is not None else None
split_size = kwargs.get("split_size", 20)
tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
View full source on GitHub →CTTransformer.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
CT — Transformer Streaming: Online punctuation restoration.Processes text incrementally with a sliding window, maintaining cache
of previous context for consistent punctuation decisions across chunks.
Used as punc_model in streaming ASR pipelines.
Supports VAD-aware punctuation: uses VAD boundaries to improve sentence segmentation.
Reference — https://arxiv.org/pdf/2003.01309.pdfOutput — {"key": str, "text": str, "punc_array": Tensor}Author — Speech Lab of DAMO Academy, Alibaba Groupclass CTTransformerStreaming(CTTransformer):
"""CT-Transformer Streaming: Online punctuation restoration.
Processes text incrementally with a sliding window, maintaining cache
of previous context for consistent punctuation decisions across chunks.
Used as punc_model in streaming ASR pipelines.
Supports VAD-aware punctuation: uses VAD boundaries to improve sentence segmentation.
Reference: https://arxiv.org/pdf/2003.01309.pdf
Output: {"key": str, "text": str, "punc_array": Tensor}
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize CTTransformerStreaming.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
def punc_forward(
self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs
View full source on GitHub →.punc_forward(text, text_lengths, vad_indexes, **kwargs) L61Compute loss value from buffer sequences.
input (torch.Tensor) — Input ids. (batch, len)hidden (torch.Tensor) — Target ids. (batch, len) def punc_forward(
self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs
):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(text)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths, vad_indexes=vad_indexes)
y = self.decoder(h)
return y, None
.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L81Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
assert len(data_in) == 1
if len(cache) == 0:
cache["pre_text"] = []
text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
text = "".join(cache["pre_text"]) + " " + text
View full source →.export(**kwargs) L221Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
CTTransformerStreaming.punc_forward(text, text_lengths, vad_indexes, **kwargs)Compute loss value from buffer sequences.
input (torch.Tensor) — Input ids. (batch, len)hidden (torch.Tensor) — Target ids. (batch, len) def punc_forward(
self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs
):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(text)
# mask = self._target_mask(input)
h, _, _ = self.encoder(x, text_lengths, vad_indexes=vad_indexes)
y = self.decoder(h)
return y, None
CTTransformerStreaming.with_vad()With vad.
def with_vad(self):
"""With vad."""
return True
CTTransformerStreaming.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
assert len(data_in) == 1
if len(cache) == 0:
cache["pre_text"] = []
text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
text = "".join(cache["pre_text"]) + " " + text
View full source on GitHub →CTTransformerStreaming.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
CTC — attention hybrid Encoder-Decoder modelclass Transformer(nn.Module):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
ctc_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
length_normalized_loss: bool = False,
**kwargs,
):
"""Initialize Transformer.
Args:
specaug: TODO.
specaug_conf: Configuration dict for specaug.
normalize: TODO.
normalize_conf: Configuration dict for normalize.
encoder: TODO.
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L89Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
stats = dict()
View full source →.encode(speech, speech_lengths, **kwargs) L134Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L189Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
View full source →Transformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
stats = dict()
View full source on GitHub →Transformer.encode(speech, speech_lengths, **kwargs)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
Transformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
View full source on GitHub →E — Branchformer: Enhanced Branchformer with improved merging.Uses element-wise merging instead of concatenation for parallel branches,
resulting in better parameter efficiency.
Inherits Transformer pipeline.
class EBranchformer(Transformer):
"""E-Branchformer: Enhanced Branchformer with improved merging.
Uses element-wise merging instead of concatenation for parallel branches,
resulting in better parameter efficiency.
Inherits Transformer pipeline.
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize EBranchformer.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
E — Paraformer: Enhanced Paraformer with streaming support.Extended Paraformer supporting both offline and streaming modes
through dynamic masking in the encoder. Used for 2-pass decoding
where first pass provides streaming results and second pass refines.
Output — {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}Author — Speech Lab of DAMO Academy, Alibaba GroupParaformer — Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognitionhttps://arxiv.org/abs/2206.08317
Author — Kun Zou, chinazoukun@gmail.comE — Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognitionhttps://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
class EParaformer(torch.nn.Module):
"""E-Paraformer: Enhanced Paraformer with streaming support.
Extended Paraformer supporting both offline and streaming modes
through dynamic masking in the encoder. Used for 2-pass decoding
where first pass provides streaming results and second pass refines.
Output: {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
Author: Kun Zou, chinazoukun@gmail.com
E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
"""
def __init__(
self,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
decoder: str = None,
decoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
predictor: str = None,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L220Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
View full source →.encode(speech, speech_lengths, **kwargs) L292Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
.calc_predictor(encoder_out, encoder_out_lens) L321Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) L337Cal decoder with predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad. def cal_decoder_with_predictor(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
):
"""Cal decoder with predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds) L425Sampler.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
"""Sampler.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
with torch.no_grad():
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
View full source →.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds) L474Sampler with grad.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO. def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
"""Sampler with grad.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
for li in range(bsz):
View full source →.init_beam_search(**kwargs) L548Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.paraformer.search import BeamSearchPara
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
weights = dict(
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L600Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source →.export(**kwargs) L765Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
EParaformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
View full source on GitHub →EParaformer.encode(speech, speech_lengths, **kwargs)Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
EParaformer.calc_predictor(encoder_out, encoder_out_lens)Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
EParaformer.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)Cal decoder with predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad. def cal_decoder_with_predictor(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
):
"""Cal decoder with predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
EParaformer.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds)Sampler.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
"""Sampler.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
with torch.no_grad():
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
View full source on GitHub →EParaformer.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds)Sampler with grad.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO. def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
"""Sampler with grad.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
for li in range(bsz):
View full source on GitHub →EParaformer.init_beam_search(**kwargs)Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.paraformer.search import BeamSearchPara
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
weights = dict(
View full source on GitHub →EParaformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source on GitHub →EParaformer.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
Author — Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chenemotion2vec — Self-Supervised Pre-Training for Speech Emotion Representationhttps://arxiv.org/abs/2312.15185
class Emotion2vec(torch.nn.Module):
"""
Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
https://arxiv.org/abs/2312.15185
"""
def __init__(self, **kwargs):
"""Initialize Emotion2vec.
Args:
**kwargs: Additional keyword arguments.
"""
super().__init__()
# import pdb; pdb.set_trace()
cfg = OmegaConf.create(kwargs["model_conf"])
self.cfg = cfg
make_layer_norm = partial(
torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
)
def make_block(drop_path, dim=None, heads=None):
"""Make block.
Args:
drop_path: TODO.
dim: TODO.
heads: TODO.
"""
View full source on GitHub →.forward(source, target, id, mode, padding_mask, mask, features_only, force_remove_masked, remove_extra_tokens, precomputed_mask, **kwargs) L121Forward pass for training.
source — TODO.target — TODO.id — TODO.mode — TODO.padding_mask — TODO.mask — TODO.features_only — TODO.force_remove_masked — TODO.remove_extra_tokens — TODO.precomputed_mask — TODO.**kwargs — Additional keyword arguments. def forward(
self,
source,
target=None,
id=None,
mode=None,
padding_mask=None,
mask=True,
features_only=False,
force_remove_masked=False,
remove_extra_tokens=True,
precomputed_mask=None,
**kwargs,
):
"""Forward pass for training.
Args:
source: TODO.
target: TODO.
id: TODO.
mode: TODO.
padding_mask: TODO.
mask: TODO.
features_only: TODO.
force_remove_masked: TODO.
remove_extra_tokens: TODO.
precomputed_mask: TODO.
**kwargs: Additional keyword arguments.
"""
View full source →.extract_features(source, mode, padding_mask, mask, remove_extra_tokens) L212Extract features.
source — TODO.mode — TODO.padding_mask — TODO.mask — TODO.remove_extra_tokens — TODO. def extract_features(
self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
):
"""Extract features.
Args:
source: TODO.
mode: TODO.
padding_mask: TODO.
mask: TODO.
remove_extra_tokens: TODO.
"""
res = self.forward(
source,
mode=mode,
padding_mask=padding_mask,
mask=mask,
features_only=True,
remove_extra_tokens=remove_extra_tokens,
)
return res
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L234Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# if source_file.endswith('.wav'):
# wav, sr = sf.read(source_file)
# channel = sf.info(source_file).channels
# assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
# assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
granularity = kwargs.get("granularity", "utterance")
extract_embedding = kwargs.get("extract_embedding", True)
if self.proj is None:
extract_embedding = True
meta_data = {}
View full source →.export(**kwargs) L320Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
Emotion2vec.forward(source, target, id, mode, padding_mask, mask, features_only, force_remove_masked, remove_extra_tokens, precomputed_mask, **kwargs)Forward pass for training.
source — TODO.target — TODO.id — TODO.mode — TODO.padding_mask — TODO.mask — TODO.features_only — TODO.force_remove_masked — TODO.remove_extra_tokens — TODO.precomputed_mask — TODO.**kwargs — Additional keyword arguments. def forward(
self,
source,
target=None,
id=None,
mode=None,
padding_mask=None,
mask=True,
features_only=False,
force_remove_masked=False,
remove_extra_tokens=True,
precomputed_mask=None,
**kwargs,
):
"""Forward pass for training.
Args:
source: TODO.
target: TODO.
id: TODO.
mode: TODO.
padding_mask: TODO.
mask: TODO.
features_only: TODO.
force_remove_masked: TODO.
remove_extra_tokens: TODO.
precomputed_mask: TODO.
**kwargs: Additional keyword arguments.
"""
View full source on GitHub →Emotion2vec.extract_features(source, mode, padding_mask, mask, remove_extra_tokens)Extract features.
source — TODO.mode — TODO.padding_mask — TODO.mask — TODO.remove_extra_tokens — TODO. def extract_features(
self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
):
"""Extract features.
Args:
source: TODO.
mode: TODO.
padding_mask: TODO.
mask: TODO.
remove_extra_tokens: TODO.
"""
res = self.forward(
source,
mode=mode,
padding_mask=padding_mask,
mask=mask,
features_only=True,
remove_extra_tokens=remove_extra_tokens,
)
return res
Emotion2vec.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# if source_file.endswith('.wav'):
# wav, sr = sf.read(source_file)
# channel = sf.info(source_file).channels
# assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
# assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
granularity = kwargs.get("granularity", "utterance")
extract_embedding = kwargs.get("extract_embedding", True)
if self.proj is None:
extract_embedding = True
meta_data = {}
View full source on GitHub →Emotion2vec.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
ERes2NetV2 — Enhanced Res2Net v2 for Speaker Verification.Improved speaker embedding model based on Res2Net architecture with
multi — scale feature aggregation. Provides 192-dim speaker embeddingsfor speaker verification and diarization.
Better than CAM++ for short-duration audio (< 3s) speaker feature extraction.
Output — {"spk_embedding": Tensor of shape (1, 192)}class ERes2NetV2SV(torch.nn.Module):
"""ERes2NetV2: Enhanced Res2Net v2 for Speaker Verification.
Improved speaker embedding model based on Res2Net architecture with
multi-scale feature aggregation. Provides 192-dim speaker embeddings
for speaker verification and diarization.
Better than CAM++ for short-duration audio (< 3s) speaker feature extraction.
Output: {"spk_embedding": Tensor of shape (1, 192)}
"""
def __init__(
self,
feat_dim=80,
embedding_size=192,
m_channels=64,
baseWidth=26,
scale=2,
expansion=2,
num_blocks=[3, 4, 6, 3],
pooling_func="TSTP",
two_emb_layer=False,
**kwargs,
):
"""Initialize ERes2NetV2SV.
Args:
feat_dim: Size/dimension parameter.
embedding_size: Size/dimension parameter.
View full source on GitHub →.forward(x) L97Forward pass for training.
x — TODO. def forward(self, x):
"""Forward pass for training.
Args:
x: TODO.
"""
return self.model(x)
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L105Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
speech = speech.to(device=kwargs["device"])
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
View full source →ERes2NetV2SV.forward(x)Forward pass for training.
x — TODO. def forward(self, x):
"""Forward pass for training.
Args:
x: TODO.
"""
return self.model(x)
ERes2NetV2SV.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(
data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
speech = speech.to(device=kwargs["device"])
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
View full source on GitHub →FSMN — KWS: Keyword Spotting model using FSMN architecture.Detects predefined keywords/wake words in audio streams.
Supports both offline and streaming operation.
Output — {"key": str, "value": detected_keyword_info}class FsmnKWS(torch.nn.Module):
"""FSMN-KWS: Keyword Spotting model using FSMN architecture.
Detects predefined keywords/wake words in audio streams.
Supports both offline and streaming operation.
Output: {"key": str, "value": detected_keyword_info}
"""
def __init__(
self,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
ctc_weight: float = 1.0,
input_size: int = 360,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
**kwargs,
):
"""Initialize FsmnKWS.
Args:
specaug: TODO.
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L104Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats = dict()
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
View full source →.encode(speech, speech_lengths, **kwargs) L146Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out = self.encoder(speech)
encoder_out_lens = speech_lengths
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L199Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list=None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer.token_list,
seg_dict=tokenizer.seg_dict,
)
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
View full source →FsmnKWS.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats = dict()
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
View full source on GitHub →FsmnKWS.encode(speech, speech_lengths, **kwargs)Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out = self.encoder(speech)
encoder_out_lens = speech_lengths
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
FsmnKWS.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list=None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer.token_list,
seg_dict=tokenizer.seg_dict,
)
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
View full source on GitHub →Author — Speech Lab of DAMO Academy, Alibaba GroupDeep — FSMN for Large Vocabulary Continuous Speech Recognitionhttps://arxiv.org/abs/1803.05030
class FsmnKWSConvert(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
ctc_weight: float = 1.0,
input_size: int = 360,
vocab_size: int = -1,
blank_id: int = 0,
**kwargs,
):
"""Initialize FsmnKWSConvert.
Args:
encoder: TODO.
encoder_conf: Configuration dict for encoder.
ctc: TODO.
ctc_conf: Configuration dict for ctc.
ctc_weight: TODO.
input_size: Size/dimension parameter.
vocab_size: Size/dimension parameter.
blank_id: TODO.
View full source on GitHub →.to_kaldi_net() L331To kaldi net.
def to_kaldi_net(self):
"""To kaldi net."""
return self.encoder.to_kaldi_net()
.to_pytorch_net(kaldi_file) L336To pytorch net.
kaldi_file — TODO. def to_pytorch_net(self, kaldi_file):
"""To pytorch net.
Args:
kaldi_file: TODO.
"""
return self.encoder.to_pytorch_net(kaldi_file)
FsmnKWSConvert.to_kaldi_net()To kaldi net.
def to_kaldi_net(self):
"""To kaldi net."""
return self.encoder.to_kaldi_net()
FsmnKWSConvert.to_pytorch_net(kaldi_file)To pytorch net.
kaldi_file — TODO. def to_pytorch_net(self, kaldi_file):
"""To pytorch net.
Args:
kaldi_file: TODO.
"""
return self.encoder.to_pytorch_net(kaldi_file)
FSMN — KWS-MT: Multi-Task FSMN Keyword Spotting.Keyword spotting with multi-task learning: simultaneously
detects keywords and performs filler token classification.
Improves keyword detection robustness through auxiliary tasks.
Output — {"key": str, "value": keyword_detection_result}Author — Speech Lab of DAMO Academy, Alibaba GroupDeep — FSMN for Large Vocabulary Continuous Speech Recognitionhttps://arxiv.org/abs/1803.05030
class FsmnKWSMT(torch.nn.Module):
"""FSMN-KWS-MT: Multi-Task FSMN Keyword Spotting.
Keyword spotting with multi-task learning: simultaneously
detects keywords and performs filler token classification.
Improves keyword detection robustness through auxiliary tasks.
Output: {"key": str, "value": keyword_detection_result}
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
ctc_conf: Optional[Dict] = None,
input_size: int = 360,
vocab_size: list = [],
ignore_id: int = -1,
blank_id: int = 0,
**kwargs,
):
"""Initialize FsmnKWSMT.
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, text2, text2_lengths, **kwargs) L106Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,)text2 — (Batch, Length)text2_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
text2: torch.Tensor,
text2_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
text2: (Batch, Length)
text2_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out2, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
View full source →.encode(speech, speech_lengths, **kwargs) L158Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out2 = self.encoder(speech)
encoder_out_lens = speech_lengths
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
if isinstance(encoder_out2, tuple):
encoder_out2 = encoder_out2[0]
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L242Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list=None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer[0].token_list,
seg_dict=tokenizer[0].seg_dict,
)
self.kws_decoder2 = KwsCtcPrefixDecoder(
ctc=self.ctc2,
keywords=keywords,
View full source →FsmnKWSMT.forward(speech, speech_lengths, text, text_lengths, text2, text2_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,)text2 — (Batch, Length)text2_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
text2: torch.Tensor,
text2_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
text2: (Batch, Length)
text2_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out2, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
View full source on GitHub →FsmnKWSMT.encode(speech, speech_lengths, **kwargs)Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out2 = self.encoder(speech)
encoder_out_lens = speech_lengths
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
if isinstance(encoder_out2, tuple):
encoder_out2 = encoder_out2[0]
View full source on GitHub →FsmnKWSMT.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list=None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer[0].token_list,
seg_dict=tokenizer[0].seg_dict,
)
self.kws_decoder2 = KwsCtcPrefixDecoder(
ctc=self.ctc2,
keywords=keywords,
View full source on GitHub →Author — Speech Lab of DAMO Academy, Alibaba GroupDeep — FSMN for Large Vocabulary Continuous Speech Recognitionhttps://arxiv.org/abs/1803.05030
class FsmnKWSMTConvert(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
ctc_conf: Optional[Dict] = None,
ctc_weight: float = 1.0,
input_size: int = 360,
blank_id: int = 0,
**kwargs,
):
"""Initialize FsmnKWSMTConvert.
Args:
encoder: TODO.
encoder_conf: Configuration dict for encoder.
ctc_conf: Configuration dict for ctc.
ctc_weight: TODO.
input_size: Size/dimension parameter.
blank_id: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__()
View full source on GitHub →.to_kaldi_net() L390To kaldi net.
def to_kaldi_net(self):
"""To kaldi net."""
return self.encoder.to_kaldi_net()
.to_kaldi_net2() L394To kaldi net2.
def to_kaldi_net2(self):
"""To kaldi net2."""
return self.encoder.to_kaldi_net2()
.to_pytorch_net(kaldi_file) L398To pytorch net.
kaldi_file — TODO. def to_pytorch_net(self, kaldi_file):
"""To pytorch net.
Args:
kaldi_file: TODO.
"""
return self.encoder.to_pytorch_net(kaldi_file)
FsmnKWSMTConvert.to_kaldi_net()To kaldi net.
def to_kaldi_net(self):
"""To kaldi net."""
return self.encoder.to_kaldi_net()
FsmnKWSMTConvert.to_kaldi_net2()To kaldi net2.
def to_kaldi_net2(self):
"""To kaldi net2."""
return self.encoder.to_kaldi_net2()
FsmnKWSMTConvert.to_pytorch_net(kaldi_file)To pytorch net.
kaldi_file — TODO. def to_pytorch_net(self, kaldi_file):
"""To pytorch net.
Args:
kaldi_file: TODO.
"""
return self.encoder.to_pytorch_net(kaldi_file)
FSMN — based Voice Activity Detection (streaming/offline).Detects speech segments in audio, returning start/end timestamps (milliseconds).
Supports both offline (full audio) and streaming (chunk-by-chunk) modes.
Offline output: [{"key": "...", "value": [[start_ms, end_ms], ...]}]
Streaming output: [[beg, -1]] (start), [[-1, end]] (end), [[beg, end]] (complete), [] (no event)
Author — Speech Lab of DAMO Academy, Alibaba GroupDeep — FSMN for Large Vocabulary Continuous Speech Recognitionhttps://arxiv.org/abs/1803.05030
class FsmnVADStreaming(nn.Module):
"""FSMN-based Voice Activity Detection (streaming/offline).
Detects speech segments in audio, returning start/end timestamps (milliseconds).
Supports both offline (full audio) and streaming (chunk-by-chunk) modes.
Offline output: [{"key": "...", "value": [[start_ms, end_ms], ...]}]
Streaming output: [[beg, -1]] (start), [[-1, end]] (end), [[beg, end]] (complete), [] (no event)
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
vad_post_args: Dict[str, Any] = None,
**kwargs,
):
"""Initialize FsmnVADStreaming.
Args:
encoder: TODO.
encoder_conf: Configuration dict for encoder.
vad_post_args: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__()
View full source on GitHub →.ResetDetection(cache) L382Resetdetection.
cache — State cache dict for streaming inference. def ResetDetection(self, cache: dict = None):
"""Resetdetection.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
cache["stats"].continous_silence_frame_count = 0
cache["stats"].latest_confirmed_speech_frame = 0
cache["stats"].lastest_confirmed_silence_frame = -1
cache["stats"].confirmed_start_frame = -1
cache["stats"].confirmed_end_frame = -1
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
cache["windows_detector"].Reset()
cache["stats"].sil_frame = 0
cache["stats"].frame_probs = []
if cache["stats"].output_data_buf:
assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
real_drop_frames = drop_frames - cache["stats"].last_drop_frames
cache["stats"].last_drop_frames = drop_frames
cache["stats"].data_buf_all = cache["stats"].data_buf_all[
real_drop_frames
* int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
]
cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
.ComputeDecibel(cache) L412Computedecibel.
cache — State cache dict for streaming inference. def ComputeDecibel(self, cache: dict = None) -> None:
"""Computedecibel.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if cache["stats"].data_buf_all is None:
cache["stats"].data_buf_all = cache["stats"].waveform[
0
] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
cache["stats"].data_buf = cache["stats"].data_buf_all
else:
cache["stats"].data_buf_all = torch.cat(
(cache["stats"].data_buf_all, cache["stats"].waveform[0])
)
waveform_numpy = cache["stats"].waveform.numpy()
offsets = np.arange(0, waveform_numpy.shape[1] - frame_sample_length + 1, frame_shift_length)
frames = waveform_numpy[0, offsets[:, np.newaxis] + np.arange(frame_sample_length)]
decibel_numpy = 10 * np.log10(np.sum(np.square(frames), axis=1) + 0.000001)
decibel_numpy = decibel_numpy.tolist()
cache["stats"].decibel.extend(decibel_numpy)
.ComputeScores(feats, cache) L443Computescores.
feats — Feature tensor (e.g., fbank), shape (batch, frames, dim).cache — State cache dict for streaming inference. def ComputeScores(self, feats: torch.Tensor, cache: dict = None) -> None:
"""Computescores.
Args:
feats: Feature tensor (e.g., fbank), shape (batch, frames, dim).
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
scores = self.encoder(feats, cache=cache["encoder"]) # return B * T * D
assert (
scores.shape[1] == feats.shape[1]
), "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
cache["stats"].frm_cnt += scores.shape[1] # count total frames
if cache["stats"].scores is None:
cache["stats"].scores = scores # the first calculation
else:
cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
.PopDataBufTillFrame(frame_idx, cache) L463Popdatabuftillframe.
frame_idx — TODO.cache — State cache dict for streaming inference. def PopDataBufTillFrame(self, frame_idx: int, cache: dict = None) -> None: # need check again
"""Popdatabuftillframe.
Args:
frame_idx: TODO.
cache: State cache dict for streaming inference.
"""
while cache["stats"].data_buf_start_frame < frame_idx:
if len(cache["stats"].data_buf) >= int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
):
if cache is None:
cache = {}
cache["stats"].data_buf_start_frame += 1
cache["stats"].data_buf = cache["stats"].data_buf_all[
(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames)
* int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
]
.PopDataToOutputBuf(start_frm, frm_cnt, first_frm_is_start_point, last_frm_is_end_point, end_point_is_sent_end, cache) L482Popdatatooutputbuf.
start_frm — TODO.frm_cnt — TODO.first_frm_is_start_point — TODO.last_frm_is_end_point — TODO.end_point_is_sent_end — TODO.cache — State cache dict for streaming inference. def PopDataToOutputBuf(
self,
start_frm: int,
frm_cnt: int,
first_frm_is_start_point: bool,
last_frm_is_end_point: bool,
end_point_is_sent_end: bool,
cache: dict = None,
) -> None:
"""Popdatatooutputbuf.
Args:
start_frm: TODO.
frm_cnt: TODO.
first_frm_is_start_point: TODO.
last_frm_is_end_point: TODO.
end_point_is_sent_end: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
self.PopDataBufTillFrame(start_frm, cache=cache)
expected_sample_number = int(
frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
)
if last_frm_is_end_point:
extra_sample = max(
0,
int(
self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
View full source →.OnSilenceDetected(valid_frame, cache) L552Onsilencedetected.
valid_frame — TODO.cache — State cache dict for streaming inference. def OnSilenceDetected(self, valid_frame: int, cache: dict = None):
"""Onsilencedetected.
Args:
valid_frame: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
cache["stats"].lastest_confirmed_silence_frame = valid_frame
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame, cache=cache)
.OnVoiceDetected(valid_frame, cache) L568Onvoicedetected.
valid_frame — TODO.cache — State cache dict for streaming inference. def OnVoiceDetected(self, valid_frame: int, cache: dict = None) -> None:
"""Onvoicedetected.
Args:
valid_frame: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
cache["stats"].latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
.OnVoiceStart(start_frame, fake_result, cache) L580Onvoicestart.
start_frame — TODO.fake_result — TODO.cache — State cache dict for streaming inference. def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = None) -> None:
"""Onvoicestart.
Args:
start_frame: TODO.
fake_result: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
if self.vad_opts.do_start_point_detection:
pass
if cache["stats"].confirmed_start_frame != -1:
print("not reset vad properly\n")
else:
cache["stats"].confirmed_start_frame = start_frame
if (
not fake_result
and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
):
self.PopDataToOutputBuf(
cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache
)
.OnVoiceEnd(end_frame, fake_result, is_last_frame, cache) L605Onvoiceend.
end_frame — TODO.fake_result — TODO.is_last_frame — Boolean flag for last frame.cache — State cache dict for streaming inference. def OnVoiceEnd(
self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = None
) -> None:
"""Onvoiceend.
Args:
end_frame: TODO.
fake_result: TODO.
is_last_frame: Boolean flag for last frame.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t, cache=cache)
if self.vad_opts.do_end_point_detection:
pass
if cache["stats"].confirmed_end_frame != -1:
print("not reset vad properly\n")
else:
cache["stats"].confirmed_end_frame = end_frame
if not fake_result:
cache["stats"].sil_frame = 0
self.PopDataToOutputBuf(
cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache
)
cache["stats"].number_end_time_detected += 1
.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache) L633Maybeonvoiceendiflastframe.
is_final_frame — Boolean flag for final frame.cur_frm_idx — TODO.cache — State cache dict for streaming inference. def MaybeOnVoiceEndIfLastFrame(
self, is_final_frame: bool, cur_frm_idx: int, cache: dict = None
) -> None:
"""Maybeonvoiceendiflastframe.
Args:
is_final_frame: Boolean flag for final frame.
cur_frm_idx: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
if is_final_frame:
self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
.GetLatency(cache) L649Getlatency.
cache — State cache dict for streaming inference. def GetLatency(self, cache: dict = None) -> int:
"""Getlatency.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
.LatencyFrmNumAtStartPoint(cache) L659Latencyfrmnumatstartpoint.
cache — State cache dict for streaming inference. def LatencyFrmNumAtStartPoint(self, cache: dict = None) -> int:
"""Latencyfrmnumatstartpoint.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
vad_latency = cache["windows_detector"].GetWinSize()
if self.vad_opts.do_extend:
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
.GetFrameState(t, cache) L672Getframestate.
t — TODO.cache — State cache dict for streaming inference. def GetFrameState(self, t: int, cache: dict = None):
"""Getframestate.
Args:
t: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
frame_state = FrameState.kFrameStateInvalid
if t >= len(cache["stats"].decibel):
return FrameState.kFrameStateSil
cur_decibel = cache["stats"].decibel[t]
cur_snr = cur_decibel - cache["stats"].noise_average_decibel
# for each frame, calc log posterior probability of each state
if cur_decibel < self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSil
self.DetectOneFrame(frame_state, t, False, cache=cache)
return frame_state
sum_score = 0.0
noise_prob = 0.0
assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(cache["stats"].sil_pdf_ids) > 0:
assert len(cache["stats"].scores) == 1 # 只支持batch_size = 1的测试
"""
- Change type of `sum_score` to float. The reason is that `sum_score` is a tensor with single element.
and `torch.Tensor` is slower `float` when tensor has only one element.
- Put the iteration of `sil_pdf_ids` inside `sum()` to reduce the overhead of creating a new list.
- The default `sil_pdf_ids` is [0], the `if` statement is used to reduce the overhead of expression
View full source →.forward(feats, waveform, cache, is_final, **kwargs) L737Forward pass for training.
feats — Feature tensor (e.g., fbank), shape (batch, frames, dim).waveform — TODO.cache — State cache dict for streaming inference.is_final — Whether this is the final chunk in streaming.**kwargs — Additional keyword arguments. def forward(
self,
feats: torch.Tensor,
waveform: torch.tensor,
cache: dict = None,
is_final: bool = False,
**kwargs,
):
"""Forward pass for training.
Args:
feats: Feature tensor (e.g., fbank), shape (batch, frames, dim).
waveform: TODO.
cache: State cache dict for streaming inference.
is_final: Whether this is the final chunk in streaming.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
# if len(cache) == 0:
# self.AllResetDetection()
# self.waveform = waveform # compute decibel for each frame
cache["stats"].waveform = waveform
is_streaming_input = kwargs.get("is_streaming_input", True)
self.ComputeDecibel(cache=cache)
self.ComputeScores(feats, cache=cache)
if not is_final:
self.DetectCommonFrames(cache=cache)
else:
self.DetectLastFrames(cache=cache)
View full source →.init_cache(cache, **kwargs) L820Init cache.
cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def init_cache(self, cache: dict = None, **kwargs):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
if kwargs.get("max_end_silence_time") is not None:
# update the max_end_silence_time
self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")
windows_detector = WindowDetector(
self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms,
)
windows_detector.Reset()
stats = Stats(
sil_pdf_ids=self.vad_opts.sil_pdf_ids,
max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time
- self.vad_opts.speech_to_sil_time_thres,
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L856Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
if len(cache) == 0:
self.init_cache(cache, **kwargs)
meta_data = {}
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
View full source →.export(**kwargs) L970Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
.DetectCommonFrames(cache) L982Detectcommonframes.
cache — State cache dict for streaming inference. def DetectCommonFrames(self, cache: dict = None) -> int:
"""Detectcommonframes.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(
cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
)
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
return 0
.DetectLastFrames(cache) L1001Detectlastframes.
cache — State cache dict for streaming inference. def DetectLastFrames(self, cache: dict = None) -> int:
"""Detectlastframes.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(
cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
)
if i != 0:
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
else:
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
return 0
.DetectOneFrame(cur_frm_state, cur_frm_idx, is_final_frame, cache) L1023Detectoneframe.
cur_frm_state — TODO.cur_frm_idx — TODO.is_final_frame — Boolean flag for final frame.cache — State cache dict for streaming inference. def DetectOneFrame(
self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = None
) -> None:
"""Detectoneframe.
Args:
cur_frm_state: TODO.
cur_frm_idx: TODO.
is_final_frame: Boolean flag for final frame.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
tmp_cur_frm_state = FrameState.kFrameStateSpeech
else:
tmp_cur_frm_state = FrameState.kFrameStateSil
elif cur_frm_state == FrameState.kFrameStateSil:
tmp_cur_frm_state = FrameState.kFrameStateSil
state_change = cache["windows_detector"].DetectOneFrame(
tmp_cur_frm_state, cur_frm_idx, cache=cache
)
frm_shift_in_ms = self.vad_opts.frame_in_ms
if AudioChangeState.kChangeStateSil2Speech == state_change:
silence_frame_count = cache["stats"].continous_silence_frame_count
cache["stats"].continous_silence_frame_count = 0
cache["stats"].pre_end_silence_detected = False
start_frame = 0
View full source →FsmnVADStreaming.ResetDetection(cache)Resetdetection.
cache — State cache dict for streaming inference. def ResetDetection(self, cache: dict = None):
"""Resetdetection.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
cache["stats"].continous_silence_frame_count = 0
cache["stats"].latest_confirmed_speech_frame = 0
cache["stats"].lastest_confirmed_silence_frame = -1
cache["stats"].confirmed_start_frame = -1
cache["stats"].confirmed_end_frame = -1
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
cache["windows_detector"].Reset()
cache["stats"].sil_frame = 0
cache["stats"].frame_probs = []
if cache["stats"].output_data_buf:
assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
real_drop_frames = drop_frames - cache["stats"].last_drop_frames
cache["stats"].last_drop_frames = drop_frames
cache["stats"].data_buf_all = cache["stats"].data_buf_all[
real_drop_frames
* int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
]
cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
FsmnVADStreaming.ComputeDecibel(cache)Computedecibel.
cache — State cache dict for streaming inference. def ComputeDecibel(self, cache: dict = None) -> None:
"""Computedecibel.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if cache["stats"].data_buf_all is None:
cache["stats"].data_buf_all = cache["stats"].waveform[
0
] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
cache["stats"].data_buf = cache["stats"].data_buf_all
else:
cache["stats"].data_buf_all = torch.cat(
(cache["stats"].data_buf_all, cache["stats"].waveform[0])
)
waveform_numpy = cache["stats"].waveform.numpy()
offsets = np.arange(0, waveform_numpy.shape[1] - frame_sample_length + 1, frame_shift_length)
frames = waveform_numpy[0, offsets[:, np.newaxis] + np.arange(frame_sample_length)]
decibel_numpy = 10 * np.log10(np.sum(np.square(frames), axis=1) + 0.000001)
decibel_numpy = decibel_numpy.tolist()
cache["stats"].decibel.extend(decibel_numpy)
FsmnVADStreaming.ComputeScores(feats, cache)Computescores.
feats — Feature tensor (e.g., fbank), shape (batch, frames, dim).cache — State cache dict for streaming inference. def ComputeScores(self, feats: torch.Tensor, cache: dict = None) -> None:
"""Computescores.
Args:
feats: Feature tensor (e.g., fbank), shape (batch, frames, dim).
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
scores = self.encoder(feats, cache=cache["encoder"]) # return B * T * D
assert (
scores.shape[1] == feats.shape[1]
), "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
cache["stats"].frm_cnt += scores.shape[1] # count total frames
if cache["stats"].scores is None:
cache["stats"].scores = scores # the first calculation
else:
cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
FsmnVADStreaming.PopDataBufTillFrame(frame_idx, cache)Popdatabuftillframe.
frame_idx — TODO.cache — State cache dict for streaming inference. def PopDataBufTillFrame(self, frame_idx: int, cache: dict = None) -> None: # need check again
"""Popdatabuftillframe.
Args:
frame_idx: TODO.
cache: State cache dict for streaming inference.
"""
while cache["stats"].data_buf_start_frame < frame_idx:
if len(cache["stats"].data_buf) >= int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
):
if cache is None:
cache = {}
cache["stats"].data_buf_start_frame += 1
cache["stats"].data_buf = cache["stats"].data_buf_all[
(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames)
* int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
]
FsmnVADStreaming.PopDataToOutputBuf(start_frm, frm_cnt, first_frm_is_start_point, last_frm_is_end_point, end_point_is_sent_end, cache)Popdatatooutputbuf.
start_frm — TODO.frm_cnt — TODO.first_frm_is_start_point — TODO.last_frm_is_end_point — TODO.end_point_is_sent_end — TODO.cache — State cache dict for streaming inference. def PopDataToOutputBuf(
self,
start_frm: int,
frm_cnt: int,
first_frm_is_start_point: bool,
last_frm_is_end_point: bool,
end_point_is_sent_end: bool,
cache: dict = None,
) -> None:
"""Popdatatooutputbuf.
Args:
start_frm: TODO.
frm_cnt: TODO.
first_frm_is_start_point: TODO.
last_frm_is_end_point: TODO.
end_point_is_sent_end: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
self.PopDataBufTillFrame(start_frm, cache=cache)
expected_sample_number = int(
frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
)
if last_frm_is_end_point:
extra_sample = max(
0,
int(
self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
View full source on GitHub →FsmnVADStreaming.OnSilenceDetected(valid_frame, cache)Onsilencedetected.
valid_frame — TODO.cache — State cache dict for streaming inference. def OnSilenceDetected(self, valid_frame: int, cache: dict = None):
"""Onsilencedetected.
Args:
valid_frame: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
cache["stats"].lastest_confirmed_silence_frame = valid_frame
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame, cache=cache)
FsmnVADStreaming.OnVoiceDetected(valid_frame, cache)Onvoicedetected.
valid_frame — TODO.cache — State cache dict for streaming inference. def OnVoiceDetected(self, valid_frame: int, cache: dict = None) -> None:
"""Onvoicedetected.
Args:
valid_frame: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
cache["stats"].latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
FsmnVADStreaming.OnVoiceStart(start_frame, fake_result, cache)Onvoicestart.
start_frame — TODO.fake_result — TODO.cache — State cache dict for streaming inference. def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = None) -> None:
"""Onvoicestart.
Args:
start_frame: TODO.
fake_result: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
if self.vad_opts.do_start_point_detection:
pass
if cache["stats"].confirmed_start_frame != -1:
print("not reset vad properly\n")
else:
cache["stats"].confirmed_start_frame = start_frame
if (
not fake_result
and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
):
self.PopDataToOutputBuf(
cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache
)
FsmnVADStreaming.OnVoiceEnd(end_frame, fake_result, is_last_frame, cache)Onvoiceend.
end_frame — TODO.fake_result — TODO.is_last_frame — Boolean flag for last frame.cache — State cache dict for streaming inference. def OnVoiceEnd(
self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = None
) -> None:
"""Onvoiceend.
Args:
end_frame: TODO.
fake_result: TODO.
is_last_frame: Boolean flag for last frame.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t, cache=cache)
if self.vad_opts.do_end_point_detection:
pass
if cache["stats"].confirmed_end_frame != -1:
print("not reset vad properly\n")
else:
cache["stats"].confirmed_end_frame = end_frame
if not fake_result:
cache["stats"].sil_frame = 0
self.PopDataToOutputBuf(
cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache
)
cache["stats"].number_end_time_detected += 1
FsmnVADStreaming.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache)Maybeonvoiceendiflastframe.
is_final_frame — Boolean flag for final frame.cur_frm_idx — TODO.cache — State cache dict for streaming inference. def MaybeOnVoiceEndIfLastFrame(
self, is_final_frame: bool, cur_frm_idx: int, cache: dict = None
) -> None:
"""Maybeonvoiceendiflastframe.
Args:
is_final_frame: Boolean flag for final frame.
cur_frm_idx: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
if is_final_frame:
self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
FsmnVADStreaming.GetLatency(cache)Getlatency.
cache — State cache dict for streaming inference. def GetLatency(self, cache: dict = None) -> int:
"""Getlatency.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
FsmnVADStreaming.LatencyFrmNumAtStartPoint(cache)Latencyfrmnumatstartpoint.
cache — State cache dict for streaming inference. def LatencyFrmNumAtStartPoint(self, cache: dict = None) -> int:
"""Latencyfrmnumatstartpoint.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
vad_latency = cache["windows_detector"].GetWinSize()
if self.vad_opts.do_extend:
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
FsmnVADStreaming.GetFrameState(t, cache)Getframestate.
t — TODO.cache — State cache dict for streaming inference. def GetFrameState(self, t: int, cache: dict = None):
"""Getframestate.
Args:
t: TODO.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
frame_state = FrameState.kFrameStateInvalid
if t >= len(cache["stats"].decibel):
return FrameState.kFrameStateSil
cur_decibel = cache["stats"].decibel[t]
cur_snr = cur_decibel - cache["stats"].noise_average_decibel
# for each frame, calc log posterior probability of each state
if cur_decibel < self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSil
self.DetectOneFrame(frame_state, t, False, cache=cache)
return frame_state
sum_score = 0.0
noise_prob = 0.0
assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(cache["stats"].sil_pdf_ids) > 0:
assert len(cache["stats"].scores) == 1 # 只支持batch_size = 1的测试
"""
- Change type of `sum_score` to float. The reason is that `sum_score` is a tensor with single element.
and `torch.Tensor` is slower `float` when tensor has only one element.
- Put the iteration of `sil_pdf_ids` inside `sum()` to reduce the overhead of creating a new list.
- The default `sil_pdf_ids` is [0], the `if` statement is used to reduce the overhead of expression
View full source on GitHub →FsmnVADStreaming.forward(feats, waveform, cache, is_final, **kwargs)Forward pass for training.
feats — Feature tensor (e.g., fbank), shape (batch, frames, dim).waveform — TODO.cache — State cache dict for streaming inference.is_final — Whether this is the final chunk in streaming.**kwargs — Additional keyword arguments. def forward(
self,
feats: torch.Tensor,
waveform: torch.tensor,
cache: dict = None,
is_final: bool = False,
**kwargs,
):
"""Forward pass for training.
Args:
feats: Feature tensor (e.g., fbank), shape (batch, frames, dim).
waveform: TODO.
cache: State cache dict for streaming inference.
is_final: Whether this is the final chunk in streaming.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
# if len(cache) == 0:
# self.AllResetDetection()
# self.waveform = waveform # compute decibel for each frame
cache["stats"].waveform = waveform
is_streaming_input = kwargs.get("is_streaming_input", True)
self.ComputeDecibel(cache=cache)
self.ComputeScores(feats, cache=cache)
if not is_final:
self.DetectCommonFrames(cache=cache)
else:
self.DetectLastFrames(cache=cache)
View full source on GitHub →FsmnVADStreaming.init_cache(cache, **kwargs)Init cache.
cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def init_cache(self, cache: dict = None, **kwargs):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
if kwargs.get("max_end_silence_time") is not None:
# update the max_end_silence_time
self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")
windows_detector = WindowDetector(
self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms,
)
windows_detector.Reset()
stats = Stats(
sil_pdf_ids=self.vad_opts.sil_pdf_ids,
max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time
- self.vad_opts.speech_to_sil_time_thres,
View full source on GitHub →FsmnVADStreaming.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
if len(cache) == 0:
self.init_cache(cache, **kwargs)
meta_data = {}
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
View full source on GitHub →FsmnVADStreaming.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
FsmnVADStreaming.DetectCommonFrames(cache)Detectcommonframes.
cache — State cache dict for streaming inference. def DetectCommonFrames(self, cache: dict = None) -> int:
"""Detectcommonframes.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(
cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
)
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
return 0
FsmnVADStreaming.DetectLastFrames(cache)Detectlastframes.
cache — State cache dict for streaming inference. def DetectLastFrames(self, cache: dict = None) -> int:
"""Detectlastframes.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(
cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
)
if i != 0:
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
else:
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
return 0
FsmnVADStreaming.DetectOneFrame(cur_frm_state, cur_frm_idx, is_final_frame, cache)Detectoneframe.
cur_frm_state — TODO.cur_frm_idx — TODO.is_final_frame — Boolean flag for final frame.cache — State cache dict for streaming inference. def DetectOneFrame(
self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = None
) -> None:
"""Detectoneframe.
Args:
cur_frm_state: TODO.
cur_frm_idx: TODO.
is_final_frame: Boolean flag for final frame.
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
tmp_cur_frm_state = FrameState.kFrameStateSpeech
else:
tmp_cur_frm_state = FrameState.kFrameStateSil
elif cur_frm_state == FrameState.kFrameStateSil:
tmp_cur_frm_state = FrameState.kFrameStateSil
state_change = cache["windows_detector"].DetectOneFrame(
tmp_cur_frm_state, cur_frm_idx, cache=cache
)
frm_shift_in_ms = self.vad_opts.frame_in_ms
if AudioChangeState.kChangeStateSil2Speech == state_change:
silence_frame_count = cache["stats"].continous_silence_frame_count
cache["stats"].continous_silence_frame_count = 0
cache["stats"].pre_end_silence_detected = False
start_frame = 0
View full source on GitHub →Fun — ASR-Nano: End-to-End ASR Large Model.Trained on tens of millions of hours of real speech data.
Supports 31 languages including Chinese dialects and regional accents.
Output — {"key": ..., "text": ..., "timestamps": [{"token", "start_time", "end_time"}, ...],Note — Outputs punctuation natively — punc_model is NOT needed.Requirements — pip install tiktoken huggingface_hubclass FunASRNano(nn.Module):
"""Fun-ASR-Nano: End-to-End ASR Large Model.
Trained on tens of millions of hours of real speech data.
Supports 31 languages including Chinese dialects and regional accents.
Features:
- Character-level timestamps (via CTC forced alignment)
- Hotword customization
- Speaker diarization (when combined with spk_model)
- Lyrics and rap recognition
- Streaming chunk-by-chunk inference (demo2.py)
Output: {"key": ..., "text": ..., "timestamps": [{"token", "start_time", "end_time"}, ...],
"ctc_timestamps": [...]}
Note: Outputs punctuation natively — punc_model is NOT needed.
Requirements: pip install tiktoken huggingface_hub
"""
def __init__(
self,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
View full source on GitHub →.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs) L194Forward pass for training.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.input_ids — TODO.attention_mask — TODO.labels_ids — TODO.fbank_beg — TODO.fbank_mask — TODO.**kwargs — Additional keyword arguments. def forward(
self,
speech: torch.Tensor = None,
speech_lengths: torch.Tensor = None,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
labels_ids: torch.Tensor = None,
fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor = None,
**kwargs,
):
"""Forward pass for training.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
input_ids: TODO.
attention_mask: TODO.
labels_ids: TODO.
fbank_beg: TODO.
fbank_mask: TODO.
**kwargs: Additional keyword arguments.
"""
batch_size, token_num = input_ids.shape
stats = {}
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
if speech is not None:
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
View full source →.forward_export(speech, speech_lengths, **kwargs) L317Forward export.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.**kwargs — Additional keyword arguments. def forward_export(self, speech, speech_lengths, **kwargs):
"""Forward export.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
**kwargs: Additional keyword arguments.
"""
x, olens = self.audio_encoder(speech, speech_lengths)
encoder_out, encoder_out_lens = self.audio_adaptor(x, olens)
return encoder_out, encoder_out_lens
.encode(speech, speech_lengths) L329Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample. def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
.data_template(data) L341Data template.
data — TODO. def data_template(self, data):
"""Data template.
Args:
data: TODO.
"""
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
if "audio" in item:
audio = item["audio"]
content = [content, audio]
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs) L371Data load speech.
contents — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.meta_data — TODO.**kwargs — Additional keyword arguments. def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
"""Data load speech.
Args:
contents: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
meta_data: TODO.
**kwargs: Additional keyword arguments.
"""
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
do_think = True
sys_prompt = True
if "dataset_conf" in kwargs:
do_think = kwargs["dataset_conf"].get("do_think", True)
sys_prompt = kwargs["dataset_conf"].get("sys_prompt", True)
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
[],
[],
[],
[],
[],
[],
[],
)
input_source_ids = []
View full source →.inference_prepare(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L533Inference prepare.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference_prepare(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Inference prepare.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
if len(data_in) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
# audio encoder
speech = batch["speech"]
View full source →.get_prompt(hotwords, language, itn) L632Get prompt.
hotwords — TODO.language — Language identifier.itn — TODO. def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
"""Get prompt.
Args:
hotwords: TODO.
language: Language identifier.
itn: TODO.
"""
if len(hotwords) > 0:
hotwords = ", ".join(hotwords)
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
prompt += f"热词列表:[{hotwords}]\n"
else:
prompt = ""
if language is None:
prompt += "语音转写"
else:
prompt += f"语音转写成{language}"
if not itn:
prompt += ",不进行文本规整"
return prompt + ":"
.generate_chatml(prompt, data) L654Generate chatml.
prompt — TODO.data — TODO. def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
"""Generate chatml.
Args:
prompt: TODO.
data: TODO.
"""
if isinstance(data, str):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
{"role": "assistant", "content": "null"},
]
elif isinstance(data, torch.Tensor):
return [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
"audio": data,
},
{"role": "assistant", "content": "null"},
]
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L678Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
prompt = self.get_prompt(
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
)
data_in = [self.generate_chatml(prompt, data) for data in data_in]
if key is None:
key = []
for _ in data_in:
chars = string.ascii_letters + string.digits
key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))
View full source →.inference_llm(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L717Inference llm.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference_llm(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Inference llm.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
ctc_results = []
if self.ctc_decoder is not None:
encoder_out = meta_data["encoder_out"]
encoder_out_lens = meta_data["encoder_out_lens"]
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
ctc_logits = self.ctc.log_softmax(decoder_out)
View full source →.from_pretrained(model, **kwargs) L856From pretrained.
model — Model instance or model name.**kwargs — Additional keyword arguments. def from_pretrained(model: str = None, **kwargs):
"""From pretrained.
Args:
model: Model instance or model name.
**kwargs: Additional keyword arguments.
"""
from funasr import AutoModel
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
return model, kwargs
FunASRNano.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs)Forward pass for training.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.input_ids — TODO.attention_mask — TODO.labels_ids — TODO.fbank_beg — TODO.fbank_mask — TODO.**kwargs — Additional keyword arguments. def forward(
self,
speech: torch.Tensor = None,
speech_lengths: torch.Tensor = None,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
labels_ids: torch.Tensor = None,
fbank_beg: torch.Tensor = None,
fbank_mask: torch.Tensor = None,
**kwargs,
):
"""Forward pass for training.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
input_ids: TODO.
attention_mask: TODO.
labels_ids: TODO.
fbank_beg: TODO.
fbank_mask: TODO.
**kwargs: Additional keyword arguments.
"""
batch_size, token_num = input_ids.shape
stats = {}
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
if speech is not None:
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
View full source on GitHub →FunASRNano.forward_export(speech, speech_lengths, **kwargs)Forward export.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.**kwargs — Additional keyword arguments. def forward_export(self, speech, speech_lengths, **kwargs):
"""Forward export.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
**kwargs: Additional keyword arguments.
"""
x, olens = self.audio_encoder(speech, speech_lengths)
encoder_out, encoder_out_lens = self.audio_adaptor(x, olens)
return encoder_out, encoder_out_lens
FunASRNano.encode(speech, speech_lengths)Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample. def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
FunASRNano.data_template(data)Data template.
data — TODO. def data_template(self, data):
"""Data template.
Args:
data: TODO.
"""
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
if "audio" in item:
audio = item["audio"]
content = [content, audio]
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
FunASRNano.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs)Data load speech.
contents — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.meta_data — TODO.**kwargs — Additional keyword arguments. def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
"""Data load speech.
Args:
contents: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
meta_data: TODO.
**kwargs: Additional keyword arguments.
"""
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
do_think = True
sys_prompt = True
if "dataset_conf" in kwargs:
do_think = kwargs["dataset_conf"].get("do_think", True)
sys_prompt = kwargs["dataset_conf"].get("sys_prompt", True)
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
[],
[],
[],
[],
[],
[],
[],
)
input_source_ids = []
View full source on GitHub →FunASRNano.inference_prepare(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Inference prepare.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference_prepare(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Inference prepare.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
if len(data_in) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
# audio encoder
speech = batch["speech"]
View full source on GitHub →FunASRNano.get_prompt(hotwords, language, itn)Get prompt.
hotwords — TODO.language — Language identifier.itn — TODO. def get_prompt(self, hotwords: list[str], language: str = None, itn: bool = True):
"""Get prompt.
Args:
hotwords: TODO.
language: Language identifier.
itn: TODO.
"""
if len(hotwords) > 0:
hotwords = ", ".join(hotwords)
prompt = f"请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n"
prompt += f"热词列表:[{hotwords}]\n"
else:
prompt = ""
if language is None:
prompt += "语音转写"
else:
prompt += f"语音转写成{language}"
if not itn:
prompt += ",不进行文本规整"
return prompt + ":"
FunASRNano.generate_chatml(prompt, data)Generate chatml.
prompt — TODO.data — TODO. def generate_chatml(self, prompt: str, data: Union[str, torch.Tensor]):
"""Generate chatml.
Args:
prompt: TODO.
data: TODO.
"""
if isinstance(data, str):
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{prompt}<|startofspeech|>!{data}<|endofspeech|>"},
{"role": "assistant", "content": "null"},
]
elif isinstance(data, torch.Tensor):
return [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": f"{prompt}<|startofspeech|>!!<|endofspeech|>",
"audio": data,
},
{"role": "assistant", "content": "null"},
]
FunASRNano.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
prompt = self.get_prompt(
kwargs.get("hotwords", []), kwargs.get("language", None), kwargs.get("itn", True)
)
data_in = [self.generate_chatml(prompt, data) for data in data_in]
if key is None:
key = []
for _ in data_in:
chars = string.ascii_letters + string.digits
key.append("rand_key_" + "".join(random.choice(chars) for _ in range(13)))
View full source on GitHub →FunASRNano.inference_llm(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Inference llm.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference_llm(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Inference llm.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
ctc_results = []
if self.ctc_decoder is not None:
encoder_out = meta_data["encoder_out"]
encoder_out_lens = meta_data["encoder_out_lens"]
decoder_out, decoder_out_lens = self.ctc_decoder(encoder_out, encoder_out_lens)
ctc_logits = self.ctc.log_softmax(decoder_out)
View full source on GitHub →FunASRNano.from_pretrained(model, **kwargs)From pretrained.
model — Model instance or model name.**kwargs — Additional keyword arguments. def from_pretrained(model: str = None, **kwargs):
"""From pretrained.
Args:
model: Model instance or model name.
**kwargs: Additional keyword arguments.
"""
from funasr import AutoModel
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
return model, kwargs
No documentation yet.
class GLMASR(nn.Module):
def __init__(self, **kwargs):
"""Initialize GLMASR.
Args:
**kwargs: Additional keyword arguments.
"""
super().__init__()
model_path = kwargs.get("model_path", kwargs.get("model", "zai-org/GLM-ASR-Nano-2512"))
device = kwargs.get("device", "cuda:0")
dtype = kwargs.get("dtype", "bf16")
hub = kwargs.get("hub", "ms")
self._max_new_tokens = kwargs.get("max_new_tokens", 512)
self._dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
self._device = device
self._torch_dtype = self._dtype_map.get(dtype, torch.bfloat16)
self._placeholder = nn.Parameter(torch.empty(0))
model_path = self._resolve_model_path(model_path, hub, kwargs)
self.model_path = model_path
from transformers import AutoModel as HFAutoModel
from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
self.glm_model = HFAutoModel.from_pretrained(
model_path,
dtype=self._torch_dtype,
View full source on GitHub →.forward(**kwargs) L73Forward pass for training.
**kwargs — Additional keyword arguments. def forward(self, **kwargs):
"""Forward pass for training.
Args:
**kwargs: Additional keyword arguments.
"""
raise NotImplementedError("GLMASR only supports inference mode")
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L81Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
time1 = time.perf_counter()
prompt = kwargs.get("prompt", "Please transcribe this audio into text")
if isinstance(data_in, (list, tuple)):
audio_list = list(data_in)
elif isinstance(data_in, str):
audio_list = [data_in]
else:
audio_list = [data_in]
View full source →GLMASR.forward(**kwargs)Forward pass for training.
**kwargs — Additional keyword arguments. def forward(self, **kwargs):
"""Forward pass for training.
Args:
**kwargs: Additional keyword arguments.
"""
raise NotImplementedError("GLMASR only supports inference mode")
GLMASR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
time1 = time.perf_counter()
prompt = kwargs.get("prompt", "Please transcribe this audio into text")
if isinstance(data_in, (list, tuple)):
audio_list = list(data_in)
elif isinstance(data_in, str):
audio_list = [data_in]
else:
audio_list = [data_in]
View full source on GitHub →LCBNet — Lightweight Convolutional Block Network for ASR.Efficient model design using depthwise separable convolutions
for low-resource deployment scenarios.
Inherits Paraformer pipeline.
class LCBNet(nn.Module):
"""LCBNet: Lightweight Convolutional Block Network for ASR.
Efficient model design using depthwise separable convolutions
for low-resource deployment scenarios.
Inherits Paraformer pipeline.
"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
text_encoder: str = None,
text_encoder_conf: dict = None,
bias_predictor: str = None,
bias_predictor_conf: dict = None,
fusion_encoder: str = None,
fusion_encoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
select_num: int = 2,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L208Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
View full source →.encode(speech, speech_lengths, **kwargs) L302Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
View full source →.init_beam_search(**kwargs) L400Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L450Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
View full source →LCBNet.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
View full source on GitHub →LCBNet.encode(speech, speech_lengths, **kwargs)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
View full source on GitHub →LCBNet.init_beam_search(**kwargs)Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
View full source on GitHub →LCBNet.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
View full source on GitHub →LLM — ASR: Large Language Model based Speech Recognition.Combines audio encoder with LLM decoder for speech-to-text.
Output — {"key": str, "text": str}class LLMASR(nn.Module):
"""LLM-ASR: Large Language Model based Speech Recognition.
Combines audio encoder with LLM decoder for speech-to-text.
Output: {"key": str, "text": str}
"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs) L188Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
label_mask: torch.Tensor,
audio_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source →.encode(speech, speech_lengths, **kwargs) L258Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.**kwargs — Additional keyword arguments. def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
**kwargs: Additional keyword arguments.
"""
speech = speech.permute(0, 2, 1)
res = self.audio_encoder(speech)
if isinstance(res, (list, tuple)):
encoder_out, encoder_out_lens = res[0], res[1]
else:
encoder_out, encoder_out_lens = res, speech_lengths
return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L279Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
prompt = kwargs.get("prompt", "Transcribe speech to text.")
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
View full source →LLMASR.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
label_mask: torch.Tensor,
audio_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source on GitHub →LLMASR.encode(speech, speech_lengths, **kwargs)Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.**kwargs — Additional keyword arguments. def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
**kwargs: Additional keyword arguments.
"""
speech = speech.permute(0, 2, 1)
res = self.audio_encoder(speech)
if isinstance(res, (list, tuple)):
encoder_out, encoder_out_lens = res[0], res[1]
else:
encoder_out, encoder_out_lens = res, speech_lengths
return encoder_out, encoder_out_lens
LLMASR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
prompt = kwargs.get("prompt", "Transcribe speech to text.")
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
View full source on GitHub →No documentation yet.
class LLMASR2(nn.Module):
""" """
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
View full source on GitHub →.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs) L561Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
fbank_beg: torch.Tensor,
fbank_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size, frames, _ = speech.shape
with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
View full source →.encode(speech, speech_lengths) L654Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample. def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
return encoder_out, encoder_out_lens
.data_template(data) L666Data template.
data — TODO. def data_template(self, data):
"""Data template.
Args:
data: TODO.
"""
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs) L693Data load speech.
contents — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.meta_data — TODO.**kwargs — Additional keyword arguments. def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
"""Data load speech.
Args:
contents: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
meta_data: TODO.
**kwargs: Additional keyword arguments.
"""
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = (
[],
[],
[],
[],
[],
[],
[],
[],
)
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L820Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
prompt = kwargs.get("prompt", None)
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
View full source →LLMASR2.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
fbank_beg: torch.Tensor,
fbank_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size, frames, _ = speech.shape
with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
View full source on GitHub →LLMASR2.encode(speech, speech_lengths)Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample. def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
return encoder_out, encoder_out_lens
LLMASR2.data_template(data)Data template.
data — TODO. def data_template(self, data):
"""Data template.
Args:
data: TODO.
"""
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
LLMASR2.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs)Data load speech.
contents — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.meta_data — TODO.**kwargs — Additional keyword arguments. def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
"""Data load speech.
Args:
contents: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
meta_data: TODO.
**kwargs: Additional keyword arguments.
"""
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = (
[],
[],
[],
[],
[],
[],
[],
[],
)
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
View full source on GitHub →LLMASR2.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
prompt = kwargs.get("prompt", None)
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
View full source on GitHub →No documentation yet.
class LLMASR3(LLMASR2):
""" """
def __init__(
self,
*args,
**kwargs,
):
"""Initialize LLMASR3.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
.encode(speech, speech_lengths) L964Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample. def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
LLMASR3.encode(speech, speech_lengths)Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample. def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech, speech_lengths)
return encoder_out, encoder_out_lens
No documentation yet.
class LLMASR4(nn.Module):
""" """
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
View full source on GitHub →.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs) L1135Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
fbank_beg: torch.Tensor,
fbank_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb
#
# pdb.set_trace()
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size_speech, frames, _ = speech.shape
batch_size, token_num = input_ids.shape
with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source →.encode(speech, speech_lengths) L1246Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample. def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
return encoder_out, encoder_out_lens
.data_template(data) L1258Data template.
data — TODO. def data_template(self, data):
"""Data template.
Args:
data: TODO.
"""
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs) L1285Data load speech.
contents — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.meta_data — TODO.**kwargs — Additional keyword arguments. def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
"""Data load speech.
Args:
contents: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
meta_data: TODO.
**kwargs: Additional keyword arguments.
"""
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
[],
[],
[],
[],
[],
[],
[],
)
input_source_ids = []
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
if i >= kwargs.get("multiturn_num_max", 5):
break
if len(input_ids) > kwargs.get("max_token_length", 1500):
View full source →.inference_prepare(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L1433Inference prepare.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference_prepare(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Inference prepare.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
prompt = kwargs.get("prompt", None)
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L1525Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
with torch.cuda.amp.autocast(
View full source →LLMASR4.forward(speech, speech_lengths, input_ids, attention_mask, labels_ids, fbank_beg, fbank_mask, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
fbank_beg: torch.Tensor,
fbank_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb
#
# pdb.set_trace()
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size_speech, frames, _ = speech.shape
batch_size, token_num = input_ids.shape
with torch.cuda.amp.autocast(enabled=False):
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source on GitHub →LLMASR4.encode(speech, speech_lengths)Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample. def encode(self, speech, speech_lengths):
# audio encoder
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
"""
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
return encoder_out, encoder_out_lens
LLMASR4.data_template(data)Data template.
data — TODO. def data_template(self, data):
"""Data template.
Args:
data: TODO.
"""
system, user, assistant = [], [], []
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
if role == "system":
system.append(content)
elif role == "user":
user.append(content)
elif role == "assistant":
assistant.append(content)
system = system * len(user)
contents = {
"system": system,
"user": user,
"assistant": assistant,
}
return contents
LLMASR4.data_load_speech(contents, tokenizer, frontend, meta_data, **kwargs)Data load speech.
contents — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.meta_data — TODO.**kwargs — Additional keyword arguments. def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
"""Data load speech.
Args:
contents: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
meta_data: TODO.
**kwargs: Additional keyword arguments.
"""
system = contents["system"]
user = contents["user"]
assistant = contents["assistant"]
pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
[],
[],
[],
[],
[],
[],
[],
)
input_source_ids = []
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
if i >= kwargs.get("multiturn_num_max", 5):
break
if len(input_ids) > kwargs.get("max_token_length", 1500):
View full source on GitHub →LLMASR4.inference_prepare(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Inference prepare.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference_prepare(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Inference prepare.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
prompt = kwargs.get("prompt", None)
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
View full source on GitHub →LLMASR4.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
data_in, data_lengths, key, tokenizer, frontend, **kwargs
)
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
with torch.cuda.amp.autocast(
View full source on GitHub →No documentation yet.
class LLMASRNAR(nn.Module):
""" """
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
adaptor: str = None,
adaptor_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs) L182Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
label_mask: torch.Tensor,
audio_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
View full source →.encode(speech, speech_lengths, **kwargs) L253Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.**kwargs — Additional keyword arguments. def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
**kwargs: Additional keyword arguments.
"""
audio_mask = kwargs.get("audio_mask", None)
audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
text_token_int = kwargs.get("text_token_int", None)
if audio_token_lengths is None:
audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_lens = self.audio_encoder.encode(**batch)
with autocast(False):
enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(
enc,
mask=enc_mask,
target_label_length=audio_token_lengths,
)
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L285Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
prompt = kwargs.get("prompt", "Transcribe speech to text.")
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
View full source →LLMASRNAR.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
label_mask: torch.Tensor,
audio_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
View full source on GitHub →LLMASRNAR.encode(speech, speech_lengths, **kwargs)Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.**kwargs — Additional keyword arguments. def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
**kwargs: Additional keyword arguments.
"""
audio_mask = kwargs.get("audio_mask", None)
audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
text_token_int = kwargs.get("text_token_int", None)
if audio_token_lengths is None:
audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_lens = self.audio_encoder.encode(**batch)
with autocast(False):
enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(
enc,
mask=enc_mask,
target_label_length=audio_token_lengths,
)
View full source on GitHub →LLMASRNAR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
prompt = kwargs.get("prompt", "Transcribe speech to text.")
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
View full source on GitHub →No documentation yet.
class LLMASRNARPrompt(nn.Module):
""" """
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.0,
llm: str = None,
llm_conf: dict = None,
adaptor: str = None,
adaptor_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
predictor_weight: int = 1.0,
report_cer: bool = True,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs) L588Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
label_mask: torch.Tensor,
audio_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
stats = {}
View full source →.encode(speech, speech_lengths, **kwargs) L691Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.**kwargs — Additional keyword arguments. def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
**kwargs: Additional keyword arguments.
"""
audio_mask = kwargs.get("audio_mask", None)
audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
text_token_int = kwargs.get("text_token_int", None)
if audio_token_lengths is None and text_token_int is not None:
audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_lens = self.audio_encoder.encode(**batch)
with autocast(False):
enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(
enc,
mask=enc_mask,
target_label_length=audio_token_lengths,
)
loss_pre = 0.0
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L753Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
prompt = kwargs.get("prompt", "Transcribe speech to text.")
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
View full source →LLMASRNARPrompt.forward(speech, speech_lengths, text, text_lengths, input_ids, attention_mask, labels_ids, label_mask, audio_mask, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
label_mask: torch.Tensor,
audio_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
stats = {}
View full source on GitHub →LLMASRNARPrompt.encode(speech, speech_lengths, **kwargs)Encode.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.**kwargs — Additional keyword arguments. def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Encode.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
**kwargs: Additional keyword arguments.
"""
audio_mask = kwargs.get("audio_mask", None)
audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
text_token_int = kwargs.get("text_token_int", None)
if audio_token_lengths is None and text_token_int is not None:
audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_lens = self.audio_encoder.encode(**batch)
with autocast(False):
enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(
enc,
mask=enc_mask,
target_label_length=audio_token_lengths,
)
loss_pre = 0.0
View full source on GitHub →LLMASRNARPrompt.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
prompt = kwargs.get("prompt", "Transcribe speech to text.")
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
View full source on GitHub →MonotonicAligner — Forced alignment model for timestamp prediction.Given audio and text, computes character-level time alignments
using monotonic attention constraint.
Usage — model.generate(input=(audio_path, text_path), data_type=("sound", "text"))Output — {"key": str, "timestamp": [[start_ms, end_ms], ...], "text": str}class MonotonicAligner(torch.nn.Module):
"""MonotonicAligner: Forced alignment model for timestamp prediction.
Given audio and text, computes character-level time alignments
using monotonic attention constraint.
Usage: model.generate(input=(audio_path, text_path), data_type=("sound", "text"))
Output: {"key": str, "timestamp": [[start_ms, end_ms], ...], "text": str}
"""
def __init__(
self,
input_size: int = 80,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
predictor: str = None,
predictor_conf: Optional[Dict] = None,
predictor_bias: int = 0,
length_normalized_loss: bool = False,
**kwargs,
):
"""Initialize MonotonicAligner.
Args:
input_size: Size/dimension parameter.
specaug: TODO.
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths) L86Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
speech = speech[:, : speech_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
View full source →.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num) L137Calc predictor timestamp.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.token_num — TODO. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
"""Calc predictor timestamp.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
token_num: TODO.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
encoder_out, encoder_out_mask, token_num
)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
.encode(speech, speech_lengths, **kwargs) L153Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L182Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
audio_list, text_token_int_list = load_audio_text_image_video(
data_in,
fs=frontend.fs,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
time2 = time.perf_counter()
View full source →MonotonicAligner.forward(speech, speech_lengths, text, text_lengths)Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
speech = speech[:, : speech_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
View full source on GitHub →MonotonicAligner.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num)Calc predictor timestamp.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.token_num — TODO. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
"""Calc predictor timestamp.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
token_num: TODO.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
encoder_out, encoder_out_mask, token_num
)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
MonotonicAligner.encode(speech, speech_lengths, **kwargs)Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
MonotonicAligner.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
audio_list, text_token_int_list = load_audio_text_image_video(
data_in,
fs=frontend.fs,
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
)
time2 = time.perf_counter()
View full source on GitHub →Paraformer — Non-autoregressive End-to-End ASR Model.High — accuracy speech recognition for Chinese/English. The production workhorse.Output — {"key": "...", "text": "recognized text", "timestamp": [[start_ms, end_ms], ...]}Note — Requires punc_model="ct-punc" for punctuation (unlike Fun-ASR-Nano/SenseVoicewhich output punctuation natively).
Author — Speech Lab of DAMO Academy, Alibaba GroupParaformer — Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognitionhttps://arxiv.org/abs/2206.08317
class Paraformer(torch.nn.Module):
"""Paraformer: Non-autoregressive End-to-End ASR Model.
High-accuracy speech recognition for Chinese/English. The production workhorse.
Features:
- Non-autoregressive (parallel decoding, fast inference)
- Character-level timestamps via CIF predictor
- Streaming and offline modes
- Hotword customization
- Speaker diarization (with spk_model)
- ONNX export support
Output: {"key": "...", "text": "recognized text", "timestamp": [[start_ms, end_ms], ...]}
Note: Requires punc_model="ct-punc" for punctuation (unlike Fun-ASR-Nano/SenseVoice
which output punctuation natively).
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L215Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
View full source →.encode(speech, speech_lengths, **kwargs) L286Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
.calc_predictor(encoder_out, encoder_out_lens) L315Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) L331Cal decoder with predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad. def cal_decoder_with_predictor(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
):
"""Cal decoder with predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds) L408Sampler.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
"""Sampler.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
with torch.no_grad():
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
View full source →.init_beam_search(**kwargs) L482Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.paraformer.search import BeamSearchPara
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
weights = dict(
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L534Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source →.export(**kwargs) L699Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
Paraformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
View full source on GitHub →Paraformer.encode(speech, speech_lengths, **kwargs)Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
Paraformer.calc_predictor(encoder_out, encoder_out_lens)Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
Paraformer.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)Cal decoder with predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad. def cal_decoder_with_predictor(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
):
"""Cal decoder with predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
Paraformer.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds)Sampler.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
"""Sampler.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
with torch.no_grad():
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
)
decoder_out, _ = decoder_outs[0], decoder_outs[1]
pred_tokens = decoder_out.argmax(-1)
nonpad_positions = ys_pad.ne(self.ignore_id)
seq_lens = (nonpad_positions).sum(1)
same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
input_mask = torch.ones_like(nonpad_positions)
bsz, seq_len = ys_pad.size()
View full source on GitHub →Paraformer.init_beam_search(**kwargs)Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.paraformer.search import BeamSearchPara
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
weights = dict(
View full source on GitHub →Paraformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source on GitHub →Paraformer.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
ParaformerStreaming — Streaming (online) version of Paraformer.Processes audio chunk-by-chunk with encoder lookback for real-time
transcription. Uses cache mechanism to maintain state across chunks.
Usage — generate(input=chunk, cache=cache, is_final=bool, chunk_size=[0,10,5])Output — {"key": str, "text": str} (partial results per chunk)class ParaformerStreaming(Paraformer):
"""ParaformerStreaming: Streaming (online) version of Paraformer.
Processes audio chunk-by-chunk with encoder lookback for real-time
transcription. Uses cache mechanism to maintain state across chunks.
Usage: generate(input=chunk, cache=cache, is_final=bool, chunk_size=[0,10,5])
Output: {"key": str, "text": str} (partial results per chunk)
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize ParaformerStreaming.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
self.scama_mask = None
if (
hasattr(self.encoder, "overlap_chunk_cls")
and self.encoder.overlap_chunk_cls is not None
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L83Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
decoding_ind = kwargs.get("decoding_ind")
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
if hasattr(self.encoder, "overlap_chunk_cls"):
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source →.encode_chunk(speech, speech_lengths, cache, **kwargs) L165Encode chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def encode_chunk(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
cache: dict = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
View full source →.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask) L336Sampler.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO.chunk_mask — TODO. def sampler(
self,
encoder_out,
encoder_out_lens,
ys_pad,
ys_pad_lens,
pre_acoustic_embeds,
chunk_mask=None,
):
"""Sampler.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
chunk_mask: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
with torch.no_grad():
decoder_outs = self.decoder(
View full source →.calc_predictor(encoder_out, encoder_out_lens) L392Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
mask_chunk_predictor = None
if self.encoder.overlap_chunk_cls is not None:
mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
encoder_out = encoder_out * mask_shfit_chunk
pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(
encoder_out,
None,
encoder_out_mask,
ignore_id=self.ignore_id,
mask_chunk_predictor=mask_chunk_predictor,
target_label_length=None,
)
predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(
pre_alphas,
View full source →.calc_predictor_chunk(encoder_out, encoder_out_lens, cache, **kwargs) L460Calc predictor chunk.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
"""Calc predictor chunk.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
is_final = kwargs.get("is_final", False)
return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) L473Cal decoder with predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad. def cal_decoder_with_predictor(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
):
"""Cal decoder with predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
.cal_decoder_with_predictor_chunk(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache) L491Cal decoder with predictor chunk.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad.cache — State cache dict for streaming inference. def cal_decoder_with_predictor_chunk(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None
):
"""Cal decoder with predictor chunk.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
cache: State cache dict for streaming inference.
"""
decoder_outs = self.decoder.forward_chunk(encoder_out, sematic_embeds, cache["decoder"])
decoder_out = decoder_outs
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
.init_cache(cache, **kwargs) L508Init cache.
cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def init_cache(self, cache: dict = None, **kwargs):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
batch_size = 1
enc_output_size = kwargs["encoder_conf"]["output_size"]
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
cache_encoder = {
"start_idx": 0,
"cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)),
"chunk_size": chunk_size,
"encoder_chunk_look_back": encoder_chunk_look_back,
"last_chunk": False,
"opt": None,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
"tail_chunk": False,
}
cache["encoder"] = cache_encoder
cache_decoder = {
View full source →.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs) L549Generate chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def generate_chunk(
self,
speech,
speech_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Generate chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
cache = kwargs.get("cache", {})
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode_chunk(
speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)
)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L647Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
if self.beam_search is None and (is_use_lm or is_use_ctc):
View full source →.export(**kwargs) L762Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
ParaformerStreaming.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
decoding_ind = kwargs.get("decoding_ind")
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
if hasattr(self.encoder, "overlap_chunk_cls"):
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source on GitHub →ParaformerStreaming.encode_chunk(speech, speech_lengths, cache, **kwargs)Encode chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def encode_chunk(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
cache: dict = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
View full source on GitHub →ParaformerStreaming.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask)Sampler.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad.pre_acoustic_embeds — TODO.chunk_mask — TODO. def sampler(
self,
encoder_out,
encoder_out_lens,
ys_pad,
ys_pad_lens,
pre_acoustic_embeds,
chunk_mask=None,
):
"""Sampler.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
pre_acoustic_embeds: TODO.
chunk_mask: TODO.
"""
tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
ys_pad.device
)
ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
if self.share_embedding:
ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
else:
ys_pad_embed = self.decoder.embed(ys_pad_masked)
with torch.no_grad():
decoder_outs = self.decoder(
View full source on GitHub →ParaformerStreaming.calc_predictor(encoder_out, encoder_out_lens)Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
mask_chunk_predictor = None
if self.encoder.overlap_chunk_cls is not None:
mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
encoder_out = encoder_out * mask_shfit_chunk
pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(
encoder_out,
None,
encoder_out_mask,
ignore_id=self.ignore_id,
mask_chunk_predictor=mask_chunk_predictor,
target_label_length=None,
)
predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(
pre_alphas,
View full source on GitHub →ParaformerStreaming.calc_predictor_chunk(encoder_out, encoder_out_lens, cache, **kwargs)Calc predictor chunk.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
"""Calc predictor chunk.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
is_final = kwargs.get("is_final", False)
return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
ParaformerStreaming.cal_decoder_with_predictor(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)Cal decoder with predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad. def cal_decoder_with_predictor(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
):
"""Cal decoder with predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
)
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
ParaformerStreaming.cal_decoder_with_predictor_chunk(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache)Cal decoder with predictor chunk.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.sematic_embeds — TODO.ys_pad_lens — Lengths of ys_pad.cache — State cache dict for streaming inference. def cal_decoder_with_predictor_chunk(
self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None
):
"""Cal decoder with predictor chunk.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
sematic_embeds: TODO.
ys_pad_lens: Lengths of ys_pad.
cache: State cache dict for streaming inference.
"""
decoder_outs = self.decoder.forward_chunk(encoder_out, sematic_embeds, cache["decoder"])
decoder_out = decoder_outs
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
ParaformerStreaming.init_cache(cache, **kwargs)Init cache.
cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def init_cache(self, cache: dict = None, **kwargs):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
batch_size = 1
enc_output_size = kwargs["encoder_conf"]["output_size"]
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
cache_encoder = {
"start_idx": 0,
"cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)),
"chunk_size": chunk_size,
"encoder_chunk_look_back": encoder_chunk_look_back,
"last_chunk": False,
"opt": None,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
"tail_chunk": False,
}
cache["encoder"] = cache_encoder
cache_decoder = {
View full source on GitHub →ParaformerStreaming.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs)Generate chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def generate_chunk(
self,
speech,
speech_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Generate chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
cache = kwargs.get("cache", {})
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode_chunk(
speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)
)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
View full source on GitHub →ParaformerStreaming.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
if self.beam_search is None and (is_use_lm or is_use_ctc):
View full source on GitHub →ParaformerStreaming.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
Author — Speech Lab of DAMO Academy, Alibaba GroupParaformer — Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognitionhttps://arxiv.org/abs/2206.08317
class Paraformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
decoder: str = None,
decoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
ctc_weight: float = 0.5,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
# report_cer: bool = True,
# report_wer: bool = True,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L182Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
View full source →.encode(speech, speech_lengths, **kwargs) L250Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
.map_alignment_to_target_index(align_path, blank_id) L366Robustly map CTC alignment path (Token IDs) to Target Indices.
Logic:
Detect boundaries where a new token segment begins.
A segment starts if the current frame is a Token AND it is different from the previous frame
(considering CTC topology where repeats are separated by blanks or are distinct tokens).
Example:
Text — [A, B]Align Path: [A, A, _, B, B]
Output — [0, 0, -1, 1, 1] def map_alignment_to_target_index(self, align_path, blank_id):
"""
Robustly map CTC alignment path (Token IDs) to Target Indices.
Logic:
Detect boundaries where a new token segment begins.
A segment starts if the current frame is a Token AND it is different from the previous frame
(considering CTC topology where repeats are separated by blanks or are distinct tokens).
Example:
Text: [A, B]
Align Path: [A, A, _, B, B]
Output: [0, 0, -1, 1, 1]
"""
# 1. Identify where the path is NOT blank
is_token = align_path != blank_id
# 2. Identify transitions
prev_path = torch.roll(align_path, 1)
# Handle the very first frame: if it's a token, it must be the start of segment 0.
prev_path[0] = blank_id # force mismatch for the first element
# A new segment starts if: It's a token AND (it differs from prev OR prev was blank)
# Note: If align_path[i] == align_path[i-1] (and not blank), it's the same segment.
new_segment_start = is_token & (align_path != prev_path)
# 3. Cumulative sum to assign indices (1..U)
segment_ids = torch.cumsum(new_segment_start.long(), dim=0) - 1
# 4. Mask out blank positions with -1
View full source →.force_align(ctc_probs, y, blank_id) L399ctc forced alignment.
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
torch.Tensor y: id sequence tensor 1d tensor (L)
int blank_id: blank symbol index
torch.Tensor — alignment result def force_align(self, ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
"""ctc forced alignment.
Args:
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
torch.Tensor y: id sequence tensor 1d tensor (L)
int blank_id: blank symbol index
Returns:
torch.Tensor: alignment result
"""
ctc_probs = ctc_probs[None].cpu()
y = y[None].cpu()
alignments, _ = torchaudio.functional.forced_align(ctc_probs, y, blank=blank_id)
return alignments[0]
.average_repeats_training(ctc_probs, target_idx_path, target_len) L414Aggregates frames belonging to the same target index using scatter_add.
ctc_probs — [T, V]target_idx_path — [T], values in [-1, 0, ... U-1]target_len — Ucompressed — [U, V] def average_repeats_training(self, ctc_probs, target_idx_path, target_len):
"""
Aggregates frames belonging to the same target index using scatter_add.
Args:
ctc_probs: [T, V]
target_idx_path: [T], values in [-1, 0, ... U-1]
target_len: U
Returns:
compressed: [U, V]
"""
U = target_len
V = ctc_probs.size(1)
compressed = torch.zeros((U, V), device=ctc_probs.device, dtype=ctc_probs.dtype)
counts = torch.zeros((U, 1), device=ctc_probs.device, dtype=ctc_probs.dtype)
# Filter valid frames (non-blank)
mask = target_idx_path != -1
valid_indices = target_idx_path[mask] # [T_valid]
valid_probs = ctc_probs[mask] # [T_valid, V]
if valid_indices.numel() == 0:
return compressed
# Scatter Add Probs
index_expanded = valid_indices.unsqueeze(1).repeat(1, V)
compressed.scatter_add_(0, index_expanded, valid_probs)
# Scatter Add Counts
View full source →.average_repeats_inference(ctc_probs, greedy_path) L451merged_probs — [U', V]timestamps — List[Tuple[int, int]] -> [(start_frame, end_frame), ...] def average_repeats_inference(self, ctc_probs, greedy_path):
"""
Returns:
merged_probs: [U', V]
timestamps: List[Tuple[int, int]] -> [(start_frame, end_frame), ...]
"""
if greedy_path.numel() == 0:
return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device)
# Find consecutive segments in the greedy path
unique_tokens, counts = torch.unique_consecutive(greedy_path, return_counts=True)
# Compute start and end indices for each segment
end_indices = torch.cumsum(counts, dim=0)
start_indices = torch.cat([torch.tensor([0], device=counts.device), end_indices[:-1]])
merged_probs = []
for i, token in enumerate(unique_tokens):
if token != self.blank_id:
start = start_indices[i].item()
end = end_indices[i].item()
# Extract and average probabilities for the decoder
avg_prob = ctc_probs[start:end].mean(dim=0)
merged_probs.append(avg_prob)
if not merged_probs:
return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device)
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L484Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is not None:
speech_lengths = speech_lengths.squeeze(-1)
else:
View full source →.export(**kwargs) L592Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
Paraformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
loss_ctc, cer_ctc = None, None
loss_pre = None
stats = dict()
# decoder: CTC branch
View full source on GitHub →Paraformer.encode(speech, speech_lengths, **kwargs)Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
Paraformer.map_alignment_to_target_index(align_path, blank_id)Robustly map CTC alignment path (Token IDs) to Target Indices.
Logic:
Detect boundaries where a new token segment begins.
A segment starts if the current frame is a Token AND it is different from the previous frame
(considering CTC topology where repeats are separated by blanks or are distinct tokens).
Example:
Text — [A, B]Align Path: [A, A, _, B, B]
Output — [0, 0, -1, 1, 1] def map_alignment_to_target_index(self, align_path, blank_id):
"""
Robustly map CTC alignment path (Token IDs) to Target Indices.
Logic:
Detect boundaries where a new token segment begins.
A segment starts if the current frame is a Token AND it is different from the previous frame
(considering CTC topology where repeats are separated by blanks or are distinct tokens).
Example:
Text: [A, B]
Align Path: [A, A, _, B, B]
Output: [0, 0, -1, 1, 1]
"""
# 1. Identify where the path is NOT blank
is_token = align_path != blank_id
# 2. Identify transitions
prev_path = torch.roll(align_path, 1)
# Handle the very first frame: if it's a token, it must be the start of segment 0.
prev_path[0] = blank_id # force mismatch for the first element
# A new segment starts if: It's a token AND (it differs from prev OR prev was blank)
# Note: If align_path[i] == align_path[i-1] (and not blank), it's the same segment.
new_segment_start = is_token & (align_path != prev_path)
# 3. Cumulative sum to assign indices (1..U)
segment_ids = torch.cumsum(new_segment_start.long(), dim=0) - 1
# 4. Mask out blank positions with -1
View full source on GitHub →Paraformer.force_align(ctc_probs, y, blank_id)ctc forced alignment.
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
torch.Tensor y: id sequence tensor 1d tensor (L)
int blank_id: blank symbol index
torch.Tensor — alignment result def force_align(self, ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
"""ctc forced alignment.
Args:
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
torch.Tensor y: id sequence tensor 1d tensor (L)
int blank_id: blank symbol index
Returns:
torch.Tensor: alignment result
"""
ctc_probs = ctc_probs[None].cpu()
y = y[None].cpu()
alignments, _ = torchaudio.functional.forced_align(ctc_probs, y, blank=blank_id)
return alignments[0]
Paraformer.average_repeats_training(ctc_probs, target_idx_path, target_len)Aggregates frames belonging to the same target index using scatter_add.
ctc_probs — [T, V]target_idx_path — [T], values in [-1, 0, ... U-1]target_len — Ucompressed — [U, V] def average_repeats_training(self, ctc_probs, target_idx_path, target_len):
"""
Aggregates frames belonging to the same target index using scatter_add.
Args:
ctc_probs: [T, V]
target_idx_path: [T], values in [-1, 0, ... U-1]
target_len: U
Returns:
compressed: [U, V]
"""
U = target_len
V = ctc_probs.size(1)
compressed = torch.zeros((U, V), device=ctc_probs.device, dtype=ctc_probs.dtype)
counts = torch.zeros((U, 1), device=ctc_probs.device, dtype=ctc_probs.dtype)
# Filter valid frames (non-blank)
mask = target_idx_path != -1
valid_indices = target_idx_path[mask] # [T_valid]
valid_probs = ctc_probs[mask] # [T_valid, V]
if valid_indices.numel() == 0:
return compressed
# Scatter Add Probs
index_expanded = valid_indices.unsqueeze(1).repeat(1, V)
compressed.scatter_add_(0, index_expanded, valid_probs)
# Scatter Add Counts
View full source on GitHub →Paraformer.average_repeats_inference(ctc_probs, greedy_path)merged_probs — [U', V]timestamps — List[Tuple[int, int]] -> [(start_frame, end_frame), ...] def average_repeats_inference(self, ctc_probs, greedy_path):
"""
Returns:
merged_probs: [U', V]
timestamps: List[Tuple[int, int]] -> [(start_frame, end_frame), ...]
"""
if greedy_path.numel() == 0:
return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device)
# Find consecutive segments in the greedy path
unique_tokens, counts = torch.unique_consecutive(greedy_path, return_counts=True)
# Compute start and end indices for each segment
end_indices = torch.cumsum(counts, dim=0)
start_indices = torch.cat([torch.tensor([0], device=counts.device), end_indices[:-1]])
merged_probs = []
for i, token in enumerate(unique_tokens):
if token != self.blank_id:
start = start_indices[i].item()
end = end_indices[i].item()
# Extract and average probabilities for the decoder
avg_prob = ctc_probs[start:end].mean(dim=0)
merged_probs.append(avg_prob)
if not merged_probs:
return torch.zeros((0, ctc_probs.size(1)), device=ctc_probs.device)
View full source on GitHub →Paraformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is not None:
speech_lengths = speech_lengths.squeeze(-1)
else:
View full source on GitHub →Paraformer.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
Qwen3 — ASR: Large Language Model based ASR supporting 52 languages.Wraps the qwen-asr package's Qwen3ASRModel for use within FunASR's AutoModel interface.
Supports auto language detection, contextual recognition, and optional forced alignment
for character-level timestamps.
pip install qwen-asr
class Qwen3ASR(nn.Module):
"""Qwen3-ASR: Large Language Model based ASR supporting 52 languages.
Wraps the qwen-asr package's Qwen3ASRModel for use within FunASR's AutoModel interface.
Supports auto language detection, contextual recognition, and optional forced alignment
for character-level timestamps.
Requirements:
pip install qwen-asr
Models:
- Qwen/Qwen3-ASR-0.6B (lighter, ~4GB GPU memory)
- Qwen/Qwen3-ASR-1.7B (more accurate, ~8GB GPU memory)
"""
def __init__(self, **kwargs):
"""Initialize Qwen3ASR.
Args:
**kwargs: Additional keyword arguments.
"""
super().__init__()
model_path = kwargs.get("model_path", kwargs.get("model", "Qwen/Qwen3-ASR-1.7B"))
device = kwargs.get("device", "cuda:0")
dtype = kwargs.get("dtype", "bf16")
hub = kwargs.get("hub", "ms")
max_new_tokens = kwargs.get("max_new_tokens", 512)
max_inference_batch_size = kwargs.get("max_inference_batch_size", 32)
forced_aligner = kwargs.get("forced_aligner", None)
forced_aligner_kwargs = kwargs.get("forced_aligner_kwargs", None)
View full source on GitHub →.forward(**kwargs) L105Forward pass for training.
**kwargs — Additional keyword arguments. def forward(self, **kwargs):
"""Forward pass for training.
Args:
**kwargs: Additional keyword arguments.
"""
raise NotImplementedError("Qwen3ASR only supports inference mode")
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L113Run Qwen3-ASR speech recognition.
data_in — Audio input. Accepts:data_lengths — Not used.key (list) — Sample identifiers.tokenizer — Not used (Qwen3-ASR has internal tokenizer).frontend — Not used (Qwen3-ASR has internal audio processing).**kwargs — Runtime parameters:tuple — (results, meta_data) where results is list of dicts: def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run Qwen3-ASR speech recognition.
Args:
data_in: Audio input. Accepts:
- list of file paths/URLs
- list of (numpy_array, sample_rate) tuples
- single numpy array or torch Tensor
data_lengths: Not used.
key (list): Sample identifiers.
tokenizer: Not used (Qwen3-ASR has internal tokenizer).
frontend: Not used (Qwen3-ASR has internal audio processing).
**kwargs: Runtime parameters:
- language (str): Language hint (e.g. "Chinese", "English") or None for auto-detect.
- return_time_stamps (bool): Return character-level timestamps (requires forced_aligner).
- output_timestamp (bool): Same as return_time_stamps (for pipeline compatibility).
- context (str): Context prompt for contextual recognition.
Returns:
tuple: (results, meta_data) where results is list of dicts:
- "key" (str): Sample ID
- "text" (str): Recognized text (with punctuation)
View full source →Qwen3ASR.forward(**kwargs)Forward pass for training.
**kwargs — Additional keyword arguments. def forward(self, **kwargs):
"""Forward pass for training.
Args:
**kwargs: Additional keyword arguments.
"""
raise NotImplementedError("Qwen3ASR only supports inference mode")
Qwen3ASR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run Qwen3-ASR speech recognition.
data_in — Audio input. Accepts:data_lengths — Not used.key (list) — Sample identifiers.tokenizer — Not used (Qwen3-ASR has internal tokenizer).frontend — Not used (Qwen3-ASR has internal audio processing).**kwargs — Runtime parameters:tuple — (results, meta_data) where results is list of dicts: def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run Qwen3-ASR speech recognition.
Args:
data_in: Audio input. Accepts:
- list of file paths/URLs
- list of (numpy_array, sample_rate) tuples
- single numpy array or torch Tensor
data_lengths: Not used.
key (list): Sample identifiers.
tokenizer: Not used (Qwen3-ASR has internal tokenizer).
frontend: Not used (Qwen3-ASR has internal audio processing).
**kwargs: Runtime parameters:
- language (str): Language hint (e.g. "Chinese", "English") or None for auto-detect.
- return_time_stamps (bool): Return character-level timestamps (requires forced_aligner).
- output_timestamp (bool): Same as return_time_stamps (for pipeline compatibility).
- context (str): Context prompt for contextual recognition.
Returns:
tuple: (results, meta_data) where results is list of dicts:
- "key" (str): Sample ID
- "text" (str): Recognized text (with punctuation)
View full source on GitHub →Qwen — Audio: Advancing Universal Audio Understanding via Unified Large-Scale Audio-Language Modelshttps://arxiv.org/abs/2311.07919
Modified from https://github.com/QwenLM/Qwen-Audio
class QwenAudioWarp(nn.Module):
"""
Qwen-Audio: Advancing Universal Audio Understanding via Unified Large-Scale Audio-Language Models
https://arxiv.org/abs/2311.07919
Modified from https://github.com/QwenLM/Qwen-Audio
"""
def __init__(self, *args, **kwargs):
"""Initialize QwenAudioWarp.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__()
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
model_or_path = kwargs.get("model_path", "QwenAudio")
model = AutoModelForCausalLM.from_pretrained(
model_or_path, device_map="cpu", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True)
self.model = model
self.tokenizer = tokenizer
def forward(
self,
):
View full source on GitHub →.forward() L49Forward pass for training.
def forward(
self,
):
"""Forward pass for training."""
pass
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L55Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
# meta_data["batch_data_time"] = -1
prompt = kwargs.get(
"prompt", "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
)
query = f"<audio>{data_in[0]}</audio>{prompt}"
audio_info = self.tokenizer.process_audio(query)
inputs = self.tokenizer(query, return_tensors="pt", audio_info=audio_info)
View full source →QwenAudioWarp.forward()Forward pass for training.
def forward(
self,
):
"""Forward pass for training."""
pass
QwenAudioWarp.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
# meta_data["batch_data_time"] = -1
prompt = kwargs.get(
"prompt", "<|startoftranscription|><|en|><|transcribe|><|en|><|notimestamps|><|wo_itn|>"
)
query = f"<audio>{data_in[0]}</audio>{prompt}"
audio_info = self.tokenizer.process_audio(query)
inputs = self.tokenizer(query, return_tensors="pt", audio_info=audio_info)
View full source on GitHub →QwenAudioChat — Qwen Audio Chat model wrapper.Interactive audio chat using the Qwen-Audio-Chat model.
Supports multi-turn conversation about audio content.
class QwenAudioChatWarp(nn.Module):
"""QwenAudioChat: Qwen Audio Chat model wrapper.
Interactive audio chat using the Qwen-Audio-Chat model.
Supports multi-turn conversation about audio content.
"""
def __init__(self, *args, **kwargs):
"""
Qwen-Audio: Advancing Universal Audio Understanding via Unified Large-Scale Audio-Language Models
https://arxiv.org/abs/2311.07919
Modified from https://github.com/QwenLM/Qwen-Audio
"""
super().__init__()
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
model_or_path = kwargs.get("model_path", "QwenAudio")
bf16 = kwargs.get("bf16", False)
fp16 = kwargs.get("fp16", False)
model = AutoModelForCausalLM.from_pretrained(
model_or_path, device_map="cpu", bf16=bf16, fp16=fp16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_or_path, trust_remote_code=True)
self.model = model
self.tokenizer = tokenizer
def forward(
self,
View full source on GitHub →.forward() L132Forward pass for training.
def forward(
self,
):
"""Forward pass for training."""
pass
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L138Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
prompt = kwargs.get("prompt", "what does the person say?")
cache = kwargs.get("cache", {})
history = cache.get("history", None)
if data_in[0] is not None:
# 1st dialogue turn
query = self.tokenizer.from_list_format(
View full source →QwenAudioChatWarp.forward()Forward pass for training.
def forward(
self,
):
"""Forward pass for training."""
pass
QwenAudioChatWarp.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
prompt = kwargs.get("prompt", "what does the person say?")
cache = kwargs.get("cache", {})
history = cache.get("history", None)
if data_in[0] is not None:
# 1st dialogue turn
query = self.tokenizer.from_list_format(
View full source on GitHub →Author — Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlinSan — m: Memory equipped self-attention for end-to-end speech recognitionhttps://arxiv.org/abs/2006.01713
class SANM(Transformer):
"""
Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize SANM.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
SANM — KWS: Self-Attention Neural Memory based Keyword Spotting.Advanced keyword spotting using self-attention mechanism
for better context modeling of keyword patterns.
Output — {"key": str, "value": detected_keyword_info}class SanmKWS(torch.nn.Module):
"""SANM-KWS: Self-Attention Neural Memory based Keyword Spotting.
Advanced keyword spotting using self-attention mechanism
for better context modeling of keyword patterns.
Output: {"key": str, "value": detected_keyword_info}
"""
def __init__(
self,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
ctc: str = None,
ctc_conf: Optional[Dict] = None,
ctc_weight: float = 1.0,
input_size: int = 360,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
**kwargs,
):
"""Initialize SanmKWS.
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L112Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# decoder: CTC branch
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
View full source →.encode(speech, speech_lengths, **kwargs) L156Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L209Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer.token_list,
seg_dict=tokenizer.seg_dict,
)
meta_data = {}
if (
View full source →.export(**kwargs) L300Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
SanmKWS.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# decoder: CTC branch
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
View full source on GitHub →SanmKWS.encode(speech, speech_lengths, **kwargs)Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
SanmKWS.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer.token_list,
seg_dict=tokenizer.seg_dict,
)
meta_data = {}
if (
View full source on GitHub →SanmKWS.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
Author — Speech Lab of DAMO Academy, Alibaba GroupParaformer — Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognitionhttps://arxiv.org/abs/2206.08317
class SanmKWSStreaming(SanmKWS):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize SanmKWSStreaming.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L63Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
decoding_ind = kwargs.get("decoding_ind")
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
if hasattr(self.encoder, "overlap_chunk_cls"):
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source →.encode_chunk(speech, speech_lengths, cache, **kwargs) L119Encode chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def encode_chunk(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
cache: dict = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
View full source →.init_cache(cache, **kwargs) L161Init cache.
cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def init_cache(self, cache: dict = None, **kwargs):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
batch_size = 1
enc_output_size = kwargs["encoder_conf"]["output_size"]
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
cache_encoder = {
"start_idx": 0,
"cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)),
"encoder_out": None,
"encoder_out_lens": None,
"chunk_size": chunk_size,
"encoder_chunk_look_back": encoder_chunk_look_back,
"last_chunk": False,
"opt": None,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
"tail_chunk": False,
}
cache["encoder"] = cache_encoder
View full source →.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs) L204Generate chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def generate_chunk(
self,
speech,
speech_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Generate chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
cache = kwargs.get("cache", {})
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
is_final = kwargs.get("is_final", False)
encoder_out, encoder_out_lens = self.encode_chunk(
speech, speech_lengths, cache=cache, is_final=is_final
)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L287Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer.token_list,
seg_dict=tokenizer.seg_dict,
View full source →.export(**kwargs) L490Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
SanmKWSStreaming.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
decoding_ind = kwargs.get("decoding_ind")
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
if hasattr(self.encoder, "overlap_chunk_cls"):
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
else:
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
View full source on GitHub →SanmKWSStreaming.encode_chunk(speech, speech_lengths, cache, **kwargs)Encode chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def encode_chunk(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
cache: dict = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
View full source on GitHub →SanmKWSStreaming.init_cache(cache, **kwargs)Init cache.
cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def init_cache(self, cache: dict = None, **kwargs):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
batch_size = 1
enc_output_size = kwargs["encoder_conf"]["output_size"]
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
cache_encoder = {
"start_idx": 0,
"cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)),
"encoder_out": None,
"encoder_out_lens": None,
"chunk_size": chunk_size,
"encoder_chunk_look_back": encoder_chunk_look_back,
"last_chunk": False,
"opt": None,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
"tail_chunk": False,
}
cache["encoder"] = cache_encoder
View full source on GitHub →SanmKWSStreaming.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs)Generate chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def generate_chunk(
self,
speech,
speech_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Generate chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
cache = kwargs.get("cache", {})
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
is_final = kwargs.get("is_final", False)
encoder_out, encoder_out_lens = self.encode_chunk(
speech, speech_lengths, cache=cache, is_final=is_final
)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
View full source on GitHub →SanmKWSStreaming.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
keywords = kwargs.get("keywords")
from funasr.utils.kws_utils import KwsCtcPrefixDecoder
self.kws_decoder = KwsCtcPrefixDecoder(
ctc=self.ctc,
keywords=keywords,
token_list=tokenizer.token_list,
seg_dict=tokenizer.seg_dict,
View full source on GitHub →SanmKWSStreaming.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
SCAMA — Streaming Chunk-Aware Multi-head Attention ASR.Streaming ASR using chunk-based encoder with configurable latency.
Supports 2-pass decoding for accuracy refinement.
Output — {"key": str, "text": str}Author — Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei XieSCAMA — Streaming chunk-aware multihead attention for online end-to-end speech recognitionhttps://arxiv.org/abs/2006.01712
class SCAMA(nn.Module):
"""SCAMA: Streaming Chunk-Aware Multi-head Attention ASR.
Streaming ASR using chunk-based encoder with configurable latency.
Supports 2-pass decoding for accuracy refinement.
Output: {"key": str, "text": str}
Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01712
"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
predictor: str = None,
predictor_conf: dict = None,
predictor_bias: int = 0,
predictor_weight: float = 0.0,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L200Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
decoding_ind = kwargs.get("decoding_ind")
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
loss_ctc, cer_ctc = None, None
loss_pre = None
View full source →.encode(speech, speech_lengths, **kwargs) L277Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
.encode_chunk(speech, speech_lengths, cache, **kwargs) L306Encode chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def encode_chunk(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
cache: dict = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
View full source →.calc_predictor_chunk(encoder_out, encoder_out_lens, cache, **kwargs) L348Calc predictor chunk.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
"""Calc predictor chunk.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
is_final = kwargs.get("is_final", False)
return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
.calc_predictor_mask(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) L462Calc predictor mask.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad. def calc_predictor_mask(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor = None,
ys_pad_lens: torch.Tensor = None,
):
# ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
# ys_in_lens = ys_pad_lens + 1
"""Calc predictor mask.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
ys_out_pad, ys_in_lens = None, None
encoder_out_mask = sequence_mask(
encoder_out_lens,
maxlen=encoder_out.size(1),
dtype=encoder_out.dtype,
device=encoder_out.device,
)[:, None, :]
mask_chunk_predictor = None
mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
View full source →.init_beam_search(**kwargs) L543Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.scama.beam_search import BeamSearchScamaStreaming
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
View full source →.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs) L599Generate chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def generate_chunk(
self,
speech,
speech_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Generate chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
cache = kwargs.get("cache", {})
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode_chunk(
speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)
)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
if "running_hyps" not in cache:
View full source →.init_cache(cache, **kwargs) L694Init cache.
cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def init_cache(self, cache: dict = None, **kwargs):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
device = kwargs.get("device", "cuda")
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
batch_size = 1
enc_output_size = kwargs["encoder_conf"]["output_size"]
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
cache_encoder = {
"start_idx": 0,
"cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
"cif_alphas": torch.zeros((batch_size, 1)).to(device=device),
"chunk_size": chunk_size,
"encoder_chunk_look_back": encoder_chunk_look_back,
"last_chunk": False,
"opt": None,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(
device=device
),
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs) L741Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
View full source →SCAMA.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
decoding_ind = kwargs.get("decoding_ind")
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# Encoder
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
loss_ctc, cer_ctc = None, None
loss_pre = None
View full source on GitHub →SCAMA.encode(speech, speech_lengths, **kwargs)Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return encoder_out, encoder_out_lens
SCAMA.encode_chunk(speech, speech_lengths, cache, **kwargs)Encode chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def encode_chunk(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
cache: dict = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
View full source on GitHub →SCAMA.calc_predictor_chunk(encoder_out, encoder_out_lens, cache, **kwargs)Calc predictor chunk.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
"""Calc predictor chunk.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
is_final = kwargs.get("is_final", False)
return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
SCAMA.calc_predictor_mask(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)Calc predictor mask.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad. def calc_predictor_mask(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor = None,
ys_pad_lens: torch.Tensor = None,
):
# ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
# ys_in_lens = ys_pad_lens + 1
"""Calc predictor mask.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
ys_out_pad, ys_in_lens = None, None
encoder_out_mask = sequence_mask(
encoder_out_lens,
maxlen=encoder_out.size(1),
dtype=encoder_out.dtype,
device=encoder_out.device,
)[:, None, :]
mask_chunk_predictor = None
mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
View full source on GitHub →SCAMA.init_beam_search(**kwargs)Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.scama.beam_search import BeamSearchScamaStreaming
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
View full source on GitHub →SCAMA.generate_chunk(speech, speech_lengths, key, tokenizer, frontend, **kwargs)Generate chunk.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def generate_chunk(
self,
speech,
speech_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Generate chunk.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
cache = kwargs.get("cache", {})
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
encoder_out, encoder_out_lens = self.encode_chunk(
speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False)
)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
if "running_hyps" not in cache:
View full source on GitHub →SCAMA.init_cache(cache, **kwargs)Init cache.
cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def init_cache(self, cache: dict = None, **kwargs):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
device = kwargs.get("device", "cuda")
chunk_size = kwargs.get("chunk_size", [0, 10, 5])
encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
batch_size = 1
enc_output_size = kwargs["encoder_conf"]["output_size"]
feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
cache_encoder = {
"start_idx": 0,
"cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
"cif_alphas": torch.zeros((batch_size, 1)).to(device=device),
"chunk_size": chunk_size,
"encoder_chunk_look_back": encoder_chunk_look_back,
"last_chunk": False,
"opt": None,
"feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(
device=device
),
View full source on GitHub →SCAMA.inference(data_in, data_lengths, key, tokenizer, frontend, cache, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
cache: dict = None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
View full source on GitHub →SeACo — Paraformer: Semantic-Aware Contextual Paraformer.The recommended Chinese ASR model. Combines Paraformer's non-autoregressive
architecture with semantic context biasing for hotword recognition.
Registered as 'paraformer-zh' alias.
Output — {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}class SeacoParaformer(BiCifParaformer, Paraformer):
"""SeACo-Paraformer: Semantic-Aware Contextual Paraformer.
The recommended Chinese ASR model. Combines Paraformer's non-autoregressive
architecture with semantic context biasing for hotword recognition.
Registered as 'paraformer-zh' alias.
Output: {"key": str, "text": str, "timestamp": [[start_ms, end_ms], ...]}
"""
def __init__(
self,
*args,
**kwargs,
):
"""Initialize SeacoParaformer.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
self.inner_dim = kwargs.get("inner_dim", 256)
self.bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
bias_encoder_bid = kwargs.get("bias_encoder_bid", False)
seaco_lsm_weight = kwargs.get("seaco_lsm_weight", 0.0)
seaco_length_normalized_loss = kwargs.get("seaco_length_normalized_loss", True)
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L122Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
# Check that batch_size is unified
assert (
speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
seaco_label_pad = kwargs.get("seaco_label_pad")
if len(hotword_lengths.size()) > 1:
hotword_lengths = hotword_lengths[:, 0]
View full source →.calc_predictor(encoder_out, encoder_out_lens) L202Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
predictor_outs = self.predictor(
encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
)
return predictor_outs[:4]
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L422Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
View full source →.generate_hotwords_list(hotword_list_or_file, tokenizer, frontend) L583Generate hotwords list.
hotword_list_or_file — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction. def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
"""Generate hotwords list.
Args:
hotword_list_or_file: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
"""
def seg_tokenize(txt, seg_dict):
"""Seg tokenize.
Args:
txt: TODO.
seg_dict: TODO.
"""
pattern = re.compile(r"^[\u4E00-\u9FA50-9]+$")
out_txt = ""
for word in txt:
word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
if pattern.match(word):
for char in word:
if char in seg_dict:
out_txt += seg_dict[char] + " "
else:
out_txt += "<unk>" + " "
else:
View full source →.export(**kwargs) L692Export.
**kwargs — Additional keyword arguments. def export(
self,
**kwargs,
):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
SeacoParaformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
# Check that batch_size is unified
assert (
speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
seaco_label_pad = kwargs.get("seaco_label_pad")
if len(hotword_lengths.size()) > 1:
hotword_lengths = hotword_lengths[:, 0]
View full source on GitHub →SeacoParaformer.calc_predictor(encoder_out, encoder_out_lens)Calc predictor.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths. def calc_predictor(self, encoder_out, encoder_out_lens):
"""Calc predictor.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
"""
encoder_out_mask = (
~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
).to(encoder_out.device)
predictor_outs = self.predictor(
encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
)
return predictor_outs[:4]
SeacoParaformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
# init beamsearch
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
View full source on GitHub →SeacoParaformer.generate_hotwords_list(hotword_list_or_file, tokenizer, frontend)Generate hotwords list.
hotword_list_or_file — TODO.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction. def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
"""Generate hotwords list.
Args:
hotword_list_or_file: TODO.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
"""
def seg_tokenize(txt, seg_dict):
"""Seg tokenize.
Args:
txt: TODO.
seg_dict: TODO.
"""
pattern = re.compile(r"^[\u4E00-\u9FA50-9]+$")
out_txt = ""
for word in txt:
word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
if pattern.match(word):
for char in word:
if char in seg_dict:
out_txt += seg_dict[char] + " "
else:
out_txt += "<unk>" + " "
else:
View full source on GitHub →SeacoParaformer.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(
self,
**kwargs,
):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
load_seg_dict(seg_dict_file)Load seg dict.
seg_dict_file — TODO.def load_seg_dict(seg_dict_file):
"""Load seg dict.
Args:
seg_dict_file: TODO.
"""
seg_dict = {}
assert isinstance(seg_dict_file, str)
with open(seg_dict_file, "r", encoding="utf8") as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dictsequence_mask(lengths, maxlen, dtype, device)Sequence mask.
lengths — TODO.maxlen — TODO.dtype — TODO.device — Target device ("cuda:0", "cpu", etc.).def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
"""Sequence mask.
Args:
lengths: TODO.
maxlen: TODO.
dtype: TODO.
device: Target device ("cuda:0", "cpu", etc.).
"""
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask = mask.detach()
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
Author — Speech Lab of DAMO Academy, Alibaba GroupSCAMA — Streaming chunk-aware multihead attention for online end-to-end speech recognitionhttps://arxiv.org/abs/2006.01713
class SenseVoiceEncoderSmall(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
tp_blocks: int = 0,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
stochastic_depth_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
kernel_size: int = 11,
sanm_shfit: int = 0,
selfattention_layer_type: str = "sanm",
**kwargs,
View full source on GitHub →.output_size() L619Output size.
def output_size(self) -> int:
"""Output size."""
return self._output_size
.forward(xs_pad, ilens) L623Embed positions in tensor.
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
):
"""Embed positions in tensor."""
maxlen = xs_pad.shape[1]
masks = sequence_mask(ilens, maxlen=maxlen, device=ilens.device)[:, None, :]
xs_pad *= self.output_size() ** 0.5
xs_pad = self.embed(xs_pad)
# forward encoder1
for layer_idx, encoder_layer in enumerate(self.encoders0):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.after_norm(xs_pad)
# forward encoder2
olens = masks.squeeze(1).sum(1).int()
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
View full source →SenseVoiceEncoderSmall.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return self._output_size
SenseVoiceEncoderSmall.forward(xs_pad, ilens)Embed positions in tensor.
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
):
"""Embed positions in tensor."""
maxlen = xs_pad.shape[1]
masks = sequence_mask(ilens, maxlen=maxlen, device=ilens.device)[:, None, :]
xs_pad *= self.output_size() ** 0.5
xs_pad = self.embed(xs_pad)
# forward encoder1
for layer_idx, encoder_layer in enumerate(self.encoders0):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.after_norm(xs_pad)
# forward encoder2
olens = masks.squeeze(1).sum(1).int()
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
View full source on GitHub →CTC — attention hybrid Encoder-Decoder modelclass SenseVoiceSmall(nn.Module):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
ctc_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
length_normalized_loss: bool = False,
**kwargs,
):
"""Initialize SenseVoiceSmall.
Args:
specaug: TODO.
specaug_conf: Configuration dict for specaug.
normalize: TODO.
normalize_conf: Configuration dict for normalize.
encoder: TODO.
View full source on GitHub →.from_pretrained(model, **kwargs) L754From pretrained.
model — Model instance or model name.**kwargs — Additional keyword arguments. def from_pretrained(model: str = None, **kwargs):
"""From pretrained.
Args:
model: Model instance or model name.
**kwargs: Additional keyword arguments.
"""
from funasr import AutoModel
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
return model, kwargs
.forward(speech, speech_lengths, text, text_lengths, **kwargs) L767Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
loss_ctc, cer_ctc = None, None
loss_rich, acc_rich = None, None
stats = dict()
View full source →.encode(speech, speech_lengths, text, **kwargs) L817Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
lids = torch.LongTensor(
[
[
(
self.lid_int_dict[int(lid)]
if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict
else 0
)
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L918Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = ["wav_file_tmp_name"],
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
View full source →.post(timestamp) L1082Post.
timestamp — TODO. def post(self, timestamp):
"""Post.
Args:
timestamp: TODO.
"""
timestamp_new = []
words_new = []
prev_word = None
for i, t in enumerate(timestamp):
word, start, end = t
start = int(start * 1000)
end = int(end * 1000)
if word == "▁":
continue
if i == 0:
# timestamp_new.append([word, start, end])
timestamp_new.append([start, end])
words_new.append(word)
elif word.startswith("▁"):
word = word[1:]
timestamp_new.append([start, end])
words_new.append(word)
elif prev_word is not None and prev_word.isalpha() and prev_word.isascii() and word.isalpha() and word.isascii():
word = prev_word + word
timestamp_new[-1][1] = end
words_new[-1] = word
else:
# timestamp_new[-1][0] += word
timestamp_new.append([start, end])
View full source →.export(**kwargs) L1116Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
return results, meta_data
SenseVoiceSmall.from_pretrained(model, **kwargs)From pretrained.
model — Model instance or model name.**kwargs — Additional keyword arguments. def from_pretrained(model: str = None, **kwargs):
"""From pretrained.
Args:
model: Model instance or model name.
**kwargs: Additional keyword arguments.
"""
from funasr import AutoModel
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
return model, kwargs
SenseVoiceSmall.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
):
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
loss_ctc, cer_ctc = None, None
loss_rich, acc_rich = None, None
stats = dict()
View full source on GitHub →SenseVoiceSmall.encode(speech, speech_lengths, text, **kwargs)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
lids = torch.LongTensor(
[
[
(
self.lid_int_dict[int(lid)]
if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict
else 0
)
View full source on GitHub →SenseVoiceSmall.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = ["wav_file_tmp_name"],
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
View full source on GitHub →SenseVoiceSmall.post(timestamp)Post.
timestamp — TODO. def post(self, timestamp):
"""Post.
Args:
timestamp: TODO.
"""
timestamp_new = []
words_new = []
prev_word = None
for i, t in enumerate(timestamp):
word, start, end = t
start = int(start * 1000)
end = int(end * 1000)
if word == "▁":
continue
if i == 0:
# timestamp_new.append([word, start, end])
timestamp_new.append([start, end])
words_new.append(word)
elif word.startswith("▁"):
word = word[1:]
timestamp_new.append([start, end])
words_new.append(word)
elif prev_word is not None and prev_word.isalpha() and prev_word.isascii() and word.isalpha() and word.isascii():
word = prev_word + word
timestamp_new[-1][1] = end
words_new[-1] = word
else:
# timestamp_new[-1][0] += word
timestamp_new.append([start, end])
View full source on GitHub →SenseVoiceSmall.export(**kwargs)Export.
**kwargs — Additional keyword arguments. def export(self, **kwargs):
"""Export.
Args:
**kwargs: Additional keyword arguments.
"""
from .export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512
models = export_rebuild_model(model=self, **kwargs)
return models
return results, meta_data
Transducer (RNN-T) — Streaming ASR using encoder-predictor-joint architecture.Combines encoder (audio frames), predictor (text history), and joint network.
Supports beam search decoding. Suitable for low-latency streaming applications.
Output — {"key": str, "text": str}class Transducer(torch.nn.Module):
"""Transducer (RNN-T): Streaming ASR using encoder-predictor-joint architecture.
Combines encoder (audio frames), predictor (text history), and joint network.
Supports beam search decoding. Suitable for low-latency streaming applications.
Output: {"key": str, "text": str}
"""
def __init__(
self,
frontend: Optional[str] = None,
frontend_conf: Optional[Dict] = None,
specaug: Optional[str] = None,
specaug_conf: Optional[Dict] = None,
normalize: str = None,
normalize_conf: Optional[Dict] = None,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
decoder: str = None,
decoder_conf: Optional[Dict] = None,
joint_network: str = None,
joint_network_conf: Optional[Dict] = None,
transducer_weight: float = 1.0,
fastemit_lambda: float = 0.0,
auxiliary_ctc_weight: float = 0.0,
auxiliary_ctc_dropout_rate: float = 0.0,
auxiliary_lm_loss_weight: float = 0.0,
auxiliary_lm_loss_smoothing: float = 0.0,
input_size: int = 80,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L190Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if (
hasattr(self.encoder, "overlap_chunk_cls")
and self.encoder.overlap_chunk_cls is not None
):
encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(
encoder_out, encoder_out_lens, chunk_outs=None
)
View full source →.encode(speech, speech_lengths, **kwargs) L276Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
View full source →.init_beam_search(**kwargs) L455Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
# 1. Build ASR model
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
beam_search = BeamSearchTransducer(
self.decoder,
self.joint_network,
kwargs.get("beam_size", 2),
View full source →.inference(data_in, data_lengths, key, tokenizer, **kwargs) L493Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.**kwargs — Additional keyword arguments. def inference(
self,
data_in: list,
data_lengths: list = None,
key: list = None,
tokenizer=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
# if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source →Transducer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if (
hasattr(self.encoder, "overlap_chunk_cls")
and self.encoder.overlap_chunk_cls is not None
):
encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(
encoder_out, encoder_out_lens, chunk_outs=None
)
View full source on GitHub →Transducer.encode(speech, speech_lengths, **kwargs)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
View full source on GitHub →Transducer.init_beam_search(**kwargs)Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
# 1. Build ASR model
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
beam_search = BeamSearchTransducer(
self.decoder,
self.joint_network,
kwargs.get("beam_size", 2),
View full source on GitHub →Transducer.inference(data_in, data_lengths, key, tokenizer, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.**kwargs — Additional keyword arguments. def inference(
self,
data_in: list,
data_lengths: list = None,
key: list = None,
tokenizer=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
# if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
View full source on GitHub →Transformer — Base encoder-decoder ASR model.Standard CTC-attention hybrid architecture with:
Base class for Conformer, Branchformer, etc.
Output — {"key": str, "text": str}class Transformer(nn.Module):
"""Transformer: Base encoder-decoder ASR model.
Standard CTC-attention hybrid architecture with:
- Encoder (self-attention + position encoding)
- CTC branch for auxiliary loss
- Attention decoder for sequence generation
- Beam search with LM fusion
Base class for Conformer, Branchformer, etc.
Output: {"key": str, "text": str}
"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L173Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
View full source →.encode(speech, speech_lengths, **kwargs) L266Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
View full source →.init_beam_search(**kwargs) L368Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L418Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
View full source →Transformer.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
View full source on GitHub →Transformer.encode(speech, speech_lengths, **kwargs)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
View full source on GitHub →Transformer.init_beam_search(**kwargs)Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
View full source on GitHub →Transformer.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
View full source on GitHub →UniASR — Unified Streaming and Non-streaming ASR model.Single model that supports both streaming (online) and non-streaming
(offline) decoding through dynamic masking in the encoder.
Inherits Paraformer pipeline.
Output — {"key": str, "text": str}class UniASR(torch.nn.Module):
"""UniASR: Unified Streaming and Non-streaming ASR model.
Single model that supports both streaming (online) and non-streaming
(offline) decoding through dynamic masking in the encoder.
Inherits Paraformer pipeline.
Output: {"key": str, "text": str}
"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
encoder2: str = None,
encoder2_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
decoder2: str = None,
decoder2_conf: dict = None,
predictor: str = None,
predictor_conf: dict = None,
predictor_bias: int = 0,
predictor_weight: float = 0.0,
predictor2: str = None,
predictor2_conf: dict = None,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L235Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
decoding_ind = kwargs.get("decoding_ind", None)
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
# 1. Encoder
if self.enable_maas_finetune:
with torch.no_grad():
speech_raw, encoder_out, encoder_out_lens = self.encode(
speech, speech_lengths, ind=ind
)
View full source →.collect_feats(speech, speech_lengths, text, text_lengths) L354Collect feats.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.text — Text tensor or string input.text_lengths — Length of each text sample. def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Collect feats.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
text: Text tensor or string input.
text_lengths: Length of each text sample.
"""
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
.encode(speech, speech_lengths, **kwargs) L381Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, ) def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
ind = kwargs.get("ind", 0)
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
speech_raw = speech.clone().to(speech.device)
# 4. Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return speech_raw, encoder_out, encoder_out_lens
.encode2(encoder_out, encoder_out_lens, speech, speech_lengths, **kwargs) L411Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, ) def encode2(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
ind = kwargs.get("ind", 0)
encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
encoder_out,
encoder_out_lens,
chunk_outs=None,
)
# residual_input
encoder_out = torch.cat((speech, encoder_out_rm), dim=-1)
encoder_out_lens = encoder_out_lens_rm
if self.stride_conv is not None:
speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens)
if not self.encoder1_encoder2_joint_training:
speech = speech.detach()
speech_lengths = speech_lengths.detach()
# 4. Forward encoder
# feats: (Batch, Length, Dim)
View full source →.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) L449Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
encoder_out — (Batch, Length, Dim)encoder_out_lens — (Batch,)ys_pad — (Batch, Length)ys_pad_lens — (Batch,) def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
) # [batch, seqlen, dim]
batch_size = decoder_out.size(0)
decoder_num_class = decoder_out.size(2)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
decoder_out.view(-1, decoder_num_class),
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="none",
View full source →.batchify_nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, batch_size) L485Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
encoder_out — (Batch, Length, Dim)encoder_out_lens — (Batch,)ys_pad — (Batch, Length)ys_pad_lens — (Batch,)batch_size — int, samples each batch contain when computing nll, def batchify_nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
GPU memory usage
"""
total_num = encoder_out.size(0)
if total_num <= batch_size:
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
else:
nll = []
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
View full source →.calc_predictor_mask(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) L787Calc predictor mask.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad. def calc_predictor_mask(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor = None,
ys_pad_lens: torch.Tensor = None,
):
# ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
# ys_in_lens = ys_pad_lens + 1
"""Calc predictor mask.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
ys_out_pad, ys_in_lens = None, None
encoder_out_mask = sequence_mask(
encoder_out_lens,
maxlen=encoder_out.size(1),
dtype=encoder_out.dtype,
device=encoder_out.device,
)[:, None, :]
mask_chunk_predictor = None
if self.encoder.overlap_chunk_cls is not None:
mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
View full source →.calc_predictor_mask2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) L877Calc predictor mask2.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad. def calc_predictor_mask2(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor = None,
ys_pad_lens: torch.Tensor = None,
):
# ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
# ys_in_lens = ys_pad_lens + 1
"""Calc predictor mask2.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
ys_out_pad, ys_in_lens = None, None
encoder_out_mask = sequence_mask(
encoder_out_lens,
maxlen=encoder_out.size(1),
dtype=encoder_out.dtype,
device=encoder_out.device,
)[:, None, :]
mask_chunk_predictor = None
if self.encoder2.overlap_chunk_cls is not None:
mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
View full source →.init_beam_search(**kwargs) L967Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.uniasr.beam_search import BeamSearchScama
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
decoding_mode = kwargs.get("decoding_mode", "model1")
if decoding_mode == "model1":
decoder = self.decoder
else:
decoder = self.decoder2
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=decoder,
length_bonus=LengthBonus(len(token_list)),
)
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L1022Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
decoding_model = kwargs.get("decoding_model", "normal")
token_num_relax = kwargs.get("token_num_relax", 5)
if decoding_model == "fast":
decoding_ind = 0
decoding_mode = "model1"
elif decoding_model == "offline":
decoding_ind = 1
decoding_mode = "model2"
else:
decoding_ind = 0
View full source →UniASR.forward(speech, speech_lengths, text, text_lengths, **kwargs)Frontend + Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
decoding_ind = kwargs.get("decoding_ind", None)
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
# 1. Encoder
if self.enable_maas_finetune:
with torch.no_grad():
speech_raw, encoder_out, encoder_out_lens = self.encode(
speech, speech_lengths, ind=ind
)
View full source on GitHub →UniASR.collect_feats(speech, speech_lengths, text, text_lengths)Collect feats.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.text — Text tensor or string input.text_lengths — Length of each text sample. def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Collect feats.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
text: Text tensor or string input.
text_lengths: Length of each text sample.
"""
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
else:
# Generate dummy stats if extract_feats_in_collect_stats is False
logging.warning(
"Generating dummy stats for feats and feats_lengths, "
"because encoder_conf.extract_feats_in_collect_stats is "
f"{self.extract_feats_in_collect_stats}"
)
feats, feats_lengths = speech, speech_lengths
return {"feats": feats, "feats_lengths": feats_lengths}
UniASR.encode(speech, speech_lengths, **kwargs)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, ) def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
ind = kwargs.get("ind", 0)
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
speech_raw = speech.clone().to(speech.device)
# 4. Forward encoder
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
return speech_raw, encoder_out, encoder_out_lens
UniASR.encode2(encoder_out, encoder_out_lens, speech, speech_lengths, **kwargs)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, ) def encode2(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
):
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
ind = kwargs.get("ind", 0)
encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
encoder_out,
encoder_out_lens,
chunk_outs=None,
)
# residual_input
encoder_out = torch.cat((speech, encoder_out_rm), dim=-1)
encoder_out_lens = encoder_out_lens_rm
if self.stride_conv is not None:
speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens)
if not self.encoder1_encoder2_joint_training:
speech = speech.detach()
speech_lengths = speech_lengths.detach()
# 4. Forward encoder
# feats: (Batch, Length, Dim)
View full source on GitHub →UniASR.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
encoder_out — (Batch, Length, Dim)encoder_out_lens — (Batch,)ys_pad — (Batch, Length)ys_pad_lens — (Batch,) def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
) # [batch, seqlen, dim]
batch_size = decoder_out.size(0)
decoder_num_class = decoder_out.size(2)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
decoder_out.view(-1, decoder_num_class),
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="none",
View full source on GitHub →UniASR.batchify_nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, batch_size)Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
encoder_out — (Batch, Length, Dim)encoder_out_lens — (Batch,)ys_pad — (Batch, Length)ys_pad_lens — (Batch,)batch_size — int, samples each batch contain when computing nll, def batchify_nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
GPU memory usage
"""
total_num = encoder_out.size(0)
if total_num <= batch_size:
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
else:
nll = []
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
View full source on GitHub →UniASR.calc_predictor_mask(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)Calc predictor mask.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad. def calc_predictor_mask(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor = None,
ys_pad_lens: torch.Tensor = None,
):
# ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
# ys_in_lens = ys_pad_lens + 1
"""Calc predictor mask.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
ys_out_pad, ys_in_lens = None, None
encoder_out_mask = sequence_mask(
encoder_out_lens,
maxlen=encoder_out.size(1),
dtype=encoder_out.dtype,
device=encoder_out.device,
)[:, None, :]
mask_chunk_predictor = None
if self.encoder.overlap_chunk_cls is not None:
mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
View full source on GitHub →UniASR.calc_predictor_mask2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)Calc predictor mask2.
encoder_out — Encoder output tensor.encoder_out_lens — Encoder output lengths.ys_pad — TODO.ys_pad_lens — Lengths of ys_pad. def calc_predictor_mask2(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor = None,
ys_pad_lens: torch.Tensor = None,
):
# ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
# ys_in_lens = ys_pad_lens + 1
"""Calc predictor mask2.
Args:
encoder_out: Encoder output tensor.
encoder_out_lens: Encoder output lengths.
ys_pad: TODO.
ys_pad_lens: Lengths of ys_pad.
"""
ys_out_pad, ys_in_lens = None, None
encoder_out_mask = sequence_mask(
encoder_out_lens,
maxlen=encoder_out.size(1),
dtype=encoder_out.dtype,
device=encoder_out.device,
)[:, None, :]
mask_chunk_predictor = None
if self.encoder2.overlap_chunk_cls is not None:
mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(
None, device=encoder_out.device, batch_size=encoder_out.size(0)
)
View full source on GitHub →UniASR.init_beam_search(**kwargs)Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.uniasr.beam_search import BeamSearchScama
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
decoding_mode = kwargs.get("decoding_mode", "model1")
if decoding_mode == "model1":
decoder = self.decoder
else:
decoder = self.decoder2
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=decoder,
length_bonus=LengthBonus(len(token_list)),
)
View full source on GitHub →UniASR.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
decoding_model = kwargs.get("decoding_model", "normal")
token_num_relax = kwargs.get("token_num_relax", 5)
if decoding_model == "fast":
decoding_ind = 0
decoding_mode = "model1"
elif decoding_model == "offline":
decoding_ind = 1
decoding_mode = "model2"
else:
decoding_ind = 0
View full source on GitHub →Whisper — OpenAI Whisper model integration.Wraps Whisper for multilingual speech recognition and translation
within FunASR's AutoModel interface.
Supports — whisper-tiny through whisper-large-v3-turbo.Output — {"key": str, "text": str}class WhisperWarp(nn.Module):
"""Whisper: OpenAI Whisper model integration.
Wraps Whisper for multilingual speech recognition and translation
within FunASR's AutoModel interface.
Supports: whisper-tiny through whisper-large-v3-turbo.
Output: {"key": str, "text": str}
"""
def __init__(self, *args, **kwargs):
"""Initialize WhisperWarp.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__()
hub = kwargs.get("hub", "funasr")
if hub == "openai":
model_or_path = kwargs.get("model_path", "Whisper-large-v3")
if model_or_path.startswith("Whisper-"):
model_or_path = model_or_path.replace("Whisper-", "")
model = whisper.load_model(model_or_path)
else:
dims = kwargs.get("dims", {})
dims = whisper.model.ModelDimensions(**dims)
model = whisper.model.Whisper(dims=dims)
self.model = model
View full source on GitHub →.forward() L66Forward pass for training.
def forward(
self,
):
"""Forward pass for training."""
pass
.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L72Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
frontend = frontend_class(
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
)
self.frontend = frontend
else:
frontend = frontend if frontend is not None else self.frontend
View full source →WhisperWarp.forward()Forward pass for training.
def forward(
self,
):
"""Forward pass for training."""
pass
WhisperWarp.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
if frontend is None and not hasattr(self, "frontend"):
frontend_class = tables.frontend_classes.get("WhisperFrontend")
frontend = frontend_class(
n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
)
self.frontend = frontend
else:
frontend = frontend if frontend is not None else self.frontend
View full source on GitHub →Whisper — LID: Language Identification using OpenAI Whisper encoder.Detects spoken language from audio input. Supports 99 languages.
Uses Whisper encoder features with a classification head.
Output — {"key": str, "text": str (language code)}class OpenAIWhisperModel(nn.Module):
"""Whisper-LID: Language Identification using OpenAI Whisper encoder.
Detects spoken language from audio input. Supports 99 languages.
Uses Whisper encoder features with a classification head.
Output: {"key": str, "text": str (language code)}
"""
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
View full source on GitHub →.forward(speech, speech_lengths, text, text_lengths, **kwargs) L164Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
View full source →.encode(speech, speech_lengths, **kwargs) L257Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
View full source →.init_beam_search(**kwargs) L359Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L409Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
View full source →OpenAIWhisperModel.forward(speech, speech_lengths, text, text_lengths, **kwargs)Encoder + Decoder + Calc loss
speech — (Batch, Length, ...)speech_lengths — (Batch, )text — (Batch, Length)text_lengths — (Batch,) def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
View full source on GitHub →OpenAIWhisperModel.encode(speech, speech_lengths, **kwargs)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, )ind — int def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
ind: int
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech, speech_lengths = self.specaug(speech, speech_lengths)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
View full source on GitHub →OpenAIWhisperModel.init_beam_search(**kwargs)Init beam search.
**kwargs — Additional keyword arguments. def init_beam_search(
self,
**kwargs,
):
"""Init beam search.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus
# 1. Build ASR model
scorers = {}
if self.ctc != None:
ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
scorers.update(ctc=ctc)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
length_bonus=LengthBonus(len(token_list)),
)
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
View full source on GitHub →OpenAIWhisperModel.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
# init beamsearch
if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
View full source on GitHub →WhisperEncoder and EResNet based LID Model
class OpenAIWhisperLIDModel(nn.Module):
"""WhisperEncoder and EResNet based LID Model"""
def __init__(
self,
vocab_size: int,
specaug: str = None,
specaug_conf: dict = None,
encoder: str = None,
encoder_conf: dict = None,
lid_predictor: str = None,
lid_predictor_conf: dict = None,
proj_dim: int = None,
clip_frames: int = None,
random_clip: bool = False,
**kwargs,
):
"""Initialize OpenAIWhisperLIDModel.
Args:
vocab_size: Size/dimension parameter.
specaug: TODO.
specaug_conf: Configuration dict for specaug.
encoder: TODO.
encoder_conf: Configuration dict for encoder.
lid_predictor: TODO.
lid_predictor_conf: Configuration dict for lid_predictor.
proj_dim: Size/dimension parameter.
clip_frames: TODO.
random_clip: TODO.
View full source on GitHub →.forward(speech, speech_lengths, lid, lid_lengths) L587Forward pass for training.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.lid — TODO.lid_lengths — Lengths of lid. def forward(
self,
speech: torch.Tensor, # may be padding
speech_lengths: torch.Tensor, # actual length
lid: torch.Tensor, # lid label, (batch_size, 1)
lid_lengths: torch.Tensor,
):
"""Forward pass for training.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
lid: TODO.
lid_lengths: Lengths of lid.
"""
assert lid.shape[1] == 1
batch_size = speech.shape[0]
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# re-generate encoder_out
if self.clip_frames is None:
reduced_encoder_out = (
torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1])
.to(encoder_out.dtype)
.to(encoder_out.device)
)
for i, enc_length in enumerate(encoder_out_lens):
reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
else:
reduced_encoder_out = (
View full source →.encode(speech, speech_lengths) L655Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, ) def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech = speech.permute(0, 2, 1)
# suit for whisper padding
padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
speech = speech.permute(0, 2, 1)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
View full source →.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs) L694Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
View full source →OpenAIWhisperLIDModel.forward(speech, speech_lengths, lid, lid_lengths)Forward pass for training.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.lid — TODO.lid_lengths — Lengths of lid. def forward(
self,
speech: torch.Tensor, # may be padding
speech_lengths: torch.Tensor, # actual length
lid: torch.Tensor, # lid label, (batch_size, 1)
lid_lengths: torch.Tensor,
):
"""Forward pass for training.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
lid: TODO.
lid_lengths: Lengths of lid.
"""
assert lid.shape[1] == 1
batch_size = speech.shape[0]
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# re-generate encoder_out
if self.clip_frames is None:
reduced_encoder_out = (
torch.zeros(batch_size, encoder_out_lens.max(), encoder_out.shape[-1])
.to(encoder_out.dtype)
.to(encoder_out.device)
)
for i, enc_length in enumerate(encoder_out_lens):
reduced_encoder_out[i, :enc_length] = encoder_out[i, :enc_length]
else:
reduced_encoder_out = (
View full source on GitHub →OpenAIWhisperLIDModel.encode(speech, speech_lengths)Frontend + Encoder. Note that this method is used by asr_inference.py
speech — (Batch, Length, ...)speech_lengths — (Batch, ) def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# Data augmentation
if self.specaug is not None and self.training:
speech = speech.permute(0, 2, 1)
# suit for whisper padding
padded_speech_lengths = torch.ones_like(speech_lengths) * speech.shape[1]
speech, padded_speech_lengths = self.specaug(speech, padded_speech_lengths)
speech = speech.permute(0, 2, 1)
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
speech, speech_lengths = self.normalize(speech, speech_lengths)
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ctc=self.ctc)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
View full source on GitHub →OpenAIWhisperLIDModel.inference(data_in, data_lengths, key, tokenizer, frontend, **kwargs)Run inference on input data.
data_in — Input data (audio samples, file paths, or text).data_lengths — Lengths of each input sample in the batch.key — Sample identifiers.tokenizer — Tokenizer instance for text encoding/decoding.frontend — Audio frontend for feature extraction.**kwargs — Additional keyword arguments. def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
"""Run inference on input data.
Args:
data_in: Input data (audio samples, file paths, or text).
data_lengths: Lengths of each input sample in the batch.
key: Sample identifiers.
tokenizer: Tokenizer instance for text encoding/decoding.
frontend: Audio frontend for feature extraction.
**kwargs: Additional keyword arguments.
"""
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
meta_data = {}
if (
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
View full source on GitHub →Conventional frontend structure for ASR.
Stft — > WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVNclass DefaultFrontend(nn.Module):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
frontend_conf: Optional[dict] = None,
apply_stft: bool = True,
use_channel: int = None,
**kwargs,
):
"""Initialize DefaultFrontend.
Args:
fs: TODO.
n_fft: TODO.
win_length: TODO.
View full source on GitHub →.output_size() L106Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels
.forward(input, input_lengths) L110Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
if isinstance(input_lengths, list):
input_lengths = torch.tensor(input_lengths)
if input.dtype == torch.float64:
input = input.float()
# 1. Domain-conversion: e.g. Stft: time -> time-freq
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 3. [Multi channel case]: Select a channel
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
if self.use_channel is not None:
View full source →DefaultFrontend.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels
DefaultFrontend.forward(input, input_lengths)Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
if isinstance(input_lengths, list):
input_lengths = torch.tensor(input_lengths)
if input.dtype == torch.float64:
input = input.float()
# 1. Domain-conversion: e.g. Stft: time -> time-freq
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 3. [Multi channel case]: Select a channel
if input_stft.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
if self.use_channel is not None:
View full source on GitHub →Conventional frontend structure for ASR.
Stft — > WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVNclass MultiChannelFrontend(nn.Module):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = None,
frame_length: int = None,
frame_shift: int = None,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
frontend_conf: Optional[dict] = None,
apply_stft: bool = True,
use_channel: int = None,
lfr_m: int = 1,
lfr_n: int = 1,
cmvn_file: str = None,
mc: bool = True,
):
"""Initialize MultiChannelFrontend.
View full source on GitHub →.output_size() L290Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels
.forward(input, input_lengths) L294Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 3. [Multi channel case]: Select a channel(sa_asr)
if input_stft.dim() == 4 and not self.mc:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
if self.use_channel is not None:
input_stft = input_stft[:, :, self.use_channel, :]
else:
# Select 1ch randomly
View full source →MultiChannelFrontend.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels
MultiChannelFrontend.forward(input, input_lengths)Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 3. [Multi channel case]: Select a channel(sa_asr)
if input_stft.dim() == 4 and not self.mc:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
if self.use_channel is not None:
input_stft = input_stft[:, :, self.use_channel, :]
else:
# Select 1ch randomly
View full source on GitHub →transform(Y, dtype)Transform.
Y — TODO.dtype — TODO.def transform(Y, dtype=np.float32):
"""Transform.
Args:
Y: TODO.
dtype: TODO.
"""
Y = np.abs(Y)
n_fft = 2 * (Y.shape[1] - 1)
sr = 8000
n_mels = 23
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
Y = np.dot(Y**2, mel_basis.T)
Y = np.log10(np.maximum(Y, 1e-10))
mean = np.mean(Y, axis=0)
Y = Y - mean
return Y.astype(dtype)
subsample(Y, T, subsampling)Subsample.
Y — TODO.T — TODO.subsampling — TODO.def subsample(Y, T, subsampling=1):
"""Subsample.
Args:
Y: TODO.
T: TODO.
subsampling: TODO.
"""
Y_ss = Y[::subsampling]
T_ss = T[::subsampling]
return Y_ss, T_ss
splice(Y, context_size)Splice.
Y — TODO.context_size — Size/dimension parameter.def splice(Y, context_size=0):
"""Splice.
Args:
Y: TODO.
context_size: Size/dimension parameter.
"""
Y_pad = np.pad(Y, [(context_size, context_size), (0, 0)], "constant")
Y_spliced = np.lib.stride_tricks.as_strided(
np.ascontiguousarray(Y_pad),
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
(Y.itemsize * Y.shape[1], Y.itemsize),
writeable=False,
)
return Y_spliced
stft(data, frame_size, frame_shift)Stft.
data — TODO.frame_size — Size/dimension parameter.frame_shift — TODO.def stft(data, frame_size=1024, frame_shift=256):
"""Stft.
Args:
data: TODO.
frame_size: Size/dimension parameter.
frame_shift: TODO.
"""
fft_size = 1 << (frame_size - 1).bit_length()
if len(data) % frame_shift == 0:
return librosa.stft(data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift).T[
:-1
]
else:
return librosa.stft(data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift).T
No documentation yet.
class FusedFrontends(nn.Module):
def __init__(self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000):
"""Initialize FusedFrontends.
Args:
frontends: TODO.
align_method: TODO.
proj_dim: Size/dimension parameter.
fs: TODO.
"""
super().__init__()
self.align_method = align_method # fusing method : linear_projection only for now
self.proj_dim = proj_dim # dim of the projection done on each frontend
self.frontends = [] # list of the frontends to combine
for i, frontend in enumerate(frontends):
frontend_type = frontend["frontend_type"]
if frontend_type == "default":
n_mels, fs, n_fft, win_length, hop_length = (
frontend.get("n_mels", 80),
fs,
frontend.get("n_fft", 512),
frontend.get("win_length"),
frontend.get("hop_length", 128),
)
window, center, normalized, onesided = (
frontend.get("window", "hann"),
frontend.get("center", True),
frontend.get("normalized", False),
View full source on GitHub →.output_size() L106Output size.
def output_size(self) -> int:
"""Output size."""
return len(self.frontends) * self.proj_dim
.forward(input, input_lengths) L110Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# step 0 : get all frontends features
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
self.feats = []
for frontend in self.frontends:
with torch.no_grad():
input_feats, feats_lens = frontend.forward(input, input_lengths)
self.feats.append([input_feats, feats_lens])
if self.align_method == "linear_projection": # TODO(Dan): to add other align methods
# first step : projections
self.feats_proj = []
for i, frontend in enumerate(self.frontends):
input_feats = self.feats[i][0]
self.feats_proj.append(self.projection_layers[i](input_feats))
# 2nd step : reshape
self.feats_reshaped = []
for i, frontend in enumerate(self.frontends):
input_feats_proj = self.feats_proj[i]
bs, nf, dim = input_feats_proj.shape
View full source →FusedFrontends.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return len(self.frontends) * self.proj_dim
FusedFrontends.forward(input, input_lengths)Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# step 0 : get all frontends features
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
self.feats = []
for frontend in self.frontends:
with torch.no_grad():
input_feats, feats_lens = frontend.forward(input, input_lengths)
self.feats.append([input_feats, feats_lens])
if self.align_method == "linear_projection": # TODO(Dan): to add other align methods
# first step : projections
self.feats_proj = []
for i, frontend in enumerate(self.frontends):
input_feats = self.feats[i][0]
self.feats_proj.append(self.projection_layers[i](input_feats))
# 2nd step : reshape
self.feats_reshaped = []
for i, frontend in enumerate(self.frontends):
input_feats_proj = self.feats_proj[i]
bs, nf, dim = input_feats_proj.shape
View full source on GitHub →base_s3prl_setup(args)Base s3prl setup.
args — TODO.def base_s3prl_setup(args):
"""Base s3prl setup.
Args:
args: TODO.
"""
args.upstream_feature_selection = getattr(args, "upstream_feature_selection", None)
args.upstream_model_config = getattr(args, "upstream_model_config", None)
args.upstream_refresh = getattr(args, "upstream_refresh", False)
args.upstream_ckpt = getattr(args, "upstream_ckpt", None)
args.init_ckpt = getattr(args, "init_ckpt", None)
args.verbose = getattr(args, "verbose", False)
args.tile_factor = getattr(args, "tile_factor", 1)
return args
Speech Pretrained Representation frontend structure for ASR.
class S3prlFrontend(nn.Module):
"""Speech Pretrained Representation frontend structure for ASR."""
def __init__(
self,
fs: Union[int, str] = 16000,
frontend_conf: Optional[dict] = None,
download_dir: str = None,
multilayer_feature: bool = False,
):
"""Initialize S3prlFrontend.
Args:
fs: TODO.
frontend_conf: Configuration dict for frontend.
download_dir: TODO.
multilayer_feature: TODO.
"""
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
if download_dir is not None:
torch.hub.set_dir(download_dir)
self.multilayer_feature = multilayer_feature
self.upstream, self.featurizer = self._get_upstream(frontend_conf)
self.pretrained_params = copy.deepcopy(self.upstream.state_dict())
self.output_dim = self.featurizer.output_dim
self.frontend_type = "s3prl"
View full source on GitHub →.output_size() L127Output size.
def output_size(self) -> int:
"""Output size."""
return self.output_dim
.forward(input, input_lengths) L131Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
self.upstream.eval()
with torch.no_grad():
feats = self.upstream(wavs)
feats = self.featurizer(wavs, feats)
if self.args.tile_factor != 1:
feats = self._tile_representations(feats)
input_feats = pad_list(feats, 0.0)
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
# Saving CUDA Memory
del feats
return input_feats, feats_lens
.reload_pretrained_parameters() L157Reload pretrained parameters.
def reload_pretrained_parameters(self):
"""Reload pretrained parameters."""
self.upstream.load_state_dict(self.pretrained_params)
logging.info("Pretrained S3PRL frontend model parameters reloaded!")
S3prlFrontend.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return self.output_dim
S3prlFrontend.forward(input, input_lengths)Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
wavs = [wav[: input_lengths[i]] for i, wav in enumerate(input)]
self.upstream.eval()
with torch.no_grad():
feats = self.upstream(wavs)
feats = self.featurizer(wavs, feats)
if self.args.tile_factor != 1:
feats = self._tile_representations(feats)
input_feats = pad_list(feats, 0.0)
feats_lens = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)
# Saving CUDA Memory
del feats
return input_feats, feats_lens
S3prlFrontend.reload_pretrained_parameters()Reload pretrained parameters.
def reload_pretrained_parameters(self):
"""Reload pretrained parameters."""
self.upstream.load_state_dict(self.pretrained_params)
logging.info("Pretrained S3PRL frontend model parameters reloaded!")
load_cmvn(cmvn_file)Load cmvn.
cmvn_file — TODO.def load_cmvn(cmvn_file):
"""Load cmvn.
Args:
cmvn_file: TODO.
"""
with open(cmvn_file, "r", encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
apply_cmvn(inputs, cmvn)Apply CMVN with mvn data
def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
means = cmvn[0:1, :dim]
vars = cmvn[1:2, :dim]
inputs += means.to(device)
inputs *= vars.to(device)
return inputs.type(torch.float32)
apply_lfr(inputs, lfr_m, lfr_n)Apply lfr.
inputs — TODO.lfr_m — TODO.lfr_n — TODO.def apply_lfr(inputs, lfr_m, lfr_n):
"""Apply lfr.
Args:
inputs: TODO.
lfr_m: TODO.
lfr_n: TODO.
"""
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
feat_dim = inputs.shape[-1]
strides = (lfr_n * feat_dim, 1)
sizes = (T_lfr, lfr_m * feat_dim)
last_idx = (T - lfr_m) // lfr_n + 1
num_padding = lfr_m - (T - last_idx * lfr_n)
if num_padding > 0:
num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
LFR_outputs = inputs.as_strided(sizes, strides)
return LFR_outputs.clone().type(torch.float32)
Conventional frontend structure for ASR.
class WavFrontend(nn.Module):
"""Conventional frontend structure for ASR."""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = "hamming",
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
**kwargs,
):
"""Initialize WavFrontend.
Args:
cmvn_file: TODO.
fs: TODO.
window: TODO.
n_mels: TODO.
frame_length: TODO.
frame_shift: TODO.
filter_length_min: TODO.
View full source on GitHub →.output_size() L145Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels * self.lfr_m
.forward(input, input_lengths, **kwargs) L149Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input.**kwargs — Additional keyword arguments. def forward(
self,
input: torch.Tensor,
input_lengths,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
**kwargs: Additional keyword arguments.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
if self.upsacle_samples:
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(
waveform,
num_mel_bins=self.n_mels,
frame_length=min(self.frame_length,waveform_length/self.fs*1000),
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
View full source →.forward_fbank(input, input_lengths) L198Forward fbank.
input — Input audio/text data.input_lengths — Lengths of input. def forward_fbank(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward fbank.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(
waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs,
)
feat_length = mat.size(0)
feats.append(mat)
View full source →.forward_lfr_cmvn(input, input_lengths) L234Forward lfr cmvn.
input — Input audio/text data.input_lengths — Lengths of input. def forward_lfr_cmvn(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward lfr cmvn.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
mat = input[i, : input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
WavFrontend.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels * self.lfr_m
WavFrontend.forward(input, input_lengths, **kwargs)Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input.**kwargs — Additional keyword arguments. def forward(
self,
input: torch.Tensor,
input_lengths,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
**kwargs: Additional keyword arguments.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
if self.upsacle_samples:
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(
waveform,
num_mel_bins=self.n_mels,
frame_length=min(self.frame_length,waveform_length/self.fs*1000),
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
View full source on GitHub →WavFrontend.forward_fbank(input, input_lengths)Forward fbank.
input — Input audio/text data.input_lengths — Lengths of input. def forward_fbank(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward fbank.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(
waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs,
)
feat_length = mat.size(0)
feats.append(mat)
View full source on GitHub →WavFrontend.forward_lfr_cmvn(input, input_lengths)Forward lfr cmvn.
input — Input audio/text data.input_lengths — Lengths of input. def forward_lfr_cmvn(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward lfr cmvn.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
mat = input[i, : input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
Conventional frontend structure for streaming ASR/VAD.
class WavFrontendOnline(nn.Module):
"""Conventional frontend structure for streaming ASR/VAD."""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = "hamming",
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
**kwargs,
):
"""Initialize WavFrontendOnline.
Args:
cmvn_file: TODO.
fs: TODO.
window: TODO.
n_mels: TODO.
frame_length: TODO.
frame_shift: TODO.
filter_length_min: TODO.
View full source on GitHub →.output_size() L324Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels * self.lfr_m
.apply_cmvn(inputs, cmvn) L329Apply CMVN with mvn data
def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
.apply_lfr(inputs, lfr_m, lfr_n, is_final) L346Apply lfr with data
def apply_lfr(
inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Apply lfr with data
"""
LFR_inputs = []
# inputs = torch.vstack((inputs_lfr_cache, inputs))
T = inputs.shape[0] # include the right context
T_lfr = int(
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
feat_dim = inputs.shape[-1]
ori_inputs = inputs
strides = (lfr_n * feat_dim, 1)
sizes = (T_lfr, lfr_m * feat_dim)
last_idx = (T - lfr_m) // lfr_n + 1
num_padding = lfr_m - (T - last_idx * lfr_n)
if is_final:
if num_padding > 0:
num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
else:
if num_padding > 0:
sizes = (last_idx, lfr_m * feat_dim)
splice_idx = last_idx
splice_idx = min(T - 1, splice_idx * lfr_n)
LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
View full source →.compute_frame_num(sample_length, frame_sample_length, frame_shift_sample_length) L380Compute frame num.
sample_length — TODO.frame_sample_length — TODO.frame_shift_sample_length — TODO. def compute_frame_num(
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
) -> int:
"""Compute frame num.
Args:
sample_length: TODO.
frame_sample_length: TODO.
frame_shift_sample_length: TODO.
"""
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
.forward_fbank(input, input_lengths, cache, **kwargs) L393Forward fbank.
input — Input audio/text data.input_lengths — Lengths of input.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def forward_fbank(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
cache: dict = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward fbank.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
batch_size = input.size(0)
input = torch.cat((cache["input_cache"], input), dim=1)
frame_num = self.compute_frame_num(
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
)
# update self.in_cache
cache["input_cache"] = input[
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
]
waveforms = torch.empty(0)
feats_pad = torch.empty(0)
feats_lens = torch.empty(0)
View full source →.forward_lfr_cmvn(input, input_lengths, is_final, cache, **kwargs) L462Forward lfr cmvn.
input — Input audio/text data.input_lengths — Lengths of input.is_final — Whether this is the final chunk in streaming.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def forward_lfr_cmvn(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
is_final: bool = False,
cache: dict = None,
**kwargs,
):
"""Forward lfr cmvn.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
is_final: Whether this is the final chunk in streaming.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
batch_size = input.size(0)
feats = []
feats_lens = []
lfr_splice_frame_idxs = []
for i in range(batch_size):
mat = input[i, : input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = self.apply_lfr(
mat, self.lfr_m, self.lfr_n, is_final
View full source →.forward(input, input_lengths, **kwargs) L505Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input.**kwargs — Additional keyword arguments. def forward(self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs):
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
**kwargs: Additional keyword arguments.
"""
is_final = kwargs.get("is_final", False)
cache = kwargs.get("cache", {})
if len(cache) == 0:
self.init_cache(cache)
batch_size = input.shape[0]
assert (
batch_size == 1
), "we support to extract feature online only when the batch size is equal to 1 now"
waveforms, feats, feats_lengths = self.forward_fbank(
input, input_lengths, cache=cache
) # input shape: B T D
if feats.shape[0]:
cache["waveforms"] = torch.cat((cache["reserve_waveforms"], waveforms), dim=1)
if not cache["lfr_splice_cache"]: # 初始化splice_cache
for i in range(batch_size):
cache["lfr_splice_cache"].append(
feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1)
View full source →.init_cache(cache) L594Init cache.
cache — State cache dict for streaming inference. def init_cache(self, cache: dict = None):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
cache["reserve_waveforms"] = torch.empty(0)
cache["input_cache"] = torch.empty(0)
cache["lfr_splice_cache"] = []
cache["waveforms"] = None
cache["fbanks"] = None
cache["fbanks_lens"] = None
return cache
WavFrontendOnline.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels * self.lfr_m
WavFrontendOnline.apply_cmvn(inputs, cmvn)Apply CMVN with mvn data
def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
WavFrontendOnline.apply_lfr(inputs, lfr_m, lfr_n, is_final)Apply lfr with data
def apply_lfr(
inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Apply lfr with data
"""
LFR_inputs = []
# inputs = torch.vstack((inputs_lfr_cache, inputs))
T = inputs.shape[0] # include the right context
T_lfr = int(
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
feat_dim = inputs.shape[-1]
ori_inputs = inputs
strides = (lfr_n * feat_dim, 1)
sizes = (T_lfr, lfr_m * feat_dim)
last_idx = (T - lfr_m) // lfr_n + 1
num_padding = lfr_m - (T - last_idx * lfr_n)
if is_final:
if num_padding > 0:
num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
else:
if num_padding > 0:
sizes = (last_idx, lfr_m * feat_dim)
splice_idx = last_idx
splice_idx = min(T - 1, splice_idx * lfr_n)
LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
View full source on GitHub →WavFrontendOnline.compute_frame_num(sample_length, frame_sample_length, frame_shift_sample_length)Compute frame num.
sample_length — TODO.frame_sample_length — TODO.frame_shift_sample_length — TODO. def compute_frame_num(
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
) -> int:
"""Compute frame num.
Args:
sample_length: TODO.
frame_sample_length: TODO.
frame_shift_sample_length: TODO.
"""
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
WavFrontendOnline.forward_fbank(input, input_lengths, cache, **kwargs)Forward fbank.
input — Input audio/text data.input_lengths — Lengths of input.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def forward_fbank(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
cache: dict = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward fbank.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
batch_size = input.size(0)
input = torch.cat((cache["input_cache"], input), dim=1)
frame_num = self.compute_frame_num(
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
)
# update self.in_cache
cache["input_cache"] = input[
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
]
waveforms = torch.empty(0)
feats_pad = torch.empty(0)
feats_lens = torch.empty(0)
View full source on GitHub →WavFrontendOnline.forward_lfr_cmvn(input, input_lengths, is_final, cache, **kwargs)Forward lfr cmvn.
input — Input audio/text data.input_lengths — Lengths of input.is_final — Whether this is the final chunk in streaming.cache — State cache dict for streaming inference.**kwargs — Additional keyword arguments. def forward_lfr_cmvn(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
is_final: bool = False,
cache: dict = None,
**kwargs,
):
"""Forward lfr cmvn.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
is_final: Whether this is the final chunk in streaming.
cache: State cache dict for streaming inference.
**kwargs: Additional keyword arguments.
"""
if cache is None:
cache = {}
batch_size = input.size(0)
feats = []
feats_lens = []
lfr_splice_frame_idxs = []
for i in range(batch_size):
mat = input[i, : input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = self.apply_lfr(
mat, self.lfr_m, self.lfr_n, is_final
View full source on GitHub →WavFrontendOnline.forward(input, input_lengths, **kwargs)Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input.**kwargs — Additional keyword arguments. def forward(self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs):
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
**kwargs: Additional keyword arguments.
"""
is_final = kwargs.get("is_final", False)
cache = kwargs.get("cache", {})
if len(cache) == 0:
self.init_cache(cache)
batch_size = input.shape[0]
assert (
batch_size == 1
), "we support to extract feature online only when the batch size is equal to 1 now"
waveforms, feats, feats_lengths = self.forward_fbank(
input, input_lengths, cache=cache
) # input shape: B T D
if feats.shape[0]:
cache["waveforms"] = torch.cat((cache["reserve_waveforms"], waveforms), dim=1)
if not cache["lfr_splice_cache"]: # 初始化splice_cache
for i in range(batch_size):
cache["lfr_splice_cache"].append(
feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1)
View full source on GitHub →WavFrontendOnline.init_cache(cache)Init cache.
cache — State cache dict for streaming inference. def init_cache(self, cache: dict = None):
"""Init cache.
Args:
cache: State cache dict for streaming inference.
"""
if cache is None:
cache = {}
cache["reserve_waveforms"] = torch.empty(0)
cache["input_cache"] = torch.empty(0)
cache["lfr_splice_cache"] = []
cache["waveforms"] = None
cache["fbanks"] = None
cache["fbanks_lens"] = None
return cache
Conventional frontend structure for ASR.
class WavFrontendMel23(nn.Module):
"""Conventional frontend structure for ASR."""
def __init__(
self,
fs: int = 16000,
frame_length: int = 25,
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
**kwargs,
):
"""Initialize WavFrontendMel23.
Args:
fs: TODO.
frame_length: TODO.
frame_shift: TODO.
lfr_m: TODO.
lfr_n: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__()
self.fs = fs
self.frame_length = frame_length
self.frame_shift = frame_shift
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.n_mels = 23
View full source on GitHub →.output_size() L641Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels * (2 * self.lfr_m + 1)
.forward(input, input_lengths) L645Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform.numpy()
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
mat = eend_ola_feature.transform(mat)
mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
mat = mat[:: self.lfr_n]
mat = torch.from_numpy(mat)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
WavFrontendMel23.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels * (2 * self.lfr_m + 1)
WavFrontendMel23.forward(input, input_lengths)Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform.numpy()
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
mat = eend_ola_feature.transform(mat)
mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
mat = mat[:: self.lfr_n]
mat = torch.from_numpy(mat)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
return feats_pad, feats_lens
Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
URL — https://github.com/openai/whisperclass WhisperFrontend(nn.Module):
"""Speech Representation Using Encoder Outputs from OpenAI's Whisper Model:
URL: https://github.com/openai/whisper
"""
def __init__(
self,
fs: int = 16000,
whisper_model: str = None,
do_pad_trim: bool = True,
n_mels: int = 80,
permute: bool = False,
**kwargs,
):
"""Initialize WhisperFrontend.
Args:
fs: TODO.
whisper_model: Whisper Model instance.
do_pad_trim: TODO.
n_mels: TODO.
permute: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__()
assert fs == 16000
self.fs = fs
import whisper
from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
View full source on GitHub →.output_size() L67Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels
.log_mel_spectrogram(audio, ilens) L71Log mel spectrogram.
audio — TODO.ilens — TODO. def log_mel_spectrogram(
self,
audio: torch.Tensor,
ilens: torch.Tensor = None,
) -> torch.Tensor:
"""Log mel spectrogram.
Args:
audio: TODO.
ilens: TODO.
"""
window = torch.hann_window(self.win_length).to(audio.device)
stft = torch.stft(audio, self.n_fft, self.hop_length, window=window, return_complex=True)
# whisper deletes the last frame by default (Shih-Lun)
magnitudes = stft[..., :-1].abs() ** 2
if self.filters_path is not None:
filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
else:
filters = self.mel_filters(audio.device, self.n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if ilens is not None:
olens = ilens // self.hop_length
else:
olens = None
log_spec = torch.maximum(
View full source →.forward(input, input_lengths, **kwargs) L108Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input.**kwargs — Additional keyword arguments. def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
**kwargs: Additional keyword arguments.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
input = input.to(torch.float32)
for i in range(batch_size):
if self.do_pad_trim:
feat = self.pad_or_trim(input[i], self.pad_samples)
else:
feat = input[i]
feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
feats.append(feat[0])
feats_lens.append(feat_len)
feats_lens = torch.as_tensor(feats_lens)
if batch_size == 1:
feats_pad = feats[0][None, :, :]
else:
View full source →WhisperFrontend.output_size()Output size.
def output_size(self) -> int:
"""Output size."""
return self.n_mels
WhisperFrontend.log_mel_spectrogram(audio, ilens)Log mel spectrogram.
audio — TODO.ilens — TODO. def log_mel_spectrogram(
self,
audio: torch.Tensor,
ilens: torch.Tensor = None,
) -> torch.Tensor:
"""Log mel spectrogram.
Args:
audio: TODO.
ilens: TODO.
"""
window = torch.hann_window(self.win_length).to(audio.device)
stft = torch.stft(audio, self.n_fft, self.hop_length, window=window, return_complex=True)
# whisper deletes the last frame by default (Shih-Lun)
magnitudes = stft[..., :-1].abs() ** 2
if self.filters_path is not None:
filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
else:
filters = self.mel_filters(audio.device, self.n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if ilens is not None:
olens = ilens // self.hop_length
else:
olens = None
log_spec = torch.maximum(
View full source on GitHub →WhisperFrontend.forward(input, input_lengths, **kwargs)Forward pass for training.
input — Input audio/text data.input_lengths — Lengths of input.**kwargs — Additional keyword arguments. def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for training.
Args:
input: Input audio/text data.
input_lengths: Lengths of input.
**kwargs: Additional keyword arguments.
"""
batch_size = input.size(0)
feats = []
feats_lens = []
input = input.to(torch.float32)
for i in range(batch_size):
if self.do_pad_trim:
feat = self.pad_or_trim(input[i], self.pad_samples)
else:
feat = input[i]
feat, feat_len = self.log_mel_spectrogram(feat[None, :], input_lengths[0])
feats.append(feat[0])
feats_lens.append(feat_len)
feats_lens = torch.as_tensor(feats_lens)
if batch_size == 1:
feats_pad = feats[0][None, :, :]
else:
View full source on GitHub →Sliding Window.
Provides a sliding window over a batched continuous raw audio tensor.
Optionally, provides padding (Currently not implemented).
Combine this module with a pre-encoder compatible with raw audio data,
for example Sinc convolutions.
Known issues:
Output length is calculated incorrectly if audio shorter than win_length.
WARNING — trailing values are discarded - padding not implemented yet.There is currently no additional window function applied to input values.
class SlidingWindow(nn.Module):
"""Sliding Window.
Provides a sliding window over a batched continuous raw audio tensor.
Optionally, provides padding (Currently not implemented).
Combine this module with a pre-encoder compatible with raw audio data,
for example Sinc convolutions.
Known issues:
Output length is calculated incorrectly if audio shorter than win_length.
WARNING: trailing values are discarded - padding not implemented yet.
There is currently no additional window function applied to input values.
"""
def __init__(
self,
win_length: int = 400,
hop_length: int = 160,
channels: int = 1,
padding: int = None,
fs=None,
):
"""Initialize.
Args:
win_length: Length of frame.
hop_length: Relative starting point of next frame.
channels: Number of input channels.
padding: Padding (placeholder, currently not implemented).
fs: Sampling rate (placeholder for compatibility, not used).
"""
super().__init__()
self.fs = fs
View full source on GitHub →.forward(input, input_lengths) L47Apply a sliding window on the input.
input — Input (B, T, C*D) or (B, T*C*D), with D=C=1.input_lengths — Input lengths within batch.Tensor — Output with dimensions (B, T, C, D), with D=win_length.Tensor — Output lengths within batch. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply a sliding window on the input.
Args:
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
input_lengths: Input lengths within batch.
Returns:
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
Tensor: Output lengths within batch.
"""
input_size = input.size()
B = input_size[0]
T = input_size[1]
C = self.channels
D = self.win_length
# (B, T, C) --> (T, B, C)
continuous = input.view(B, T, C).permute(1, 0, 2)
windowed = continuous.unfold(0, D, self.hop_length)
# (T, B, C, D) --> (B, T, C, D)
output = windowed.permute(1, 0, 2, 3).contiguous()
# After unfold(), windowed lengths change:
output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
return output, output_lengths
.output_size() L72Return output length of feature dimension D, i.e. the window length.
def output_size(self) -> int:
"""Return output length of feature dimension D, i.e. the window length."""
return self.win_length
SlidingWindow.forward(input, input_lengths)Apply a sliding window on the input.
input — Input (B, T, C*D) or (B, T*C*D), with D=C=1.input_lengths — Input lengths within batch.Tensor — Output with dimensions (B, T, C, D), with D=win_length.Tensor — Output lengths within batch. def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply a sliding window on the input.
Args:
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
input_lengths: Input lengths within batch.
Returns:
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
Tensor: Output lengths within batch.
"""
input_size = input.size()
B = input_size[0]
T = input_size[1]
C = self.channels
D = self.win_length
# (B, T, C) --> (T, B, C)
continuous = input.view(B, T, C).permute(1, 0, 2)
windowed = continuous.unfold(0, D, self.hop_length)
# (T, B, C, D) --> (B, T, C, D)
output = windowed.permute(1, 0, 2, 3).contiguous()
# After unfold(), windowed lengths change:
output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
return output, output_lengths
SlidingWindow.output_size()Return output length of feature dimension D, i.e. the window length.
def output_size(self) -> int:
"""Return output length of feature dimension D, i.e. the window length."""
return self.win_length
build_tokenizer(token_type, bpemodel, non_linguistic_symbols, remove_non_linguistic_symbols, space_symbol, delimiter, g2p_type)A helper function to instantiate Tokenizer
def build_tokenizer(
token_type: str,
bpemodel: Union[Path, str, Iterable[str]] = None,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
space_symbol: str = "<space>",
delimiter: str = None,
g2p_type: str = None,
) -> AbsTokenizer:
"""A helper function to instantiate Tokenizer"""
if token_type == "bpe":
if bpemodel is None:
raise ValueError('bpemodel is required if token_type = "bpe"')
if remove_non_linguistic_symbols:
raise RuntimeError(
"remove_non_linguistic_symbols is not implemented for token_type=bpe"
)
return SentencepiecesTokenizer(bpemodel)
elif token_type == "word":
if remove_non_linguistic_symbols and non_linguistic_symbols is not None:
return WordTokenizer(
delimiter=delimiter,
non_linguistic_symbols=non_linguistic_symbols,
remove_non_linguistic_symbols=True,
)
else:
return WordTokenizer(delimiter=delimiter)
View full source on GitHub →No documentation yet.
class CharTokenizer(BaseTokenizer):
def __init__(
self,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
split_with_space: bool = False,
seg_dict: str = None,
**kwargs,
):
"""Initialize CharTokenizer.
Args:
non_linguistic_symbols: TODO.
space_symbol: TODO.
remove_non_linguistic_symbols: TODO.
split_with_space: TODO.
seg_dict: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self.space_symbol = space_symbol
if non_linguistic_symbols is None:
self.non_linguistic_symbols = set()
elif isinstance(non_linguistic_symbols, (Path, str)):
non_linguistic_symbols = Path(non_linguistic_symbols)
try:
with non_linguistic_symbols.open("r", encoding="utf-8") as f:
self.non_linguistic_symbols = set(line.rstrip() for line in f)
except FileNotFoundError:
View full source on GitHub →.text2tokens(line) L63Text2tokens.
line — TODO. def text2tokens(self, line: Union[str, list]) -> List[str]:
# if self.split_with_space:
"""Text2tokens.
Args:
line: TODO.
"""
if self.seg_dict is not None:
tokens = line.strip().split(" ")
tokens = seg_tokenize(tokens, self.seg_dict)
else:
tokens = []
while len(line) != 0:
for w in self.non_linguistic_symbols:
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
line = line[len(w) :]
break
else:
t = line[0]
if t == " ":
# t = "<space>"
line = line[1:]
continue
tokens.append(t)
line = line[1:]
return tokens
.tokens2text(tokens) L92Tokens2text.
tokens — TODO. def tokens2text(self, tokens: Iterable[str]) -> str:
"""Tokens2text.
Args:
tokens: TODO.
"""
tokens = [t if t != self.space_symbol else " " for t in tokens]
return "".join(tokens)
CharTokenizer.text2tokens(line)Text2tokens.
line — TODO. def text2tokens(self, line: Union[str, list]) -> List[str]:
# if self.split_with_space:
"""Text2tokens.
Args:
line: TODO.
"""
if self.seg_dict is not None:
tokens = line.strip().split(" ")
tokens = seg_tokenize(tokens, self.seg_dict)
else:
tokens = []
while len(line) != 0:
for w in self.non_linguistic_symbols:
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
line = line[len(w) :]
break
else:
t = line[0]
if t == " ":
# t = "<space>"
line = line[1:]
continue
tokens.append(t)
line = line[1:]
return tokens
CharTokenizer.tokens2text(tokens)Tokens2text.
tokens — TODO. def tokens2text(self, tokens: Iterable[str]) -> str:
"""Tokens2text.
Args:
tokens: TODO.
"""
tokens = [t if t != self.space_symbol else " " for t in tokens]
return "".join(tokens)
load_seg_dict(seg_dict_file)Load seg dict.
seg_dict_file — TODO.def load_seg_dict(seg_dict_file):
"""Load seg dict.
Args:
seg_dict_file: TODO.
"""
seg_dict = {}
assert isinstance(seg_dict_file, str)
with open(seg_dict_file, "r", encoding="utf8") as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
seg_tokenize(txt, seg_dict)Seg tokenize.
txt — TODO.seg_dict — TODO.def seg_tokenize(txt, seg_dict):
# pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
"""Seg tokenize.
Args:
txt: TODO.
seg_dict: TODO.
"""
pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
out_txt = ""
for word in txt:
word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
if pattern.match(word):
for char in word:
if char in seg_dict:
out_txt += seg_dict[char] + " "
else:
out_txt += "<unk>" + " "
else:
out_txt += "<unk>" + " "
return out_txt.strip().split()
Text cleaner.
>>> cleaner = TextCleaner("tacotron")
>>> cleaner("(Hello-World); & jr. & dr.")
'HELLO WORLD, AND JUNIOR AND DOCTOR'
class TextCleaner:
"""Text cleaner.
Examples:
>>> cleaner = TextCleaner("tacotron")
>>> cleaner("(Hello-World); & jr. & dr.")
'HELLO WORLD, AND JUNIOR AND DOCTOR'
"""
def __init__(self, cleaner_types: Collection[str] = None):
"""Initialize TextCleaner.
Args:
cleaner_types: TODO.
"""
if cleaner_types is None:
self.cleaner_types = []
elif isinstance(cleaner_types, str):
self.cleaner_types = [cleaner_types]
else:
self.cleaner_types = list(cleaner_types)
def __call__(self, text: str) -> str:
"""Internal: call .
Args:
text: Text tensor or string input.
"""
View full source on GitHub →HuggingfaceTokenizer(init_param_path, **kwargs)Huggingfacetokenizer.
init_param_path — TODO.**kwargs — Additional keyword arguments.def HuggingfaceTokenizer(init_param_path, **kwargs):
"""Huggingfacetokenizer.
Args:
init_param_path: TODO.
**kwargs: Additional keyword arguments.
"""
try:
from transformers import AutoTokenizer
except Exception as e:
raise ImportError(
"HuggingfaceTokenizer requires 'transformers'. "
"Please install it with: pip install -U transformers"
) from e
tokenizer = AutoTokenizer.from_pretrained(init_param_path)
return tokenizer
No documentation yet.
class KoreanCleaner:
@classmethod
def _normalize_numbers(cls, text):
"""Internal: normalize numbers.
Args:
text: Text tensor or string input.
"""
number_to_kor = {
"0": "영",
"1": "일",
"2": "이",
"3": "삼",
"4": "사",
"5": "오",
"6": "육",
"7": "칠",
"8": "팔",
"9": "구",
}
new_text = "".join(
number_to_kor[char] if char in number_to_kor.keys() else char for char in text
)
return new_text
@classmethod
def _normalize_english_text(cls, text):
"""Internal: normalize english text.
Args:
View full source on GitHub →.normalize_text(cls, text) L75Normalize text.
text — Text tensor or string input. def normalize_text(cls, text):
# stage 0 : text strip
"""Normalize text.
Args:
text: Text tensor or string input.
"""
text = text.strip()
# stage 1 : normalize numbers
text = cls._normalize_numbers(text)
# stage 2 : normalize english text
text = cls._normalize_english_text(text)
return text
KoreanCleaner.normalize_text(cls, text)Normalize text.
text — Text tensor or string input. def normalize_text(cls, text):
# stage 0 : text strip
"""Normalize text.
Args:
text: Text tensor or string input.
"""
text = text.strip()
# stage 1 : normalize numbers
text = cls._normalize_numbers(text)
# stage 2 : normalize english text
text = cls._normalize_english_text(text)
return text
split_by_space(text)Split by space.
text — Text tensor or string input.def split_by_space(text) -> List[str]:
"""Split by space.
Args:
text: Text tensor or string input.
"""
if " " in text:
text = text.replace(" ", " <space> ")
return [c.replace("<space>", " ") for c in text.split(" ")]
else:
return text.split(" ")
pyopenjtalk_g2p(text)Pyopenjtalk g2p.
text — Text tensor or string input.def pyopenjtalk_g2p(text) -> List[str]:
"""Pyopenjtalk g2p.
Args:
text: Text tensor or string input.
"""
import pyopenjtalk
# phones is a str object separated by space
phones = pyopenjtalk.g2p(text, kana=False)
phones = phones.split(" ")
return phones
pyopenjtalk_g2p_accent(text)Pyopenjtalk g2p accent.
text — Text tensor or string input.def pyopenjtalk_g2p_accent(text) -> List[str]:
"""Pyopenjtalk g2p accent.
Args:
text: Text tensor or string input.
"""
import pyopenjtalk
import re
phones = []
for labels in pyopenjtalk.run_frontend(text)[1]:
p = re.findall(r"\-(.*?)\+.*?\/A:([0-9\-]+).*?\/F:.*?_([0-9]+)", labels)
if len(p) == 1:
phones += [p[0][0], p[0][2], p[0][1]]
return phones
pyopenjtalk_g2p_accent_with_pause(text)Pyopenjtalk g2p accent with pause.
text — Text tensor or string input.def pyopenjtalk_g2p_accent_with_pause(text) -> List[str]:
"""Pyopenjtalk g2p accent with pause.
Args:
text: Text tensor or string input.
"""
import pyopenjtalk
import re
phones = []
for labels in pyopenjtalk.run_frontend(text)[1]:
if labels.split("-")[1].split("+")[0] == "pau":
phones += ["pau"]
continue
p = re.findall(r"\-(.*?)\+.*?\/A:([0-9\-]+).*?\/F:.*?_([0-9]+)", labels)
if len(p) == 1:
phones += [p[0][0], p[0][2], p[0][1]]
return phones
pyopenjtalk_g2p_kana(text)Pyopenjtalk g2p kana.
text — Text tensor or string input.def pyopenjtalk_g2p_kana(text) -> List[str]:
"""Pyopenjtalk g2p kana.
Args:
text: Text tensor or string input.
"""
import pyopenjtalk
kanas = pyopenjtalk.g2p(text, kana=True)
return list(kanas)
pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels)Extract phoneme + prosoody symbol sequence from input full-context labels.
The algorithm is based on `Prosodic features control by symbols as input of
sequence — to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks.text (str) — Input text.drop_unvoiced_vowels (bool) — whether to drop unvoiced vowels.List[str]: List of phoneme + prosody symbols.
>>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
>>> pyopenjtalk_g2p_prosody("こんにちは。")
['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
.. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic
modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104
def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> List[str]:
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
The algorithm is based on `Prosodic features control by symbols as input of
sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks.
Args:
text (str): Input text.
drop_unvoiced_vowels (bool): whether to drop unvoiced vowels.
Returns:
List[str]: List of phoneme + prosody symbols.
Examples:
>>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
>>> pyopenjtalk_g2p_prosody("こんにちは。")
['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
.. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic
modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104
"""
import pyopenjtalk
labels = pyopenjtalk.run_frontend(text)[1]
N = len(labels)
phones = []
for n in range(N):
lab_curr = labels[n]
View full source on GitHub →pypinyin_g2p(text)Pypinyin g2p.
text — Text tensor or string input.def pypinyin_g2p(text) -> List[str]:
"""Pypinyin g2p.
Args:
text: Text tensor or string input.
"""
from pypinyin import pinyin
from pypinyin import Style
phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)]
return phones
pypinyin_g2p_phone(text)Pypinyin g2p phone.
text — Text tensor or string input.def pypinyin_g2p_phone(text) -> List[str]:
"""Pypinyin g2p phone.
Args:
text: Text tensor or string input.
"""
from pypinyin import pinyin
from pypinyin import Style
from pypinyin.style._utils import get_finals
from pypinyin.style._utils import get_initials
phones = [
p
for phone in pinyin(text, style=Style.TONE3)
for p in [
get_initials(phone[0], strict=True),
get_finals(phone[0], strict=True),
]
if len(p) != 0
]
return phones
On behalf of g2p_en.G2p.
g2p_en.G2p isn't pickalable and it can't be copied to the other processes
via multiprocessing module.
As a workaround, g2p_en.G2p is instantiated upon calling this class.
class G2p_en:
"""On behalf of g2p_en.G2p.
g2p_en.G2p isn't pickalable and it can't be copied to the other processes
via multiprocessing module.
As a workaround, g2p_en.G2p is instantiated upon calling this class.
"""
def __init__(self, no_space: bool = False):
"""Initialize G2p_en.
Args:
no_space: TODO.
"""
self.no_space = no_space
self.g2p = None
def __call__(self, text) -> List[str]:
"""Internal: call .
Args:
text: Text tensor or string input.
"""
if self.g2p is None:
self.g2p = g2p_en.G2p()
phones = self.g2p(text)
if self.no_space:
# remove space which represents word serapater
View full source on GitHub →On behalf of g2pk.G2p.
g2pk.G2p isn't pickalable and it can't be copied to the other processes
via multiprocessing module.
As a workaround, g2pk.G2p is instantiated upon calling this class.
class G2pk:
"""On behalf of g2pk.G2p.
g2pk.G2p isn't pickalable and it can't be copied to the other processes
via multiprocessing module.
As a workaround, g2pk.G2p is instantiated upon calling this class.
"""
def __init__(self, descritive=False, group_vowels=False, to_syl=False, no_space=False):
"""Initialize G2pk.
Args:
descritive: TODO.
group_vowels: TODO.
to_syl: TODO.
no_space: TODO.
"""
self.descritive = descritive
self.group_vowels = group_vowels
self.to_syl = to_syl
self.no_space = no_space
self.g2p = None
def __call__(self, text) -> List[str]:
"""Internal: call .
Args:
text: Text tensor or string input.
"""
View full source on GitHub →No documentation yet.
class Jaso:
PUNC = "!'(),-.:;?"
SPACE = " "
JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])
VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + PUNC + SPACE
def __init__(self, space_symbol=" ", no_space=False):
"""Initialize Jaso.
Args:
space_symbol: TODO.
no_space: TODO.
"""
self.space_symbol = space_symbol
self.no_space = no_space
def _text_to_jaso(self, line: str) -> List[str]:
"""Internal: text to jaso.
Args:
line: TODO.
"""
jasos = list(jamo.hangul_to_jamo(line))
return jasos
def _remove_non_korean_characters(self, tokens):
View full source on GitHub →Phonemizer module for various languages.
This is wrapper module of https://github.com/bootphon/phonemizer.
You can define various g2p modules by specifying options for phonemizer.
See available options:
https://github.com/bootphon/phonemizer/blob/master/phonemizer/phonemize.py#L32
class Phonemizer:
"""Phonemizer module for various languages.
This is wrapper module of https://github.com/bootphon/phonemizer.
You can define various g2p modules by specifying options for phonemizer.
See available options:
https://github.com/bootphon/phonemizer/blob/master/phonemizer/phonemize.py#L32
"""
def __init__(
self,
backend,
word_separator: Optional[str] = None,
syllable_separator: Optional[str] = None,
phone_separator: Optional[str] = " ",
strip=False,
split_by_single_token: bool = False,
**phonemizer_kwargs,
):
# delayed import
"""Initialize Phonemizer.
Args:
backend: TODO.
word_separator: TODO.
syllable_separator: TODO.
phone_separator: TODO.
strip: TODO.
View full source on GitHub →No documentation yet.
class PhonemeTokenizer(AbsTokenizer):
def __init__(
self,
g2p_type: Union[None, str],
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
"""Initialize PhonemeTokenizer.
Args:
g2p_type: TODO.
non_linguistic_symbols: TODO.
space_symbol: TODO.
remove_non_linguistic_symbols: TODO.
"""
if g2p_type is None:
self.g2p = split_by_space
elif g2p_type == "g2p_en":
self.g2p = G2p_en(no_space=False)
elif g2p_type == "g2p_en_no_space":
self.g2p = G2p_en(no_space=True)
elif g2p_type == "pyopenjtalk":
self.g2p = pyopenjtalk_g2p
elif g2p_type == "pyopenjtalk_kana":
self.g2p = pyopenjtalk_g2p_kana
elif g2p_type == "pyopenjtalk_accent":
self.g2p = pyopenjtalk_g2p_accent
elif g2p_type == "pyopenjtalk_accent_with_pause":
self.g2p = pyopenjtalk_g2p_accent_with_pause
View full source on GitHub →.text2tokens(line) L614Text2tokens.
line — TODO. def text2tokens(self, line: str) -> List[str]:
"""Text2tokens.
Args:
line: TODO.
"""
tokens = []
while len(line) != 0:
for w in self.non_linguistic_symbols:
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
line = line[len(w) :]
break
else:
t = line[0]
tokens.append(t)
line = line[1:]
line = "".join(tokens)
tokens = self.g2p(line)
return tokens
.tokens2text(tokens) L637Tokens2text.
tokens — TODO. def tokens2text(self, tokens: Iterable[str]) -> str:
# phoneme type is not invertible
"""Tokens2text.
Args:
tokens: TODO.
"""
return "".join(tokens)
PhonemeTokenizer.text2tokens(line)Text2tokens.
line — TODO. def text2tokens(self, line: str) -> List[str]:
"""Text2tokens.
Args:
line: TODO.
"""
tokens = []
while len(line) != 0:
for w in self.non_linguistic_symbols:
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
line = line[len(w) :]
break
else:
t = line[0]
tokens.append(t)
line = line[1:]
line = "".join(tokens)
tokens = self.g2p(line)
return tokens
PhonemeTokenizer.tokens2text(tokens)Tokens2text.
tokens — TODO. def tokens2text(self, tokens: Iterable[str]) -> str:
# phoneme type is not invertible
"""Tokens2text.
Args:
tokens: TODO.
"""
return "".join(tokens)
No documentation yet.
class SentencepiecesTokenizer(BaseTokenizer):
def __init__(self, bpemodel: Union[Path, str], **kwargs):
"""Initialize SentencepiecesTokenizer.
Args:
bpemodel: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self.bpemodel = str(bpemodel)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()
# because it's not picklable and it may cause following error,
# "TypeError: can't pickle SwigPyObject objects",
# when giving it as argument of "multiprocessing.Process()".
self.sp = None
self._build_sentence_piece_processor()
def __repr__(self):
"""Internal: repr ."""
return f'{self.__class__.__name__}(model="{self.bpemodel}")'
def _build_sentence_piece_processor(self):
# Build SentencePieceProcessor lazily.
"""Internal: build sentence piece processor."""
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.bpemodel)
def text2tokens(self, line: str) -> List[str]:
View full source on GitHub →.text2tokens(line) L42Text2tokens.
line — TODO. def text2tokens(self, line: str) -> List[str]:
"""Text2tokens.
Args:
line: TODO.
"""
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
.tokens2text(tokens) L51Tokens2text.
tokens — TODO. def tokens2text(self, tokens: Iterable[str]) -> str:
"""Tokens2text.
Args:
tokens: TODO.
"""
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
.encode(line, **kwargs) L60Encode.
line — TODO.**kwargs — Additional keyword arguments. def encode(self, line: str, **kwargs) -> List[int]:
"""Encode.
Args:
line: TODO.
**kwargs: Additional keyword arguments.
"""
self._build_sentence_piece_processor()
return self.sp.EncodeAsIds(line)
.decode(line, **kwargs) L70Decode.
line — TODO.**kwargs — Additional keyword arguments. def decode(self, line: List[int], **kwargs):
"""Decode.
Args:
line: TODO.
**kwargs: Additional keyword arguments.
"""
self._build_sentence_piece_processor()
return self.sp.DecodeIds(line)
.get_vocab_size() L80Get vocab size.
def get_vocab_size(self):
"""Get vocab size."""
return self.sp.GetPieceSize()
.ids2tokens(*args, **kwargs) L84Ids2tokens.
*args — Variable positional arguments.**kwargs — Additional keyword arguments. def ids2tokens(self, *args, **kwargs):
"""Ids2tokens.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
return self.decode(*args, **kwargs)
.tokens2ids(*args, **kwargs) L93Tokens2ids.
*args — Variable positional arguments.**kwargs — Additional keyword arguments. def tokens2ids(self, *args, **kwargs):
"""Tokens2ids.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
return self.encode(*args, **kwargs)
SentencepiecesTokenizer.text2tokens(line)Text2tokens.
line — TODO. def text2tokens(self, line: str) -> List[str]:
"""Text2tokens.
Args:
line: TODO.
"""
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
SentencepiecesTokenizer.tokens2text(tokens)Tokens2text.
tokens — TODO. def tokens2text(self, tokens: Iterable[str]) -> str:
"""Tokens2text.
Args:
tokens: TODO.
"""
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
SentencepiecesTokenizer.encode(line, **kwargs)Encode.
line — TODO.**kwargs — Additional keyword arguments. def encode(self, line: str, **kwargs) -> List[int]:
"""Encode.
Args:
line: TODO.
**kwargs: Additional keyword arguments.
"""
self._build_sentence_piece_processor()
return self.sp.EncodeAsIds(line)
SentencepiecesTokenizer.decode(line, **kwargs)Decode.
line — TODO.**kwargs — Additional keyword arguments. def decode(self, line: List[int], **kwargs):
"""Decode.
Args:
line: TODO.
**kwargs: Additional keyword arguments.
"""
self._build_sentence_piece_processor()
return self.sp.DecodeIds(line)
SentencepiecesTokenizer.get_vocab_size()Get vocab size.
def get_vocab_size(self):
"""Get vocab size."""
return self.sp.GetPieceSize()
SentencepiecesTokenizer.ids2tokens(*args, **kwargs)Ids2tokens.
*args — Variable positional arguments.**kwargs — Additional keyword arguments. def ids2tokens(self, *args, **kwargs):
"""Ids2tokens.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
return self.decode(*args, **kwargs)
SentencepiecesTokenizer.tokens2ids(*args, **kwargs)Tokens2ids.
*args — Variable positional arguments.**kwargs — Additional keyword arguments. def tokens2ids(self, *args, **kwargs):
"""Tokens2ids.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
return self.encode(*args, **kwargs)
No documentation yet.
class TokenIDConverter:
def __init__(
self,
token_list: Union[Path, str, Iterable[str]],
unk_symbol: str = "<unk>",
):
"""Initialize TokenIDConverter.
Args:
token_list: TODO.
unk_symbol: TODO.
"""
if isinstance(token_list, (Path, str)):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
with token_list.open("r", encoding="utf-8") as f:
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
else:
self.token_list: List[str] = list(token_list)
self.token_list_repr = ""
for i, t in enumerate(self.token_list):
if i == 3:
break
self.token_list_repr += f"{t}, "
View full source on GitHub →.get_num_vocabulary_size() L53Get num vocabulary size.
def get_num_vocabulary_size(self) -> int:
"""Get num vocabulary size."""
return len(self.token_list)
.ids2tokens(integers) L57Ids2tokens.
integers — TODO. def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
"""Ids2tokens.
Args:
integers: TODO.
"""
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
.tokens2ids(tokens) L67Tokens2ids.
tokens — TODO. def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
"""Tokens2ids.
Args:
tokens: TODO.
"""
return [self.token2id.get(i, self.unk_id) for i in tokens]
TokenIDConverter.get_num_vocabulary_size()Get num vocabulary size.
def get_num_vocabulary_size(self) -> int:
"""Get num vocabulary size."""
return len(self.token_list)
TokenIDConverter.ids2tokens(integers)Ids2tokens.
integers — TODO. def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
"""Ids2tokens.
Args:
integers: TODO.
"""
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
TokenIDConverter.tokens2ids(tokens)Tokens2ids.
tokens — TODO. def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
"""Tokens2ids.
Args:
tokens: TODO.
"""
return [self.token2id.get(i, self.unk_id) for i in tokens]
WhisperTokenizer(**kwargs)Whispertokenizer.
**kwargs — Additional keyword arguments.def WhisperTokenizer(**kwargs):
"""Whispertokenizer.
Args:
**kwargs: Additional keyword arguments.
"""
try:
from whisper.tokenizer import get_tokenizer
except:
print("Notice: If you want to use whisper, please `pip install -U openai-whisper`")
language = kwargs.get("language", None)
task = kwargs.get("task", "transcribe")
is_multilingual = kwargs.get("is_multilingual", True)
num_languages = kwargs.get("num_languages", 99)
tokenizer = get_tokenizer(
multilingual=is_multilingual,
num_languages=num_languages,
language=language,
task=task,
)
return tokenizer
SenseVoiceTokenizer(**kwargs)Sensevoicetokenizer.
**kwargs — Additional keyword arguments.def SenseVoiceTokenizer(**kwargs):
"""Sensevoicetokenizer.
Args:
**kwargs: Additional keyword arguments.
"""
from funasr.models.sense_voice.whisper_lib.tokenizer import get_tokenizer
language = kwargs.get("language", None)
task = kwargs.get("task", None)
is_multilingual = kwargs.get("is_multilingual", True)
num_languages = kwargs.get("num_languages", 8749)
vocab_path = kwargs.get("vocab_path", None)
tokenizer = get_tokenizer(
multilingual=is_multilingual,
num_languages=num_languages,
language=language,
task=task,
vocab_path=vocab_path,
)
return tokenizer
No documentation yet.
class WordTokenizer(AbsTokenizer):
def __init__(
self,
delimiter: str = None,
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
remove_non_linguistic_symbols: bool = False,
):
"""Initialize WordTokenizer.
Args:
delimiter: TODO.
non_linguistic_symbols: TODO.
remove_non_linguistic_symbols: TODO.
"""
self.delimiter = delimiter
if not remove_non_linguistic_symbols and non_linguistic_symbols is not None:
warnings.warn(
"non_linguistic_symbols is only used " "when remove_non_linguistic_symbols = True"
)
if non_linguistic_symbols is None:
self.non_linguistic_symbols = set()
elif isinstance(non_linguistic_symbols, (Path, str)):
non_linguistic_symbols = Path(non_linguistic_symbols)
try:
with non_linguistic_symbols.open("r", encoding="utf-8") as f:
self.non_linguistic_symbols = set(line.rstrip() for line in f)
except FileNotFoundError:
warnings.warn(f"{non_linguistic_symbols} doesn't exist.")
View full source on GitHub →.text2tokens(line) L50Text2tokens.
line — TODO. def text2tokens(self, line: str) -> List[str]:
"""Text2tokens.
Args:
line: TODO.
"""
tokens = []
for t in line.split(self.delimiter):
if self.remove_non_linguistic_symbols and t in self.non_linguistic_symbols:
continue
tokens.append(t)
return tokens
.tokens2text(tokens) L63Tokens2text.
tokens — TODO. def tokens2text(self, tokens: Iterable[str]) -> str:
"""Tokens2text.
Args:
tokens: TODO.
"""
if self.delimiter is None:
delimiter = " "
else:
delimiter = self.delimiter
return delimiter.join(tokens)
WordTokenizer.text2tokens(line)Text2tokens.
line — TODO. def text2tokens(self, line: str) -> List[str]:
"""Text2tokens.
Args:
line: TODO.
"""
tokens = []
for t in line.split(self.delimiter):
if self.remove_non_linguistic_symbols and t in self.non_linguistic_symbols:
continue
tokens.append(t)
return tokens
WordTokenizer.tokens2text(tokens)Tokens2text.
tokens — TODO. def tokens2text(self, tokens: Iterable[str]) -> str:
"""Tokens2text.
Args:
tokens: TODO.
"""
if self.delimiter is None:
delimiter = " "
else:
delimiter = self.delimiter
return delimiter.join(tokens)
No documentation yet.
class thread_wrapper(threading.Thread):
def __init__(self, func, args=()):
"""Initialize thread_wrapper.
Args:
func: TODO.
args: TODO.
"""
super(thread_wrapper, self).__init__()
self.func = func
self.args = args
self.result = []
def run(self):
"""Run."""
self.result = self.func(*self.args)
def get_result(self):
"""Get result."""
try:
return self.result
except Exception:
return None
.get_result() L31Get result.
def get_result(self):
"""Get result."""
try:
return self.result
except Exception:
return None
thread_wrapper.run()Run.
def run(self):
"""Run."""
self.result = self.func(*self.args)
thread_wrapper.get_result()Get result.
def get_result(self):
"""Get result."""
try:
return self.result
except Exception:
return None
space_mixed_label(input_str)Space mixed label.
input_str — TODO.def space_mixed_label(input_str):
"""Space mixed label.
Args:
input_str: TODO.
"""
splits = split_mixed_label(input_str)
space_str = ''.join(f'{sub} ' for sub in splits)
return space_str.strip()
read_lists(list_file)Read lists.
list_file — TODO.def read_lists(list_file):
"""Read lists.
Args:
list_file: TODO.
"""
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
if line.strip() != '':
lists.append(line.strip())
return lists
make_pair(wav_lists, trans_lists)Make pair.
wav_lists — TODO.trans_lists — TODO.def make_pair(wav_lists, trans_lists):
"""Make pair.
Args:
wav_lists: TODO.
trans_lists: TODO.
"""
logging.info('make pair for wav-trans list')
trans_table = {}
for line in trans_lists:
arr = line.strip().replace('\t', ' ').split()
if len(arr) < 2:
logging.debug('invalid line in trans file: {}'.format(
line.strip()))
continue
trans_table[arr[0]] = line.replace(arr[0],'').strip()
lists = []
for line in wav_lists:
arr = line.strip().replace('\t', ' ').split()
if len(arr) == 2 and arr[0] in trans_table:
lists.append(
dict(key=arr[0],
txt=trans_table[arr[0]],
wav=arr[1],
sample_rate=16000))
else:
logging.debug("can't find corresponding trans for key: {}".format(
View full source on GitHub →count_duration(tid, data_lists)Count duration.
tid — TODO.data_lists — TODO.def count_duration(tid, data_lists):
"""Count duration.
Args:
tid: TODO.
data_lists: TODO.
"""
results = []
for obj in data_lists:
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
try:
rate, waveform = kaldiio.load_mat(wav_file)
waveform = torch.tensor(waveform, dtype=torch.float32)
waveform = waveform.unsqueeze(0)
frames = len(waveform[0])
duration = frames / float(rate)
except:
logging.info(f'load file failed: {wav_file}')
duration = 0.0
obj['duration'] = duration
results.append(obj)
View full source on GitHub →load_data_and_score(keywords_list, data_file, trans_file, score_file)Load data and score.
keywords_list — TODO.data_file — TODO.trans_file — TODO.score_file — TODO.def load_data_and_score(keywords_list, data_file, trans_file, score_file):
# score_table: {uttid: [keywordlist]}
"""Load data and score.
Args:
keywords_list: TODO.
data_file: TODO.
trans_file: TODO.
score_file: TODO.
"""
score_table = {}
with open(score_file, 'r', encoding='utf8') as fin:
# read score file and store in table
for line in fin:
arr = line.strip().split()
key = arr[0]
is_detected = arr[1]
if is_detected == 'detected':
if key not in score_table:
score_table.update(
{key: {
'kw': space_mixed_label(arr[2]),
'confi': float(arr[3])
}})
else:
if key not in score_table:
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
wav_lists = read_lists(data_file)
trans_lists = read_lists(trans_file)
View full source on GitHub →Writer class to create kaldi like data directory.
>>> with DatadirWriter("output") as writer:
... # output/sub.txt is created here
... subwriter = writer["sub.txt"]
... # Write "uttidA some/where/a.wav"
... subwriter["uttidA"] = "some/where/a.wav"
... subwriter["uttidB"] = "some/where/b.wav"
class DatadirWriter:
"""Writer class to create kaldi like data directory.
Examples:
>>> with DatadirWriter("output") as writer:
... # output/sub.txt is created here
... subwriter = writer["sub.txt"]
... # Write "uttidA some/where/a.wav"
... subwriter["uttidA"] = "some/where/a.wav"
... subwriter["uttidB"] = "some/where/b.wav"
"""
def __init__(self, p: Union[Path, str]):
"""Initialize DatadirWriter.
Args:
p: TODO.
"""
self.path = Path(p)
self.chilidren = {}
self.fd = None
self.has_children = False
self.keys = set()
def __enter__(self):
"""Internal: enter ."""
return self
def __getitem__(self, key: str) -> "DatadirWriter":
View full source on GitHub →.close() L82Close.
def close(self):
"""Close."""
if self.has_children:
prev_child = None
for child in self.chilidren.values():
child.close()
if prev_child is not None and prev_child.keys != child.keys:
warnings.warn(
f"Ids are mismatching between " f"{prev_child.path} and {child.path}"
)
prev_child = child
elif self.fd is not None:
self.fd.close()
DatadirWriter.close()Close.
def close(self):
"""Close."""
if self.has_children:
prev_child = None
for child in self.chilidren.values():
child.close()
if prev_child is not None and prev_child.keys != child.keys:
warnings.warn(
f"Ids are mismatching between " f"{prev_child.path} and {child.path}"
)
prev_child = child
elif self.fd is not None:
self.fd.close()
load_module_from_path(file_path)从给定的文件路径动态加载模块。
:param file_path: 模块文件的绝对路径。
:return: 加载的模块
def load_module_from_path(file_path):
"""
从给定的文件路径动态加载模块。
:param file_path: 模块文件的绝对路径。
:return: 加载的模块
"""
module_name = file_path.split("/")[-1].replace(".py", "")
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
import_module_from_path(file_path)Import module from path.
file_path — TODO.def import_module_from_path(file_path: str):
"""Import module from path.
Args:
file_path: TODO.
"""
if file_path.startswith("http"):
from funasr.download.file import download_from_url
file_path = download_from_url(file_path)
file_dir = os.path.dirname(file_path)
# file_name = os.path.basename(file_path)
module_name = file_path.split("/")[-1].replace(".py", "")
if len(file_dir) < 1:
file_dir = "./"
sys.path.append(file_dir)
try:
importlib.import_module(module_name)
print(f"Loading remote code successfully: {file_path}")
except Exception as e:
print(f"Loading remote code failed: {file_path}, {e}")
export(model, data_in, quantize, opset_version, type, **kwargs)Export.
model — Model instance or model name.data_in — Input data (audio samples, file paths, or text).quantize — TODO.opset_version — TODO.type — TODO.**kwargs — Additional keyword arguments.def export(
model, data_in=None, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
):
"""Export.
Args:
model: Model instance or model name.
data_in: Input data (audio samples, file paths, or text).
quantize: TODO.
opset_version: TODO.
type: TODO.
**kwargs: Additional keyword arguments.
"""
model_scripts = model.export(**kwargs)
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
os.makedirs(export_dir, exist_ok=True)
if not isinstance(model_scripts, (list, tuple)):
model_scripts = (model_scripts,)
for m in model_scripts:
m.eval()
if type == "onnx":
_onnx(
m,
data_in=data_in,
quantize=quantize,
opset_version=opset_version,
export_dir=export_dir,
**kwargs,
)
View full source on GitHub →install_requirements(requirements_path)Install requirements.
requirements_path — TODO.def install_requirements(requirements_path):
"""Install requirements.
Args:
requirements_path: TODO.
"""
try:
result = pip_install_r(requirements_path)
# check status
if result.returncode == 0:
print("install model requirements successfully")
return True
else:
print("fail to install model requirements! ")
print("error", result.stderr)
return False
except Exception as e:
result = pip_install_r(requirements_path)
# check status
if result.returncode == 0:
print("install model requirements successfully")
return True
else:
print("fail to install model requirements! ")
print("error", result.stderr)
return False
pip_install_r(requirements_path)Pip install r.
requirements_path — TODO.def pip_install_r(requirements_path):
"""Pip install r.
Args:
requirements_path: TODO.
"""
cmd = []
if shutil.which("pip") is not None:
cmd = ["pip"]
elif shutil.which("uv") is not None:
cmd = ["uv", "pip"]
else:
raise RuntimeError("pip not found, failed to install model requirements")
cmd += ["install", "-r", requirements_path]
return subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
split_mixed_label(input_str)Split mixed label.
input_str — TODO.def split_mixed_label(input_str):
"""Split mixed label.
Args:
input_str: TODO.
"""
tokens = []
s = input_str.lower()
while len(s) > 0:
match = re.match(r'[A-Za-z!?,<>()\']+', s)
if match is not None:
word = match.group(0)
else:
word = s[0:1]
tokens.append(word)
s = s.replace(word, '', 1).strip(' ')
return tokens
query_token_set(txt, symbol_table, lexicon_table)Query token set.
txt — TODO.symbol_table — TODO.lexicon_table — TODO.def query_token_set(txt, symbol_table, lexicon_table):
"""Query token set.
Args:
txt: TODO.
symbol_table: TODO.
lexicon_table: TODO.
"""
tokens_str = tuple()
tokens_idx = tuple()
if txt in symbol_table:
tokens_str = tokens_str + (txt, )
tokens_idx = tokens_idx + (symbol_table[txt], )
return tokens_str, tokens_idx
parts = split_mixed_label(txt)
for part in parts:
if part == '!sil' or part == '(sil)' or part == '<sil>':
tokens_str = tokens_str + ('!sil', )
elif part == '<blank>' or part == '<blank>':
tokens_str = tokens_str + ('<blank>', )
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
tokens_str = tokens_str + ('<unk>', )
elif part in symbol_table:
tokens_str = tokens_str + (part, )
elif part in lexicon_table:
for ch in lexicon_table[part]:
tokens_str = tokens_str + (ch, )
else:
View full source on GitHub →Decoder interface wrapper for CTCPrefixDecode.
class KwsCtcPrefixDecoder():
"""Decoder interface wrapper for CTCPrefixDecode."""
def __init__(
self,
ctc: torch.nn.Module,
keywords: str,
token_list: list,
seg_dict: dict,
):
"""Initialize class.
Args:
ctc (torch.nn.Module): The CTC implementation.
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
"""
self.ctc = ctc
self.token_list = token_list
token_table = {}
for token in token_list:
token_table[token] = token_list.index(token)
self.keywords_idxset = {0}
self.keywords_token = {}
self.keywords_str = keywords
keywords_list = self.keywords_str.strip().replace(' ', '').split(',')
for keyword in keywords_list:
strs, indexs = query_token_set(keyword, token_table, seg_dict)
View full source on GitHub →.beam_search(logits, logits_lengths, keywords_tokenset, score_beam_size, path_beam_size) L125CTC prefix beam search inner implementation
logits (torch.Tensor) — (1, max_len, vocab_size)logits_lengths (torch.Tensor) — (1, )keywords_tokenset (set) — token set for filtering scorescore_beam_size (int) — beam size for scorepath_beam_size (int) — beam size for pathList[List[int]]: nbest results
def beam_search(
self,
logits: torch.Tensor,
logits_lengths: torch.Tensor,
keywords_tokenset: set = None,
score_beam_size: int = 3,
path_beam_size: int = 20,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
logits (torch.Tensor): (1, max_len, vocab_size)
logits_lengths (torch.Tensor): (1, )
keywords_tokenset (set): token set for filtering score
score_beam_size (int): beam size for score
path_beam_size (int): beam size for path
Returns:
List[List[int]]: nbest results
"""
maxlen = logits.size(0)
ctc_probs = logits
cur_hyps = [(tuple(), (1.0, 0.0, []))]
# CTC beam search step by step
for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
View full source →.is_sublist(main_list, check_list) L232Is sublist.
main_list — TODO.check_list — TODO. def is_sublist(self, main_list, check_list):
"""Is sublist.
Args:
main_list: TODO.
check_list: TODO.
"""
if len(main_list) < len(check_list):
return -1
if len(main_list) == len(check_list):
return 0 if main_list == check_list else -1
for i in range(len(main_list) - len(check_list) + 1):
if main_list[i] == check_list[0]:
for j in range(len(check_list)):
if main_list[i + j] != check_list[j]:
break
else:
return i
else:
return -1
.decode(x) L295Get an initial state for decoding.
x (torch.Tensor) — The encoded feature tensorReturns — decode result def decode(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: decode result
"""
raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
xlen = torch.tensor([raw_logp.size(1)])
return self._decode_inside(raw_logp, xlen)
KwsCtcPrefixDecoder.beam_search(logits, logits_lengths, keywords_tokenset, score_beam_size, path_beam_size)CTC prefix beam search inner implementation
logits (torch.Tensor) — (1, max_len, vocab_size)logits_lengths (torch.Tensor) — (1, )keywords_tokenset (set) — token set for filtering scorescore_beam_size (int) — beam size for scorepath_beam_size (int) — beam size for pathList[List[int]]: nbest results
def beam_search(
self,
logits: torch.Tensor,
logits_lengths: torch.Tensor,
keywords_tokenset: set = None,
score_beam_size: int = 3,
path_beam_size: int = 20,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
logits (torch.Tensor): (1, max_len, vocab_size)
logits_lengths (torch.Tensor): (1, )
keywords_tokenset (set): token set for filtering score
score_beam_size (int): beam size for score
path_beam_size (int): beam size for path
Returns:
List[List[int]]: nbest results
"""
maxlen = logits.size(0)
ctc_probs = logits
cur_hyps = [(tuple(), (1.0, 0.0, []))]
# CTC beam search step by step
for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
View full source on GitHub →KwsCtcPrefixDecoder.is_sublist(main_list, check_list)Is sublist.
main_list — TODO.check_list — TODO. def is_sublist(self, main_list, check_list):
"""Is sublist.
Args:
main_list: TODO.
check_list: TODO.
"""
if len(main_list) < len(check_list):
return -1
if len(main_list) == len(check_list):
return 0 if main_list == check_list else -1
for i in range(len(main_list) - len(check_list) + 1):
if main_list[i] == check_list[0]:
for j in range(len(check_list)):
if main_list[i + j] != check_list[j]:
break
else:
return i
else:
return -1
KwsCtcPrefixDecoder.decode(x)Get an initial state for decoding.
x (torch.Tensor) — The encoded feature tensorReturns — decode result def decode(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: decode result
"""
raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
xlen = torch.tensor([raw_logp.size(1)])
return self._decode_inside(raw_logp, xlen)
is_ffmpeg_installed()Is ffmpeg installed.
def is_ffmpeg_installed():
"""Is ffmpeg installed."""
try:
output = subprocess.check_output(["ffmpeg", "-version"], stderr=subprocess.STDOUT)
return "ffmpeg version" in output.decode("utf-8")
except (subprocess.CalledProcessError, FileNotFoundError):
return False
load_audio_text_image_video(data_or_path_or_list, fs, audio_fs, data_type, tokenizer, **kwargs)Load audio/text/image/video data from various input formats.
data_or_path_or_list — File path, URL, numpy array, torch Tensor, bytes, or list.fs (int) — Target sample rate (default 16000).audio_fs (int) — Source audio sample rate.data_type (str) — Input type ("sound", "text", "fbank").torch.Tensor or list: Loaded and resampled audio tensor(s).
def load_audio_text_image_video(
data_or_path_or_list,
fs: int = 16000,
audio_fs: int = 16000,
data_type="sound",
tokenizer=None,
**kwargs,
):
"""Load audio/text/image/video data from various input formats.
Args:
data_or_path_or_list: File path, URL, numpy array, torch Tensor, bytes, or list.
fs (int): Target sample rate (default 16000).
audio_fs (int): Source audio sample rate.
data_type (str): Input type ("sound", "text", "fbank").
Returns:
torch.Tensor or list: Loaded and resampled audio tensor(s).
"""
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
data_types = [data_type] * len(data_or_path_or_list)
data_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, data_or_path_or_list_i) in enumerate(
zip(data_types, data_or_path_or_list)
):
for j, (data_type_j, data_or_path_or_list_j) in enumerate(
zip(data_type_i, data_or_path_or_list_i)
):
data_or_path_or_list_j = load_audio_text_image_video(
View full source on GitHub →load_bytes(input)Convert audio bytes to numpy array.
input (bytes) — Raw audio bytes.numpy.ndarray — Decoded audio samples.def load_bytes(input):
"""Convert audio bytes to numpy array.
Args:
input (bytes): Raw audio bytes.
Returns:
numpy.ndarray: Decoded audio samples.
"""
# Only run the (expensive) frame-rate validation when the payload is an
# actual audio container (WAV, MP3, OGG, …). Raw PCM buffers have no
# recognisable header and would cause pydub to spend ~200 ms before
# raising an exception that is then silently swallowed anyway.
if _is_audio_container(input):
try:
input = validate_frame_rate(input)
except Exception:
pass
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in "iu":
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype("float32")
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
View full source on GitHub →validate_frame_rate(input, fs)Validate frame rate.
input — Input audio/text data.fs — TODO.def validate_frame_rate(
input,
fs: int = 16000,
):
# 将文件读取为字节流
"""Validate frame rate.
Args:
input: Input audio/text data.
fs: TODO.
"""
byte_data = BytesIO(input)
# 使用 pydub 加载音频
try:
audio = AudioSegment.from_file(byte_data)
except:
raise RuntimeError(
"You are decoding the pcm data, please install pydub first. via `pip install pydub`."
)
# 确保采样率为 16000 Hz
if audio.frame_rate != fs:
audio = audio.set_frame_rate(fs)
# 将重新采样后的音频导出为字节流
output = BytesIO()
audio.export(output, format="wav")
output.seek(0)
View full source on GitHub →extract_fbank(data, data_len, data_type, frontend, **kwargs)Extract filter-bank features from audio data.
data — Audio samples (list of numpy arrays or tensors).data_len — Lengths of each sample.data_type (str) — Input type ("sound", "fbank").frontend — Frontend instance for feature extraction.tuple — (features_tensor, feature_lengths, feature_times)def extract_fbank(data, data_len=None, data_type: str = "sound", frontend=None, **kwargs):
"""Extract filter-bank features from audio data.
Args:
data: Audio samples (list of numpy arrays or tensors).
data_len: Lengths of each sample.
data_type (str): Input type ("sound", "fbank").
frontend: Frontend instance for feature extraction.
Returns:
tuple: (features_tensor, feature_lengths, feature_times)
"""
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
elif data.shape[0] > 1:
data = data.mean(dim=0, keepdim=True) # convert stereo/multi-channel to mono
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
elif data.shape[0] > 1:
data = data.mean(dim=0, keepdim=True) # convert stereo/multi-channel to mono
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data_i, np.ndarray):
data_i = torch.from_numpy(data_i)
View full source on GitHub →statistic_model_parameters(model, prefix)Statistic model parameters.
model — Model instance or model name.prefix — TODO.def statistic_model_parameters(model, prefix=None):
"""Statistic model parameters.
Args:
model: Model instance or model name.
prefix: TODO.
"""
var_dict = model.state_dict()
numel = 0
for i, key in enumerate(
sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))
):
if prefix is None or key.startswith(prefix):
numel += var_dict[key].numel()
return numel
int2vec(x, vec_dim, dtype)Int2vec.
x — TODO.vec_dim — Size/dimension parameter.dtype — TODO.def int2vec(x, vec_dim=8, dtype=np.int32):
"""Int2vec.
Args:
x: TODO.
vec_dim: Size/dimension parameter.
dtype: TODO.
"""
b = ("{:0" + str(vec_dim) + "b}").format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == "1").astype(dtype)
seq2arr(seq, vec_dim)Seq2arr.
seq — TODO.vec_dim — Size/dimension parameter.def seq2arr(seq, vec_dim=8):
"""Seq2arr.
Args:
seq: TODO.
vec_dim: Size/dimension parameter.
"""
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
load_scp_as_dict(scp_path, value_type, kv_sep)Load scp as dict.
scp_path — TODO.value_type — TODO.kv_sep — TODO.def load_scp_as_dict(scp_path, value_type="str", kv_sep=" "):
"""Load scp as dict.
Args:
scp_path: TODO.
value_type: TODO.
kv_sep: TODO.
"""
with io.open(scp_path, "r", encoding="utf-8") as f:
ret_dict = OrderedDict()
for one_line in f.readlines():
one_line = one_line.strip()
pos = one_line.find(kv_sep)
key, value = one_line[:pos], one_line[pos + 1 :]
if value_type == "list":
value = value.split(" ")
ret_dict[key] = value
return ret_dict
load_scp_as_list(scp_path, value_type, kv_sep)Load scp as list.
scp_path — TODO.value_type — TODO.kv_sep — TODO.def load_scp_as_list(scp_path, value_type="str", kv_sep=" "):
"""Load scp as list.
Args:
scp_path: TODO.
value_type: TODO.
kv_sep: TODO.
"""
with io.open(scp_path, "r", encoding="utf8") as f:
ret_dict = []
for one_line in f.readlines():
one_line = one_line.strip()
pos = one_line.find(kv_sep)
key, value = one_line[:pos], one_line[pos + 1 :]
if value_type == "list":
value = value.split(" ")
ret_dict.append((key, value))
return ret_dict
deep_update(original, update)Recursively merge update dict into original dict (in-place).
For nested dicts, merges recursively. For other types, overwrites.
original (dict) — Target dict to be updated in-place.update (dict) — Source dict with new values.def deep_update(original, update):
"""Recursively merge update dict into original dict (in-place).
For nested dicts, merges recursively. For other types, overwrites.
Args:
original (dict): Target dict to be updated in-place.
update (dict): Source dict with new values.
"""
for key, value in update.items():
if isinstance(value, dict) and key in original:
if len(value) == 0:
original[key] = value
deep_update(original[key], value)
else:
original[key] = value
prepare_model_dir(**kwargs)Prepare model dir.
**kwargs — Additional keyword arguments.def prepare_model_dir(**kwargs):
"""Prepare model dir.
Args:
**kwargs: Additional keyword arguments.
"""
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
OmegaConf.save(config=kwargs, f=yaml_file)
logging.info(f"kwargs: {kwargs}")
logging.info("config.yaml is saved to: %s", yaml_file)
model_path = kwargs.get("model_path", None)
if model_path is not None:
config_json = os.path.join(model_path, "configuration.json")
if os.path.exists(config_json):
shutil.copy(
config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")
)
extract_filename_without_extension(file_path)从给定的文件路径中提取文件名(不包含路径和扩展名)
:param file_path: 完整的文件路径
:return: 文件名(不含路径和扩展名)
def extract_filename_without_extension(file_path):
"""
从给定的文件路径中提取文件名(不包含路径和扩展名)
:param file_path: 完整的文件路径
:return: 文件名(不含路径和扩展名)
"""
# 首先,使用os.path.basename获取路径中的文件名部分(含扩展名)
filename_with_extension = os.path.basename(file_path)
# 然后,使用os.path.splitext分离文件名和扩展名
filename, extension = os.path.splitext(filename_with_extension)
# 返回不包含扩展名的文件名
return filename
smart_remove(path)Intelligently removes files, empty directories, and non-empty directories recursively.
def smart_remove(path):
"""Intelligently removes files, empty directories, and non-empty directories recursively."""
# Check if the provided path exists
if not os.path.exists(path):
print(f"{path} does not exist.")
return
# If the path is a file, delete it
if os.path.isfile(path):
os.remove(path)
print(f"File {path} has been deleted.")
# If the path is a directory
elif os.path.isdir(path):
try:
# Attempt to remove an empty directory
os.rmdir(path)
print(f"Empty directory {path} has been deleted.")
except OSError:
# If the directory is not empty, remove it along with all its contents
shutil.rmtree(path)
print(f"Non-empty directory {path} has been recursively deleted.")
isChinese(ch)Ischinese.
ch — TODO.def isChinese(ch: str):
"""Ischinese.
Args:
ch: TODO.
"""
if "\u4e00" <= ch <= "\u9fff" or "\u0030" <= ch <= "\u0039" or ch == "@":
return True
return False
isAllChinese(word)Isallchinese.
word — TODO.def isAllChinese(word: Union[List[Any], str]):
"""Isallchinese.
Args:
word: TODO.
"""
word_lists = []
for i in word:
cur = i.replace(" ", "")
cur = cur.replace("</s>", "")
cur = cur.replace("<s>", "")
cur = cur.replace("<unk>", "")
cur = cur.replace("<OOV>", "")
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if isChinese(ch) is False:
return False
return True
isAllAlpha(word)Isallalpha.
word — TODO.def isAllAlpha(word: Union[List[Any], str]):
"""Isallalpha.
Args:
word: TODO.
"""
word_lists = []
for i in word:
cur = i.replace(" ", "")
cur = cur.replace("</s>", "")
cur = cur.replace("<s>", "")
cur = cur.replace("<unk>", "")
cur = cur.replace("<OOV>", "")
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and isChinese(ch) is True:
return False
return True
abbr_dispose(words, time_stamp)Abbr dispose.
words — TODO.time_stamp — TODO.def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
"""Abbr dispose.
Args:
words: TODO.
time_stamp: TODO.
"""
words_size = len(words)
word_lists = []
abbr_begin = []
abbr_end = []
last_num = -1
ts_lists = []
ts_nums = []
ts_index = 0
for num in range(words_size):
if num <= last_num:
continue
if len(words[num]) == 1 and words[num].encode("utf-8").isalpha():
if (
num + 1 < words_size
and words[num + 1] == " "
and num + 2 < words_size
and len(words[num + 2]) == 1
and words[num + 2].encode("utf-8").isalpha()
):
# found the begin of abbr
abbr_begin.append(num)
num += 2
View full source on GitHub →sentence_postprocess(words, time_stamp)Sentence postprocess.
words — TODO.time_stamp — TODO.def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
"""Sentence postprocess.
Args:
words: TODO.
time_stamp: TODO.
"""
middle_lists = []
word_lists = []
word_item = ""
ts_lists = []
# wash words lists
for i in words:
word = ""
if isinstance(i, str):
word = i
else:
word = i.decode("utf-8")
if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
continue
else:
middle_lists.append(word)
# all chinese characters
if isAllChinese(middle_lists):
for i, ch in enumerate(middle_lists):
word_lists.append(ch.replace(" ", ""))
if time_stamp is not None:
View full source on GitHub →sentence_postprocess_sentencepiece(words)Sentence postprocess sentencepiece.
words — TODO.def sentence_postprocess_sentencepiece(words):
"""Sentence postprocess sentencepiece.
Args:
words: TODO.
"""
middle_lists = []
word_lists = []
word_item = ""
# wash words lists
for i in words:
word = ""
if isinstance(i, str):
word = i
else:
word = i.decode("utf-8")
if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
continue
else:
middle_lists.append(word)
# all alpha characters
for i, ch in enumerate(middle_lists):
word = ""
if "\u2581" in ch and i == 0:
word_item = ""
word = ch.replace("\u2581", "")
word_item += word
View full source on GitHub →format_str_v2(s)Format str v2.
s — TODO.def format_str_v2(s):
"""Format str v2.
Args:
s: TODO.
"""
sptk_dict = {}
for sptk in emoji_dict:
sptk_dict[sptk] = s.count(sptk)
s = s.replace(sptk, "")
emo = "<|NEUTRAL|>"
for e in emo_dict:
if sptk_dict[e] > sptk_dict[emo]:
emo = e
for e in event_dict:
if sptk_dict[e] > 0:
s = event_dict[e] + s
s = s + emo_dict[emo]
for emoji in emo_set.union(event_set):
s = s.replace(" " + emoji, emoji)
s = s.replace(emoji + " ", emoji)
return s.strip()
rich_transcription_postprocess(s)Rich transcription postprocess.
s — TODO.def rich_transcription_postprocess(s):
"""Rich transcription postprocess.
Args:
s: TODO.
"""
def get_emo(s):
"""Get emo.
Args:
s: TODO.
"""
return s[-1] if s[-1] in emo_set else None
def get_event(s):
"""Get event.
Args:
s: TODO.
"""
return s[0] if s[0] in event_set else None
s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
for lang in lang_dict:
s = s.replace(lang, "<|lang|>")
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
new_s = " " + s_list[0]
cur_ent_event = get_event(new_s)
for i in range(1, len(s_list)):
if len(s_list[i]) == 0:
View full source on GitHub →check_audio_list(audio)Check audio list.
audio — TODO.def check_audio_list(audio: list):
"""Check audio list.
Args:
audio: TODO.
"""
audio_dur = 0
for i in range(len(audio)):
seg = audio[i]
assert seg[1] >= seg[0], "modelscope error: Wrong time stamps."
assert isinstance(seg[2], np.ndarray), "modelscope error: Wrong data type."
assert (
int(seg[1] * 16000) - int(seg[0] * 16000) == seg[2].shape[0]
), "modelscope error: audio data in list is inconsistent with time length."
if i > 0:
assert seg[0] >= audio[i - 1][1], "modelscope error: Wrong time stamps."
audio_dur += seg[1] - seg[0]
return audio_dur
sv_preprocess(inputs)Sv preprocess.
inputs — TODO.def sv_preprocess(inputs: Union[np.ndarray, list]):
"""Sv preprocess.
Args:
inputs: TODO.
"""
output = []
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.load(io.BytesIO(file_bytes), dtype="float32")
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray):
assert len(inputs[i].shape) == 1, "modelscope error: Input array should be [N, T]"
data = inputs[i]
if data.dtype in ["int16", "int32", "int64"]:
data = (data / (1 << 15)).astype("float32")
else:
data = data.astype("float32")
data = torch.from_numpy(data)
else:
raise ValueError(
"modelscope error: The input type is restricted to audio address and nump array."
)
output.append(data)
return output
sv_chunk(vad_segments, fs)Sv chunk.
vad_segments — TODO.fs — TODO.def sv_chunk(vad_segments: list, fs=16000) -> list:
"""Sv chunk.
Args:
vad_segments: TODO.
fs: TODO.
"""
config = {
"seg_dur": 1.5,
"seg_shift": 0.75,
}
def seg_chunk(seg_data):
"""Seg chunk.
Args:
seg_data: TODO.
"""
seg_st = seg_data[0]
data = seg_data[2]
chunk_len = int(config["seg_dur"] * fs)
chunk_shift = int(config["seg_shift"] * fs)
last_chunk_ed = 0
seg_res = []
for chunk_st in range(0, data.shape[0], chunk_shift):
chunk_ed = min(chunk_st + chunk_len, data.shape[0])
if chunk_ed <= last_chunk_ed:
break
last_chunk_ed = chunk_ed
chunk_st = max(0, chunk_ed - chunk_len)
View full source on GitHub →extract_feature(audio)Extract feature.
audio — TODO.def extract_feature(audio):
"""Extract feature.
Args:
audio: TODO.
"""
features = []
for au in audio:
feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)
return features
postprocess(segments, vad_segments, labels, embeddings)Postprocess.
segments — TODO.vad_segments — TODO.labels — TODO.embeddings — TODO.def postprocess(
segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray
) -> list:
"""Postprocess.
Args:
segments: TODO.
vad_segments: TODO.
labels: TODO.
embeddings: TODO.
"""
assert len(segments) == len(labels)
labels = correct_labels(labels)
distribute_res = []
for i in range(len(segments)):
distribute_res.append([segments[i][0], segments[i][1], labels[i]])
# merge the same speakers chronologically
distribute_res = merge_seque(distribute_res)
# accquire speaker center
spk_embs = []
for i in range(labels.max() + 1):
spk_emb = embeddings[labels == i].mean(0)
spk_embs.append(spk_emb)
spk_embs = np.stack(spk_embs)
def is_overlapped(t1, t2):
"""Is overlapped.
Args:
View full source on GitHub →correct_labels(labels)Correct labels.
labels — TODO.def correct_labels(labels):
"""Correct labels.
Args:
labels: TODO.
"""
labels_id = 0
id2id = {}
new_labels = []
for i in labels:
if i not in id2id:
id2id[i] = labels_id
labels_id += 1
new_labels.append(id2id[i])
return np.array(new_labels)
merge_seque(distribute_res)Merge seque.
distribute_res — TODO.def merge_seque(distribute_res):
"""Merge seque.
Args:
distribute_res: TODO.
"""
res = [distribute_res[0]]
for i in range(1, len(distribute_res)):
if distribute_res[i][2] != res[-1][2] or distribute_res[i][0] > res[-1][1]:
res.append(distribute_res[i])
else:
res[-1][1] = distribute_res[i][1]
return res
smooth(res, mindur)Smooth.
res — TODO.mindur — TODO.def smooth(res, mindur=1):
# short segments are assigned to nearest speakers.
"""Smooth.
Args:
res: TODO.
mindur: TODO.
"""
for i in range(len(res)):
res[i][0] = round(res[i][0], 2)
res[i][1] = round(res[i][1], 2)
if res[i][1] - res[i][0] < mindur:
if i == 0:
res[i][2] = res[i + 1][2]
elif i == len(res) - 1:
res[i][2] = res[i - 1][2]
elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
res[i][2] = res[i - 1][2]
else:
res[i][2] = res[i + 1][2]
# merge the speakers
res = merge_seque(res)
return res
distribute_spk(sentence_list, sd_time_list)Distribute spk.
sentence_list — TODO.sd_time_list — TODO.def distribute_spk(sentence_list, sd_time_list):
"""Distribute spk.
Args:
sentence_list: TODO.
sd_time_list: TODO.
"""
sd_sentence_list = []
for d in sentence_list:
sentence_start = d["ts_list"][0][0]
sentence_end = d["ts_list"][-1][1]
sentence_spk = 0
max_overlap = 0
for sd_time in sd_time_list:
spk_st, spk_ed, spk = sd_time
spk_st = spk_st * 1000
spk_ed = spk_ed * 1000
overlap = max(min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
if overlap > max_overlap:
max_overlap = overlap
sentence_spk = spk
d["spk"] = sentence_spk
sd_sentence_list.append(d)
return sd_sentence_list
cif_wo_hidden(alphas, threshold)Cif wo hidden.
alphas — TODO.threshold — TODO.def cif_wo_hidden(alphas, threshold):
"""Cif wo hidden.
Args:
alphas: TODO.
threshold: TODO.
"""
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
integrate,
)
fires = torch.stack(list_fires, 1)
return fires
ts_prediction_lfr6_standard(us_alphas, us_peaks, char_list, vad_offset, force_time_shift, sil_in_str, upsample_rate)Ts prediction lfr6 standard.
us_alphas — TODO.us_peaks — TODO.char_list — TODO.vad_offset — TODO.force_time_shift — TODO.sil_in_str — TODO.upsample_rate — TODO.def ts_prediction_lfr6_standard(
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
):
"""Ts prediction lfr6 standard.
Args:
us_alphas: TODO.
us_peaks: TODO.
char_list: TODO.
vad_offset: TODO.
force_time_shift: TODO.
sil_in_str: TODO.
upsample_rate: TODO.
"""
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12 # 3 times upsampled
TIME_RATE=10.0 * 6 / 1000 / upsample_rate
if len(us_alphas.shape) == 2:
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
alphas, peaks = us_alphas, us_peaks
if char_list[-1] == "</s>":
char_list = char_list[:-1]
fire_place = (
torch.where(peaks >= 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
) # total offset
if len(fire_place) != len(char_list) + 1:
alphas /= alphas.sum() / (len(char_list) + 1)
View full source on GitHub →timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text)Split recognized text into sentences using punctuation, with timestamps.
punc_id_list (Tensor/list) — Punctuation IDs from CT-Transformer.Values — 1=none, 2=comma, 3=period, 4=question.timestamp_postprocessed (list) — Per-character timestamps [[start_ms, end_ms], ...].text_postprocessed (str) — Space-separated recognized text.return_raw_text (bool) — Include raw_text in output.list[dict]: Sentences with keys: text, start, end, timestamp, [raw_text].
def timestamp_sentence(
punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):
"""Split recognized text into sentences using punctuation, with timestamps.
Args:
punc_id_list (Tensor/list): Punctuation IDs from CT-Transformer.
Values: 1=none, 2=comma, 3=period, 4=question.
timestamp_postprocessed (list): Per-character timestamps [[start_ms, end_ms], ...].
text_postprocessed (str): Space-separated recognized text.
return_raw_text (bool): Include raw_text in output.
Returns:
list[dict]: Sentences with keys: text, start, end, timestamp, [raw_text].
"""
punc_list = [",", "。", "?", "、"]
res = []
if text_postprocessed is None:
return res
if timestamp_postprocessed is None:
return res
if len(timestamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
if punc_id_list is None or len(punc_id_list) == 0:
res.append(
{
View full source on GitHub →timestamp_sentence_en(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text)Timestamp sentence en.
punc_id_list — TODO.timestamp_postprocessed — TODO.text_postprocessed — TODO.return_raw_text — TODO.def timestamp_sentence_en(
punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):
"""Timestamp sentence en.
Args:
punc_id_list: TODO.
timestamp_postprocessed: TODO.
text_postprocessed: TODO.
return_raw_text: TODO.
"""
punc_list = [",", ".", "?", ","]
res = []
if text_postprocessed is None:
return res
if timestamp_postprocessed is None:
return res
if len(timestamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
if punc_id_list is None or len(punc_id_list) == 0:
res.append(
{
"text": text_postprocessed.split(),
"start": timestamp_postprocessed[0][0],
"end": timestamp_postprocessed[-1][1],
"timestamp": timestamp_postprocessed,
}
View full source on GitHub →No documentation yet.
class MakePadMask(nn.Module):
def __init__(self, max_seq_len=512, flip=True):
"""Initialize MakePadMask.
Args:
max_seq_len: TODO.
flip: TODO.
"""
super().__init__()
if flip:
self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
else:
self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.
This implementation creates the same mask tensor with original make_pad_mask,
which can be converted into onnx format.
Dimension length of xs should be 2 or 3.
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if xs is not None and len(xs.shape) == 3:
if length_dim == 1:
lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
else:
lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
if maxlen is not None:
View full source on GitHub →.forward(lengths, xs, length_dim, maxlen) L23Make mask tensor containing indices of padded part.
This implementation creates the same mask tensor with original make_pad_mask,
which can be converted into onnx format.
Dimension length of xs should be 2 or 3.
def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.
This implementation creates the same mask tensor with original make_pad_mask,
which can be converted into onnx format.
Dimension length of xs should be 2 or 3.
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if xs is not None and len(xs.shape) == 3:
if length_dim == 1:
lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
else:
lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
if maxlen is not None:
m = maxlen
elif xs is not None:
m = xs.shape[-1]
else:
m = torch.max(lengths)
mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
if length_dim == 1:
return mask.transpose(1, 2)
else:
return mask
MakePadMask.forward(lengths, xs, length_dim, maxlen)Make mask tensor containing indices of padded part.
This implementation creates the same mask tensor with original make_pad_mask,
which can be converted into onnx format.
Dimension length of xs should be 2 or 3.
def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.
This implementation creates the same mask tensor with original make_pad_mask,
which can be converted into onnx format.
Dimension length of xs should be 2 or 3.
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if xs is not None and len(xs.shape) == 3:
if length_dim == 1:
lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
else:
lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
if maxlen is not None:
m = maxlen
elif xs is not None:
m = xs.shape[-1]
else:
m = torch.max(lengths)
mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
if length_dim == 1:
return mask.transpose(1, 2)
else:
return mask
No documentation yet.
class sequence_mask(nn.Module):
def __init__(self, max_seq_len=512, flip=True):
"""Initialize sequence_mask.
Args:
max_seq_len: TODO.
flip: TODO.
"""
super().__init__()
def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
"""Forward pass for training.
Args:
lengths: TODO.
max_seq_len: TODO.
dtype: TODO.
device: Target device ("cuda:0", "cpu", etc.).
"""
if max_seq_len is None:
max_seq_len = lengths.max()
row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
.forward(lengths, max_seq_len, dtype, device) L63Forward pass for training.
lengths — TODO.max_seq_len — TODO.dtype — TODO.device — Target device ("cuda:0", "cpu", etc.). def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
"""Forward pass for training.
Args:
lengths: TODO.
max_seq_len: TODO.
dtype: TODO.
device: Target device ("cuda:0", "cpu", etc.).
"""
if max_seq_len is None:
max_seq_len = lengths.max()
row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
sequence_mask.forward(lengths, max_seq_len, dtype, device)Forward pass for training.
lengths — TODO.max_seq_len — TODO.dtype — TODO.device — Target device ("cuda:0", "cpu", etc.). def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
"""Forward pass for training.
Args:
lengths: TODO.
max_seq_len: TODO.
dtype: TODO.
device: Target device ("cuda:0", "cpu", etc.).
"""
if max_seq_len is None:
max_seq_len = lengths.max()
row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
normalize(input, p, dim, out)Normalize.
input — Input audio/text data.p — TODO.dim — TODO.out — TODO.def normalize(
input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Normalize.
Args:
input: Input audio/text data.
p: TODO.
dim: TODO.
out: TODO.
"""
if out is None:
denom = input.norm(p, dim, keepdim=True).expand_as(input)
return input / denom
else:
denom = input.norm(p, dim, keepdim=True).expand_as(input)
return torch.div(input, denom, out=out)
subsequent_mask(size)Subsequent mask.
size — TODO.def subsequent_mask(size: torch.Tensor):
"""Subsequent mask.
Args:
size: TODO.
"""
return torch.ones(size, size).tril()
MakePadMask_test()Makepadmask test.
def MakePadMask_test():
"""Makepadmask test."""
feats_length = torch.tensor([10]).type(torch.long)
mask_fn = MakePadMask()
mask = mask_fn(feats_length)
print(mask)
str2bool(value)Str2bool.
value — TODO.def str2bool(value: str) -> bool:
"""Str2bool.
Args:
value: TODO.
"""
return bool(strtobool(value))
remove_parenthesis(value)Remove parenthesis.
value — TODO.def remove_parenthesis(value: str):
"""Remove parenthesis.
Args:
value: TODO.
"""
value = value.strip()
if value.startswith("(") and value.endswith(")"):
value = value[1:-1]
elif value.startswith("[") and value.endswith("]"):
value = value[1:-1]
return value
remove_quotes(value)Remove quotes.
value — TODO.def remove_quotes(value: str):
"""Remove quotes.
Args:
value: TODO.
"""
value = value.strip()
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
elif value.startswith("'") and value.endswith("'"):
value = value[1:-1]
return value
int_or_none(value)int_or_none.
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=int_or_none)
>>> parser.parse_args(['--foo', '456'])
Namespace(foo=456)
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
def int_or_none(value: str) -> Optional[int]:
"""int_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=int_or_none)
>>> parser.parse_args(['--foo', '456'])
Namespace(foo=456)
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return int(value)
float_or_none(value)float_or_none.
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=float_or_none)
>>> parser.parse_args(['--foo', '4.5'])
Namespace(foo=4.5)
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
def float_or_none(value: str) -> Optional[float]:
"""float_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=float_or_none)
>>> parser.parse_args(['--foo', '4.5'])
Namespace(foo=4.5)
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return float(value)
humanfriendly_parse_size_or_none(value)Humanfriendly parse size or none.
value — TODO.def humanfriendly_parse_size_or_none(value) -> Optional[float]:
"""Humanfriendly parse size or none.
Args:
value: TODO.
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return humanfriendly.parse_size(value)
str_or_int(value)Str or int.
value — TODO.def str_or_int(value: str) -> Union[str, int]:
"""Str or int.
Args:
value: TODO.
"""
try:
return int(value)
except ValueError:
return value
str_or_none(value)str_or_none.
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=str_or_none)
>>> parser.parse_args(['--foo', 'aaa'])
Namespace(foo='aaa')
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
def str_or_none(value: str) -> Optional[str]:
"""str_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=str_or_none)
>>> parser.parse_args(['--foo', 'aaa'])
Namespace(foo='aaa')
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return value
str2pair_str(value)str2pair_str.
>>> import argparse
>>> str2pair_str('abc,def ')
('abc', 'def')
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=str2pair_str)
>>> parser.parse_args(['--foo', 'abc,def'])
Namespace(foo=('abc', 'def'))
def str2pair_str(value: str) -> Tuple[str, str]:
"""str2pair_str.
Examples:
>>> import argparse
>>> str2pair_str('abc,def ')
('abc', 'def')
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=str2pair_str)
>>> parser.parse_args(['--foo', 'abc,def'])
Namespace(foo=('abc', 'def'))
"""
value = remove_parenthesis(value)
a, b = value.split(",")
# Workaround for configargparse issues:
# If the list values are given from yaml file,
# the value givent to type() is shaped as python-list,
# e.g. ['a', 'b', 'c'],
# so we need to remove double quotes from it.
return remove_quotes(a), remove_quotes(b)
str2triple_str(value)str2triple_str.
>>> str2triple_str('abc,def ,ghi')
('abc', 'def', 'ghi')
def str2triple_str(value: str) -> Tuple[str, str, str]:
"""str2triple_str.
Examples:
>>> str2triple_str('abc,def ,ghi')
('abc', 'def', 'ghi')
"""
value = remove_parenthesis(value)
a, b, c = value.split(",")
# Workaround for configargparse issues:
# If the list values are given from yaml file,
# the value givent to type() is shaped as python-list,
# e.g. ['a', 'b', 'c'],
# so we need to remove quotes from it.
return remove_quotes(a), remove_quotes(b), remove_quotes(c)
slice_padding_fbank(speech, speech_lengths, vad_segments)Slice padding fbank.
speech — Speech audio tensor, shape (batch, time).speech_lengths — Length of each speech sample.vad_segments — TODO.def slice_padding_fbank(speech, speech_lengths, vad_segments):
"""Slice padding fbank.
Args:
speech: Speech audio tensor, shape (batch, time).
speech_lengths: Length of each speech sample.
vad_segments: TODO.
"""
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0] * 16)
end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
speech_i = speech[0, bed_idx:end_idx]
speech_lengths_i = end_idx - bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
return feats_pad, speech_lengths_pad
slice_padding_audio_samples(speech, speech_lengths, vad_segments)Slice audio into VAD segments with proper padding.
speech (Tensor) — Full audio tensor.speech_lengths (int) — Total audio length.vad_segments (list) — List of (segment_info, original_index) tuples,tuple — (speech_list, speech_lengths_list) - lists of numpy arrays and their lengths.def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
"""Slice audio into VAD segments with proper padding.
Args:
speech (Tensor): Full audio tensor.
speech_lengths (int): Total audio length.
vad_segments (list): List of (segment_info, original_index) tuples,
where segment_info is [start_ms, end_ms].
Returns:
tuple: (speech_list, speech_lengths_list) - lists of numpy arrays and their lengths.
"""
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0] * 16)
end_idx = min(int(segment[0][1] * 16), speech_lengths)
speech_i = speech[bed_idx:end_idx]
speech_lengths_i = end_idx - bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
return speech_list, speech_lengths_list
merge_vad(vad_result, max_length, min_length)Merge short VAD segments to reduce fragmentation.
vad_result (list) — VAD segments [[start_ms, end_ms], ...].max_length (int) — Maximum merged segment length in ms (default 15000).min_length (int) — Minimum segment length; shorter ones get merged (default 0).list — Merged VAD segments [[start_ms, end_ms], ...].def merge_vad(vad_result, max_length=15000, min_length=0):
"""Merge short VAD segments to reduce fragmentation.
Args:
vad_result (list): VAD segments [[start_ms, end_ms], ...].
max_length (int): Maximum merged segment length in ms (default 15000).
min_length (int): Minimum segment length; shorter ones get merged (default 0).
Returns:
list: Merged VAD segments [[start_ms, end_ms], ...].
"""
new_result = []
if len(vad_result) <= 1:
return vad_result
time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
time_step = sorted(list(set(time_step)))
if len(time_step) == 0:
return []
bg = 0
for i in range(len(time_step) - 1):
time = time_step[i]
if time_step[i + 1] - bg < max_length:
continue
if time - bg > min_length:
new_result.append([bg, time])
# if time - bg < max_length * 1.5:
# new_result.append([bg, time])
# else:
# split_num = int(time - bg) // max_length + 1
View full source on GitHub →get_pypi_version(package_name)Get pypi version.
package_name — TODO.def get_pypi_version(package_name):
"""Get pypi version.
Args:
package_name: TODO.
"""
import requests
url = f"https://pypi.org/pypi/{package_name}/json"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
return version.parse(data["info"]["version"])
else:
raise Exception("Failed to retrieve version information from PyPI.")
check_for_update(disable)Check for update.
disable — TODO.def check_for_update(disable=False):
"""Check for update.
Args:
disable: TODO.
"""
current_version = version.parse(__version__)
print(f"funasr version: {current_version}.")
if disable:
return
print(
"Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
)
pypi_version = get_pypi_version("funasr")
if current_version < pypi_version:
print(f"New version is available: {pypi_version}.")
print('Please use the command "pip install -U funasr" to upgrade.')
else:
print(f"You are using the latest version of funasr-{current_version}")
main_hydra(kwargs)Main hydra.
kwargs — Additional keyword arguments.def main_hydra(kwargs: DictConfig):
"""Main hydra.
Args:
kwargs: Additional keyword arguments.
"""
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
main(**kwargs)
main(**kwargs)Main.
**kwargs — Additional keyword arguments.def main(**kwargs):
"""Main.
Args:
**kwargs: Additional keyword arguments.
"""
print(kwargs)
# set random seed
# tables.print()
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
tokenizer = kwargs.get("tokenizer", None)
# build frontend if frontend is none None
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()
# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_train = dataset_class(
kwargs.get("train_data_set_list"),
frontend=frontend,
tokenizer=None,
View full source on GitHub →main_hydra(cfg)Main hydra.
cfg — Configuration overrides.def main_hydra(cfg: DictConfig):
"""Main hydra.
Args:
cfg: Configuration overrides.
"""
def to_plain_list(cfg_item):
"""To plain list.
Args:
cfg_item: TODO.
"""
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
if "device" not in kwargs:
kwargs["device"] = "cpu"
model = AutoModel(**kwargs)
View full source on GitHub →main_hydra(cfg)Main hydra.
cfg — Configuration overrides.def main_hydra(cfg: DictConfig):
"""Main hydra.
Args:
cfg: Configuration overrides.
"""
def to_plain_list(cfg_item):
"""To plain list.
Args:
cfg_item: TODO.
"""
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
model = AutoModel(**kwargs)
res = model.generate(input=kwargs["input"])
print(res)
field2slice(field)Convert field string to slice
Note that field string accepts 1-based integer.
>>> field2slice("1-")
slice(0, None, None)
>>> field2slice("1-3")
slice(0, 3, None)
>>> field2slice("-3")
slice(None, 3, None)
def field2slice(field: Optional[str]) -> slice:
"""Convert field string to slice
Note that field string accepts 1-based integer.
Examples:
>>> field2slice("1-")
slice(0, None, None)
>>> field2slice("1-3")
slice(0, 3, None)
>>> field2slice("-3")
slice(None, 3, None)
"""
field = field.strip()
try:
if "-" in field:
# e.g. "2-" or "2-5" or "-7"
s1, s2 = field.split("-", maxsplit=1)
if s1.strip() == "":
s1 = None
else:
s1 = int(s1)
if s1 == 0:
raise ValueError("1-based string")
if s2.strip() == "":
s2 = None
else:
s2 = int(s2)
else:
# e.g. "2"
View full source on GitHub →tokenize(input, output, field, delimiter, token_type, space_symbol, non_linguistic_symbols, bpemodel, log_level, write_vocabulary, vocabulary_size, remove_non_linguistic_symbols, cutoff, add_symbol, cleaner, g2p)Tokenize.
input — Input audio/text data.output — TODO.field — TODO.delimiter — TODO.token_type — TODO.space_symbol — TODO.non_linguistic_symbols — TODO.bpemodel — TODO.log_level — TODO.write_vocabulary — TODO.vocabulary_size — Size/dimension parameter.remove_non_linguistic_symbols — TODO.cutoff — TODO.add_symbol — TODO.cleaner — TODO.g2p — TODO.def tokenize(
input: str,
output: str,
field: Optional[str],
delimiter: Optional[str],
token_type: str,
space_symbol: str,
non_linguistic_symbols: Optional[str],
bpemodel: Optional[str],
log_level: str,
write_vocabulary: bool,
vocabulary_size: int,
remove_non_linguistic_symbols: bool,
cutoff: int,
add_symbol: List[str],
cleaner: Optional[str],
g2p: Optional[str],
):
"""Tokenize.
Args:
input: Input audio/text data.
output: TODO.
field: TODO.
delimiter: TODO.
token_type: TODO.
space_symbol: TODO.
non_linguistic_symbols: TODO.
bpemodel: TODO.
View full source on GitHub →get_parser()Get parser.
def get_parser() -> argparse.ArgumentParser:
"""Get parser."""
parser = argparse.ArgumentParser(
description="Tokenize texts",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--input", "-i", required=True, help="Input text. - indicates sys.stdin")
parser.add_argument("--output", "-o", required=True, help="Output text. - indicates sys.stdout")
parser.add_argument(
"--field",
"-f",
help="The target columns of the input text as 1-based integer. e.g 2-",
)
parser.add_argument(
"--token_type",
"-t",
default="char",
choices=["char", "bpe", "word", "phn"],
help="Token type",
)
parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
View full source on GitHub →main(cmd)Main.
cmd — TODO.def main(cmd=None):
"""Main.
Args:
cmd: TODO.
"""
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
tokenize(**kwargs)
main_hydra(kwargs)Main hydra.
kwargs — Additional keyword arguments.def main_hydra(kwargs: DictConfig):
"""Main hydra.
Args:
kwargs: Additional keyword arguments.
"""
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
main(**kwargs)
main(**kwargs)Main.
**kwargs — Additional keyword arguments.def main(**kwargs):
# set random seed
"""Main.
Args:
**kwargs: Additional keyword arguments.
"""
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
# open tf32
torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
tables.print()
# Check if we are using DDP or FSDP
use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
use_fsdp = kwargs.get("use_fsdp", False)
# use_ddp = False if use_fsdp else use_fsdp
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
torch.cuda.set_device(local_rank)
logging.info("Build model, frontend, tokenizer")
device = kwargs.get("device", "cuda")
kwargs["device"] = "cpu"
model = AutoModel(**kwargs)
View full source on GitHub →main_hydra(kwargs)Main hydra.
kwargs — Additional keyword arguments.def main_hydra(kwargs: DictConfig):
"""Main hydra.
Args:
kwargs: Additional keyword arguments.
"""
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
main(**kwargs)
main(**kwargs)Main.
**kwargs — Additional keyword arguments.def main(**kwargs):
# set random seed
"""Main.
Args:
**kwargs: Additional keyword arguments.
"""
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
# open tf32
torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
if local_rank == 0:
tables.print()
use_ddp = world_size > 1
use_fsdp = kwargs.get("use_fsdp", False)
use_deepspeed = kwargs.get("use_deepspeed", False)
if use_deepspeed:
logging.info(f"use_deepspeed: {use_deepspeed}")
deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
elif use_ddp or use_fsdp:
logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
View full source on GitHub →download_dataset()Download dataset.
def download_dataset():
"""Download dataset."""
pass
download_dataset_from_ms(**kwargs)Download dataset from ms.
**kwargs — Additional keyword arguments.def download_dataset_from_ms(**kwargs):
"""Download dataset from ms.
Args:
**kwargs: Additional keyword arguments.
"""
from modelscope.msdatasets import MsDataset
dataset_name = kwargs.get("dataset_name", "speech_asr/speech_asr_aishell1_trainsets")
subset_name = kwargs.get("subset_name", "default")
split = kwargs.get("split", "train")
data_dump_dir = kwargs.get("data_dump_dir", None)
ds = MsDataset.load(
dataset_name=dataset_name, subset_name=subset_name, split=split, cache_dir=data_dump_dir
)
download_model(**kwargs)Download model from hub and parse its configuration.
Resolves model name aliases, downloads from ModelScope or HuggingFace,
reads config.yaml and configuration.json, and returns complete kwargs
for model instantiation.
**kwargs — Must include 'model' (str). Optional: 'hub', 'model_revision',dict — Complete kwargs with resolved paths, model class name, and config.def download_model(**kwargs):
"""Download model from hub and parse its configuration.
Resolves model name aliases, downloads from ModelScope or HuggingFace,
reads config.yaml and configuration.json, and returns complete kwargs
for model instantiation.
Args:
**kwargs: Must include 'model' (str). Optional: 'hub', 'model_revision',
'is_training', etc.
Returns:
dict: Complete kwargs with resolved paths, model class name, and config.
"""
hub = kwargs.get("hub", "ms")
if hub == "ms" or hub == "modelscope":
kwargs = download_from_ms(**kwargs)
elif hub == "hf" or hub == "huggingface":
kwargs = download_from_hf(**kwargs)
elif hub == "openai":
model_or_path = kwargs.get("model")
if os.path.exists(model_or_path):
# local path
kwargs["model_path"] = model_or_path
kwargs["model"] = "WhisperWarp"
else:
# model name
if model_or_path in name_maps_openai:
model_or_path = name_maps_openai[model_or_path]
View full source on GitHub →download_from_ms(**kwargs)Download from ms.
**kwargs — Additional keyword arguments.def download_from_ms(**kwargs):
"""Download from ms.
Args:
**kwargs: Additional keyword arguments.
"""
model_or_path = kwargs.get("model")
if model_or_path in name_maps_ms:
model_or_path = name_maps_ms[model_or_path]
model_revision = kwargs.get("model_revision", "master")
if not os.path.exists(model_or_path) and "model_path" not in kwargs:
try:
model_or_path = get_or_download_model_dir(
model_or_path,
model_revision,
is_training=kwargs.get("is_training"),
check_latest=kwargs.get("check_latest", True),
)
except Exception as e:
print(f"Download: {model_or_path} failed!: {e}")
kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
model_or_path = kwargs["model_path"]
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
conf_json = json.load(f)
cfg = {}
if "file_path_metas" in conf_json:
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
View full source on GitHub →download_from_hf(**kwargs)Download from hf.
**kwargs — Additional keyword arguments.def download_from_hf(**kwargs):
"""Download from hf.
Args:
**kwargs: Additional keyword arguments.
"""
model_or_path = kwargs.get("model")
if model_or_path in name_maps_hf:
model_or_path = name_maps_hf[model_or_path]
model_revision = kwargs.get("model_revision", "master")
if not os.path.exists(model_or_path) and "model_path" not in kwargs:
try:
model_or_path = get_or_download_model_dir_hf(
model_or_path,
model_revision,
is_training=kwargs.get("is_training"),
check_latest=kwargs.get("check_latest", True),
)
except Exception as e:
print(f"Download: {model_or_path} failed!: {e}")
kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
conf_json = json.load(f)
cfg = {}
if "file_path_metas" in conf_json:
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
View full source on GitHub →add_file_root_path(model_or_path, file_path_metas, cfg)Add file root path.
model_or_path — TODO.file_path_metas — TODO.cfg — Configuration overrides.def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
"""Add file root path.
Args:
model_or_path: TODO.
file_path_metas: TODO.
cfg: Configuration overrides.
"""
if isinstance(file_path_metas, dict):
if isinstance(cfg, list):
cfg.append({})
for k, v in file_path_metas.items():
if isinstance(v, str):
p = os.path.join(model_or_path, v)
if os.path.exists(p):
if isinstance(cfg, dict):
cfg[k] = p
elif isinstance(cfg, list):
# if len(cfg) == 0:
# cfg.append({})
cfg[-1][k] = p
elif isinstance(v, dict):
if isinstance(cfg, dict):
if k not in cfg:
cfg[k] = {}
add_file_root_path(model_or_path, v, cfg[k])
# elif isinstance(cfg, list):
View full source on GitHub →get_or_download_model_dir(model, model_revision, is_training, check_latest)Get local model directory or download model if necessary.
model (str) — model id or path to local model directory.model_revision (str, optional) — model version number.:param is_training:
def get_or_download_model_dir(
model,
model_revision=None,
is_training=False,
check_latest=True,
):
"""Get local model directory or download model if necessary.
Args:
model (str): model id or path to local model directory.
model_revision (str, optional): model version number.
:param is_training:
"""
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import Invoke, ThirdParty
key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
if os.path.exists(model) and check_latest:
model_cache_dir = model if os.path.isdir(model) else os.path.dirname(model)
try:
check_local_model_is_latest(
model_cache_dir, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}
)
except:
print("could not check the latest version")
else:
model_cache_dir = snapshot_download(
View full source on GitHub →get_or_download_model_dir_hf(model, model_revision, is_training, check_latest)Get local model directory or download model if necessary.
model (str) — model id or path to local model directory.model_revision (str, optional) — model version number.:param is_training:
def get_or_download_model_dir_hf(
model,
model_revision=None,
is_training=False,
check_latest=True,
):
"""Get local model directory or download model if necessary.
Args:
model (str): model id or path to local model directory.
model_revision (str, optional): model version number.
:param is_training:
"""
from huggingface_hub import snapshot_download
model_cache_dir = snapshot_download(model)
return model_cache_dir
download_from_url(url)Download from url.
url — TODO.def download_from_url(url):
"""Download from url.
Args:
url: TODO.
"""
result = urlparse(url)
file_path = None
if result.scheme is not None and len(result.scheme) > 0:
storage = HTTPStorage()
# bytes
data = storage.read(url)
work_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(work_dir):
os.makedirs(work_dir)
file_path = os.path.join(work_dir, os.path.basename(url))
with open(file_path, "wb") as fb:
fb.write(data)
assert file_path is not None, f"failed to download: {url}"
return file_path
Abstract class of storage.
All backends need to implement two apis: ``read()`` and ``read_text()``.
``read()`` reads the file as a byte stream and ``read_text()`` reads
the file as texts.
class Storage(metaclass=ABCMeta):
"""Abstract class of storage.
All backends need to implement two apis: ``read()`` and ``read_text()``.
``read()`` reads the file as a byte stream and ``read_text()`` reads
the file as texts.
"""
@abstractmethod
def read(self, filepath: str):
"""Read.
Args:
filepath: TODO.
"""
pass
@abstractmethod
def read_text(self, filepath: str):
"""Read text.
Args:
filepath: TODO.
"""
pass
@abstractmethod
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write.
View full source on GitHub →.read(filepath) L45Read.
filepath — TODO. def read(self, filepath: str):
"""Read.
Args:
filepath: TODO.
"""
pass
.read_text(filepath) L54Read text.
filepath — TODO. def read_text(self, filepath: str):
"""Read text.
Args:
filepath: TODO.
"""
pass
.write(obj, filepath) L63Write.
obj — TODO.filepath — TODO. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write.
Args:
obj: TODO.
filepath: TODO.
"""
pass
.write_text(obj, filepath, encoding) L73Write text.
obj — TODO.filepath — TODO.encoding — TODO. def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
"""Write text.
Args:
obj: TODO.
filepath: TODO.
encoding: TODO.
"""
pass
Storage.read(filepath)Read.
filepath — TODO. def read(self, filepath: str):
"""Read.
Args:
filepath: TODO.
"""
pass
Storage.read_text(filepath)Read text.
filepath — TODO. def read_text(self, filepath: str):
"""Read text.
Args:
filepath: TODO.
"""
pass
Storage.write(obj, filepath)Write.
obj — TODO.filepath — TODO. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write.
Args:
obj: TODO.
filepath: TODO.
"""
pass
Storage.write_text(obj, filepath, encoding)Write text.
obj — TODO.filepath — TODO.encoding — TODO. def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
"""Write text.
Args:
obj: TODO.
filepath: TODO.
encoding: TODO.
"""
pass
Local hard disk storage
class LocalStorage(Storage):
"""Local hard disk storage"""
def read(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
with open(filepath, "rb") as f:
content = f.read()
return content
def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, "r", encoding=encoding) as f:
value_buf = f.read()
return value_buf
View full source on GitHub →.read(filepath) L87Read data from a given ``filepath`` with 'rb' mode.
filepath (str or Path) — Path to read data.bytes — Expected bytes object. def read(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
with open(filepath, "rb") as f:
content = f.read()
return content
.read_text(filepath, encoding) L100Read data from a given ``filepath`` with 'r' mode.
filepath (str or Path) — Path to read data.encoding (str) — The encoding format used to open the ``filepath``.Default — 'utf-8'.str — Expected text reading from ``filepath``. def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, "r", encoding=encoding) as f:
value_buf = f.read()
return value_buf
.write(obj, filepath) L115Write data to a given ``filepath`` with 'wb' mode.
``write`` will create a directory if the directory of ``filepath``
does not exist.
obj (bytes) — Data to be written.filepath (str or Path) — Path to write data. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
with open(filepath, "wb") as f:
f.write(obj)
.write_text(obj, filepath, encoding) L133Write data to a given ``filepath`` with 'w' mode.
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
obj (str) — Data to be written.filepath (str or Path) — Path to write data.encoding (str) — The encoding format used to open the ``filepath``.Default — 'utf-8'. def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
with open(filepath, "w", encoding=encoding) as f:
f.write(obj)
.as_local_path(filepath) L154Only for unified API and do nothing.
def as_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
yield filepath
LocalStorage.read(filepath)Read data from a given ``filepath`` with 'rb' mode.
filepath (str or Path) — Path to read data.bytes — Expected bytes object. def read(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
with open(filepath, "rb") as f:
content = f.read()
return content
LocalStorage.read_text(filepath, encoding)Read data from a given ``filepath`` with 'r' mode.
filepath (str or Path) — Path to read data.encoding (str) — The encoding format used to open the ``filepath``.Default — 'utf-8'.str — Expected text reading from ``filepath``. def read_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, "r", encoding=encoding) as f:
value_buf = f.read()
return value_buf
LocalStorage.write(obj, filepath)Write data to a given ``filepath`` with 'wb' mode.
``write`` will create a directory if the directory of ``filepath``
does not exist.
obj (bytes) — Data to be written.filepath (str or Path) — Path to write data. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
with open(filepath, "wb") as f:
f.write(obj)
LocalStorage.write_text(obj, filepath, encoding)Write data to a given ``filepath`` with 'w' mode.
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
obj (str) — Data to be written.filepath (str or Path) — Path to write data.encoding (str) — The encoding format used to open the ``filepath``.Default — 'utf-8'. def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
with open(filepath, "w", encoding=encoding) as f:
f.write(obj)
LocalStorage.as_local_path(filepath)Only for unified API and do nothing.
def as_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
yield filepath
HTTP and HTTPS storage.
class HTTPStorage(Storage):
"""HTTP and HTTPS storage."""
def read(self, url):
# TODO @wenmeng.zwm add progress bar if file is too large
"""Read.
Args:
url: TODO.
"""
r = requests.get(url)
r.raise_for_status()
return r.content
def read_text(self, url):
"""Read text.
Args:
url: TODO.
"""
r = requests.get(url)
r.raise_for_status()
return r.text
@contextlib.contextmanager
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
View full source on GitHub →.read(url) L162Read.
url — TODO. def read(self, url):
# TODO @wenmeng.zwm add progress bar if file is too large
"""Read.
Args:
url: TODO.
"""
r = requests.get(url)
r.raise_for_status()
return r.content
.read_text(url) L173Read text.
url — TODO. def read_text(self, url):
"""Read text.
Args:
url: TODO.
"""
r = requests.get(url)
r.raise_for_status()
return r.text
.as_local_path(filepath) L184Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
filepath (str) — Download a file from ``filepath``.>>> storage = HTTPStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = HTTPStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
.write(obj, url) L209Write.
obj — TODO.url — TODO. def write(self, obj: bytes, url: Union[str, Path]) -> None:
"""Write.
Args:
obj: TODO.
url: TODO.
"""
raise NotImplementedError("write is not supported by HTTP Storage")
.write_text(obj, url, encoding) L218Write text.
obj — TODO.url — TODO.encoding — TODO. def write_text(self, obj: str, url: Union[str, Path], encoding: str = "utf-8") -> None:
"""Write text.
Args:
obj: TODO.
url: TODO.
encoding: TODO.
"""
raise NotImplementedError("write_text is not supported by HTTP Storage")
HTTPStorage.read(url)Read.
url — TODO. def read(self, url):
# TODO @wenmeng.zwm add progress bar if file is too large
"""Read.
Args:
url: TODO.
"""
r = requests.get(url)
r.raise_for_status()
return r.content
HTTPStorage.read_text(url)Read text.
url — TODO. def read_text(self, url):
"""Read text.
Args:
url: TODO.
"""
r = requests.get(url)
r.raise_for_status()
return r.text
HTTPStorage.as_local_path(filepath)Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
filepath (str) — Download a file from ``filepath``.>>> storage = HTTPStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = HTTPStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
HTTPStorage.write(obj, url)Write.
obj — TODO.url — TODO. def write(self, obj: bytes, url: Union[str, Path]) -> None:
"""Write.
Args:
obj: TODO.
url: TODO.
"""
raise NotImplementedError("write is not supported by HTTP Storage")
HTTPStorage.write_text(obj, url, encoding)Write text.
obj — TODO.url — TODO.encoding — TODO. def write_text(self, obj: str, url: Union[str, Path], encoding: str = "utf-8") -> None:
"""Write text.
Args:
obj: TODO.
url: TODO.
encoding: TODO.
"""
raise NotImplementedError("write_text is not supported by HTTP Storage")
OSS storage.
class OSSStorage(Storage):
"""OSS storage."""
def __init__(self, oss_config_file=None):
# read from config file or env var
"""Initialize OSSStorage.
Args:
oss_config_file: TODO.
"""
raise NotImplementedError("OSSStorage.__init__ to be implemented in the future")
def read(self, filepath):
"""Read.
Args:
filepath: TODO.
"""
raise NotImplementedError("OSSStorage.read to be implemented in the future")
def read_text(self, filepath, encoding="utf-8"):
"""Read text.
Args:
filepath: TODO.
encoding: TODO.
"""
raise NotImplementedError("OSSStorage.read_text to be implemented in the future")
@contextlib.contextmanager
View full source on GitHub →.read(filepath) L241Read.
filepath — TODO. def read(self, filepath):
"""Read.
Args:
filepath: TODO.
"""
raise NotImplementedError("OSSStorage.read to be implemented in the future")
.read_text(filepath, encoding) L249Read text.
filepath — TODO.encoding — TODO. def read_text(self, filepath, encoding="utf-8"):
"""Read text.
Args:
filepath: TODO.
encoding: TODO.
"""
raise NotImplementedError("OSSStorage.read_text to be implemented in the future")
.as_local_path(filepath) L259Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
filepath (str) — Download a file from ``filepath``.>>> storage = OSSStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = OSSStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
.write(obj, filepath) L284Write.
obj — TODO.filepath — TODO. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write.
Args:
obj: TODO.
filepath: TODO.
"""
raise NotImplementedError("OSSStorage.write to be implemented in the future")
.write_text(obj, filepath, encoding) L293Write text.
obj — TODO.filepath — TODO.encoding — TODO. def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
"""Write text.
Args:
obj: TODO.
filepath: TODO.
encoding: TODO.
"""
raise NotImplementedError("OSSStorage.write_text to be implemented in the future")
OSSStorage.read(filepath)Read.
filepath — TODO. def read(self, filepath):
"""Read.
Args:
filepath: TODO.
"""
raise NotImplementedError("OSSStorage.read to be implemented in the future")
OSSStorage.read_text(filepath, encoding)Read text.
filepath — TODO.encoding — TODO. def read_text(self, filepath, encoding="utf-8"):
"""Read text.
Args:
filepath: TODO.
encoding: TODO.
"""
raise NotImplementedError("OSSStorage.read_text to be implemented in the future")
OSSStorage.as_local_path(filepath)Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
filepath (str) — Download a file from ``filepath``.>>> storage = OSSStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
def as_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = OSSStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
OSSStorage.write(obj, filepath)Write.
obj — TODO.filepath — TODO. def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write.
Args:
obj: TODO.
filepath: TODO.
"""
raise NotImplementedError("OSSStorage.write to be implemented in the future")
OSSStorage.write_text(obj, filepath, encoding)Write text.
obj — TODO.filepath — TODO.encoding — TODO. def write_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
"""Write text.
Args:
obj: TODO.
filepath: TODO.
encoding: TODO.
"""
raise NotImplementedError("OSSStorage.write_text to be implemented in the future")
No documentation yet.
class File(object):
_prefix_to_storage: dict = {
"oss": OSSStorage,
"http": HTTPStorage,
"https": HTTPStorage,
"local": LocalStorage,
}
@staticmethod
def _get_storage(uri):
"""Internal: get storage.
Args:
uri: TODO.
"""
assert isinstance(uri, str), f"uri should be str type, but got {type(uri)}"
if "://" not in uri:
# local path
storage_type = "local"
else:
prefix, _ = uri.split("://")
storage_type = prefix
assert storage_type in File._prefix_to_storage, (
f"Unsupported uri {uri}, valid prefixs: " f"{list(File._prefix_to_storage.keys())}"
)
if storage_type not in G_STORAGES:
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
View full source on GitHub →.read(uri) L341Read data from a given ``filepath`` with 'rb' mode.
filepath (str or Path) — Path to read data.bytes — Expected bytes object. def read(uri: str) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
storage = File._get_storage(uri)
return storage.read(uri)
.read_text(uri, encoding) L354Read data from a given ``filepath`` with 'r' mode.
filepath (str or Path) — Path to read data.encoding (str) — The encoding format used to open the ``filepath``.Default — 'utf-8'.str — Expected text reading from ``filepath``. def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
storage = File._get_storage(uri)
return storage.read_text(uri)
.write(obj, uri) L369Write data to a given ``filepath`` with 'wb' mode.
``write`` will create a directory if the directory of ``filepath``
does not exist.
obj (bytes) — Data to be written.filepath (str or Path) — Path to write data. def write(obj: bytes, uri: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
storage = File._get_storage(uri)
return storage.write(obj, uri)
.write_text(obj, uri, encoding) L384Write data to a given ``filepath`` with 'w' mode.
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
obj (str) — Data to be written.filepath (str or Path) — Path to write data.encoding (str) — The encoding format used to open the ``filepath``.Default — 'utf-8'. def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
storage = File._get_storage(uri)
return storage.write_text(obj, uri)
.as_local_path(uri) L401Only for unified API and do nothing.
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
storage = File._get_storage(uri)
with storage.as_local_path(uri) as local_path:
yield local_path
File.read(uri)Read data from a given ``filepath`` with 'rb' mode.
filepath (str or Path) — Path to read data.bytes — Expected bytes object. def read(uri: str) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
storage = File._get_storage(uri)
return storage.read(uri)
File.read_text(uri, encoding)Read data from a given ``filepath`` with 'r' mode.
filepath (str or Path) — Path to read data.encoding (str) — The encoding format used to open the ``filepath``.Default — 'utf-8'.str — Expected text reading from ``filepath``. def read_text(uri: Union[str, Path], encoding: str = "utf-8") -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
storage = File._get_storage(uri)
return storage.read_text(uri)
File.write(obj, uri)Write data to a given ``filepath`` with 'wb' mode.
``write`` will create a directory if the directory of ``filepath``
does not exist.
obj (bytes) — Data to be written.filepath (str or Path) — Path to write data. def write(obj: bytes, uri: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
storage = File._get_storage(uri)
return storage.write(obj, uri)
File.write_text(obj, uri, encoding)Write data to a given ``filepath`` with 'w' mode.
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
obj (str) — Data to be written.filepath (str or Path) — Path to write data.encoding (str) — The encoding format used to open the ``filepath``.Default — 'utf-8'. def write_text(obj: str, uri: str, encoding: str = "utf-8") -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
storage = File._get_storage(uri)
return storage.write_text(obj, uri)
File.as_local_path(uri)Only for unified API and do nothing.
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
storage = File._get_storage(uri)
with storage.as_local_path(uri) as local_path:
yield local_path
main()Main.
def main():
"""Main."""
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, required=True)
parser.add_argument("--export-dir", type=str, required=True)
parser.add_argument("--export", type=str2bool, default=True, help="whether to export model")
parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torchscript", "bladedisc"]')
parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]')
parser.add_argument("--quantize", type=str2bool, default=False, help="export quantized model")
parser.add_argument("--fallback-num", type=int, default=0, help="amp fallback number")
parser.add_argument("--audio_in", type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument("--model_revision", type=str, default=None, help="model_revision")
parser.add_argument("--calib_num", type=int, default=200, help="calib max num")
args = parser.parse_args()
model_dir = args.model_name
output_dir = args.model_name
if not Path(args.model_name).exists():
from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(
args.model_name, cache_dir=args.export_dir, revision=args.model_revision
)
output_dir = os.path.join(args.export_dir, args.model_name)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
if args.export:
View full source on GitHub →add_gradient_noise(model, iteration, duration, eta, scale_factor)Adds noise from a standard normal distribution to the gradients.
The standard deviation (`sigma`) is controlled
by the three hyper-parameters below.
`sigma` goes to zero (no noise) with more iterations.
model — Model.iteration — Number of iterations.duration — {100, 1000}: Number of durations to controleta — {0.01, 0.3, 1.0}: The magnitude of `sigma`.scale_factor — {0.55}: The scale of `sigma`.def add_gradient_noise(
model: torch.nn.Module,
iteration: int,
duration: float = 100,
eta: float = 1.0,
scale_factor: float = 0.55,
):
"""Adds noise from a standard normal distribution to the gradients.
The standard deviation (`sigma`) is controlled
by the three hyper-parameters below.
`sigma` goes to zero (no noise) with more iterations.
Args:
model: Model.
iteration: Number of iterations.
duration: {100, 1000}: Number of durations to control
the interval of the `sigma` change.
eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`.
scale_factor: {0.55}: The scale of `sigma`.
"""
interval = (iteration // duration) + 1
sigma = eta / interval**scale_factor
for param in model.parameters():
if param.grad is not None:
_shape = param.grad.size()
noise = sigma * torch.randn(_shape).to(param.device)
param.grad += noise
average_checkpoints(output_dir, last_n, **kwargs)Average the last 'last_n' checkpoints' model state_dicts.
If a tensor is of type torch.int, perform sum instead of average.
def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
"""
Average the last 'last_n' checkpoints' model state_dicts.
If a tensor is of type torch.int, perform sum instead of average.
"""
checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
print(f"average_checkpoints: {checkpoint_paths}")
state_dicts = []
# Load state_dicts from checkpoints
for path in checkpoint_paths:
if os.path.isfile(path):
state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
else:
print(f"Checkpoint file {path} not found.")
# Check if we have any state_dicts to average
if len(state_dicts) < 1:
print("No checkpoints found for averaging.")
return
# Average or sum weights
avg_state_dict = OrderedDict()
for key in state_dicts[0].keys():
tensors = [state_dict[key].cpu() for state_dict in state_dicts]
# Check the type of the tensor
if str(tensors[0].dtype).startswith("torch.int"):
# Perform sum for integer tensors
summed_tensor = sum(tensors)
avg_state_dict[key] = summed_tensor
View full source on GitHub →to_device(data, device, dtype, non_blocking, copy)Change the device of object recursively
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
"""Change the device of object recursively"""
if isinstance(data, dict):
return {k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()}
elif dataclasses.is_dataclass(data) and not isinstance(data, type):
return type(data)(
*[to_device(v, device, dtype, non_blocking, copy) for v in dataclasses.astuple(data)]
)
# maybe namedtuple. I don't know the correct way to judge namedtuple.
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(*[to_device(o, device, dtype, non_blocking, copy) for o in data])
elif isinstance(data, (list, tuple)):
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
elif isinstance(data, np.ndarray):
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
elif isinstance(data, torch.Tensor):
return data.to(device, dtype, non_blocking, copy)
else:
return data
force_gatherable(data, device)Change object to gatherable in torch.nn.DataParallel recursively
The difference from to_device() is changing to torch.Tensor if float or int
value is found.
The restriction to the returned value in DataParallel:
The object must be
or a list, tuple, dict.
def force_gatherable(data, device):
"""Change object to gatherable in torch.nn.DataParallel recursively
The difference from to_device() is changing to torch.Tensor if float or int
value is found.
The restriction to the returned value in DataParallel:
The object must be
- torch.cuda.Tensor
- 1 or more dimension. 0-dimension-tensor sends warning.
or a list, tuple, dict.
"""
if isinstance(data, dict):
return {k: force_gatherable(v, device) for k, v in data.items()}
# DataParallel can't handle NamedTuple well
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(*[force_gatherable(o, device) for o in data])
elif isinstance(data, (list, tuple, set)):
return type(data)(force_gatherable(v, device) for v in data)
elif isinstance(data, np.ndarray):
return force_gatherable(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
if data.dim() == 0:
# To 1-dim array
data = data[None]
return data.to(device)
elif isinstance(data, float):
return torch.tensor([data], dtype=torch.float, device=device)
elif isinstance(data, int):
View full source on GitHub →Wrapped module to parallelize specified method
torch.nn.DataParallel parallelizes only "forward()"
and, maybe, the method having the other name can't be applied
except for wrapping the module just like this class.
>>> class A(torch.nn.Module):
... def foo(self, x):
... ...
>>> model = A()
>>> model = ForwardAdaptor(model, "foo")
>>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
>>> x = torch.randn(2, 10)
>>> model(x)
class ForwardAdaptor(torch.nn.Module):
"""Wrapped module to parallelize specified method
torch.nn.DataParallel parallelizes only "forward()"
and, maybe, the method having the other name can't be applied
except for wrapping the module just like this class.
Examples:
>>> class A(torch.nn.Module):
... def foo(self, x):
... ...
>>> model = A()
>>> model = ForwardAdaptor(model, "foo")
>>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
>>> x = torch.randn(2, 10)
>>> model(x)
"""
def __init__(self, module: torch.nn.Module, name: str):
"""Initialize ForwardAdaptor.
Args:
module: TODO.
name: TODO.
"""
super().__init__()
self.module = module
self.name = name
if not hasattr(module, name):
raise ValueError(f"{module} doesn't have {name}")
View full source on GitHub →.forward(*args, **kwargs) L35Forward pass for training.
*args — Variable positional arguments.**kwargs — Additional keyword arguments. def forward(self, *args, **kwargs):
"""Forward pass for training.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
func = getattr(self.module, self.name)
return func(*args, **kwargs)
ForwardAdaptor.forward(*args, **kwargs)Forward pass for training.
*args — Variable positional arguments.**kwargs — Additional keyword arguments. def forward(self, *args, **kwargs):
"""Forward pass for training.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
func = getattr(self.module, self.name)
return func(*args, **kwargs)
initialize(model, init)Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Custom initialization routines can be implemented into submodules
as function `espnet_initialization_fn` within the custom module.
model — Target.init — Method of initialization.def initialize(model: torch.nn.Module, init: str):
"""Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Custom initialization routines can be implemented into submodules
as function `espnet_initialization_fn` within the custom module.
Args:
model: Target.
init: Method of initialization.
"""
# weight init
for p in model.parameters():
if p.dim() > 1:
if init == "xavier_uniform":
torch.nn.init.xavier_uniform_(p.data)
elif init == "xavier_normal":
torch.nn.init.xavier_normal_(p.data)
elif init == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
elif init == "kaiming_normal":
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
else:
raise ValueError("Unknown initialization: " + init)
# bias init
for p in model.parameters():
if p.dim() == 1:
p.data.zero_()
View full source on GitHub →load_pretrained_model(path, model, ignore_init_mismatch, map_location, oss_bucket, scope_map, excludes, **kwargs)Load a model state and set it to the model.
init_param — <file_path>:<src_key>:<dst_key>:<exclude_Keys>def load_pretrained_model(
path: str,
model: torch.nn.Module,
ignore_init_mismatch: bool = True,
map_location: str = "cpu",
oss_bucket=None,
scope_map=[],
excludes=None,
**kwargs,
):
"""Load a model state and set it to the model.
Args:
init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
Examples:
"""
obj = model
dst_state = obj.state_dict()
logging.info(f"ckpt: {path}")
if oss_bucket is None:
ori_state = torch.load(path, map_location=map_location)
else:
buffer = BytesIO(oss_bucket.get_object(path).read())
ori_state = torch.load(buffer, map_location=map_location)
View full source on GitHub →get_human_readable_count(number)Return human_readable_count
Originated from:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
Abbreviates an integer number with K, M, B, T for thousands, millions,
billions and trillions, respectively.
>>> get_human_readable_count(123)
'123 '
>>> get_human_readable_count(1234) # (one thousand)
'1 K'
>>> get_human_readable_count(2e6) # (two million)
'2 M'
>>> get_human_readable_count(3e9) # (three billion)
'3 B'
>>> get_human_readable_count(4e12) # (four trillion)
'4 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
number — a positive integer numberReturn:
A string formatted according to the pattern described above.
def get_human_readable_count(number: int) -> str:
"""Return human_readable_count
Originated from:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
Abbreviates an integer number with K, M, B, T for thousands, millions,
billions and trillions, respectively.
Examples:
>>> get_human_readable_count(123)
'123 '
>>> get_human_readable_count(1234) # (one thousand)
'1 K'
>>> get_human_readable_count(2e6) # (two million)
'2 M'
>>> get_human_readable_count(3e9) # (three billion)
'3 B'
>>> get_human_readable_count(4e12) # (four trillion)
'4 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
Args:
number: a positive integer number
Return:
A string formatted according to the pattern described above.
"""
assert number >= 0
labels = [" ", "K", "M", "B", "T"]
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
View full source on GitHub →to_bytes(dtype)To bytes.
dtype — TODO.def to_bytes(dtype) -> int:
# torch.float16 -> 16
"""To bytes.
Args:
dtype: TODO.
"""
return int(str(dtype)[-2:]) // 8
model_summary(model)Model summary.
model — Model instance or model name.def model_summary(model: torch.nn.Module) -> str:
"""Model summary.
Args:
model: Model instance or model name.
"""
message = "Model structure:\n"
message += str(model)
tot_params, num_params = 0, 0
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}, shape: {}, numel: {}".format(
name, param.dtype, param.device, param.requires_grad, param.shape, param.numel()
)
)
tot_params += param.numel()
if param.requires_grad:
num_params += param.numel()
percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
tot_params = get_human_readable_count(tot_params)
num_params = get_human_readable_count(num_params)
message += "\n\nModel summary:\n"
message += f" Class Name: {model.__class__.__name__}\n"
message += f" Total Number of model parameters: {tot_params}\n"
message += f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
dtype = next(iter(model.parameters())).dtype
message += f" Type: {dtype}"
View full source on GitHub →recursive_sum(obj, weight, distributed)Recursive sum.
obj — TODO.weight — TODO.distributed — TODO.def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
"""Recursive sum.
Args:
obj: TODO.
weight: TODO.
distributed: TODO.
"""
assert weight.dim() == 1, weight.size()
if isinstance(obj, (tuple, list)):
return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
elif isinstance(obj, dict):
return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
elif isinstance(obj, torch.Tensor):
assert obj.size() == weight.size(), (obj.size(), weight.size())
obj = (obj * weight.type(obj.dtype)).sum()
if distributed:
torch.distributed.all_reduce(obj, op=ReduceOp.SUM)
return obj
elif obj is None:
return None
else:
raise ValueError(type(obj))
recursive_divide(a, b)Recursive divide.
a — TODO.b — TODO.def recursive_divide(a, b: torch.Tensor):
"""Recursive divide.
Args:
a: TODO.
b: TODO.
"""
if isinstance(a, (tuple, list)):
return type(a)(recursive_divide(v, b) for v in a)
elif isinstance(a, dict):
return {k: recursive_divide(v, b) for k, v in a.items()}
elif isinstance(a, torch.Tensor):
assert a.size() == b.size(), (a.size(), b.size())
return a / b.type(a.dtype)
elif a is None:
return None
else:
raise ValueError(type(a))
recursive_average(obj, weight, distributed)Recursive average.
obj — TODO.weight — TODO.distributed — TODO.def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
"""Recursive average.
Args:
obj: TODO.
weight: TODO.
distributed: TODO.
"""
obj = recursive_sum(obj, weight, distributed)
weight = weight.sum()
if distributed:
torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
# Normalize weight to be sum-to-1
obj = recursive_divide(obj, weight)
return obj, weight
set_all_random_seed(seed)Set all random seed.
seed — TODO.def set_all_random_seed(seed: int):
"""Set all random seed.
Args:
seed: TODO.
"""
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
maybe_autocast(enabled)Maybe autocast.
enabled — TODO.def maybe_autocast(enabled):
"""Maybe autocast.
Args:
enabled: TODO.
"""
if enabled:
with autocast():
yield
else:
yield
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int) — Maximum number of epochs for training.model (torch.nn.Module) — The model to be trained.optim (torch.optim.Optimizer) — The optimizer to use for training.scheduler (torch.optim.lr_scheduler._LRScheduler) — The learning rate scheduler.dataloader_train (torch.utils.data.DataLoader) — DataLoader for the training dataset.dataloader_val (torch.utils.data.DataLoader) — DataLoader for the validation dataset.output_dir (str) — Directory where model checkpoints will be saved.resume (str, optional) — Path to a checkpoint to resume training from.class Trainer:
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int): Maximum number of epochs for training.
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
output_dir (str): Directory where model checkpoints will be saved.
resume (str, optional): Path to a checkpoint to resume training from.
"""
def __init__(
self,
local_rank,
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
output_dir: str = "./",
**kwargs,
):
"""
Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
Args:
model (torch.nn.Module): The model to be trained.
View full source on GitHub →.save_checkpoint(epoch, step, model, optim, scheduler, scaler, step_in_epoch, **kwargs) L143Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
epoch (int) — The epoch number at which the checkpoint is being saved. def save_checkpoint(
self,
epoch,
step=None,
model=None,
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
step_in_epoch = None if step is None else step_in_epoch
if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
state = {
"epoch": epoch,
"step": step,
"total_step": self.batch_total,
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
View full source →.resume_checkpoint(model, optim, scheduler, scaler) L260Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
resume_path (str) — The file path to the checkpoint to resume from. def resume_checkpoint(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if self.resume:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint["epoch"]
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
for k in dst_state.keys():
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
elif k.startswith("module.") and "module." + k not in src_state.keys():
k_ddp = k.replace("module.", "", 1)
else:
k_ddp = k
View full source →.train_epoch(model, optim, scheduler, scaler, dataloader_train, dataloader_val, epoch, writer, **kwargs) L340Defines the training process for a single epoch with gradient accumulation.
epoch (int) — The current epoch number. def train_epoch(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
dataloader_train=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
optim.zero_grad()
speed_stats = {}
iterator_stop = torch.tensor(0).to(self.device)
View full source →.validate_epoch(model, dataloader_val, epoch, writer, **kwargs) L537Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
epoch (int) — The current epoch number. def validate_epoch(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
speed_stats = {}
time5 = time.perf_counter()
iterator_stop = torch.tensor(0).to(self.device)
dataloader_val.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_val):
if self.use_ddp or self.use_fsdp:
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if iterator_stop > 0:
View full source →.log(epoch, batch_idx, step_in_epoch, batch_num_epoch, lr, loss, speed_stats, stats, writer, tag, data_split_i, data_split_num, log_step, **kwargs) L651Log.
epoch — TODO.batch_idx — TODO.step_in_epoch — TODO.batch_num_epoch — TODO.lr — TODO.loss — TODO.speed_stats — TODO.stats — TODO.writer — TODO.tag — TODO.data_split_i — TODO.data_split_num — TODO.log_step — TODO.**kwargs — Additional keyword arguments. def log(
self,
epoch=0,
batch_idx=0,
step_in_epoch=0,
batch_num_epoch=-1,
lr=0.0,
loss=0.0,
speed_stats=None,
stats=None,
writer=None,
tag="train",
data_split_i=0,
data_split_num=1,
log_step=None,
**kwargs,
):
"""Log.
Args:
epoch: TODO.
batch_idx: TODO.
step_in_epoch: TODO.
batch_num_epoch: TODO.
lr: TODO.
loss: TODO.
speed_stats: TODO.
stats: TODO.
writer: TODO.
View full source →.close(writer) L744Close.
writer — TODO. def close(self, writer=None):
"""Close.
Args:
writer: TODO.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
if writer is not None:
writer.close()
if self.use_ddp or self.use_fsdp:
torch.distributed.destroy_process_group()
Trainer.save_checkpoint(epoch, step, model, optim, scheduler, scaler, step_in_epoch, **kwargs)Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
epoch (int) — The epoch number at which the checkpoint is being saved. def save_checkpoint(
self,
epoch,
step=None,
model=None,
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
step_in_epoch = None if step is None else step_in_epoch
if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
state = {
"epoch": epoch,
"step": step,
"total_step": self.batch_total,
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
View full source on GitHub →Trainer.resume_checkpoint(model, optim, scheduler, scaler)Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
resume_path (str) — The file path to the checkpoint to resume from. def resume_checkpoint(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if self.resume:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint["epoch"]
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
for k in dst_state.keys():
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
elif k.startswith("module.") and "module." + k not in src_state.keys():
k_ddp = k.replace("module.", "", 1)
else:
k_ddp = k
View full source on GitHub →Trainer.train_epoch(model, optim, scheduler, scaler, dataloader_train, dataloader_val, epoch, writer, **kwargs)Defines the training process for a single epoch with gradient accumulation.
epoch (int) — The current epoch number. def train_epoch(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
dataloader_train=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
optim.zero_grad()
speed_stats = {}
iterator_stop = torch.tensor(0).to(self.device)
View full source on GitHub →Trainer.validate_epoch(model, dataloader_val, epoch, writer, **kwargs)Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
epoch (int) — The current epoch number. def validate_epoch(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
speed_stats = {}
time5 = time.perf_counter()
iterator_stop = torch.tensor(0).to(self.device)
dataloader_val.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_val):
if self.use_ddp or self.use_fsdp:
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if iterator_stop > 0:
View full source on GitHub →Trainer.log(epoch, batch_idx, step_in_epoch, batch_num_epoch, lr, loss, speed_stats, stats, writer, tag, data_split_i, data_split_num, log_step, **kwargs)Log.
epoch — TODO.batch_idx — TODO.step_in_epoch — TODO.batch_num_epoch — TODO.lr — TODO.loss — TODO.speed_stats — TODO.stats — TODO.writer — TODO.tag — TODO.data_split_i — TODO.data_split_num — TODO.log_step — TODO.**kwargs — Additional keyword arguments. def log(
self,
epoch=0,
batch_idx=0,
step_in_epoch=0,
batch_num_epoch=-1,
lr=0.0,
loss=0.0,
speed_stats=None,
stats=None,
writer=None,
tag="train",
data_split_i=0,
data_split_num=1,
log_step=None,
**kwargs,
):
"""Log.
Args:
epoch: TODO.
batch_idx: TODO.
step_in_epoch: TODO.
batch_num_epoch: TODO.
lr: TODO.
loss: TODO.
speed_stats: TODO.
stats: TODO.
writer: TODO.
View full source on GitHub →Trainer.close(writer)Close.
writer — TODO. def close(self, writer=None):
"""Close.
Args:
writer: TODO.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
if writer is not None:
writer.close()
if self.use_ddp or self.use_fsdp:
torch.distributed.destroy_process_group()
maybe_autocast(dtype, use_deepspeed)Maybe autocast.
dtype — TODO.use_deepspeed — TODO.def maybe_autocast(dtype=None, use_deepspeed=False):
"""Maybe autocast.
Args:
dtype: TODO.
use_deepspeed: TODO.
"""
if use_deepspeed:
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
yield
else:
if dtype == torch.float16 or dtype == torch.bfloat16:
with autocast(enabled=True, dtype=dtype):
yield
else:
yield
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int) — Maximum number of epochs for training.model (torch.nn.Module) — The model to be trained.optim (torch.optim.Optimizer) — The optimizer to use for training.scheduler (torch.optim.lr_scheduler._LRScheduler) — The learning rate scheduler.dataloader_train (torch.utils.data.DataLoader) — DataLoader for the training dataset.dataloader_val (torch.utils.data.DataLoader) — DataLoader for the validation dataset.output_dir (str) — Directory where model checkpoints will be saved.resume (str, optional) — Path to a checkpoint to resume training from.class Trainer:
"""
A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
and optionally resuming from a saved checkpoint.
Attributes:
max_epoch (int): Maximum number of epochs for training.
model (torch.nn.Module): The model to be trained.
optim (torch.optim.Optimizer): The optimizer to use for training.
scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
output_dir (str): Directory where model checkpoints will be saved.
resume (str, optional): Path to a checkpoint to resume training from.
"""
def __init__(
self,
rank=0,
local_rank=0,
world_size=1,
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
use_bf16: bool = False,
use_deepspeed: bool = False,
output_dir: str = "./",
**kwargs,
):
"""
View full source on GitHub →.save_checkpoint(epoch, step, model, optim, scheduler, scaler, step_in_epoch, **kwargs) L171Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
epoch (int) — The epoch number at which the checkpoint is being saved. def save_checkpoint(
self,
epoch,
step=None,
model=None,
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
step_in_epoch = None if step is None else step_in_epoch
if self.use_deepspeed:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
state = {
"epoch": epoch,
# "state_dict": model.state_dict(),
# "optimizer": optim.state_dict(),
View full source →.resume_checkpoint(model, optim, scheduler, scaler) L414Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
resume_path (str) — The file path to the checkpoint to resume from. def resume_checkpoint(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if self.resume:
if self.use_deepspeed:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.exists(ckpt):
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
self.start_epoch = checkpoint["epoch"]
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_epoch = (
checkpoint["val_acc_step_or_epoch"]
if "val_acc_step_or_epoch" in checkpoint
else {}
)
self.val_loss_step_or_epoch = (
checkpoint["val_loss_step_or_epoch"]
if "val_loss_step_or_epoch" in checkpoint
View full source →.train_epoch(model, optim, scheduler, scaler, dataloader_train, dataloader_val, epoch, **kwargs) L554Defines the training process for a single epoch with gradient accumulation.
epoch (int) — The current epoch number. def train_epoch(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
dataloader_train=None,
dataloader_val=None,
epoch=None,
**kwargs,
):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
optim.zero_grad()
speed_stats = {}
iterator_stop = torch.tensor(0).to(self.device)
dataloader_train.batch_sampler.set_epoch(epoch)
View full source →.forward_step(model, batch, loss_dict) L676Forward step.
model — Model instance or model name.batch — TODO.loss_dict — TODO. def forward_step(self, model, batch, loss_dict={}):
"""Forward step.
Args:
model: Model instance or model name.
batch: TODO.
loss_dict: TODO.
"""
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
retval = model(**batch)
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
loss_dict["loss"] = loss
loss_dict["stats"] = stats
loss_dict["weight"] = weight
.backward_step(model, scaler, loss_dict) L694Backward step.
model — Model instance or model name.scaler — TODO.loss_dict — TODO. def backward_step(self, model, scaler, loss_dict={}):
"""Backward step.
Args:
model: Model instance or model name.
scaler: TODO.
loss_dict: TODO.
"""
loss = loss_dict["loss"]
if self.use_deepspeed:
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
.update_step(model, optim, scheduler, scaler, loss_dict) L713Update step.
model — Model instance or model name.optim — TODO.scheduler — TODO.scaler — TODO.loss_dict — TODO. def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
"""Update step.
Args:
model: Model instance or model name.
optim: TODO.
scheduler: TODO.
scaler: TODO.
loss_dict: TODO.
"""
batch_idx = loss_dict["batch_idx"]
if self.use_deepspeed:
model.step()
else:
if (batch_idx + 1) % self.accum_grad == 0:
# Perform gradient clipping if it is set
if self.grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=self.grad_clip,
norm_type=self.grad_clip_type,
)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
optim.zero_grad() # Reset gradients
return
# Execute an optimization step (update model parameters)
View full source →.validate_epoch(model, dataloader_val, epoch, writer, **kwargs) L754Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
epoch (int) — The current epoch number. def validate_epoch(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
self.val_loss_avg = 0.0
self.val_acc_avg = 0.0
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
speed_stats = {}
time_beg = time.perf_counter()
time5 = time_beg
dataloader_val.batch_sampler.set_epoch(epoch)
View full source →.log(loss_dict, tag, **kwargs) L848Log.
loss_dict — TODO.tag — TODO.**kwargs — Additional keyword arguments. def log(
self,
loss_dict: dict = None,
tag="train",
**kwargs,
):
"""Log.
Args:
loss_dict: TODO.
tag: TODO.
**kwargs: Additional keyword arguments.
"""
loss = loss_dict["loss"].detach().cpu().item()
epoch = loss_dict["epoch"]
batch_idx = loss_dict["batch_idx"]
step_in_epoch = loss_dict["step_in_epoch"]
batch_total = loss_dict["batch_total"]
batch_num_epoch = loss_dict["batch_num_epoch"]
lr = loss_dict["lr"]
speed_stats = loss_dict["speed_stats"]
stats = loss_dict["stats"]
data_split_i = loss_dict["data_split_i"]
data_split_num = loss_dict["data_split_num"]
log_step = loss_dict.get("log_step", None)
if (batch_idx + 1) % self.log_interval == 0:
batch_idx = log_step if log_step is not None else batch_idx
gpu_info = (
View full source →.close(writer) L929Close.
writer — TODO. def close(self, writer=None):
"""Close.
Args:
writer: TODO.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
if writer is not None:
writer.close()
if self.use_ddp or self.use_fsdp:
torch.distributed.destroy_process_group()
.warp_model(model, **kwargs) L945Warp model.
model — Model instance or model name.**kwargs — Additional keyword arguments. def warp_model(self, model, **kwargs):
"""Warp model.
Args:
model: Model instance or model name.
**kwargs: Additional keyword arguments.
"""
if self.use_deepspeed:
from deepspeed.runtime.zero.stage_1_and_2 import (
estimate_zero2_model_states_mem_needs_all_live,
)
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
world_size = int(os.environ.get("WORLD_SIZE", 1))
# NOTE(xcsong): look in detail how the memory estimator API works:
# https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
if int(os.environ.get("RANK", 0)) == 0:
logging.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size,
)
logging.info("Estimating model states memory needs (zero3)...")
estimate_zero3_model_states_mem_needs_all_live(
model,
View full source →.warp_optim_scheduler(model, **kwargs) L997Warp optim scheduler.
model — Model instance or model name.**kwargs — Additional keyword arguments. def warp_optim_scheduler(self, model, **kwargs):
"""Warp optim scheduler.
Args:
model: Model instance or model name.
**kwargs: Additional keyword arguments.
"""
from funasr.optimizers import optim_classes
from funasr.schedulers import scheduler_classes
from omegaconf import OmegaConf, DictConfig
import json
# optim
logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
if self.use_deepspeed:
import deepspeed
args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
View full source →Trainer.save_checkpoint(epoch, step, model, optim, scheduler, scaler, step_in_epoch, **kwargs)Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
epoch (int) — The epoch number at which the checkpoint is being saved. def save_checkpoint(
self,
epoch,
step=None,
model=None,
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
intended to be called at the end of each epoch to save the training progress.
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
step_in_epoch = None if step is None else step_in_epoch
if self.use_deepspeed:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
state = {
"epoch": epoch,
# "state_dict": model.state_dict(),
# "optimizer": optim.state_dict(),
View full source on GitHub →Trainer.resume_checkpoint(model, optim, scheduler, scaler)Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
resume_path (str) — The file path to the checkpoint to resume from. def resume_checkpoint(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
):
"""
Resumes training from a checkpoint at the given file path.
Loads the model's state, the optimizer's state, and the scheduler's state.
Args:
resume_path (str): The file path to the checkpoint to resume from.
"""
if self.resume:
if self.use_deepspeed:
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.exists(ckpt):
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
self.start_epoch = checkpoint["epoch"]
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_epoch = (
checkpoint["val_acc_step_or_epoch"]
if "val_acc_step_or_epoch" in checkpoint
else {}
)
self.val_loss_step_or_epoch = (
checkpoint["val_loss_step_or_epoch"]
if "val_loss_step_or_epoch" in checkpoint
View full source on GitHub →Trainer.train_epoch(model, optim, scheduler, scaler, dataloader_train, dataloader_val, epoch, **kwargs)Defines the training process for a single epoch with gradient accumulation.
epoch (int) — The current epoch number. def train_epoch(
self,
model=None,
optim=None,
scheduler=None,
scaler=None,
dataloader_train=None,
dataloader_val=None,
epoch=None,
**kwargs,
):
"""
Defines the training process for a single epoch with gradient accumulation.
Args:
epoch (int): The current epoch number.
"""
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
model.train()
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
optim.zero_grad()
speed_stats = {}
iterator_stop = torch.tensor(0).to(self.device)
dataloader_train.batch_sampler.set_epoch(epoch)
View full source on GitHub →Trainer.forward_step(model, batch, loss_dict)Forward step.
model — Model instance or model name.batch — TODO.loss_dict — TODO. def forward_step(self, model, batch, loss_dict={}):
"""Forward step.
Args:
model: Model instance or model name.
batch: TODO.
loss_dict: TODO.
"""
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
retval = model(**batch)
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
loss_dict["loss"] = loss
loss_dict["stats"] = stats
loss_dict["weight"] = weight
Trainer.backward_step(model, scaler, loss_dict)Backward step.
model — Model instance or model name.scaler — TODO.loss_dict — TODO. def backward_step(self, model, scaler, loss_dict={}):
"""Backward step.
Args:
model: Model instance or model name.
scaler: TODO.
loss_dict: TODO.
"""
loss = loss_dict["loss"]
if self.use_deepspeed:
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
Trainer.update_step(model, optim, scheduler, scaler, loss_dict)Update step.
model — Model instance or model name.optim — TODO.scheduler — TODO.scaler — TODO.loss_dict — TODO. def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
"""Update step.
Args:
model: Model instance or model name.
optim: TODO.
scheduler: TODO.
scaler: TODO.
loss_dict: TODO.
"""
batch_idx = loss_dict["batch_idx"]
if self.use_deepspeed:
model.step()
else:
if (batch_idx + 1) % self.accum_grad == 0:
# Perform gradient clipping if it is set
if self.grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=self.grad_clip,
norm_type=self.grad_clip_type,
)
if not torch.isfinite(grad_norm):
logging.warning(
f"The grad norm is {grad_norm}. Skipping updating the model."
)
optim.zero_grad() # Reset gradients
return
# Execute an optimization step (update model parameters)
View full source on GitHub →Trainer.validate_epoch(model, dataloader_val, epoch, writer, **kwargs)Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
epoch (int) — The current epoch number. def validate_epoch(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
):
"""
Defines the validation process for a single epoch.
Should be implemented with the actual model validation steps.
Args:
epoch (int): The current epoch number.
"""
self.val_loss_avg = 0.0
self.val_acc_avg = 0.0
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
model.eval()
with torch.no_grad():
speed_stats = {}
time_beg = time.perf_counter()
time5 = time_beg
dataloader_val.batch_sampler.set_epoch(epoch)
View full source on GitHub →Trainer.log(loss_dict, tag, **kwargs)Log.
loss_dict — TODO.tag — TODO.**kwargs — Additional keyword arguments. def log(
self,
loss_dict: dict = None,
tag="train",
**kwargs,
):
"""Log.
Args:
loss_dict: TODO.
tag: TODO.
**kwargs: Additional keyword arguments.
"""
loss = loss_dict["loss"].detach().cpu().item()
epoch = loss_dict["epoch"]
batch_idx = loss_dict["batch_idx"]
step_in_epoch = loss_dict["step_in_epoch"]
batch_total = loss_dict["batch_total"]
batch_num_epoch = loss_dict["batch_num_epoch"]
lr = loss_dict["lr"]
speed_stats = loss_dict["speed_stats"]
stats = loss_dict["stats"]
data_split_i = loss_dict["data_split_i"]
data_split_num = loss_dict["data_split_num"]
log_step = loss_dict.get("log_step", None)
if (batch_idx + 1) % self.log_interval == 0:
batch_idx = log_step if log_step is not None else batch_idx
gpu_info = (
View full source on GitHub →Trainer.close(writer)Close.
writer — TODO. def close(self, writer=None):
"""Close.
Args:
writer: TODO.
"""
if self.use_ddp or self.use_fsdp:
dist.barrier()
if writer is not None:
writer.close()
if self.use_ddp or self.use_fsdp:
torch.distributed.destroy_process_group()
Trainer.warp_model(model, **kwargs)Warp model.
model — Model instance or model name.**kwargs — Additional keyword arguments. def warp_model(self, model, **kwargs):
"""Warp model.
Args:
model: Model instance or model name.
**kwargs: Additional keyword arguments.
"""
if self.use_deepspeed:
from deepspeed.runtime.zero.stage_1_and_2 import (
estimate_zero2_model_states_mem_needs_all_live,
)
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
world_size = int(os.environ.get("WORLD_SIZE", 1))
# NOTE(xcsong): look in detail how the memory estimator API works:
# https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
if int(os.environ.get("RANK", 0)) == 0:
logging.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size,
)
logging.info("Estimating model states memory needs (zero3)...")
estimate_zero3_model_states_mem_needs_all_live(
model,
View full source on GitHub →Trainer.warp_optim_scheduler(model, **kwargs)Warp optim scheduler.
model — Model instance or model name.**kwargs — Additional keyword arguments. def warp_optim_scheduler(self, model, **kwargs):
"""Warp optim scheduler.
Args:
model: Model instance or model name.
**kwargs: Additional keyword arguments.
"""
from funasr.optimizers import optim_classes
from funasr.schedulers import scheduler_classes
from omegaconf import OmegaConf, DictConfig
import json
# optim
logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
# scheduler
logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
if self.use_deepspeed:
import deepspeed
args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
View full source on GitHub →AudioDataset
class AudioDataset(torch.utils.data.Dataset):
"""
AudioDataset
"""
def __init__(
self,
path,
index_ds: str = None,
frontend=None,
tokenizer=None,
is_training: bool = True,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs,
):
"""Initialize AudioDataset.
Args:
path: TODO.
index_ds: TODO.
frontend: Audio frontend for feature extraction.
tokenizer: Tokenizer instance for text encoding/decoding.
is_training: Boolean flag for training.
int_pad_value: TODO.
float_pad_value: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__()
index_ds_class = tables.index_ds_classes.get(index_ds)
View full source on GitHub →.get_source_len(index) L67Get source len.
index — TODO. def get_source_len(self, index):
"""Get source len.
Args:
index: TODO.
"""
item = self.index_ds[index]
return self.index_ds.get_source_len(item)
.get_target_len(index) L76Get target len.
index — TODO. def get_target_len(self, index):
"""Get target len.
Args:
index: TODO.
"""
item = self.index_ds[index]
return self.index_ds.get_target_len(item)
.collator(samples) L127Collator.
samples — TODO. def collator(self, samples: list = None):
"""Collator.
Args:
samples: TODO.
"""
outputs = {}
for sample in samples:
for key in sample.keys():
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if isinstance(data_list[0], torch.Tensor):
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(
data_list, batch_first=True, padding_value=pad_value
)
return outputs
AudioDataset.get_source_len(index)Get source len.
index — TODO. def get_source_len(self, index):
"""Get source len.
Args:
index: TODO.
"""
item = self.index_ds[index]
return self.index_ds.get_source_len(item)
AudioDataset.get_target_len(index)Get target len.
index — TODO. def get_target_len(self, index):
"""Get target len.
Args:
index: TODO.
"""
item = self.index_ds[index]
return self.index_ds.get_target_len(item)
AudioDataset.collator(samples)Collator.
samples — TODO. def collator(self, samples: list = None):
"""Collator.
Args:
samples: TODO.
"""
outputs = {}
for sample in samples:
for key in sample.keys():
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if isinstance(data_list[0], torch.Tensor):
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(
data_list, batch_first=True, padding_value=pad_value
)
return outputs
No documentation yet.
class AudioDatasetHotword(AudioDataset):
# for finetuning contextual_paraformer and seaco_paraformer
def __init__(
self,
*args,
seaco_id: bool = 0,
**kwargs,
):
"""Initialize AudioDatasetHotword.
Args:
*args: Variable positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
self.seaco_id = seaco_id
def __getitem__(self, index):
"""Internal: getitem .
Args:
index: TODO.
"""
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
View full source on GitHub →.collator(samples) L268Collator.
samples — TODO. def collator(self, samples: list = None):
"""Collator.
Args:
samples: TODO.
"""
outputs = {}
hotword_indxs = []
seaco_id = samples[0]["seaco_id"]
for sample in samples:
for key in sample.keys():
if key == "seaco_id":
continue
elif key == "hotword_indx":
hotword_indxs.append(sample[key])
else:
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if isinstance(data_list[0], torch.Tensor):
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(
data_list, batch_first=True, padding_value=pad_value
)
View full source →AudioDatasetHotword.collator(samples)Collator.
samples — TODO. def collator(self, samples: list = None):
"""Collator.
Args:
samples: TODO.
"""
outputs = {}
hotword_indxs = []
seaco_id = samples[0]["seaco_id"]
for sample in samples:
for key in sample.keys():
if key == "seaco_id":
continue
elif key == "hotword_indx":
hotword_indxs.append(sample[key])
else:
if key not in outputs:
outputs[key] = []
outputs[key].append(sample[key])
for key, data_list in outputs.items():
if isinstance(data_list[0], torch.Tensor):
if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
pad_value = self.int_pad_value
else:
pad_value = self.float_pad_value
outputs[key] = torch.nn.utils.rnn.pad_sequence(
data_list, batch_first=True, padding_value=pad_value
)
View full source on GitHub →EspnetStyleBatchSampler_fn(dataset, **kwargs)Espnetstylebatchsampler fn.
dataset — TODO.**kwargs — Additional keyword arguments.def EspnetStyleBatchSampler_fn(dataset, **kwargs):
"""Espnetstylebatchsampler fn.
Args:
dataset: TODO.
**kwargs: Additional keyword arguments.
"""
dataloader_args = {}
batch_sampler = EspnetStyleBatchSampler(dataset, **kwargs)
dataloader_args["batch_sampler"] = batch_sampler
dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
return dataloader_args
No documentation yet.
class EspnetStyleBatchSampler(DistributedSampler):
def __init__(
self,
dataset,
batch_size,
batch_type="token",
rank=None,
num_replicas=None,
rank_split=False,
shuffle=True,
drop_last=False,
is_training: bool = True,
sort_size: int = 1024,
start_step: int = 0,
**kwargs,
):
"""Initialize EspnetStyleBatchSampler.
Args:
dataset: TODO.
batch_size: Number of samples per batch.
batch_type: TODO.
rank: TODO.
num_replicas: TODO.
rank_split: TODO.
shuffle: TODO.
drop_last: TODO.
is_training: Boolean flag for training.
sort_size: Size/dimension parameter.
View full source on GitHub →.set_epoch(epoch) L187Set epoch.
epoch — TODO. def set_epoch(self, epoch):
# Set the epoch for shuffling
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
EspnetStyleBatchSampler.set_epoch(epoch)Set epoch.
epoch — TODO. def set_epoch(self, epoch):
# Set the epoch for shuffling
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
No documentation yet.
class IndexDSJsonlRankFull(torch.utils.data.Dataset):
def __init__(self, path: str, **kwargs):
"""Initialize IndexDSJsonlRankFull.
Args:
path: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__()
self.max_source_length = kwargs.get("max_source_length", 2048)
self.min_source_length = kwargs.get("min_source_length", 0)
self.max_target_length = kwargs.get("max_target_length", 2048)
self.min_target_length = kwargs.get("min_target_length", 0)
self.max_token_length = kwargs.get("max_token_length", 2200)
is_training = kwargs.get("is_training", True)
if not (path.endswith(".jsonl") or path.endswith(".json")):
# jsonl list file
data_split_num = kwargs.get("data_split_num", 1)
data_split_i = kwargs.get("data_split_i", 0)
if not is_training:
data_split_num = 1
data_split_i = 0
with open(path, encoding="utf-8") as fin:
file_list_all = fin.readlines()
num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 # 16
file_list = file_list_all[
View full source on GitHub →.get_source_len(data_dict) L158Get source len.
data_dict — TODO. def get_source_len(self, data_dict):
"""Get source len.
Args:
data_dict: TODO.
"""
return data_dict.get("source_len", 1)
.get_target_len(data_dict) L166Get target len.
data_dict — TODO. def get_target_len(self, data_dict):
"""Get target len.
Args:
data_dict: TODO.
"""
return data_dict.get("target_len", 0)
IndexDSJsonlRankFull.get_source_len(data_dict)Get source len.
data_dict — TODO. def get_source_len(self, data_dict):
"""Get source len.
Args:
data_dict: TODO.
"""
return data_dict.get("source_len", 1)
IndexDSJsonlRankFull.get_target_len(data_dict)Get target len.
data_dict — TODO. def get_target_len(self, data_dict):
"""Get target len.
Args:
data_dict: TODO.
"""
return data_dict.get("target_len", 0)
gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file)Gen scp from jsonl.
jsonl_file — TODO.data_type_list — TODO.wav_scp_file — TODO.text_file — TODO.def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file):
"""Gen scp from jsonl.
Args:
jsonl_file: TODO.
data_type_list: TODO.
wav_scp_file: TODO.
text_file: TODO.
"""
wav_f = open(wav_scp_file, "w")
text_f = open(text_file, "w")
with open(jsonl_file, encoding="utf-8") as fin:
for line in fin:
data = json.loads(line.strip())
prompt = data.get("prompt", "<ASR>")
source = data[data_type_list[0]]
target = data[data_type_list[1]]
source_len = data.get("source_len", 1)
target_len = data.get("target_len", 0)
if "aishell" in source:
target = target.replace(" ", "")
key = data["key"]
wav_f.write(f"{key}\t{source}\n")
wav_f.flush()
text_f.write(f"{key}\t{target}\n")
text_f.flush()
wav_f.close()
View full source on GitHub →main_hydra(cfg)Main hydra.
cfg — Configuration overrides.def main_hydra(cfg: DictConfig):
"""Main hydra.
Args:
cfg: Configuration overrides.
"""
kwargs = OmegaConf.to_container(cfg, resolve=True)
print(kwargs)
scp_file_list = kwargs.get(
"scp_file_list",
("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
)
if isinstance(scp_file_list, str):
scp_file_list = eval(scp_file_list)
data_type_list = kwargs.get("data_type_list", ("source", "target"))
jsonl_file = kwargs.get(
"jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
)
gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list)
No documentation yet.
class SpeechPreprocessSpeedPerturb(nn.Module):
def __init__(self, speed_perturb: list = None, **kwargs):
"""Initialize SpeechPreprocessSpeedPerturb.
Args:
speed_perturb: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__()
self.speed_perturb = speed_perturb
def forward(self, waveform, fs, **kwargs):
"""Forward pass for training.
Args:
waveform: TODO.
fs: TODO.
**kwargs: Additional keyword arguments.
"""
if self.speed_perturb is None:
return waveform
speed = random.choice(self.speed_perturb)
if speed != 1.0:
if not isinstance(waveform, torch.Tensor):
waveform = torch.tensor(waveform)
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
)
waveform = waveform.view(-1)
View full source on GitHub →.forward(waveform, fs, **kwargs) L30Forward pass for training.
waveform — TODO.fs — TODO.**kwargs — Additional keyword arguments. def forward(self, waveform, fs, **kwargs):
"""Forward pass for training.
Args:
waveform: TODO.
fs: TODO.
**kwargs: Additional keyword arguments.
"""
if self.speed_perturb is None:
return waveform
speed = random.choice(self.speed_perturb)
if speed != 1.0:
if not isinstance(waveform, torch.Tensor):
waveform = torch.tensor(waveform)
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
)
waveform = waveform.view(-1)
return waveform
SpeechPreprocessSpeedPerturb.forward(waveform, fs, **kwargs)Forward pass for training.
waveform — TODO.fs — TODO.**kwargs — Additional keyword arguments. def forward(self, waveform, fs, **kwargs):
"""Forward pass for training.
Args:
waveform: TODO.
fs: TODO.
**kwargs: Additional keyword arguments.
"""
if self.speed_perturb is None:
return waveform
speed = random.choice(self.speed_perturb)
if speed != 1.0:
if not isinstance(waveform, torch.Tensor):
waveform = torch.tensor(waveform)
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]]
)
waveform = waveform.view(-1)
return waveform
No documentation yet.
class TextPreprocessSegDict(nn.Module):
def __init__(
self,
seg_dict: str = None,
text_cleaner: Collection[str] = None,
split_with_space: bool = False,
**kwargs
):
"""Initialize TextPreprocessSegDict.
Args:
seg_dict: TODO.
text_cleaner: TODO.
split_with_space: TODO.
**kwargs: Additional keyword arguments.
"""
super().__init__()
self.text_cleaner = TextCleaner(text_cleaner)
def forward(self, text, **kwargs):
"""Forward pass for training.
Args:
text: Text tensor or string input.
**kwargs: Additional keyword arguments.
"""
text = self.text_cleaner(text)
return text
.forward(text, **kwargs) L73Forward pass for training.
text — Text tensor or string input.**kwargs — Additional keyword arguments. def forward(self, text, **kwargs):
"""Forward pass for training.
Args:
text: Text tensor or string input.
**kwargs: Additional keyword arguments.
"""
text = self.text_cleaner(text)
return text
TextPreprocessSegDict.forward(text, **kwargs)Forward pass for training.
text — Text tensor or string input.**kwargs — Additional keyword arguments. def forward(self, text, **kwargs):
"""Forward pass for training.
Args:
text: Text tensor or string input.
**kwargs: Additional keyword arguments.
"""
text = self.text_cleaner(text)
return text
CustomDistributedBatchSampler_fn(dataset, **kwargs)Customdistributedbatchsampler fn.
dataset — TODO.**kwargs — Additional keyword arguments.def CustomDistributedBatchSampler_fn(dataset, **kwargs):
"""Customdistributedbatchsampler fn.
Args:
dataset: TODO.
**kwargs: Additional keyword arguments.
"""
dataloader_args = {}
batch_type = kwargs.get("batch_type", "example")
if batch_type == "example":
batch_sampler = CustomDistributedBatchSampler(dataset, **kwargs)
else:
if kwargs.get("sort_size", -1) > 0:
batch_sampler = CustomDistributedBufferDynamicBatchSampler(dataset, **kwargs)
else:
batch_sampler = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
# batch_sampler = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
dataloader_args["batch_sampler"] = batch_sampler
dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
return dataloader_args
No documentation yet.
class CustomDistributedBatchSampler(Sampler):
def __init__(
self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=False,
is_training: bool = True,
**kwargs,
):
"""Initialize CustomDistributedBatchSampler.
Args:
dataset: TODO.
batch_size: Number of samples per batch.
num_replicas: TODO.
rank: TODO.
shuffle: TODO.
drop_last: TODO.
is_training: Boolean flag for training.
**kwargs: Additional keyword arguments.
"""
try:
rank = dist.get_rank()
num_replicas = dist.get_world_size()
except:
rank = 0
View full source on GitHub →.set_epoch(epoch) L148Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
CustomDistributedBatchSampler.set_epoch(epoch)Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
No documentation yet.
class CustomDistributedBufferBatchSampler(Sampler):
def __init__(
self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=False,
is_training: bool = True,
sort_size: int = 1024,
**kwargs,
):
"""Initialize CustomDistributedBufferBatchSampler.
Args:
dataset: TODO.
batch_size: Number of samples per batch.
num_replicas: TODO.
rank: TODO.
shuffle: TODO.
drop_last: TODO.
is_training: Boolean flag for training.
sort_size: Size/dimension parameter.
**kwargs: Additional keyword arguments.
"""
try:
rank = dist.get_rank()
num_replicas = dist.get_world_size()
View full source on GitHub →.set_epoch(epoch) L284Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
CustomDistributedBufferBatchSampler.set_epoch(epoch)Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
No documentation yet.
class CustomDistributedDynamicBatchSampler(DistributedSampler):
def __init__(
self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=False,
is_training: bool = True,
**kwargs,
):
"""Initialize CustomDistributedDynamicBatchSampler.
Args:
dataset: TODO.
batch_size: Number of samples per batch.
num_replicas: TODO.
rank: TODO.
shuffle: TODO.
drop_last: TODO.
is_training: Boolean flag for training.
**kwargs: Additional keyword arguments.
"""
try:
rank = dist.get_rank()
num_replicas = dist.get_world_size()
except:
rank = 0
View full source on GitHub →.set_epoch(epoch) L384Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
CustomDistributedDynamicBatchSampler.set_epoch(epoch)Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
No documentation yet.
class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
def __init__(
self,
dataset,
batch_size,
batch_type="token",
num_replicas=None,
rank=None,
rank_split=False,
shuffle=True,
drop_last=False,
is_training: bool = True,
sort_size: int = 1024,
start_step: int = 0,
**kwargs,
):
"""Initialize CustomDistributedBufferDynamicBatchSampler.
Args:
dataset: TODO.
batch_size: Number of samples per batch.
batch_type: TODO.
num_replicas: TODO.
rank: TODO.
rank_split: TODO.
shuffle: TODO.
drop_last: TODO.
is_training: Boolean flag for training.
sort_size: Size/dimension parameter.
View full source on GitHub →.set_epoch(epoch) L526Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
CustomDistributedBufferDynamicBatchSampler.set_epoch(epoch)Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
No documentation yet.
class DistributedSamplerWarp(BatchSampler):
def __init__(
self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False
):
"""Initialize DistributedSamplerWarp.
Args:
dataset: TODO.
batch_size: Number of samples per batch.
num_replicas: TODO.
rank: TODO.
shuffle: TODO.
drop_last: TODO.
"""
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.batch_size = batch_size
self.num_replicas = num_replicas
self.rank = rank
self.shuffle = shuffle
self.drop_last = drop_last
View full source on GitHub →.set_epoch(epoch) L582Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
DistributedSamplerWarp.set_epoch(epoch)Set epoch.
epoch — TODO. def set_epoch(self, epoch):
"""Set epoch.
Args:
epoch: TODO.
"""
self.epoch = epoch
gen_jsonl_from_wav_text_list(path, data_type_list, jsonl_file_out, **kwargs)Gen jsonl from wav text list.
path — TODO.data_type_list — TODO.jsonl_file_out — TODO.**kwargs — Additional keyword arguments.def gen_jsonl_from_wav_text_list(
path, data_type_list=("source", "target"), jsonl_file_out: str = None, **kwargs
):
"""Gen jsonl from wav text list.
Args:
path: TODO.
data_type_list: TODO.
jsonl_file_out: TODO.
**kwargs: Additional keyword arguments.
"""
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
cpu_cores = os.cpu_count() or 1
print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
if rank == 0:
json_dict = {}
for data_type, data_file in zip(data_type_list, path):
json_dict[data_type] = {}
with open(data_file, "r") as f:
data_file_lists = f.readlines()
lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
# import pdb;pdb.set_trace()
View full source on GitHub →parse_context_length(data_list, data_type, id)Parse context length.
data_list — TODO.data_type — TODO.id — TODO.def parse_context_length(data_list: list, data_type: str, id=0):
"""Parse context length.
Args:
data_list: TODO.
data_type: TODO.
id: TODO.
"""
pbar = tqdm(total=len(data_list), dynamic_ncols=True)
res = {}
for i, line in enumerate(data_list):
pbar.update(1)
pbar.set_description(f"cpu: {id}")
lines = line.strip().split(maxsplit=1)
key = lines[0]
line = lines[1] if len(lines) > 1 else ""
line = line.strip()
if data_type == "source":
if os.path.exists(line):
waveform, _ = librosa.load(line, sr=16000)
sample_num = len(waveform)
context_len = int(sample_num * 1000 / 16000 / 10)
else:
print("source file not found: {}".format(line))
continue
else:
context_len = len(line.split()) if " " in line else len(line)
res[key] = {data_type: line, f"{data_type}_len": context_len}
return res
main_hydra(cfg)Main hydra.
cfg — Configuration overrides.def main_hydra(cfg: DictConfig):
"""Main hydra.
Args:
cfg: Configuration overrides.
"""
kwargs = OmegaConf.to_container(cfg, resolve=True)
print(kwargs)
scp_file_list = kwargs.get(
"scp_file_list",
("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
)
if isinstance(scp_file_list, str):
scp_file_list = eval(scp_file_list)
data_type_list = kwargs.get("data_type_list", ("source", "target"))
jsonl_file_out = kwargs.get(
"jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
)
gen_jsonl_from_wav_text_list(
scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
)
gen_jsonl_from_wav_text_list(path, data_type_list, jsonl_file_out, **kwargs)Gen jsonl from wav text list.
path — TODO.data_type_list — TODO.jsonl_file_out — TODO.**kwargs — Additional keyword arguments.def gen_jsonl_from_wav_text_list(
path, data_type_list=("source",), jsonl_file_out: str = None, **kwargs
):
"""Gen jsonl from wav text list.
Args:
path: TODO.
data_type_list: TODO.
jsonl_file_out: TODO.
**kwargs: Additional keyword arguments.
"""
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
cpu_cores = os.cpu_count() or 1
print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
if rank == 0:
json_dict = {}
# for data_type, data_file in zip(data_type_list, path):
data_type = data_type_list[0]
data_file = path
json_dict[data_type] = {}
with open(data_file, "r") as f:
data_file_lists = f.readlines()
print("")
View full source on GitHub →parse_context_length(data_list, data_type, id)Parse context length.
data_list — TODO.data_type — TODO.id — TODO.def parse_context_length(data_list: list, data_type: str, id=0):
"""Parse context length.
Args:
data_list: TODO.
data_type: TODO.
id: TODO.
"""
pbar = tqdm(total=len(data_list), dynamic_ncols=True)
res = {}
for i, line in enumerate(data_list):
pbar.update(1)
pbar.set_description(f"cpu: {id}")
lines = line.strip().split(maxsplit=1)
key = lines[0]
line = lines[1] if len(lines) > 1 else ""
line = line.strip()
if os.path.exists(line):
waveform, _ = librosa.load(line, sr=16000)
sample_num = len(waveform)
context_len = int(sample_num / 16000 * 1000 / 10)
else:
context_len = len(line.split()) if " " in line else len(line)
res[key] = {data_type: line, f"{data_type}_len": context_len}
return res
main_hydra(cfg)Main hydra.
cfg — Configuration overrides.def main_hydra(cfg: DictConfig):
"""Main hydra.
Args:
cfg: Configuration overrides.
"""
kwargs = OmegaConf.to_container(cfg, resolve=True)
print(kwargs)
scp_file_list = kwargs.get("scp_file_list", "/Users/zhifu/funasr1.0/data/list/train_wav.scp")
# if isinstance(scp_file_list, str):
# scp_file_list = eval(scp_file_list)
data_type_list = kwargs.get("data_type_list", ("source",))
jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/data/list/wav_len.txt")
gen_jsonl_from_wav_text_list(
scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
)
gen_jsonl_from_wav_text_list(path, data_type_list, jsonl_file_out, model_dir, **kwargs)Gen jsonl from wav text list.
path — TODO.data_type_list — TODO.jsonl_file_out — TODO.model_dir — TODO.**kwargs — Additional keyword arguments.def gen_jsonl_from_wav_text_list(
path, data_type_list=("source", "target"), jsonl_file_out: str = None, model_dir: str = "iic/SenseVoiceSmall", **kwargs
):
"""Gen jsonl from wav text list.
Args:
path: TODO.
data_type_list: TODO.
jsonl_file_out: TODO.
model_dir: TODO.
**kwargs: Additional keyword arguments.
"""
try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
cpu_cores = os.cpu_count() or 1
print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
if rank == 0:
json_dict = {}
for data_type, data_file in zip(data_type_list, path):
json_dict[data_type] = {}
with open(data_file, "r") as f:
data_file_lists = f.readlines()
lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
View full source on GitHub →contains_punctuation(s)Contains punctuation.
s — TODO.def contains_punctuation(s):
"""Contains punctuation.
Args:
s: TODO.
"""
punctuations = (
string.punctuation +
',。、;:?!""''()【】《》〈〉「」『』〔〕[]{}~·…—–'
)
return any(char in punctuations for char in s)
parse_context_length(data_list, data_type, id)Parse context length.
data_list — TODO.data_type — TODO.id — TODO.def parse_context_length(data_list: list, data_type: str, id=0):
"""Parse context length.
Args:
data_list: TODO.
data_type: TODO.
id: TODO.
"""
pbar = tqdm(total=len(data_list), dynamic_ncols=True)
res = {}
for i, line in enumerate(data_list):
pbar.update(1)
pbar.set_description(f"cpu: {id}")
lines = line.strip().split(maxsplit=1)
key = lines[0]
line = lines[1] if len(lines) > 1 else ""
line = line.strip()
if os.path.exists(line):
waveform, _ = librosa.load(line, sr=16000)
sample_num = len(waveform)
context_len = int(sample_num / 16000 * 1000 / 10)
else:
context_len = len(line.split()) if " " in line else len(line)
if data_type == "source":
res[key] = {data_type: line, f"{data_type}_len": context_len}
elif data_type == "target":
punc = contains_punctuation(line)
if punc:
with_or_wo_itn = "<|withitn|>"
else:
View full source on GitHub →main_hydra(cfg)Main hydra.
cfg — Configuration overrides.def main_hydra(cfg: DictConfig):
"""Main hydra.
Args:
cfg: Configuration overrides.
"""
kwargs = OmegaConf.to_container(cfg, resolve=True)
print(kwargs)
scp_file_list = kwargs.get(
"scp_file_list",
("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
)
if isinstance(scp_file_list, str):
scp_file_list = eval(scp_file_list)
data_type_list = kwargs.get("data_type_list", ("source", "target"))
jsonl_file_out = kwargs.get(
"jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
)
model_dir = kwargs.get("model_dir", "iic/SenseVoiceSmall")
gen_jsonl_from_wav_text_list(
scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out, model_dir=model_dir
)
gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu)Gen scp from jsonl.
jsonl_file — TODO.jsonl_file_out — TODO.ncpu — TODO.def gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu):
"""Gen scp from jsonl.
Args:
jsonl_file: TODO.
jsonl_file_out: TODO.
ncpu: TODO.
"""
jsonl_file_out_f = open(jsonl_file_out, "w")
with open(jsonl_file, encoding="utf-8") as fin:
lines = fin.readlines()
num_total = len(lines)
if ncpu > 1:
# 使用ThreadPoolExecutor限制并发线程数
with ThreadPoolExecutor(max_workers=ncpu) as executor:
# 提交任务到线程池
futures = {executor.submit(update_data, lines, i) for i in tqdm(range(num_total))}
# 等待所有任务完成,这会阻塞直到所有提交的任务完成
for future in concurrent.futures.as_completed(futures):
# 这里可以添加额外的逻辑来处理完成的任务,但在这个例子中我们只是等待
pass
else:
for i in range(num_total):
update_data(lines, i)
logging.info("All audio durations have been processed.")
for line in lines:
View full source on GitHub →update_data(lines, i)Update data.
lines — TODO.i — TODO.def update_data(lines, i):
"""Update data.
Args:
lines: TODO.
i: TODO.
"""
line = lines[i]
data = json.loads(line.strip())
wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
if os.path.exists(wav_path):
waveform, _ = librosa.load(wav_path, sr=16000)
sample_num = len(waveform)
source_len = int(sample_num / 16000 * 1000 / 10)
source_len_old = data["source_len"]
# if (source_len_old - source_len) > 100 or (source_len - source_len_old) > 100:
# logging.info(f"old: {source_len_old}, new: {source_len}, wav: {wav_path}")
data["source_len"] = source_len
data["source"] = wav_path
jsonl_line = json.dumps(data, ensure_ascii=False)
lines[i] = jsonl_line
update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)Update wav len.
jsonl_file_list_in — TODO.jsonl_file_out_dir — TODO.ncpu — TODO.def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
"""Update wav len.
Args:
jsonl_file_list_in: TODO.
jsonl_file_out_dir: TODO.
ncpu: TODO.
"""
os.makedirs(jsonl_file_out_dir, exist_ok=True)
with open(jsonl_file_list_in, "r") as f:
data_file_lists = f.readlines()
for i, jsonl in enumerate(data_file_lists):
filename_with_extension = os.path.basename(jsonl.strip())
jsonl_file_out = os.path.join(jsonl_file_out_dir, filename_with_extension)
logging.info(f"{i}/{len(data_file_lists)}, jsonl: {jsonl}, {jsonl_file_out}")
gen_scp_from_jsonl(jsonl.strip(), jsonl_file_out, ncpu)
main_hydra(cfg)Main hydra.
cfg — Configuration overrides.def main_hydra(cfg: DictConfig):
"""Main hydra.
Args:
cfg: Configuration overrides.
"""
kwargs = OmegaConf.to_container(cfg, resolve=True)
logging.info(kwargs)
jsonl_file_list_in = kwargs.get(
"jsonl_file_list_in", "/Users/zhifu/funasr1.0/data/list/data_jsonl.list"
)
jsonl_file_out_dir = kwargs.get("jsonl_file_out_dir", "/Users/zhifu/funasr1.0/data_tmp")
ncpu = kwargs.get("ncpu", 1)
update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)
DataloaderMapStyle(frontend, tokenizer, **kwargs)Dataloadermapstyle.
frontend — Audio frontend for feature extraction.tokenizer — Tokenizer instance for text encoding/decoding.**kwargs — Additional keyword arguments.def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
# dataset
"""Dataloadermapstyle.
Args:
frontend: Audio frontend for feature extraction.
tokenizer: Tokenizer instance for text encoding/decoding.
**kwargs: Additional keyword arguments.
"""
logging.info("Build dataloader")
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(
kwargs.get("train_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=True,
**kwargs.get("dataset_conf"),
)
dataset_val = dataset_class(
kwargs.get("valid_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=False,
**kwargs.get("dataset_conf"),
)
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
batch_sampler_val = None
if batch_sampler is not None:
View full source on GitHub →No documentation yet.
class DataloaderMapStyle:
def __init__(self, frontend=None, tokenizer=None, **kwargs):
# dataset
"""Initialize DataloaderMapStyle.
Args:
frontend: Audio frontend for feature extraction.
tokenizer: Tokenizer instance for text encoding/decoding.
**kwargs: Additional keyword arguments.
"""
logging.info("Build dataloader")
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = None
# split dataset
self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
if self.data_split_num == 1:
dataset_tr = dataset_class(
kwargs.get("train_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=True,
**kwargs.get("dataset_conf"),
)
dataset_val = dataset_class(
kwargs.get("valid_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=False,
**kwargs.get("dataset_conf"),
View full source on GitHub →.build_iter(epoch, data_split_i, start_step, **kwargs) L96Build iter.
epoch — TODO.data_split_i — TODO.start_step — TODO.**kwargs — Additional keyword arguments. def build_iter(self, epoch=0, data_split_i=0, start_step=0, **kwargs):
# reload dataset slice
"""Build iter.
Args:
epoch: TODO.
data_split_i: TODO.
start_step: TODO.
**kwargs: Additional keyword arguments.
"""
if self.data_split_num > 1:
del self.dataset_tr
self.dataset_tr = self.dataset_class(
self.kwargs.get("train_data_set_list"),
frontend=self.frontend,
tokenizer=self.tokenizer,
is_training=True,
**self.kwargs.get("dataset_conf"),
data_split_i=data_split_i,
)
# dataloader
batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
batch_sampler_val = None
if batch_sampler is not None:
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
batch_sampler = batch_sampler_class(
self.dataset_tr, start_step=start_step, **self.kwargs.get("dataset_conf")
)
View full source →DataloaderMapStyle.build_iter(epoch, data_split_i, start_step, **kwargs)Build iter.
epoch — TODO.data_split_i — TODO.start_step — TODO.**kwargs — Additional keyword arguments. def build_iter(self, epoch=0, data_split_i=0, start_step=0, **kwargs):
# reload dataset slice
"""Build iter.
Args:
epoch: TODO.
data_split_i: TODO.
start_step: TODO.
**kwargs: Additional keyword arguments.
"""
if self.data_split_num > 1:
del self.dataset_tr
self.dataset_tr = self.dataset_class(
self.kwargs.get("train_data_set_list"),
frontend=self.frontend,
tokenizer=self.tokenizer,
is_training=True,
**self.kwargs.get("dataset_conf"),
data_split_i=data_split_i,
)
# dataloader
batch_sampler = self.kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
batch_sampler_val = None
if batch_sampler is not None:
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
batch_sampler = batch_sampler_class(
self.dataset_tr, start_step=start_step, **self.kwargs.get("dataset_conf")
)
View full source on GitHub →DataloaderIterable(frontend, tokenizer, **kwargs)Dataloaderiterable.
frontend — Audio frontend for feature extraction.tokenizer — Tokenizer instance for text encoding/decoding.**kwargs — Additional keyword arguments.def DataloaderIterable(frontend=None, tokenizer=None, **kwargs):
"""Dataloaderiterable.
Args:
frontend: Audio frontend for feature extraction.
tokenizer: Tokenizer instance for text encoding/decoding.
**kwargs: Additional keyword arguments.
"""
logging.info("Build dataloader")
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "LargeDataset"))
dataset_tr = dataset_class(
kwargs.get("train_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=True,
**kwargs.get("dataset_conf"),
)
dataset_val = dataset_class(
kwargs.get("valid_data_set_list"),
frontend=frontend,
tokenizer=tokenizer,
is_training=False,
**kwargs.get("dataset_conf"),
)
return dataset_tr, dataset_val
Label — smoothing loss.:param int size: the number of class
:param int padding_idx: ignored class id
:param float smoothing: smoothing rate (0.0 means the conventional CE)
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function to be smoothed
class LabelSmoothingLoss(nn.Module):
"""Label-smoothing loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param float smoothing: smoothing rate (0.0 means the conventional CE)
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function to be smoothed
"""
def __init__(
self,
size,
padding_idx,
smoothing,
normalize_length=False,
criterion=nn.KLDivLoss(reduction="none"),
):
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x, target):
"""Compute loss between x and target.
View full source on GitHub →.forward(x, target) L42Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.contiguous().view(-1, self.size)
target = target.contiguous().view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
LabelSmoothingLoss.forward(x, target)Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.contiguous().view(-1, self.size)
target = target.contiguous().view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
No documentation yet.
class SequenceBinaryCrossEntropy(nn.Module):
def __init__(self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none")):
"""Initialize SequenceBinaryCrossEntropy.
Args:
normalize_length: TODO.
criterion: TODO.
"""
super().__init__()
self.normalize_length = normalize_length
self.criterion = criterion
def forward(self, pred, label, lengths):
"""Forward pass for training.
Args:
pred: TODO.
label: TODO.
lengths: TODO.
"""
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
.forward(pred, label, lengths) L79Forward pass for training.
pred — TODO.label — TODO.lengths — TODO. def forward(self, pred, label, lengths):
"""Forward pass for training.
Args:
pred: TODO.
label: TODO.
lengths: TODO.
"""
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
SequenceBinaryCrossEntropy.forward(pred, label, lengths)Forward pass for training.
pred — TODO.label — TODO.lengths — TODO. def forward(self, pred, label, lengths):
"""Forward pass for training.
Args:
pred: TODO.
label: TODO.
lengths: TODO.
"""
pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
Nll loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function
class NllLoss(nn.Module):
"""Nll loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function
"""
def __init__(
self,
size,
padding_idx,
normalize_length=False,
criterion=nn.NLLLoss(reduction="none"),
):
"""Construct an NllLoss object."""
super(NllLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
View full source on GitHub →.forward(x, target) L117Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
kl = self.criterion(x, target)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore, 0).sum() / denom
NllLoss.forward(x, target)Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
kl = self.criterion(x, target)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore, 0).sum() / denom
No documentation yet.
class CustomLambdaLR(_LRScheduler):
def __init__(
self,
optimizer,
warmup_steps: int = 25000,
total_steps: int = 500000,
last_epoch=-1,
verbose=False,
):
"""Initialize CustomLambdaLR.
Args:
optimizer: TODO.
warmup_steps: TODO.
total_steps: TODO.
last_epoch: TODO.
verbose: TODO.
"""
self.warmup_steps = warmup_steps
self.total_steps = total_steps
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
"""Get lr."""
step = self.last_epoch + 1
if step < self.warmup_steps:
lr_scale = step / self.warmup_steps
else:
lr_scale = max(
View full source on GitHub →.get_lr() L41Get lr.
def get_lr(self):
"""Get lr."""
step = self.last_epoch + 1
if step < self.warmup_steps:
lr_scale = step / self.warmup_steps
else:
lr_scale = max(
0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
)
return [base_lr * lr_scale for base_lr in self.base_lrs]
CustomLambdaLR.get_lr()Get lr.
def get_lr(self):
"""Get lr."""
step = self.last_epoch + 1
if step < self.warmup_steps:
lr_scale = step / self.warmup_steps
else:
lr_scale = max(
0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
)
return [base_lr * lr_scale for base_lr in self.base_lrs]
The LR scheduler proposed by Noam
Ref:
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
FIXME(kamo) — PyTorch doesn't provide _LRScheduler as public class,thus the behaviour isn't guaranteed at forward PyTorch version.
NOTE(kamo) — The "model_size" in original implementation is derived fromthe model, but in this implementation, this parameter is a constant value.
You need to change it if the model is changed.
class NoamLR(_LRScheduler, AbsBatchStepScheduler):
"""The LR scheduler proposed by Noam
Ref:
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
FIXME(kamo): PyTorch doesn't provide _LRScheduler as public class,
thus the behaviour isn't guaranteed at forward PyTorch version.
NOTE(kamo): The "model_size" in original implementation is derived from
the model, but in this implementation, this parameter is a constant value.
You need to change it if the model is changed.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
model_size: Union[int, float] = 320,
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
"""Initialize NoamLR.
Args:
optimizer: TODO.
model_size: Size/dimension parameter.
warmup_steps: TODO.
last_epoch: TODO.
"""
View full source on GitHub →.lr_for_WarmupLR(lr) L56Lr for warmuplr.
lr — TODO. def lr_for_WarmupLR(self, lr: float) -> float:
"""Lr for warmuplr.
Args:
lr: TODO.
"""
return lr / self.model_size**0.5 / self.warmup_steps**0.5
.get_lr() L71Get lr.
def get_lr(self):
"""Get lr."""
step_num = self.last_epoch + 1
return [
lr * self.model_size**-0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
for lr in self.base_lrs
]
NoamLR.lr_for_WarmupLR(lr)Lr for warmuplr.
lr — TODO. def lr_for_WarmupLR(self, lr: float) -> float:
"""Lr for warmuplr.
Args:
lr: TODO.
"""
return lr / self.model_size**0.5 / self.warmup_steps**0.5
NoamLR.get_lr()Get lr.
def get_lr(self):
"""Get lr."""
step_num = self.last_epoch + 1
return [
lr * self.model_size**-0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
for lr in self.base_lrs
]
No documentation yet.
class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
last_epoch: int = -1,
phase_ratio: Optional[List[float]] = None,
init_lr_scale: float = 0.01,
final_lr_scale: float = 0.01,
):
"""Initialize TriStageLR.
Args:
optimizer: TODO.
last_epoch: TODO.
phase_ratio: TODO.
init_lr_scale: TODO.
final_lr_scale: TODO.
"""
self.optimizer = optimizer
self.last_epoch = last_epoch
self.phase_ratio = phase_ratio
self.init_lr_scale = init_lr_scale
self.final_lr_scale = final_lr_scale
self.optimizer_lr = self.optimizer.defaults["lr"]
def init_tri_stage_scheudler(self, max_update):
"""Init tri stage scheudler.
Args:
max_update: TODO.
View full source on GitHub →.init_tri_stage_scheudler(max_update) L40Init tri stage scheudler.
max_update — TODO. def init_tri_stage_scheudler(self, max_update):
"""Init tri stage scheudler.
Args:
max_update: TODO.
"""
self.max_update = max_update
self.peak_lr = self.optimizer_lr
self.init_lr = self.init_lr_scale * self.optimizer_lr
self.final_lr = self.final_lr_scale * self.optimizer_lr
assert self.max_update > 0
assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
assert len(self.phase_ratio) == 3
self.warmup_steps = int(self.max_update * self.phase_ratio[0])
self.hold_steps = int(self.max_update * self.phase_ratio[1])
self.decay_steps = int(self.max_update * self.phase_ratio[2])
self.warmup_rate = (
(self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0
)
self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps
# initial learning rate
self.lr = self.init_lr
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
self.set_optimizer_lr(self.lr)
super().__init__(self.optimizer, self.last_epoch)
.step_update(num_updates) L96Update the learning rate after each update.
def step_update(self, num_updates):
"""Update the learning rate after each update."""
stage, steps_in_stage = self._decide_stage(num_updates)
if stage == 0:
self.lr = self.init_lr + self.warmup_rate * steps_in_stage
elif stage == 1:
self.lr = self.peak_lr
elif stage == 2:
self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
elif stage == 3:
self.lr = self.final_lr
else:
raise ValueError("Undefined stage")
self.set_optimizer_lr(self.lr)
.set_optimizer_lr(lr) L111Set optimizer lr.
lr — TODO. def set_optimizer_lr(self, lr):
"""Set optimizer lr.
Args:
lr: TODO.
"""
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
.get_lr() L120Get lr.
def get_lr(self):
"""Get lr."""
step_num = self.last_epoch + 1
self.step_update(step_num)
return [self.lr]
TriStageLR.init_tri_stage_scheudler(max_update)Init tri stage scheudler.
max_update — TODO. def init_tri_stage_scheudler(self, max_update):
"""Init tri stage scheudler.
Args:
max_update: TODO.
"""
self.max_update = max_update
self.peak_lr = self.optimizer_lr
self.init_lr = self.init_lr_scale * self.optimizer_lr
self.final_lr = self.final_lr_scale * self.optimizer_lr
assert self.max_update > 0
assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
assert len(self.phase_ratio) == 3
self.warmup_steps = int(self.max_update * self.phase_ratio[0])
self.hold_steps = int(self.max_update * self.phase_ratio[1])
self.decay_steps = int(self.max_update * self.phase_ratio[2])
self.warmup_rate = (
(self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0
)
self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps
# initial learning rate
self.lr = self.init_lr
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
self.set_optimizer_lr(self.lr)
super().__init__(self.optimizer, self.last_epoch)
TriStageLR.step_update(num_updates)Update the learning rate after each update.
def step_update(self, num_updates):
"""Update the learning rate after each update."""
stage, steps_in_stage = self._decide_stage(num_updates)
if stage == 0:
self.lr = self.init_lr + self.warmup_rate * steps_in_stage
elif stage == 1:
self.lr = self.peak_lr
elif stage == 2:
self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
elif stage == 3:
self.lr = self.final_lr
else:
raise ValueError("Undefined stage")
self.set_optimizer_lr(self.lr)
TriStageLR.set_optimizer_lr(lr)Set optimizer lr.
lr — TODO. def set_optimizer_lr(self, lr):
"""Set optimizer lr.
Args:
lr: TODO.
"""
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
TriStageLR.get_lr()Get lr.
def get_lr(self):
"""Get lr."""
step_num = self.last_epoch + 1
self.step_update(step_num)
return [self.lr]
The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following difference:
NoamLR:
lr = optimizer.lr * model_size ** -0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
class WarmupLR(_LRScheduler, AbsBatchStepScheduler):
"""The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following difference:
NoamLR:
lr = optimizer.lr * model_size ** -0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
"""Initialize WarmupLR.
Args:
optimizer: TODO.
warmup_steps: TODO.
last_epoch: TODO.
"""
self.warmup_steps = warmup_steps
View full source on GitHub →.get_lr() L50Get lr.
def get_lr(self):
"""Get lr."""
step_num = self.last_epoch + 1
return [
lr * self.warmup_steps**0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
for lr in self.base_lrs
]
WarmupLR.get_lr()Get lr.
def get_lr(self):
"""Get lr."""
step_num = self.last_epoch + 1
return [
lr * self.warmup_steps**0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
for lr in self.base_lrs
]