diff --git a/speechSR48k/config.json b/speechSR48k/config.json new file mode 100644 index 0000000..1f60608 --- /dev/null +++ b/speechSR48k/config.json @@ -0,0 +1,48 @@ + +{ + "train": { + "log_interval": 200, + "eval_interval": 10000, + "save_interval": 10000, + "seed": 1234, + "epochs": 20000, + "learning_rate": 1e-4, + "betas": [0.8, 0.99], + "eps": 1e-9, + "batch_size": 32, + "fp16_run": false, + "lr_decay": 0.995, + "segment_size": 9600, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45 + }, + "data": { + "train_filelist_path": "filelists/train_48k_vctk_trim_bigvgan_sr.txt", + "test_filelist_path": "filelists/test_48k_vctk_trim_bigvgan_sr.txt", + "text_cleaners":["english_cleaners2"], + "max_wav_value": 32768.0, + "sampling_rate": 48000, + "filter_length": 1280, + "hop_length": 320, + "win_length": 1280, + "n_mel_channels": 128, + "mel_fmin": 0, + "mel_fmax": 24000, + "add_blank": true, + "n_speakers": 0, + "cleaned_text": true, + "aug_rate": 1.0, + "top_db": 20 + }, + "model": { + "resblock": "0", + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + "upsample_rates": [3], + "upsample_initial_channel": 32, + "upsample_kernel_sizes": [3], + "use_spectral_norm": false + } + } + \ No newline at end of file diff --git a/speechSR48k/speechsr.py b/speechSR48k/speechsr.py new file mode 100644 index 0000000..0d2f013 --- /dev/null +++ b/speechSR48k/speechsr.py @@ -0,0 +1,252 @@ +import torch +from torch import nn +from torch.nn import functional as F +import modules + +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from commons import init_weights, get_padding +from torch.cuda.amp import autocast +import torchaudio +from einops import rearrange + +from alias_free_torch import * +import activations + +class AMPBlock0(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): + super(AMPBlock0, self).__init__() + + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))), + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=True)) + for _ in range(self.num_layers) + ]) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)) + resblock = AMPBlock0 + + self.resblocks = nn.ModuleList() + for i in range(1): + ch = upsample_initial_channel//(2**(i)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d, activation="snakebeta")) + + activation_post = activations.SnakeBeta(ch, alpha_logscale=True) + self.activation_post = Activation1d(activation=activation_post) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + + x = F.interpolate(x, int(x.shape[-1] * 3), mode='linear') + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.resblocks: + l.remove_weight_norm() + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + +class DiscriminatorR(torch.nn.Module): + def __init__(self, resolution, use_spectral_norm=False): + super(DiscriminatorR, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + + n_fft, hop_length, win_length = resolution + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window, + normalized=True, center=False, pad_mode=None, power=None) + + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))), + norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))), + norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), + ]) + self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) + + def forward(self, y): + fmap = [] + + x = self.spec_transform(y) # [B, 2, Freq, Frames, 2] + x = torch.cat([x.real, x.imag], dim=1) + x = rearrange(x, 'b c w t -> b c t w') + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2,3,5,7,11] + resolutions = [[4096, 1024, 4096], [2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]] + + discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + + spec_channels, + segment_size, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + **kwargs): + + super().__init__() + self.spec_channels = spec_channels + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + + self.dec = Generator(1, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes) + + def forward(self, x): + + y = self.dec(x) + return y + @torch.no_grad() + def infer(self, x, max_len=None): + + o = self.dec(x[:,:,:max_len]) + return o + diff --git a/start.py b/start.py index 4805ba4..776e5ca 100644 --- a/start.py +++ b/start.py @@ -6,6 +6,15 @@ from modules import Translate from modules import Video import os import shutil +import os +import torch +import argparse +import numpy as np +from scipy.io.wavfile import write +import torchaudio +import utils + +from speechSR48k.speechsr import SynthesizerTrn as SpeechSR48 app = FastAPI() @@ -23,6 +32,66 @@ sst = SST() tts = TTS() translator = Translate() video_manager = Video() +synthesizer = SynthesizerTrn() + +seed = 1111 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +np.random.seed(seed) + +def get_param_num(model): + num_param = sum(param.numel() for param in model.parameters()) + return num_param + +def SuperResoltuion(a, hierspeech): + + speechsr = hierspeech + + os.makedirs(a.output_dir, exist_ok=True) + + # Prompt load + audio, sample_rate = torchaudio.load(a.input_speech) + + # support only single channel + audio = audio[:1,:] + # Resampling + if sample_rate != 16000: + audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window") + file_name = os.path.splitext(os.path.basename(a.input_speech))[0] + with torch.no_grad(): + converted_audio = speechsr(audio.unsqueeze(1).cuda()) + converted_audio = converted_audio.squeeze() + converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 0.999 * 32767.0 + converted_audio = converted_audio.cpu().numpy().astype('int16') + + file_name2 = "{}.wav".format(file_name) + output_file = os.path.join(a.output_dir, file_name2) + + write(output_file, 48000, converted_audio) + return output_file + + +def model_load(a): + if a.output_sr == 48000: + speechsr = SpeechSR48(h_sr48.data.n_mel_channels, + h_sr48.train.segment_size // h_sr48.data.hop_length, + **h_sr48.model).cuda() + utils.load_checkpoint(a.ckpt_sr48, speechsr, None) + speechsr.eval() + else: + # 24000 Hz + speechsr = SpeechSR24(h_sr.data.n_mel_channels, + h_sr.train.segment_size // h_sr.data.hop_length, + **h_sr.model).cuda() + utils.load_checkpoint(a.ckpt_sr, speechsr, None) + speechsr.eval() + return speechsr + +def inference_sr(a): + + speechsr = model_load(a) + return SuperResoltuion(a, speechsr) + @app.post("/process_video/") async def process_video(video_file: UploadFile = File(...)): @@ -37,10 +106,12 @@ async def process_video(video_file: UploadFile = File(...)): translated_text = await translator.translate_text(final_result, source_lang="en", target_lang="ru") - text_speaker_tuples = [(translated_text, 1)] - await tts.batch_text_to_speech(text_speaker_tuples, output_folder=OUTPUT_FOLDER) + text_speaker_tuples = [(translated_text, (1,4))] + audio = await tts.batch_text_to_speech(text_speaker_tuples, output_folder=OUTPUT_FOLDER) + + path_audio = inference_sr(audio) - output_video_path = os.path.join(VIDEO_FOLDER, f"{os.path.splitext(video_file.filename)[0]}_processed.mp4") + output_video_path = os.path.join(path_audio, VIDEO_FOLDER, f"{os.path.splitext(video_file.filename)[0]}_processed.mp4") video_clip = await video_manager.load_video_from_path(video_path) for start, end in zip(vad_timing[::2], vad_timing[1::2]): await video_manager.replace_audio_in_range(video_clip, os.path.join(OUTPUT_FOLDER, "output_1.wav"), start, end)