Torch audio 文档
索引
安装
API 教程
音频数据集
管道教程
训练实用技巧
Conformer RNN-T 语音识别
Emformer RNN-T 语音识别
Conv-TasNet 源分离
HuBERT 预训练与微调(ASR)
实时音视频自动语音识别
Python API 参考文档
Python 原型 API 参考
C++ 原型 API 参考
PyTorch 库
PyTorch
torchaudio
torchtext
torchvision
TorchElastic
TorchServe
在 XLA 设备上使用 PyTorch

使用 Emformer RNN-T 进行设备端语音识别

作者: Moto Hira, Jeff Hwang

本教程展示了如何使用 Emformer RNN-T 和流式 API 对流式设备输入(例如笔记本电脑的麦克风)进行语音识别。

本教程需要 FFmpeg 库。有关详细信息,请参阅 FFmpeg 依赖

本教程已在 MacBook Pro 和搭载 Windows 10 的 Dynabook 上测试通过。

本教程无法在 Google Colab 上运行,因为运行此教程的服务器没有可供您使用的麦克风。

1. 概述

我们使用流式 API 从音频设备(麦克风)逐块获取音频,然后使用 Emformer RNN-T 进行推理。

关于流式 API 和 Emformer RNN-T 的基本用法,请参考 StreamReader 基本用法使用 Emformer RNN-T 进行在线 ASR

2. 检查支持的设备

首先,我们需要检查 Streaming API 可以访问的设备,并确定需要传递给 StreamReader() 类的参数(srcformat)。

我们使用 ffmpeg 命令来完成此操作。ffmpeg 抽象了底层硬件实现的差异,但 format 的预期值因操作系统而异,并且每个 format 定义了不同的 src 语法。

支持的 format 值和 src 语法的详细信息可以在 https://ffmpeg.org/ffmpeg-devices.html 中找到。

对于 macOS,以下命令将列出可用的设备。

$ ffmpeg -f avfoundation -list_devices true -i dummy
...
[AVFoundation indev @ 0x126e049d0] AVFoundation video devices:
[AVFoundation indev @ 0x126e049d0] [0] FaceTime HD Camera
[AVFoundation indev @ 0x126e049d0] [1] Capture screen 0
[AVFoundation indev @ 0x126e049d0] AVFoundation audio devices:
[AVFoundation indev @ 0x126e049d0] [0] ZoomAudioDevice
[AVFoundation indev @ 0x126e049d0] [1] MacBook Pro Microphone

我们将为 Streaming API 使用以下值。

StreamReader(
    src = ":1",  # no video, audio from device 1, "MacBook Pro Microphone"
    format = "avfoundation",
)

对于 Windows 系统,dshow 设备应该可以正常工作。

> ffmpeg -f dshow -list_devices true -i dummy
...
[dshow @ 000001adcabb02c0] DirectShow video devices (some may be both video and audio devices)
[dshow @ 000001adcabb02c0]  "TOSHIBA Web Camera - FHD"
[dshow @ 000001adcabb02c0]     Alternative name "@device_pnp_\\?\usb#vid_10f1&pid_1a42&mi_00#7&27d916e6&0&0000#{65e8773d-8f56-11d0-a3b9-00a0c9223196}\global"
[dshow @ 000001adcabb02c0] DirectShow audio devices
[dshow @ 000001adcabb02c0]  "... (Realtek High Definition Audio)"
[dshow @ 000001adcabb02c0]     Alternative name "@device_cm_{33D9A762-90C8-11D0-BD43-00A0C911CE86}\wave_{BF2B8AE1-10B8-4CA4-A0DC-D02E18A56177}"

在上述情况下,可以使用以下值从麦克风进行流式传输。

StreamReader(
    src = "audio=@device_cm_{33D9A762-90C8-11D0-BD43-00A0C911CE86}\wave_{BF2B8AE1-10B8-4CA4-A0DC-D02E18A56177}",
    format = "dshow",
)

3. 数据采集

从麦克风输入流式传输音频需要正确安排数据采集的时间。如果未能做到这一点,可能会导致数据流中出现不连续的情况。

因此,我们将在子进程中运行数据采集。

首先,我们创建一个辅助函数,该函数封装了在子进程中执行的整个过程。

该函数初始化流式传输 API,采集数据并将其放入队列中,主进程会监视该队列。

import torch
import torchaudio


# The data acquisition process will stop after this number of steps.
# This eliminates the need of process synchronization and makes this
# tutorial simple.
NUM_ITER = 100


def stream(q, format, src, segment_length, sample_rate):
    from torchaudio.io import StreamReader

    print("Building StreamReader...")
    streamer = StreamReader(src, format=format)
    streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=sample_rate)

    print(streamer.get_src_stream_info(0))
    print(streamer.get_out_stream_info(0))

    print("Streaming...")
    print()
    stream_iterator = streamer.stream(timeout=-1, backoff=1.0)
    for _ in range(NUM_ITER):
        (chunk,) = next(stream_iterator)
        q.put(chunk)

与非设备流式传输的显著区别在于,我们为 stream 方法提供了 timeoutbackoff 参数。

在获取数据时,如果获取请求的速率高于硬件准备数据的速率,底层实现会报告特定的错误代码,并期望客户端代码进行重试。

精确的计时是实现流畅流式传输的关键。在重试之前,将这种错误从底层实现一直报告到 Python 层会带来不必要的开销。因此,重试行为在 C++ 层实现,timeoutbackoff 参数允许客户端代码控制该行为。

有关 timeoutbackoff 参数的详细信息,请参阅 stream() 方法的文档。

backoff 的合适值取决于系统配置。判断 backoff 值是否合适的一种方法是将获取的音频片段保存为连续的音频文件并试听。如果 backoff 值过大,数据流将不连续,生成的音频听起来会加速。如果 backoff 值过小或为零,音频流虽然正常,但数据采集过程会进入忙等待状态,从而增加 CPU 的消耗。

4. 构建推理管道

下一步是创建推理所需的组件。

该过程与使用 Emformer RNN-T 进行在线 ASR相同。

class Pipeline:
"""Build inference pipeline from RNNTBundle.

    Args:
        bundle (torchaudio.pipelines.RNNTBundle): Bundle object
        beam_width (int): Beam size of beam search decoder.
    """

    def __init__(self, bundle: torchaudio.pipelines.RNNTBundle, beam_width: int = 10):
        self.bundle = bundle
        self.feature_extractor = bundle.get_streaming_feature_extractor()
        self.decoder = bundle.get_decoder()
        self.token_processor = bundle.get_token_processor()

        self.beam_width = beam_width

        self.state = None
        self.hypotheses = None

    def infer(self, segment: torch.Tensor) -> str:
"""Perform streaming inference"""
        features, length = self.feature_extractor(segment)
        self.hypotheses, self.state = self.decoder.infer(
            features, length, self.beam_width, state=self.state, hypothesis=self.hypotheses
        )
        transcript = self.token_processor(self.hypotheses[0][0], lstrip=False)
        return transcript
class ContextCacher:
"""Cache the end of input data and prepend the next input data with it.

    Args:
        segment_length (int): The size of main segment.
            If the incoming segment is shorter, then the segment is padded.
        context_length (int): The size of the context, cached and appended.
    """

    def __init__(self, segment_length: int, context_length: int):
        self.segment_length = segment_length
        self.context_length = context_length
        self.context = torch.zeros([context_length])

    def __call__(self, chunk: torch.Tensor):
        if chunk.size(0) < self.segment_length:
            chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
        chunk_with_context = torch.cat((self.context, chunk))
        self.context = chunk[-self.context_length :]
        return chunk_with_context

5. 主流程

主进程的执行流程如下:

  1. 初始化推理管道。

  2. 启动数据采集子进程。

  3. 运行推理。

  4. 清理

由于数据采集子进程将通过“spawn”方法启动,全局作用域中的所有代码也会在子进程中执行。

我们希望仅在主进程中实例化管道,因此将其放入一个函数中,并在 __name__ == "__main__" 的保护下调用它。

def main(device, src, bundle):
    print(torch.__version__)
    print(torchaudio.__version__)

    print("Building pipeline...")
    pipeline = Pipeline(bundle)

    sample_rate = bundle.sample_rate
    segment_length = bundle.segment_length * bundle.hop_length
    context_length = bundle.right_context_length * bundle.hop_length

    print(f"Sample rate: {sample_rate}")
    print(f"Main segment: {segment_length} frames ({segment_length/sample_rate} seconds)")
    print(f"Right context: {context_length} frames ({context_length/sample_rate} seconds)")

    cacher = ContextCacher(segment_length, context_length)

    @torch.inference_mode()
    def infer():
        for _ in range(NUM_ITER):
            chunk = q.get()
            segment = cacher(chunk[:, 0])
            transcript = pipeline.infer(segment)
            print(transcript, end="\r", flush=True)

    import torch.multiprocessing as mp

    ctx = mp.get_context("spawn")
    q = ctx.Queue()
    p = ctx.Process(target=stream, args=(q, device, src, segment_length, sample_rate))
    p.start()
    infer()
    p.join()


if __name__ == "__main__":
    main(
        device="avfoundation",
        src=":1",
        bundle=torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH,
    )
Building pipeline...
Sample rate: 16000
Main segment: 2560 frames (0.16 seconds)
Right context: 640 frames (0.04 seconds)
Building StreamReader...
SourceAudioStream(media_type='audio', codec='pcm_f32le', codec_long_name='PCM 32-bit floating point little-endian', format='flt', bit_rate=1536000, sample_rate=48000.0, num_channels=1)
OutputStream(source_index=0, filter_description='aresample=16000,aformat=sample_fmts=fltp')
Streaming...

hello world

标签: torchaudio.io

本页目录