espnet的enh训练任务分析笔记

前言

This is the common recipe for ESPnet2 speech enhancement frontend.
这是 ESPnet2 语音增强前端的通用配方。

本文为笔者学习espnet语音处理包语音增强部分的学习笔记,初稿为claude-3.5-sonnet辅助生成,后续会不断在此基础上更新,加入自己的理解。

以下为espnet工具包相关的网址

ESPnet

espnet installation

enh.sh

enh.sh官方文档

该enh.sh在espnet中的位置:egs2/TEMPLATE/enh1/enh.sh , 13 stages are included.

训练任务流程

  • 选择数据集
  • 选择配置文件(可更改具体参数)
  • 运行脚本

以经典数据集wsj0_2mix为例,wsj0_2mix,在conf目录的tuning子目录中选择配置,我选择的是train_enh_rnn_tf.yaml,该配置用于训练一个基于 RNN 的语音分离模型,其中tf后缀在这个配置文件名中代表 Time-Frequency domain(时频域)

接着运行run.sh,比如:

1
./run.sh --stage 1 --stop_stage 6 --conf conf/tuning/train_enh_rnn_tf.yaml

以下是run.sh的具体内容(因为很短就贴出来)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail

min_or_max=min # "min" or "max". This is to determine how the mixtures are generated in local/data.sh.
sample_rate=8k


train_set="tr_${min_or_max}_${sample_rate}"
valid_set="cv_${min_or_max}_${sample_rate}"
test_sets="tt_${min_or_max}_${sample_rate} "

./enh.sh \
--train_set "${train_set}" \
--valid_set "${valid_set}" \
--test_sets "${test_sets}" \
--fs "${sample_rate}" \
--lang en \
--ngpu 1 \
--local_data_opts "--sample_rate ${sample_rate} --min_or_max ${min_or_max}" \
--enh_config conf/tuning/train_enh_dprnn_tasnet.yaml \
"$@"

注意到,调用的enh.sh,实际上就是TEMPLATE中的enh.sh(下文会具体分析)

1) 训练过程监控:

  • 查看日志: tail -f exp/enh_train_*/train.log
  • 查看关键指标:grep "loss:" exp/enh_train_*/train.log

2) 评估模型

1
2
./run.sh --stage 7 --stop_stage 8 \
--conf conf/tuning/train_enh_rnn_tf.yaml

查看评估结果:

  • 结果保存在 exp/enh_train_*/RESULTS.txt
  • 包含SI-SNR、SDR等指标

3) 使用模型

对单个音频进行增强:

1
2
3
4
5
python -m espnet2.bin.enh_inference \
--audio_file /path/to/mixed.wav \
--config exp/enh_train_*/config.yaml \
--model_file exp/enh_train_*/valid.acc.best.pth \
--output_dir ./enhanced

获取增强后的音频:

  • 增强结果保存在 ./enhanced 目录
  • 每个说话人的分离结果单独保存

注意事项

1) 训练中断后继续:

  • 直接运行相同的命令即可
  • ESPnet会自动加载最新的检查点

2) 常见问题:

  • 内存不足: 减小 batch_size
  • 显存不足: 减小 batch_size 或使用梯度累积
  • 训练不收敛: 调整学习率或检查数据预处理

3) 建议:

  • 先用小数据集测试流程
  • 保存好配置文件和日志
  • 定期备份实验结果

配置文件分析

基础训练参数

1
2
3
4
5
6
optim: adam  # 优化器选择:adam优化器
init: xavier_uniform # 参数初始化方式:xavier均匀分布初始化
max_epoch: 100 # 最大训练轮数
batch_type: folded # 批次类型:folded表示按序列长度折叠
batch_size: 8 # 每批次样本数
num_workers: 4 # 数据加载器的并行工作进程数

优化器配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
optim_conf:
lr: 1.0e-03 # 初始学习率
eps: 1.0e-08 # 数值稳定性参数
weight_decay: 1.0e-7 # L2正则化系数

# 早停耐心值:验证集性能多少轮未改善就停止
patience: 10

# 验证集调度器判断标准
val_scheduler_criterion:
- valid
- loss

# 最佳模型保存标准
best_model_criterion:
- - valid
- si_snr # 尺度不变信噪比
- max # 最大化
- - valid
- loss # 损失值
- min # 最小化

# 保存最好的模型数量
keep_nbest_models: 1

# 学习率调度器:当验证集性能不再提升时降低学习率
scheduler: reducelronplateau
scheduler_conf:
mode: min # 监控模式:最小化
factor: 0.7 # 学习率降低因子
patience: 1 # 调度器耐心值

损失函数配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# 第一个损失函数:均方误差(MSE)
- name: mse
conf:
compute_on_mask: True # 在掩码上计算
mask_type: PSM # 相位敏感掩码
wrapper: pit # 用PIT(排列不变训练)包装
wrapper_conf:
weight: 1.0 # 损失权重

# 第二个损失函数:L1损失
- name: l1
conf:
compute_on_mask: False # 在波形上计算
wrapper: pit
wrapper_conf:
weight: 1.0
independent_perm: False # 使用前一个criterion的排列顺序

# 第三个损失函数:SI-SNR损失
- name: si_snr
conf:
eps: 1.0e-7 # 数值稳定性参数
wrapper: pit
wrapper_conf:
weight: 5.0 # 较大权重表示更重视此损失
independent_perm: False

模型架构配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
encoder: stft # STFT编码器配置
encoder_conf:
n_fft: 256 # FFT点数
hop_length: 128 # 帧移

# STFT解码器配置
decoder: stft
decoder_conf:
n_fft: 256
hop_length: 128

# 分离器配置:RNN架构
separator: rnn
separator_conf:
rnn_type: blstm # 双向LSTM
num_spk: 2 # 说话人数量
nonlinear: relu # 激活函数
layer: 3 # RNN层数
unit: 896 # 隐层单元数
dropout: 0.5 # Dropout比率

enh.sh 的分析

Stage 1前的配置介绍

基本设置

  1. bash调试模式设置

    1
    2
    3
    set -e        # 遇到错误就退出
    set -u # 使用未定义变量时报错
    set -o pipefail # 管道中任一命令失败则整个管道失败
  2. 辅助函数

    1
    2
    3
    4
    5
    6
    7
    8
    # 日志函数:打印时间戳和调用位置信息
    log() {
    local fname=${BASH_SOURCE[1]##*/}
    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
    }

    # 求最小值函数:用于计算并行作业数
    min()

注:让我比较疑惑的一个点是为什么不把日志重定向输出到一个文件?直接echo的话不会很长吗?

必填参数

  1. 数据集相关
  • --train_set: 训练集名称
  • --valid_set: 验证集名称
  • --test_sets: 测试集名称列表

选填参数

  1. 基本配置参数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    stage=1                # 处理开始的阶段
    stop_stage=10000 # 处理结束的阶段
    skip_data_prep=false # 是否跳过数据准备阶段
    skip_train=false # 是否跳过训练阶段
    skip_eval=false # 是否跳过推理和评估阶段
    skip_packing=true # 是否跳过打包阶段
    skip_upload_hf=true # 是否跳过上传到HuggingFace阶段
    ngpu=1 # GPU数量(0表示使用CPU)
    num_nodes=1 # 节点数量
    nj=32 # 并行作业数
  2. 特征提取相关参数

    1
    2
    3
    4
    5
    feats_type=raw        # 特征类型(raw或fbank_pitch)
    audio_format=flac # 音频格式:wav,flac等
    fs=16k # 采样率
    min_wav_duration=0.1 # 最短音频长度(秒)
    max_wav_duration=20 # 最长音频长度(秒)
  3. 增强模型相关参数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    enh_exp=             # 增强实验目录路径
    enh_tag= # 增强模型训练结果目录的后缀
    enh_config= # 增强模型训练配置
    enh_args= # 增强模型训练的额外参数
    ref_num=2 # 参考信号数量(等于说话人数量)
    inf_num= # 模型输出的推理结果数量
    noise_type_num=1 # 输入音频中的噪声类型数量
    dereverb_ref_num=1 # 去混响参考信号数量
    is_tse_task=false # 是否为目标说话人提取任务
  4. 训练数据相关参数

    1
    2
    3
    use_dereverb_ref=false   # 是否使用去混响参考信号
    use_noise_ref=false # 是否使用噪声参考信号
    variable_num_refs=false # 是否使用可变数量的参考信号
  5. 推理和评估相关参数

    1
    2
    3
    inference_args="--normalize_output_wav true --output_format wav"  # 推理参数
    inference_model=valid.loss.ave.pth # 推理使用的模型文件
    scoring_protocol="STOI SDR SAR SIR SI_SNR" # 评分指标

各Stage功能详细分析

Stage 1: 数据准备

  • 功能:准备训练、验证和测试数据集
  • 执行:调用local/data.sh脚本处理数据
  • 关键代码
    1
    2
    3
    4
    5
    if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    log "Stage 1: Data preparation for data/${train_set}, data/${valid_set}, etc."
    # [Task dependent] 需要为新语料库创建data.sh
    local/data.sh ${local_data_opts}
    fi
  • 重要说明
    • 这个阶段是任务相关的,需要根据具体的语料库创建相应的data.sh脚本
    • local_data_opts参数可以传递给data.sh进行数据处理的定制
  • 运行产出
    • data/${train_set}, data/${valid_set} 等目录下生成:
      • wav.scp:音频文件路径映射
      • utt2spk:话语到说话人映射
      • spk2utt:说话人到话语映射
      • mix.scp:混合音频文件列表
      • ref.scp:参考音频文件列表

注:笔者一开始在找了好久data.sh在哪里,后面发现在具体的数据集中(详见上文训练任务流程)

data.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env bash

# 设置bash的错误处理
set -e # 遇到错误就退出
set -u # 使用未定义变量时报错
set -o pipefail # 管道中任一命令失败则整个管道失败

# 定义日志函数
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

# 帮助信息
help_message=$(cat << EOF
Usage: $0 [--min_or_max <min/max>] [--sample_rate <8k/16k>]
optional argument:
[--min_or_max]: min (Default), max # 混合方式:最小或最大
[--sample_rate]: 8k (Default), 16k # 采样率:8k或16k
EOF
)

# 导入数据库配置
. ./db.sh

# 设置路径变量
wsj_full_wav=$PWD/data/wsj0/wsj0_wav # WSJ0原始音频路径
wsj_2mix_wav=$PWD/data/wsj0_mix/2speakers # 双说话人混合音频路径
wsj_2mix_scripts=$PWD/data/wsj0_mix/scripts # 混合脚本路径

# 设置文本相关变量
other_text=data/local/other_text/text # 其他文本数据路径
nlsyms=data/nlsyms.txt # 非语言符号文件
min_or_max=min # 默认混合方式为min
sample_rate=8k # 默认采样率为8k

# 解析命令行参数
. utils/parse_options.sh

# 检查WSJ0和WSJ1数据集路径是否存在
if [ ! -e "${WSJ0}" ]; then
log "Fill the value of 'WSJ0' of db.sh"
exit 1
fi
if [ ! -e "${WSJ1}" ]; then
log "Fill the value of 'WSJ1' of db.sh"
exit 1
fi

# 设置数据集名称
train_set="tr_"${min_or_max}_${sample_rate} # 训练集
train_dev="cv_"${min_or_max}_${sample_rate} # 验证集
recog_set="tt_"${min_or_max}_${sample_rate} # 测试集

### WSJ0混合数据处理部分 ###
# 下载混合脚本并创建双说话人混合音频
local/wsj0_create_mixture.sh ${wsj_2mix_scripts} ${WSJ0} ${wsj_full_wav} \
${wsj_2mix_wav} || exit 1;

# 准备WSJ0_2mix数据集
local/wsj0_2mix_data_prep.sh --min-or-max ${min_or_max} --sample-rate ${sample_rate} \
${wsj_2mix_wav}/wav${sample_rate}/${min_or_max} ${wsj_2mix_scripts} ${wsj_full_wav} || exit 1;

### 创建参考音频的.scp文件 ###
# 为每个数据集创建说话人1和说话人2的scp文件
for folder in ${train_set} ${train_dev} ${recog_set}; do
sed -e 's/\/mix\//\/s1\//g' ./data/$folder/wav.scp > ./data/$folder/spk1.scp
sed -e 's/\/mix\//\/s2\//g' ./data/$folder/wav.scp > ./data/$folder/spk2.scp
done

### WSJ语料库处理部分 ###
# 准备WSJ数据
log "local/wsj_data_prep.sh ${WSJ0}/??-{?,??}.? ${WSJ1}/??-{?,??}.?"
local/wsj_data_prep.sh ${WSJ0}/??-{?,??}.? ${WSJ1}/??-{?,??}.?

# 格式化WSJ数据
log "local/wsj_format_data.sh"
local/wsj_format_data.sh

# 创建wsj目录并移动相关数据
log "mkdir -p data/wsj"
mkdir -p data/wsj
log "mv data/{dev_dt_*,local,test_dev*,test_eval*,train_si284} data/wsj"
mv data/{dev_dt_*,local,test_dev*,test_eval*,train_si284} data/wsj

# 准备额外的文本数据
log "Prepare text from lng_modl dir..."
mkdir -p "$(dirname ${other_text})"

# 处理语言模型训练数据
zcat ${WSJ1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z | \
grep -v "<" | tr "[:lower:]" "[:upper:]" | \
awk '{ printf("wsj1_lng_%07d %s\n",NR,$0) } ' > ${other_text}

# 创建非语言符号文件
log "Create non linguistic symbols: ${nlsyms}"
cut -f 2- data/wsj/train_si284/text | tr " " "\n" | sort | uniq | grep "<" > ${nlsyms}
cat ${nlsyms}

开始
├─ 检查WSJ数据集路径
├─ 生成混合语音数据
├─ 创建说话人分离文件
├─ 准备WSJ原始数据
├─ 处理附加文本
└─ 提取非语言符号
结束

Stage 2: 速度扰动

  • 功能:对训练数据进行速度扰动增强
  • 条件:仅在设置了speed_perturb_factors且不使用去混响参考时执行
  • 处理:对音频进行不同速度的扰动,生成增强数据
  • 运行产出
    • data/${train_set}_sp 目录下生成:
      • 扰动后的音频文件和对应的配置文件
      • 更新的 wav.scp, utt2spk, spk2utt 等文件

Stage 3: 音频格式化

  • 功能:统一处理音频格式
  • 关键代码
    1
    2
    3
    4
    5
    # 格式化wav.scp文件
    scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
    --out-filename "${spk}.scp" \
    --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
    "data/${dset}/${spk}.scp" "${data_feats}${_suf}/${dset}"
  • 处理步骤
    1. 重新创建”wav.scp”文件
    2. 统一音频格式和采样率
    3. 处理多说话人的情况
    4. 支持segments文件的分割处理
  • 运行产出
    • ${data_feats}/${dset} 目录下:
      • 统一格式后的音频文件
      • 更新的 wav.scp 文件
      • 各说话人的 .scp 文件

Stage 4: 数据筛选

  • 功能:移除不符合长度要求的音频数据
  • 关键代码
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    # 计算最小和最大长度(样本数)
    _fs=$(python3 -c "import humanfriendly as h;print(h.parse_size('${fs}'))")
    _min_length=$(python3 -c "print(int(${min_wav_duration} * ${_fs}))")
    _max_length=$(python3 -c "print(int(${max_wav_duration} * ${_fs}))")

    # 根据长度筛选数据
    <"${data_feats}/org/${dset}/utt2num_samples" \
    awk -v min_length="${_min_length}" -v max_length="${_max_length}" \
    '{ if ($2 > min_length && $2 < max_length ) print $0; }' \
    >"${data_feats}/${dset}/utt2num_samples"
  • 处理步骤
    1. 将时间长度转换为样本数
    2. 根据样本数筛选音频
    3. 更新相关的scp文件
  • 运行产出
    • ${data_feats}/${dset} 目录下:
      • 筛选后的 utt2num_samples 文件
      • 更新后的 wav.scp, spk.scp 等文件
      • 移除不符合长度要求的音频条目

Stage 5: 统计收集

  • 功能:收集训练所需的统计信息
  • 关键代码
    1
    2
    3
    4
    5
    6
    7
    ${python} -m ${train_module} \
    --collect_stats true \
    ${_train_data_param} \
    ${_valid_data_param} \
    --train_shape_file "${_logdir}/train.JOB.scp" \
    --valid_shape_file "${_logdir}/valid.JOB.scp" \
    --output_dir "${_logdir}/stats.JOB"
  • 处理步骤
    1. 收集训练和验证数据的统计信息
    2. 生成shape文件
    3. 聚合统计信息
  • 运行产出
    • ${_logdir} 目录下:
      • stats.JOB 目录:包含统计信息
      • train.JOB.scp:训练数据shape信息
      • valid.JOB.scp:验证数据shape信息
      • global_stats:全局统计信息

Stage 6: 模型训练

  • 功能:执行增强模型的训练
  • 关键代码
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    ${python} -m ${train_module} \
    ${_train_data_param} \
    ${_valid_data_param} \
    ${_train_shape_param} \
    ${_valid_shape_param} \
    ${_fold_length_param} \
    --resume true \
    --output_dir "${enh_exp}" \
    ${init_param:+--init_param $init_param} \
    ${_opts} ${enh_args}
  • 处理步骤
    1. 设置训练数据和验证数据
    2. 配置训练参数
    3. 支持断点续训
    4. 可选预训练模型初始化
  • 运行产出
    • ${enh_exp} 目录下:
      • config.yaml:模型配置文件
      • 模型检查点文件(*.pth)
      • trainer.log:训练日志
      • 验证结果和曲线图表

Stage 7: 推理处理

  • 功能:使用训练好的模型进行音频增强
  • 关键代码
    1
    2
    3
    4
    5
    6
    7
    8
    ${python} -m ${infer_module} \
    --ngpu "${_ngpu}" \
    --fs "${fs}" \
    ${_data_param} \
    --key_file "${_logdir}"/keys.JOB.scp \
    --train_config "${enh_exp}"/config.yaml \
    --model_file "${enh_exp}"/"${inference_model}" \
    --output_dir "${_logdir}"/output.JOB
  • 处理步骤
    1. 加载训练好的模型
    2. 对测试集进行推理
    3. 生成增强后的音频
    4. 支持GPU推理
  • 运行产出
    • ${_logdir}/output.JOB 目录下:
      • enhanced.wav:增强后的音频文件
      • keys.JOB.scp:处理的音频键值对
      • 推理日志和结果文件

Stage 8: 评分

  • 功能:评估增强效果
  • 关键代码
    1
    2
    3
    4
    5
    6
    7
    ${python} -m espnet2.bin.enh_scoring \
    --key_file "${_logdir}"/keys.JOB.scp \
    --output_dir "${_logdir}"/output.JOB \
    ${_ref_scp} \
    ${_inf_scp} \
    --ref_channel ${ref_channel} \
    --flexible_numspk ${flexible_numspk}
  • 评估指标
    • STOI: 语音可懂度
    • SDR: 信号失真比
    • SAR: 伪影比
    • SIR: 干扰比
    • SI_SNR: 尺度不变信噪比
  • 运行产出
    • ${_logdir}/output.JOB 目录下:
      • scoring.txt:包含各项评分指标
      • score_stats:详细的评分统计
      • 各指标的得分分布图

Stage 9-10: ASR评估

  • 功能:使用ASR模型评估增强效果
  • 关键代码
    1
    2
    3
    4
    5
    6
    7
    ${python} -m espnet2.bin.asr_inference \
    --ngpu "${_ngpu}" \
    --data_path_and_name_and_type "${_ddir}/wav.scp,speech,${_type}" \
    --key_file "${_logdir}"/keys.JOB.scp \
    --asr_train_config "${asr_exp}"/config.yaml \
    --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
    --output_dir "${_logdir}"/output.JOB
  • 处理步骤
    1. 使用ASR模型解码增强后的音频
    2. 计算字错误率(CER)或词错误率(WER)
    3. 生成评估报告
  • 运行产出
    • ${_logdir}/output.JOB 目录下:
      • asr_inference.txt:ASR解码结果
      • text:识别的文本结果
      • wer.txt/cer.txt:错误率统计

Stage 11: 模型打包

  • 功能:将训练好的模型打包
  • 处理
    • 打包模型文件
    • 打包配置信息
    • 生成发布包
  • 运行产出
    • ${enh_exp}/pack 目录下:
      • model.zip:打包的模型文件
      • config.yaml:配置文件副本
      • README.md:模型说明文档

Stage 12: 上传模型

  • 功能:将模型上传到HuggingFace
  • 条件:当skip_upload_hf=false时执行
  • 处理
    • 准备上传文件
    • 配置HuggingFace仓库
    • 上传模型
  • 运行产出
    • 在 HuggingFace仓库中:
      • 上传的模型文件和配置
      • 模型卡片(model card)
      • 示例代码和使用说明

总结

enh1.sh是一个完整的语音增强处理流程脚本,包含了从数据准备到模型训练、评估的全过程。通过合理配置参数,可以灵活控制处理流程的各个环节。使用时需要特别注意:

  1. 必须提供训练集、验证集和测试集的名称
  2. 根据需求合理设置GPU数量和并行作业数
  3. 可以通过stage和stop_stage控制执行流程
  4. 评估阶段提供了多种评估方式,包括客观指标和ASR评估