🔀 Added speechSR48k

This commit is contained in:
itqop 2023-12-22 10:50:30 +03:00
parent bf83b4d33e
commit ddb6023afc
3 changed files with 374 additions and 3 deletions

48
speechSR48k/config.json Normal file
View File

@ -0,0 +1,48 @@
{
"train": {
"log_interval": 200,
"eval_interval": 10000,
"save_interval": 10000,
"seed": 1234,
"epochs": 20000,
"learning_rate": 1e-4,
"betas": [0.8, 0.99],
"eps": 1e-9,
"batch_size": 32,
"fp16_run": false,
"lr_decay": 0.995,
"segment_size": 9600,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45
},
"data": {
"train_filelist_path": "filelists/train_48k_vctk_trim_bigvgan_sr.txt",
"test_filelist_path": "filelists/test_48k_vctk_trim_bigvgan_sr.txt",
"text_cleaners":["english_cleaners2"],
"max_wav_value": 32768.0,
"sampling_rate": 48000,
"filter_length": 1280,
"hop_length": 320,
"win_length": 1280,
"n_mel_channels": 128,
"mel_fmin": 0,
"mel_fmax": 24000,
"add_blank": true,
"n_speakers": 0,
"cleaned_text": true,
"aug_rate": 1.0,
"top_db": 20
},
"model": {
"resblock": "0",
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"upsample_rates": [3],
"upsample_initial_channel": 32,
"upsample_kernel_sizes": [3],
"use_spectral_norm": false
}
}

252
speechSR48k/speechsr.py Normal file
View File

@ -0,0 +1,252 @@
import torch
from torch import nn
from torch.nn import functional as F
import modules
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from commons import init_weights, get_padding
from torch.cuda.amp import autocast
import torchaudio
from einops import rearrange
from alias_free_torch import *
import activations
class AMPBlock0(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock0, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))),
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
])
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=True))
for _ in range(self.num_layers)
])
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3))
resblock = AMPBlock0
self.resblocks = nn.ModuleList()
for i in range(1):
ch = upsample_initial_channel//(2**(i))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d, activation="snakebeta"))
activation_post = activations.SnakeBeta(ch, alpha_logscale=True)
self.activation_post = Activation1d(activation=activation_post)
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.interpolate(x, int(x.shape[-1] * 3), mode='linear')
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
x = xs / self.num_kernels
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.resblocks:
l.remove_weight_norm()
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorR(torch.nn.Module):
def __init__(self, resolution, use_spectral_norm=False):
super(DiscriminatorR, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
n_fft, hop_length, win_length = resolution
self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window,
normalized=True, center=False, pad_mode=None, power=None)
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))),
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
])
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def forward(self, y):
fmap = []
x = self.spec_transform(y) # [B, 2, Freq, Frames, 2]
x = torch.cat([x.real, x.imag], dim=1)
x = rearrange(x, 'b c w t -> b c t w')
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2,3,5,7,11]
resolutions = [[4096, 1024, 4096], [2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]]
discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
spec_channels,
segment_size,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
**kwargs):
super().__init__()
self.spec_channels = spec_channels
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.dec = Generator(1, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes)
def forward(self, x):
y = self.dec(x)
return y
@torch.no_grad()
def infer(self, x, max_len=None):
o = self.dec(x[:,:,:max_len])
return o

View File

@ -6,6 +6,15 @@ from modules import Translate
from modules import Video from modules import Video
import os import os
import shutil import shutil
import os
import torch
import argparse
import numpy as np
from scipy.io.wavfile import write
import torchaudio
import utils
from speechSR48k.speechsr import SynthesizerTrn as SpeechSR48
app = FastAPI() app = FastAPI()
@ -23,6 +32,66 @@ sst = SST()
tts = TTS() tts = TTS()
translator = Translate() translator = Translate()
video_manager = Video() video_manager = Video()
synthesizer = SynthesizerTrn()
seed = 1111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
def get_param_num(model):
num_param = sum(param.numel() for param in model.parameters())
return num_param
def SuperResoltuion(a, hierspeech):
speechsr = hierspeech
os.makedirs(a.output_dir, exist_ok=True)
# Prompt load
audio, sample_rate = torchaudio.load(a.input_speech)
# support only single channel
audio = audio[:1,:]
# Resampling
if sample_rate != 16000:
audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window")
file_name = os.path.splitext(os.path.basename(a.input_speech))[0]
with torch.no_grad():
converted_audio = speechsr(audio.unsqueeze(1).cuda())
converted_audio = converted_audio.squeeze()
converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 0.999 * 32767.0
converted_audio = converted_audio.cpu().numpy().astype('int16')
file_name2 = "{}.wav".format(file_name)
output_file = os.path.join(a.output_dir, file_name2)
write(output_file, 48000, converted_audio)
return output_file
def model_load(a):
if a.output_sr == 48000:
speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
h_sr48.train.segment_size // h_sr48.data.hop_length,
**h_sr48.model).cuda()
utils.load_checkpoint(a.ckpt_sr48, speechsr, None)
speechsr.eval()
else:
# 24000 Hz
speechsr = SpeechSR24(h_sr.data.n_mel_channels,
h_sr.train.segment_size // h_sr.data.hop_length,
**h_sr.model).cuda()
utils.load_checkpoint(a.ckpt_sr, speechsr, None)
speechsr.eval()
return speechsr
def inference_sr(a):
speechsr = model_load(a)
return SuperResoltuion(a, speechsr)
@app.post("/process_video/") @app.post("/process_video/")
async def process_video(video_file: UploadFile = File(...)): async def process_video(video_file: UploadFile = File(...)):
@ -37,10 +106,12 @@ async def process_video(video_file: UploadFile = File(...)):
translated_text = await translator.translate_text(final_result, source_lang="en", target_lang="ru") translated_text = await translator.translate_text(final_result, source_lang="en", target_lang="ru")
text_speaker_tuples = [(translated_text, 1)] text_speaker_tuples = [(translated_text, (1,4))]
await tts.batch_text_to_speech(text_speaker_tuples, output_folder=OUTPUT_FOLDER) audio = await tts.batch_text_to_speech(text_speaker_tuples, output_folder=OUTPUT_FOLDER)
output_video_path = os.path.join(VIDEO_FOLDER, f"{os.path.splitext(video_file.filename)[0]}_processed.mp4") path_audio = inference_sr(audio)
output_video_path = os.path.join(path_audio, VIDEO_FOLDER, f"{os.path.splitext(video_file.filename)[0]}_processed.mp4")
video_clip = await video_manager.load_video_from_path(video_path) video_clip = await video_manager.load_video_from_path(video_path)
for start, end in zip(vad_timing[::2], vad_timing[1::2]): for start, end in zip(vad_timing[::2], vad_timing[1::2]):
await video_manager.replace_audio_in_range(video_clip, os.path.join(OUTPUT_FOLDER, "output_1.wav"), start, end) await video_manager.replace_audio_in_range(video_clip, os.path.join(OUTPUT_FOLDER, "output_1.wav"), start, end)