官网的都是手动训练,我做了一个自动化训练,执行一下,然后全部就能训练完。
说明:
audio是存放原始音频的位置,auto_train_main是核心自动化代码。
auto_train_main代码:
# -*- coding: utf-8 -*-
import string
import random
import requests
import pymysql
import sys
import os
import shutil
import subprocess
import paramiko
#-1代表错误,0代表警告提示,1代表执行成功
from run import slicer_fn
from run_auto_label import training_model
from basemodel_weitiao import weitiao
def con_mysql():
conn = pymysql.connect(host="xxx",
user="xxx",
password="xxx",
port=xx,
db="xx",
charset="utf8")
return conn
# 查询,是否有正在执行的任务
def get_task(conn, user_id, audio_url):
cursor = conn.cursor()
# 是否有空的机器可以训练,以后改队列
sql = "SELECT * FROM kantts_auto_train_task where status=1"
cursor.execute(sql)
results = cursor.fetchall() # 获取所有查询结果
if len(results) != 0:
print("warn:【0.获取task警告】:目前机器正在被人训练...请稍后在来")
error_msg = "warn:【0.获取task警告】:目前机器正在被人训练...请稍后在来"
return {"code": 0, "error_msg": error_msg}
# 这个用户是否已经训练完成
sql = "SELECT * FROM kantts_auto_train_task where user_id=%s"
cursor.execute(sql, [user_id])
results = cursor.fetchall() # 获取所有查询结果
if len(results) != 0:
task_info = results[0]
code =task_info[3]
if code == 2:
error_msg = "warn:【0.获取task警告】:用户已经训练过模型了"
return {"code": 0, "error_msg": error_msg}
#更新数据库,继续开始
return update_task(conn,user_id,audio_url,1,"")
# 没有数据,则插入任务
sql = "INSERT INTO " \
"kantts_auto_train_task(user_id,audio_url,status) " \
"VALUES(%s,%s,%s)"
cursor.execute(sql, [user_id, audio_url, 1])
conn.commit()
return {"code": 1, "error_msg": ""}
def update_task(conn , user_id , audio_url ,status ,error_msg):
cursor = conn.cursor()
sql = "update kantts_auto_train_task set " \
"user_id = %s ,audio_url = %s, status =%s , error_msg=%s where user_id = %s"
cursor.execute(sql, [user_id, audio_url, status, error_msg, user_id])
conn.commit()
return {"code": 1, "error_msg": ""}
#警告不更新状态,只更新提示
def update_task_warn(conn , user_id ,error_msg):
cursor = conn.cursor()
sql = "update kantts_auto_train_task set " \
"user_id = %s , error_msg=%s where user_id = %s"
cursor.execute(sql, [user_id, error_msg, user_id])
conn.commit()
return {"code": 1, "error_msg": ""}
# 获取训练的音频数据
#删除目录所有内容
def deletePathFile(path):
for filename in os.listdir(path):
file_path = os.path.join(path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
print('Successfully deleted all content from directory %s' % path)
def downloadAudio(audio_url):
if audio_url.endswith(".wav"):
audio_name = 'audio/source_audio.wav'
response = requests.get(audio_url, stream=True)
with open(audio_name, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
return {"code": 1, "error_msg": ""}
else:
error_msg="error:【1.获取音频错误】:音频必须为wav"
print(error_msg)
return {"code": -1, "error_msg": error_msg}
def random_string(length):
letters = string.ascii_letters + string.digits
return ''.join(random.choice(letters) for _ in range(length))
def checkRs(conn,task,user_id,audio_url):
if task["code"] == -1:
# 写入数据库,然后停止
update_task(conn,user_id,audio_url,-1,task["error_msg"])
sys.exit()
if task["code"] == 0:
# 写入数据库,然后停止
update_task_warn(conn,user_id,task["error_msg"])
sys.exit()
#音频切片
def create_split_mkdir(user_id):
# 判断目录是否存在,不存在则创建
test_path = '/kan_tts/tmp/test_wavs/' + user_id
if os.path.exists(test_path):
# 删除目录及其内容
shutil.rmtree(test_path)
os.mkdir(test_path)
else:
os.mkdir(test_path)
return test_path
#同步模型到合成的机器
def scp_file_path(local_path,remote_path):
remote_path = "mqq@192.168.51.39:"+remote_path
p = subprocess.Popen(["scp","-r", local_path, remote_path])
sts = os.waitpid(p.pid, 0)
if __name__ == '__main__':
user_id = "xxx"
audio_url = "https://xxx.wav"
conn = con_mysql()
print("接收到的参数是{\"user_id\":%s,\"audio_url\":%s}" % (user_id,audio_url))
#检测机器是否被占用
task = get_task(conn, user_id, audio_url)
checkRs(conn,task,user_id,audio_url)
print("开始执行任务.."+user_id)
#删除目录中的其他音频
deletePathFile("audio")
#获取音频
print("====1.开始获取音频")
task = downloadAudio(audio_url)
checkRs(conn, task, user_id, audio_url)
print("====1.音频处理完成")
#切分音频
print("====2.开始切分音频")
test_path = create_split_mkdir(user_id)
try:
#指定待切分的目录
slicer_fn("audio",test_path)
except Exception as e:
task["code"]=-1
task["error_msg"]="error:【2.音频切分错误】,请检查你的音频提交音否正常"
print("error:【2.音频切分错误】,请检查你的音频提交音否正常")
checkRs(conn, task, user_id, audio_url)
print(e)
sys.exit(0)
print("====2.完成切分音频")
print("====3.开始进行标注")
try:
training_model(user_id)
except Exception as e:
task["code"]=-1
task["error_msg"]="error:【3.数据标注错误】,请检查你的切分音频路径"
print("error:【3.数据标注错误】,请检查你的切分音频路径")
checkRs(conn, task, user_id, audio_url)
print(e)
sys.exit(0)
print("====3.标注完成")
print("====4.开始微调训练4000步,预计30分钟")
try:
dataset_id = "/kan_tts/tmp/output_dir/"+user_id
pretrain_work_dir = "/kan_tts/tmp/pretrain_work_dir/"+user_id
weitiao(dataset_id, pretrain_work_dir)
except Exception as e:
task["code"]=-1
task["error_msg"]="error:【4.微调训练错误】,请检查是否音频质量"
print("error:【4.微调训练错误】,请检查是否音频质量")
checkRs(conn, task, user_id, audio_url)
print(e)
sys.exit(0)
print("====4,完成微调")
print("====5,开始往机器同步")
try:
local_path = "/kan_tts/tmp/pretrain_work_dir/"+user_id
remote_path = "/pzk/ttsGuaZai/tmp/pretrain_work_dir"
task = scp_file_path(local_path,remote_path)
# if task["code"] == -1:
# print(task["error_msg"])
# checkRs(conn, task, user_id, audio_url)
# sys.exit(0)
except Exception as e:
task["code"] = -1
task["error_msg"] = "error:【5.同步到合成机器错误】,请检查远程目录以及本地目录的predict_dir是否存在此用户模型"
print("error:【5.同步到合成机器错误】,请检查远程目录以及本地目录的predict_dir是否存在此用户模型")
print("检查是否开启ssh免密https://blog.csdn.net/u010044182/article/details/128664248")
checkRs(conn, task, user_id, audio_url)
print(e)
sys.exit(0)
# print("====5,同步结束")
print("====6,配置数据库-至正式服")
#配置音频数据库
#配置当前的正式服的训练信息
print("====6,配置数据库完成")
其他的代码就是model_scope官网的代码,切分代码请看我历史博客,里面有。