speech-api/start.py

129 lines
4.1 KiB
Python
Raw Permalink Normal View History

2023-12-22 00:43:45 +01:00
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
from modules import SST
from modules import TTS
from modules import Translate
from modules import Video
import os
import shutil
2023-12-22 08:50:30 +01:00
import os
import torch
import argparse
import numpy as np
from scipy.io.wavfile import write
import torchaudio
import utils
from speechSR48k.speechsr import SynthesizerTrn as SpeechSR48
2023-12-22 00:43:45 +01:00
app = FastAPI()
UPLOAD_FOLDER = "uploads"
OUTPUT_FOLDER = "output"
AUDIO_FOLDER = "audio"
VIDEO_FOLDER = "video"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
os.makedirs(AUDIO_FOLDER, exist_ok=True)
os.makedirs(VIDEO_FOLDER, exist_ok=True)
sst = SST()
tts = TTS()
translator = Translate()
video_manager = Video()
2023-12-22 08:50:30 +01:00
synthesizer = SynthesizerTrn()
seed = 1111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
def get_param_num(model):
num_param = sum(param.numel() for param in model.parameters())
return num_param
def SuperResoltuion(a, hierspeech):
speechsr = hierspeech
os.makedirs(a.output_dir, exist_ok=True)
# Prompt load
audio, sample_rate = torchaudio.load(a.input_speech)
# support only single channel
audio = audio[:1,:]
# Resampling
if sample_rate != 16000:
audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window")
file_name = os.path.splitext(os.path.basename(a.input_speech))[0]
with torch.no_grad():
converted_audio = speechsr(audio.unsqueeze(1).cuda())
converted_audio = converted_audio.squeeze()
converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 0.999 * 32767.0
converted_audio = converted_audio.cpu().numpy().astype('int16')
file_name2 = "{}.wav".format(file_name)
output_file = os.path.join(a.output_dir, file_name2)
write(output_file, 48000, converted_audio)
return output_file
def model_load(a):
if a.output_sr == 48000:
speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
h_sr48.train.segment_size // h_sr48.data.hop_length,
**h_sr48.model).cuda()
utils.load_checkpoint(a.ckpt_sr48, speechsr, None)
speechsr.eval()
else:
# 24000 Hz
speechsr = SpeechSR24(h_sr.data.n_mel_channels,
h_sr.train.segment_size // h_sr.data.hop_length,
**h_sr.model).cuda()
utils.load_checkpoint(a.ckpt_sr, speechsr, None)
speechsr.eval()
return speechsr
def inference_sr(a):
speechsr = model_load(a)
return SuperResoltuion(a, speechsr)
2023-12-22 00:43:45 +01:00
@app.post("/process_video/")
async def process_video(video_file: UploadFile = File(...)):
video_path = os.path.join(UPLOAD_FOLDER, video_file.filename)
with open(video_path, "wb") as video:
video.write(video_file.file.read())
audio_output_path = os.path.join(AUDIO_FOLDER, f"{os.path.splitext(video_file.filename)[0]}.wav")
await video_manager.extract_audio(video_path, audio_output_path)
final_result, vad_timing = await sst.process_audio_with_timing(audio_output_path)
translated_text = await translator.translate_text(final_result, source_lang="en", target_lang="ru")
2023-12-22 08:50:30 +01:00
text_speaker_tuples = [(translated_text, (1,4))]
audio = await tts.batch_text_to_speech(text_speaker_tuples, output_folder=OUTPUT_FOLDER)
path_audio = inference_sr(audio)
2023-12-22 00:43:45 +01:00
2023-12-22 08:50:30 +01:00
output_video_path = os.path.join(path_audio, VIDEO_FOLDER, f"{os.path.splitext(video_file.filename)[0]}_processed.mp4")
2023-12-22 00:43:45 +01:00
video_clip = await video_manager.load_video_from_path(video_path)
for start, end in zip(vad_timing[::2], vad_timing[1::2]):
await video_manager.replace_audio_in_range(video_clip, os.path.join(OUTPUT_FOLDER, "output_1.wav"), start, end)
await video_manager.save_video(video_clip, output_video_path)
shutil.rmtree(UPLOAD_FOLDER)
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
shutil.rmtree(AUDIO_FOLDER)
os.makedirs(AUDIO_FOLDER, exist_ok=True)
shutil.rmtree(OUTPUT_FOLDER)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
return FileResponse(output_video_path, media_type="video/mp4", filename=output_video_path)