Torch audio 文档
索引
安装
API 教程
音频数据集
管道教程
训练实用技巧
Conformer RNN-T 语音识别
Emformer RNN-T 语音识别
Conv-TasNet 源分离
HuBERT 预训练与微调(ASR)
实时音视频自动语音识别
Python API 参考文档
Python 原型 API 参考
C++ 原型 API 参考
PyTorch 库
PyTorch
torchaudio
torchtext
torchvision
TorchElastic
TorchServe
在 XLA 设备上使用 PyTorch

音频输入输出

作者: Moto Hira

本教程展示了如何使用 TorchAudio 的基本 I/O API 来检查音频数据,将其加载到 PyTorch 张量中,并保存 PyTorch 张量。

最近版本中对音频 I/O 进行了多项计划/实施的更改。有关这些更改的详细信息,请参阅 Dispatcher 介绍

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)
2.6.0
2.6.0

准备工作

首先,我们导入模块并下载本教程中使用的音频资源。

在 Google Colab 中运行本教程时,请使用以下命令安装所需的包:

!pip install boto3
import io
import os
import tarfile
import tempfile

import boto3
import matplotlib.pyplot as plt
import requests
from botocore import UNSIGNED
from botocore.config import Config
from IPython.display import Audio
from torchaudio.utils import download_asset

SAMPLE_GSM = download_asset("tutorial-assets/steam-train-whistle-daniel_simon.gsm")
SAMPLE_WAV = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
SAMPLE_WAV_8000 = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042-8000hz.wav")


def _hide_seek(obj):
    class _wrapper:
        def __init__(self, obj):
            self.obj = obj

        def read(self, n):
            return self.obj.read(n)

    return _wrapper(obj)
  0%|          | 0.00/7.99k [00:00<?, ?B/s]
100%|##########| 7.99k/7.99k [00:00<00:00, 14.5MB/s]

  0%|          | 0.00/53.2k [00:00<?, ?B/s]
100%|##########| 53.2k/53.2k [00:00<00:00, 48.4MB/s]

查询音频元数据

函数 torchaudio.info() 用于获取音频的元数据。您可以提供一个路径类对象或文件类对象。

metadata = torchaudio.info(SAMPLE_WAV)
print(metadata)
AudioMetaData(sample_rate=16000, num_frames=54400, num_channels=1, bits_per_sample=16, encoding=PCM_S)

其中

  • sample_rate 是音频的采样率

  • num_channels 是声道数量

  • num_frames 是每个声道的帧数

  • bits_per_sample 是位深度

  • encoding 是采样编码格式

encoding 可以取以下值之一:

注意

  • 对于具有压缩和/或可变比特率的格式(如 MP3),bits_per_sample 可以为 0

  • 对于 GSM-FR 格式,num_frames 可以为 0

metadata = torchaudio.info(SAMPLE_GSM)
print(metadata)
AudioMetaData(sample_rate=8000, num_frames=39680, num_channels=1, bits_per_sample=0, encoding=GSM)

查询类文件对象

torchaudio.info() 适用于类文件对象。

url = "https://download.pytorch.org/torchaudio/tutorial-assets/steam-train-whistle-daniel_simon.wav"
with requests.get(url, stream=True) as response:
    metadata = torchaudio.info(_hide_seek(response.raw))
print(metadata)
AudioMetaData(sample_rate=44100, num_frames=109368, num_channels=2, bits_per_sample=16, encoding=PCM_S)

当传递一个类文件对象时,info 不会读取所有底层数据;相反,它只会从开头读取一部分数据。因此,对于给定的音频格式,它可能无法检索到正确的元数据,包括格式本身。在这种情况下,您可以传递 format 参数来指定音频的格式。

加载音频数据

要加载音频数据,您可以使用 torchaudio.load()

该函数接受类似路径的对象或类似文件的对象作为输入。

返回的值是一个包含波形(Tensor)和采样率(int)的元组。

默认情况下,生成的张量对象的 dtype=torch.float32,其值范围为 [-1.0, 1.0]

有关支持的格式列表,请参阅 torchaudio 文档

waveform, sample_rate = torchaudio.load(SAMPLE_WAV)
def plot_waveform(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")
plot_waveform(waveform, sample_rate)

waveform

def plot_specgram(waveform, sample_rate, title="Spectrogram"):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle(title)
plot_specgram(waveform, sample_rate)

Spectrogram

Audio(waveform.numpy()[0], rate=sample_rate)

从类文件对象加载

I/O 函数支持类文件对象。这使得可以从本地文件系统内外的位置获取和解码音频数据。以下示例展示了这一点。

# Load audio data as HTTP request
url = "https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
with requests.get(url, stream=True) as response:
    waveform, sample_rate = torchaudio.load(_hide_seek(response.raw))
plot_specgram(waveform, sample_rate, title="HTTP datasource")

HTTP datasource

# Load audio from tar file
tar_path = download_asset("tutorial-assets/VOiCES_devkit.tar.gz")
tar_item = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
with tarfile.open(tar_path, mode="r") as tarfile_:
    fileobj = tarfile_.extractfile(tar_item)
    waveform, sample_rate = torchaudio.load(fileobj)
plot_specgram(waveform, sample_rate, title="TAR file")

TAR file

  0%|          | 0.00/110k [00:00<?, ?B/s]
100%|##########| 110k/110k [00:00<00:00, 43.5MB/s]
# Load audio from S3
bucket = "pytorch-tutorial-assets"
key = "VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
response = client.get_object(Bucket=bucket, Key=key)
waveform, sample_rate = torchaudio.load(_hide_seek(response["Body"]))
plot_specgram(waveform, sample_rate, title="From S3")

From S3

切片技巧

提供 num_framesframe_offset 参数可以将解码限制在输入音频的相应片段。

虽然可以通过普通的 Tensor 切片操作(即 waveform[:, frame_offset:frame_offset+num_frames])实现相同的结果,但提供 num_framesframe_offset 参数更为高效。

这是因为该函数在解码完请求的帧后会立即停止数据获取和解码。当音频数据通过网络传输时,这一特性尤为有利,因为一旦获取到所需的数据量,数据传输就会停止。

以下示例展示了这一点。

# Illustration of two different decoding methods.
# The first one will fetch all the data and decode them, while
# the second one will stop fetching data once it completes decoding.
# The resulting waveforms are identical.

frame_offset, num_frames = 16000, 16000  # Fetch and decode the 1 - 2 seconds

url = "https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
print("Fetching all the data...")
with requests.get(url, stream=True) as response:
    waveform1, sample_rate1 = torchaudio.load(_hide_seek(response.raw))
    waveform1 = waveform1[:, frame_offset : frame_offset + num_frames]
    print(f" - Fetched {response.raw.tell()} bytes")

print("Fetching until the requested frames are available...")
with requests.get(url, stream=True) as response:
    waveform2, sample_rate2 = torchaudio.load(
        _hide_seek(response.raw), frame_offset=frame_offset, num_frames=num_frames
    )
    print(f" - Fetched {response.raw.tell()} bytes")

print("Checking the resulting waveform ... ", end="")
assert (waveform1 == waveform2).all()
print("matched!")
Fetching all the data...
 * Fetched 108844 bytes
Fetching until the requested frames are available...
 * Fetched 108844 bytes
Checking the resulting waveform ... matched!

保存音频到文件

为了将音频数据保存为常见应用程序可识别的格式,您可以使用 torchaudio.save()

该函数接受路径类对象或文件类对象。

当传递文件类对象时,您还需要提供参数 format,以便函数知道应使用哪种格式。如果是路径类对象,函数将从扩展名推断格式。如果保存到没有扩展名的文件,则需要提供参数 format

当保存 WAV 格式的数据时,float32 Tensor 的默认编码是 32 位浮点 PCM。您可以提供参数 encodingbits_per_sample 来更改此行为。例如,要以 16 位有符号整数 PCM 保存数据,您可以执行以下操作。

以较低的位深度保存数据会减小文件大小,但也会降低精度。

waveform, sample_rate = torchaudio.load(SAMPLE_WAV)
def inspect_file(path):
    print("-" * 10)
    print("Source:", path)
    print("-" * 10)
    print(f" - File size: {os.path.getsize(path)} bytes")
    print(f" - {torchaudio.info(path)}")
    print()

保存时不使用任何编码选项。函数将自动选择适合所提供数据的编码。

with tempfile.TemporaryDirectory() as tempdir:
    path = f"{tempdir}/save_example_default.wav"
    torchaudio.save(path, waveform, sample_rate)
    inspect_file(path)
*---------
Source: /tmp/tmpu65i6inj/save_example_default.wav
*---------
 * File size: 108878 bytes
 * AudioMetaData(sample_rate=16000, num_frames=54400, num_channels=1, bits_per_sample=16, encoding=PCM_S)

另存为 16 位有符号整数线性 PCM 文件,生成的文件占用一半的存储空间,但会损失精度

with tempfile.TemporaryDirectory() as tempdir:
    path = f"{tempdir}/save_example_PCM_S16.wav"
    torchaudio.save(path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
    inspect_file(path)
*---------
Source: /tmp/tmp9xxqlowd/save_example_PCM_S16.wav
*---------
 * File size: 108878 bytes
 * AudioMetaData(sample_rate=16000, num_frames=54400, num_channels=1, bits_per_sample=16, encoding=PCM_S)

torchaudio.save() 也可以处理其他格式。举几个例子:

formats = [
    "flac",
    # "vorbis",
    # "sph",
    # "amb",
    # "amr-nb",
    # "gsm",
]
waveform, sample_rate = torchaudio.load(SAMPLE_WAV_8000)
with tempfile.TemporaryDirectory() as tempdir:
    for format in formats:
        path = f"{tempdir}/save_example.{format}"
        torchaudio.save(path, waveform, sample_rate, format=format)
        inspect_file(path)
*---------
Source: /tmp/tmpjayiu4jz/save_example.flac
*---------
 * File size: 45262 bytes
 * AudioMetaData(sample_rate=8000, num_frames=27200, num_channels=1, bits_per_sample=16, encoding=FLAC)

保存到类文件对象

与其他 I/O 函数类似,您可以将音频保存到类文件对象中。当保存到类文件对象时,参数 format 是必需的。

waveform, sample_rate = torchaudio.load(SAMPLE_WAV)

# Saving to bytes buffer
buffer_ = io.BytesIO()
torchaudio.save(buffer_, waveform, sample_rate, format="wav")

buffer_.seek(0)
print(buffer_.read(16))
b'RIFFF\xa9\x01\x00WAVEfmt '
本页目录