(简体中文|English)

Speech Recognition

Note: The modelscope pipeline supports all the models in model zoo to inference and finetine. Here we take the typic models as examples to demonstrate the usage.

Inference

Quick start

Paraformer Model

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
)

rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
print(rec_result)

Paraformer-long Model

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
    vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
    #punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
    punc_model='damo/punc_ct-transformer_cn-en-common-vocab471067-large',
)

rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav', 
                                batch_size_token=5000, batch_size_token_threshold_s=40, max_single_segment_time=6000)
print(rec_result)

Where,

  • batch_size_token refs to dynamic batch_size and the total tokens of batch is batch_size_token, 1 token = 60 ms.

  • batch_size_token_threshold_s: The batch_size is set to 1, when the audio duration exceeds the threshold value of batch_size_token_threshold_s, specified in s.

  • max_single_segment_time: The maximum length for audio segmentation in VAD, specified in ms.

Suggestion: When encountering OOM (Out of Memory) issues with long audio inputs, as the GPU memory usage increases with the square of the audio duration, there are three possible scenarios:

  • a) In the initial inference stage, GPU memory usage primarily depends on batch_size_token. Reducing this value appropriately can help reduce memory usage.

  • b) In the middle of the inference process, when encountering long audio segments segmented by VAD, if the total number of tokens is still smaller than batch_size_token but OOM issues persist, reducing batch_size_token_threshold_s can help. If the threshold is exceeded, forcing the batch size to 1 can be considered.

  • c) Towards the end of the inference process, when encountering long audio segments segmented by VAD and the total number of tokens is smaller than batch_size_token but exceeds the threshold batch_size_token_threshold_s, forcing the batch size to 1 may still result in OOM errors. In such cases, reducing max_single_segment_time can be considered to shorten the duration of audio segments generated by VAD.

Paraformer-online Model

Streaming Decoding
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.7',
    update_model=False,
    mode='paraformer_streaming'
    )
import soundfile
speech, sample_rate = soundfile.read("example/asr_example.wav")

chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size,
              "encoder_chunk_look_back": encoder_chunk_look_back, "decoder_chunk_look_back": decoder_chunk_look_back}
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
# first chunk, 600ms
speech_chunk = speech[0:chunk_stride]
rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
print(rec_result)
# next chunk, 600ms
speech_chunk = speech[chunk_stride:chunk_stride+chunk_stride]
rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
print(rec_result)
Fake Streaming Decoding
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.7',
    update_model=False,
    mode="paraformer_fake_streaming"
)
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)

Full code of demo, please ref to demo

Paraformer-contextual Model

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

param_dict = dict()
# param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
param_dict['hotword']="邓郁松 王颖春 王晔君"
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
    param_dict=param_dict)

rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav')
print(rec_result)

UniASR Model

There are three decoding mode for UniASR model(fastnormaloffline), for more model details, please refer to docs

decoding_model = "fast" # "fast"、"normal"、"offline"
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825',
    param_dict={"decoding_model": decoding_model})

rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
print(rec_result)

The decoding mode of fast and normal is fake streaming, which could be used for evaluating of recognition accuracy. Full code of demo, please ref to demo

Paraformer-Spk

This model allows user to get recognition results which contain speaker info of each sentence. Refer to CAM++ for detailed information about speaker diarization model.

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

if __name__ == '__main__':
    audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_speaker_demo.wav'
    output_dir = "./results"
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model='damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn',
        model_revision='v0.0.2',
        vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
        punc_model='damo/punc_ct-transformer_cn-en-common-vocab471067-large',
        output_dir=output_dir,
    )
    rec_result = inference_pipeline(audio_in=audio_in, batch_size_token=5000, batch_size_token_threshold_s=40, max_single_segment_time=6000)
    print(rec_result)

MFCCA Model

For more model details, please refer to docs

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
    model_revision='v3.0.0'
)

rec_result = inference_pipeline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
print(rec_result)

API-reference

Define pipeline

  • task: Tasks.auto_speech_recognition

  • model: model name in model zoo, or model path in local disk

  • ngpu: 1 (Default), decoding on GPU. If ngpu=0, decoding on CPU

  • ncpu: 1 (Default), sets the number of threads used for intraop parallelism on CPU

  • output_dir: None (Default), the output path of results if set

  • batch_size: 1 (Default), batch size when decoding

Infer pipeline

  • audio_in: the input to decode, which could be:

    • wav_path, e.g.: asr_example.wav,

    • pcm_path, e.g.: asr_example.pcm,

    • audio bytes stream, e.g.: bytes data from a microphone

    • audio sample point,e.g.: audio, rate = soundfile.read("asr_example_zh.wav"), the dtype is numpy.ndarray or torch.Tensor

    • wav.scp, kaldi style wav list (wav_id \t wav_path), e.g.:

    asr_example1  ./audios/asr_example1.wav
    asr_example2  ./audios/asr_example2.wav
    

    In this case of wav.scp input, output_dir must be set to save the output results

  • audio_fs: audio sampling rate, only set when audio_in is pcm audio

  • output_dir: None (Default), the output path of results if set

Inference with multi-thread CPUs or multi GPUs

FunASR also offer recipes egs_modelscope/asr/TEMPLATE/infer.sh to decode with multi-thread CPUs, or multi GPUs.

Settings of infer.sh

  • model: model name in model zoo, or model path in local disk

  • data_dir: the dataset dir needs to include wav.scp. If ${data_dir}/text is also exists, CER will be computed

  • output_dir: output dir of the recognition results

  • batch_size: 64 (Default), batch size of inference on gpu

  • gpu_inference: true (Default), whether to perform gpu decoding, set false for CPU inference

  • gpuid_list: 0,1 (Default), which gpu_ids are used to infer

  • njob: only used for CPU inference (gpu_inference=false), 64 (Default), the number of jobs for CPU decoding

  • checkpoint_dir: only used for infer finetuned models, the path dir of finetuned models

  • checkpoint_name: only used for infer finetuned models, valid.cer_ctc.ave.pb (Default), which checkpoint is used to infer

  • decoding_mode: normal (Default), decoding mode for UniASR model(fast、normal、offline)

  • hotword_txt: None (Default), hotword file for contextual paraformer model(the hotword file name ends with .txt”)

Decode with multi GPUs:

    bash infer.sh \
    --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
    --data_dir "./data/test" \
    --output_dir "./results" \
    --batch_size 64 \
    --gpu_inference true \
    --gpuid_list "0,1"

Decode with multi-thread CPUs:

    bash infer.sh \
    --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
    --data_dir "./data/test" \
    --output_dir "./results" \
    --gpu_inference false \
    --njob 64

Results

The decoding results can be found in $output_dir/1best_recog/text.cer, which includes recognition results of each sample and the CER metric of the whole test set.

If you decode the SpeechIO test sets, you can use textnorm with stage=3, and DETAILS.txt, RESULTS.txt record the results and CER after text normalization.

Finetune with pipeline

Quick start

finetune.py

import os

from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer

from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.modelscope_param import modelscope_args


def modelscope_finetune(params):
    if not os.path.exists(params.output_dir):
        os.makedirs(params.output_dir, exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = MsDataset.load(params.data_path)
    kwargs = dict(
        model=params.model,
        data_dir=ds_dict,
        dataset_type=params.dataset_type,
        work_dir=params.output_dir,
        batch_bins=params.batch_bins,
        max_epoch=params.max_epoch,
        lr=params.lr,
        mate_params=params.param_dict)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()


if __name__ == '__main__':
    params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", data_path="./data")
    params.output_dir = "./checkpoint"              # m模型保存路径
    params.data_path = "speech_asr_aishell1_trainsets"            # 数据路径
    params.dataset_type = "small"                   # 小数据量设置small,若数据量大于1000小时,请使用large
    params.batch_bins = 2000                       # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
    params.max_epoch = 20                           # 最大训练轮数
    params.lr = 0.00005                             # 设置学习率
    init_param = []                                 # 初始模型路径,默认加载modelscope模型初始化,例如: ["checkpoint/20epoch.pb"]
    freeze_param = []                               # 模型参数freeze, 例如: ["encoder"]
    ignore_init_mismatch = True                     # 是否忽略模型参数初始化不匹配
    use_lora = False                                # 是否使用lora进行模型微调
    params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
    if use_lora:
        enable_lora = True
        lora_bias = "all"
        lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
        lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
        params.param_dict.update(lora_config)

    modelscope_finetune(params)
python finetune.py &> log.txt &

Finetune with your data

  • Modify finetune training related parameters in finetune.py

    • output_dir: result dir

    • data_dir: the dataset dir needs to include files: train/wav.scp, train/text; validation/wav.scp, validation/text

    • dataset_type: for dataset larger than 1000 hours, set as large, otherwise set as small

    • batch_bins: batch size. For dataset_type is small, batch_bins indicates the feature frames. For dataset_type is large, batch_bins indicates the duration in ms

    • max_epoch: number of training epoch

    • lr: learning rate

    • init_param: [](Default), init model path, load modelscope model initialization by default. For example: [“checkpoint/20epoch.pb”]

    • freeze_param: [](Default), Freeze model parameters. For example:[“encoder”]

    • ignore_init_mismatch: True(Default), Ignore size mismatch when loading pre-trained model

    • use_lora: False(Default), Fine-tuning model use lora, more detail please refer to LORA

  • Training data formats:

cat ./example_data/text
BAC009S0002W0122 而 对 楼 市 成 交 抑 制 作 用 最 大 的 限 购
BAC009S0002W0123 也 成 为 地 方 政 府 的 眼 中 钉
english_example_1 hello world
english_example_2 go swim 去 游 泳

cat ./example_data/wav.scp
BAC009S0002W0122 /mnt/data/wav/train/S0002/BAC009S0002W0122.wav
BAC009S0002W0123 /mnt/data/wav/train/S0002/BAC009S0002W0123.wav
english_example_1 /mnt/data/wav/train/S0002/english_example_1.wav
english_example_2 /mnt/data/wav/train/S0002/english_example_2.wav
  • Then you can run the pipeline to finetune with:

python finetune.py

If you want finetune with multi-GPUs, you could:

CUDA_VISIBLE_DEVICES=1,2 python -m torch.distributed.launch --nproc_per_node 2 finetune.py > log.txt 2>&1

Inference with your finetuned model

    bash infer.sh \
    --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
    --data_dir "./data/test" \
    --output_dir "./results" \
    --batch_size 64 \
    --gpu_inference true \
    --gpuid_list "0,1" \
    --checkpoint_dir "./checkpoint" \
    --checkpoint_name "valid.cer_ctc.ave.pb"
  • Decode with multi-thread CPUs:

    bash infer.sh \
    --model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
    --data_dir "./data/test" \
    --output_dir "./results" \
    --gpu_inference false \
    --njob 64 \
    --checkpoint_dir "./checkpoint" \
    --checkpoint_name "valid.cer_ctc.ave.pb"