first version
This commit is contained in:
parent
3400a8070a
commit
de07a045ce
|
|
@ -0,0 +1,5 @@
|
|||
faster-whisper
|
||||
pyannote.audio
|
||||
python-dotenv
|
||||
pydub
|
||||
tqdm
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from transcriber.main import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordInfo:
|
||||
"""Single word with timing and confidence."""
|
||||
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Segment:
|
||||
"""ASR segment with text, timing, and optional word-level detail."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
words: list[WordInfo] = field(default_factory=list)
|
||||
|
||||
|
||||
class WhisperEngine:
|
||||
"""Speech recognition engine based on faster-whisper."""
|
||||
|
||||
def __init__(self, model_name: str, device: str, compute_type: str):
|
||||
logger.info("Loading Whisper model: %s on %s (%s)", model_name, device, compute_type)
|
||||
self._model = WhisperModel(model_name, device=device, compute_type=compute_type)
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: str | None = None,
|
||||
beam_size: int = 5,
|
||||
vad_filter: bool = True,
|
||||
) -> list[Segment]:
|
||||
"""Transcribe audio file and return list of segments.
|
||||
|
||||
Args:
|
||||
audio_path: Path to WAV file.
|
||||
language: Language code or None for auto-detection.
|
||||
beam_size: Beam search width.
|
||||
vad_filter: Whether to enable VAD filtering.
|
||||
|
||||
Returns:
|
||||
List of transcription segments with word-level timestamps.
|
||||
"""
|
||||
logger.info("Transcribing: %s", audio_path)
|
||||
|
||||
segments_gen, info = self._model.transcribe(
|
||||
audio_path,
|
||||
language=language,
|
||||
beam_size=beam_size,
|
||||
word_timestamps=True,
|
||||
vad_filter=vad_filter,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
temperature=0.0,
|
||||
condition_on_previous_text=False,
|
||||
no_speech_threshold=0.6,
|
||||
log_prob_threshold=-1.0,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Detected language: %s (%.2f), duration: %.1fs",
|
||||
info.language, info.language_probability, info.duration,
|
||||
)
|
||||
|
||||
results = []
|
||||
for seg in segments_gen:
|
||||
words = [
|
||||
WordInfo(
|
||||
word=w.word.strip(),
|
||||
start=w.start,
|
||||
end=w.end,
|
||||
probability=w.probability,
|
||||
)
|
||||
for w in (seg.words or [])
|
||||
]
|
||||
results.append(Segment(
|
||||
start=seg.start,
|
||||
end=seg.end,
|
||||
text=seg.text.strip(),
|
||||
words=words,
|
||||
))
|
||||
|
||||
logger.info("Transcription complete: %d segments", len(results))
|
||||
return results
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import logging
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkInfo:
|
||||
"""Metadata for a single audio chunk."""
|
||||
|
||||
path: str
|
||||
start_offset: float
|
||||
duration: float
|
||||
|
||||
|
||||
def get_audio_duration(wav_path: str) -> float:
|
||||
"""Get duration of audio file in seconds using ffprobe."""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "quiet",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "csv=p=0",
|
||||
wav_path,
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffprobe failed: {result.stderr[:300]}")
|
||||
return float(result.stdout.strip())
|
||||
|
||||
|
||||
def chunk_audio(wav_path: str, max_duration_sec: int = 1800) -> list[ChunkInfo]:
|
||||
"""Split audio into chunks if longer than max_duration_sec.
|
||||
|
||||
Args:
|
||||
wav_path: Path to the preprocessed WAV file.
|
||||
max_duration_sec: Maximum chunk duration in seconds (default 30 min).
|
||||
|
||||
Returns:
|
||||
List of ChunkInfo with paths and timing metadata.
|
||||
"""
|
||||
total_duration = get_audio_duration(wav_path)
|
||||
logger.info("Audio duration: %.1f sec", total_duration)
|
||||
|
||||
if total_duration <= max_duration_sec:
|
||||
return [ChunkInfo(path=wav_path, start_offset=0.0, duration=total_duration)]
|
||||
|
||||
chunks = []
|
||||
src = Path(wav_path)
|
||||
chunk_dir = src.parent / "chunks"
|
||||
chunk_dir.mkdir(exist_ok=True)
|
||||
|
||||
offset = 0.0
|
||||
idx = 0
|
||||
while offset < total_duration:
|
||||
chunk_path = str(chunk_dir / f"{src.stem}_chunk{idx:03d}.wav")
|
||||
remaining = total_duration - offset
|
||||
duration = min(max_duration_sec, remaining)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-ss", str(offset),
|
||||
"-i", wav_path,
|
||||
"-t", str(duration),
|
||||
"-c", "copy",
|
||||
chunk_path,
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Chunk {idx} failed: {result.stderr[:300]}")
|
||||
|
||||
chunks.append(ChunkInfo(
|
||||
path=chunk_path,
|
||||
start_offset=offset,
|
||||
duration=duration,
|
||||
))
|
||||
logger.info("Chunk %d: %.1fs - %.1fs", idx, offset, offset + duration)
|
||||
|
||||
offset += duration
|
||||
idx += 1
|
||||
|
||||
return chunks
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_FORMATS = {".m4a", ".mp3", ".wav", ".aac"}
|
||||
|
||||
|
||||
def preprocess_audio(input_path: str, output_dir: str) -> str:
|
||||
"""Convert audio to mono 16kHz PCM WAV with normalization and DC offset removal.
|
||||
|
||||
Args:
|
||||
input_path: Path to the source audio file.
|
||||
output_dir: Directory for the processed file.
|
||||
|
||||
Returns:
|
||||
Path to the processed WAV file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If input file does not exist.
|
||||
ValueError: If file format is not supported.
|
||||
RuntimeError: If ffmpeg processing fails.
|
||||
"""
|
||||
src = Path(input_path)
|
||||
if not src.exists():
|
||||
raise FileNotFoundError(f"Audio file not found: {input_path}")
|
||||
if src.suffix.lower() not in SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported format: {src.suffix}. Supported: {SUPPORTED_FORMATS}"
|
||||
)
|
||||
|
||||
out = Path(output_dir) / f"{src.stem}_processed.wav"
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-i", str(src),
|
||||
"-ac", "1",
|
||||
"-ar", "16000",
|
||||
"-sample_fmt", "s16",
|
||||
"-af", "highpass=f=10,loudnorm=I=-16:TP=-1.5:LRA=11",
|
||||
str(out),
|
||||
]
|
||||
|
||||
logger.info("Preprocessing: %s -> %s", src.name, out.name)
|
||||
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=600
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg failed: {result.stderr[:500]}")
|
||||
|
||||
logger.info("Preprocessing complete: %s", out.name)
|
||||
return str(out)
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriberConfig:
|
||||
"""Configuration for the transcription pipeline."""
|
||||
|
||||
input_path: str = ""
|
||||
output_dir: str = "./output"
|
||||
model: str = "large-v3"
|
||||
device: str = "cuda"
|
||||
compute_type: str = "float16"
|
||||
language: str = "ru"
|
||||
beam_size: int = 5
|
||||
vad: bool = True
|
||||
max_speakers: int | None = None
|
||||
min_speakers: int | None = None
|
||||
formats: list[str] = field(default_factory=lambda: ["txt", "json"])
|
||||
pause_threshold: float = 1.5
|
||||
chunk_duration: int = 1800
|
||||
hf_token: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
load_dotenv()
|
||||
if not self.hf_token:
|
||||
self.hf_token = os.getenv("HF_TOKEN", "")
|
||||
if not self.hf_token:
|
||||
raise ValueError(
|
||||
"HF_TOKEN is required for pyannote diarization. "
|
||||
"Set it in .env or pass via --hf-token"
|
||||
)
|
||||
if self.device == "cpu":
|
||||
self.compute_type = "int8"
|
||||
Path(self.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
from pyannote.audio.pipelines.utils.hook import ProgressHook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerTurn:
|
||||
"""A single speaker turn with timing."""
|
||||
|
||||
start: float
|
||||
end: float
|
||||
speaker: str
|
||||
|
||||
|
||||
class DiarizationEngine:
|
||||
"""Speaker diarization engine based on pyannote.audio."""
|
||||
|
||||
def __init__(self, hf_token: str, device: str):
|
||||
logger.info("Loading diarization pipeline on %s", device)
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
token=hf_token,
|
||||
)
|
||||
self._device = torch.device(device)
|
||||
self._pipeline.to(self._device)
|
||||
|
||||
def diarize(
|
||||
self,
|
||||
audio_path: str,
|
||||
min_speakers: int | None = None,
|
||||
max_speakers: int | None = None,
|
||||
) -> list[SpeakerTurn]:
|
||||
"""Run speaker diarization on audio file.
|
||||
|
||||
Args:
|
||||
audio_path: Path to WAV file.
|
||||
min_speakers: Minimum expected number of speakers.
|
||||
max_speakers: Maximum expected number of speakers.
|
||||
|
||||
Returns:
|
||||
List of speaker turns sorted by start time.
|
||||
"""
|
||||
logger.info("Diarizing: %s", audio_path)
|
||||
|
||||
kwargs = {}
|
||||
if min_speakers is not None:
|
||||
kwargs["min_speakers"] = min_speakers
|
||||
if max_speakers is not None:
|
||||
kwargs["max_speakers"] = max_speakers
|
||||
|
||||
with ProgressHook() as hook:
|
||||
diarization = self._pipeline(audio_path, hook=hook, **kwargs)
|
||||
|
||||
turns = []
|
||||
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||
turns.append(SpeakerTurn(
|
||||
start=turn.start,
|
||||
end=turn.end,
|
||||
speaker=speaker,
|
||||
))
|
||||
|
||||
speaker_set = {t.speaker for t in turns}
|
||||
logger.info("Diarization complete: %d turns, %d speakers", len(turns), len(speaker_set))
|
||||
return turns
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from transcriber.merge.aligner import MergedSegment
|
||||
|
||||
|
||||
def write_json(segments: list[MergedSegment], output_path: str) -> str:
|
||||
"""Export merged segments as a structured JSON file.
|
||||
|
||||
Args:
|
||||
segments: List of merged speaker segments.
|
||||
output_path: Path to the output .json file.
|
||||
|
||||
Returns:
|
||||
Path to the written file.
|
||||
"""
|
||||
path = Path(output_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = [
|
||||
{
|
||||
"speaker": seg.speaker,
|
||||
"start": round(seg.start, 2),
|
||||
"end": round(seg.end, 2),
|
||||
"text": seg.text,
|
||||
}
|
||||
for seg in segments
|
||||
]
|
||||
|
||||
path.write_text(
|
||||
json.dumps(data, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return str(path)
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
from pathlib import Path
|
||||
|
||||
from transcriber.merge.aligner import MergedSegment
|
||||
|
||||
|
||||
def write_txt(segments: list[MergedSegment], output_path: str) -> str:
|
||||
"""Export merged segments as a readable dialogue text file.
|
||||
|
||||
Args:
|
||||
segments: List of merged speaker segments.
|
||||
output_path: Path to the output .txt file.
|
||||
|
||||
Returns:
|
||||
Path to the written file.
|
||||
"""
|
||||
path = Path(output_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
lines = []
|
||||
for seg in segments:
|
||||
lines.append(f"[{seg.speaker}]")
|
||||
lines.append(seg.text)
|
||||
lines.append("")
|
||||
|
||||
path.write_text("\n".join(lines), encoding="utf-8")
|
||||
return str(path)
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from transcriber.config import TranscriberConfig
|
||||
from transcriber.pipeline import TranscriptionPipeline
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Transcribe audio with speaker diarization",
|
||||
)
|
||||
parser.add_argument("input", help="Path to audio file (.m4a, .mp3, .wav, .aac)")
|
||||
parser.add_argument("--output", default="./output", help="Output directory (default: ./output)")
|
||||
parser.add_argument("--model", default="large-v3", help="Whisper model name (default: large-v3)")
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device (default: cuda)")
|
||||
parser.add_argument("--compute-type", default="float16", help="Compute type (default: float16)")
|
||||
parser.add_argument("--language", default="ru", help="Language code (default: ru)")
|
||||
parser.add_argument("--beam-size", type=int, default=5, help="Beam search size (default: 5)")
|
||||
parser.add_argument("--vad", default="on", choices=["on", "off"], help="VAD filter (default: on)")
|
||||
parser.add_argument("--max-speakers", type=int, default=None, help="Maximum number of speakers")
|
||||
parser.add_argument("--min-speakers", type=int, default=None, help="Minimum number of speakers")
|
||||
parser.add_argument("--format", nargs="+", default=["txt", "json"], help="Output formats (default: txt json)")
|
||||
parser.add_argument("--pause-threshold", type=float, default=1.5, help="Max pause for merging (default: 1.5s)")
|
||||
parser.add_argument("--chunk-duration", type=int, default=1800, help="Max chunk duration in sec (default: 1800)")
|
||||
parser.add_argument("--hf-token", default="", help="HuggingFace token (default: from .env)")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Enable debug logging")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point for the transcription CLI."""
|
||||
args = parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
config = TranscriberConfig(
|
||||
input_path=args.input,
|
||||
output_dir=args.output,
|
||||
model=args.model,
|
||||
device=args.device,
|
||||
compute_type=args.compute_type,
|
||||
language=args.language,
|
||||
beam_size=args.beam_size,
|
||||
vad=args.vad == "on",
|
||||
max_speakers=args.max_speakers,
|
||||
min_speakers=args.min_speakers,
|
||||
formats=args.format,
|
||||
pause_threshold=args.pause_threshold,
|
||||
chunk_duration=args.chunk_duration,
|
||||
hf_token=args.hf_token,
|
||||
)
|
||||
|
||||
pipeline = TranscriptionPipeline(config)
|
||||
try:
|
||||
exported = pipeline.run()
|
||||
for path in exported:
|
||||
print(f"Saved: {path}")
|
||||
except Exception:
|
||||
logging.exception("Pipeline failed")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from transcriber.asr.whisper_engine import Segment
|
||||
from transcriber.diarization.pyannote_engine import SpeakerTurn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MergedSegment:
|
||||
"""Final segment with speaker label and merged text."""
|
||||
|
||||
speaker: str
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
|
||||
|
||||
def _compute_overlap(seg_start: float, seg_end: float, turn_start: float, turn_end: float) -> float:
|
||||
"""Compute temporal overlap between two intervals in seconds."""
|
||||
overlap_start = max(seg_start, turn_start)
|
||||
overlap_end = min(seg_end, turn_end)
|
||||
return max(0.0, overlap_end - overlap_start)
|
||||
|
||||
|
||||
def _assign_speaker(segment: Segment, speaker_turns: list[SpeakerTurn]) -> str:
|
||||
"""Assign speaker to ASR segment by maximum overlap."""
|
||||
best_speaker = "Unknown"
|
||||
best_overlap = 0.0
|
||||
|
||||
for turn in speaker_turns:
|
||||
if turn.end < segment.start:
|
||||
continue
|
||||
if turn.start > segment.end:
|
||||
break
|
||||
overlap = _compute_overlap(segment.start, segment.end, turn.start, turn.end)
|
||||
if overlap > best_overlap:
|
||||
best_overlap = overlap
|
||||
best_speaker = turn.speaker
|
||||
|
||||
return best_speaker
|
||||
|
||||
|
||||
def _normalize_speaker_labels(segments: list[MergedSegment]) -> list[MergedSegment]:
|
||||
"""Replace raw pyannote labels (SPEAKER_00) with sequential Speaker 1, Speaker 2."""
|
||||
label_map: dict[str, str] = {}
|
||||
counter = 1
|
||||
|
||||
for seg in segments:
|
||||
if seg.speaker not in label_map and seg.speaker != "Unknown":
|
||||
label_map[seg.speaker] = f"Speaker {counter}"
|
||||
counter += 1
|
||||
|
||||
for seg in segments:
|
||||
seg.speaker = label_map.get(seg.speaker, seg.speaker)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def align_and_merge(
|
||||
asr_segments: list[Segment],
|
||||
speaker_turns: list[SpeakerTurn],
|
||||
pause_threshold: float = 1.5,
|
||||
) -> list[MergedSegment]:
|
||||
"""Align ASR segments with speaker turns and merge adjacent same-speaker segments.
|
||||
|
||||
Args:
|
||||
asr_segments: Segments from Whisper ASR.
|
||||
speaker_turns: Speaker turns from diarization (sorted by start).
|
||||
pause_threshold: Max pause between segments to merge (seconds).
|
||||
|
||||
Returns:
|
||||
List of merged segments with speaker labels.
|
||||
"""
|
||||
if not asr_segments:
|
||||
return []
|
||||
|
||||
aligned = []
|
||||
for seg in asr_segments:
|
||||
speaker = _assign_speaker(seg, speaker_turns)
|
||||
aligned.append(MergedSegment(
|
||||
speaker=speaker,
|
||||
start=seg.start,
|
||||
end=seg.end,
|
||||
text=seg.text,
|
||||
))
|
||||
|
||||
merged: list[MergedSegment] = [aligned[0]]
|
||||
for seg in aligned[1:]:
|
||||
prev = merged[-1]
|
||||
gap = seg.start - prev.end
|
||||
if seg.speaker == prev.speaker and gap <= pause_threshold:
|
||||
prev.end = seg.end
|
||||
prev.text = f"{prev.text} {seg.text}"
|
||||
else:
|
||||
merged.append(seg)
|
||||
|
||||
merged = _normalize_speaker_labels(merged)
|
||||
logger.info("Alignment complete: %d ASR segments -> %d merged", len(asr_segments), len(merged))
|
||||
return merged
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from transcriber.asr.whisper_engine import Segment, WhisperEngine
|
||||
from transcriber.audio.chunking import ChunkInfo, chunk_audio
|
||||
from transcriber.audio.preprocess import preprocess_audio
|
||||
from transcriber.config import TranscriberConfig
|
||||
from transcriber.diarization.pyannote_engine import DiarizationEngine
|
||||
from transcriber.export.json_writer import write_json
|
||||
from transcriber.export.txt_writer import write_txt
|
||||
from transcriber.merge.aligner import MergedSegment, align_and_merge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EXPORTERS = {
|
||||
"txt": write_txt,
|
||||
"json": write_json,
|
||||
}
|
||||
|
||||
|
||||
class TranscriptionPipeline:
|
||||
"""Orchestrates the full transcription pipeline: preprocess -> ASR -> diarize -> merge -> export."""
|
||||
|
||||
def __init__(self, config: TranscriberConfig):
|
||||
self._config = config
|
||||
self._asr: WhisperEngine | None = None
|
||||
self._diarizer: DiarizationEngine | None = None
|
||||
|
||||
def _init_engines(self) -> None:
|
||||
"""Lazily initialize ASR and diarization engines."""
|
||||
if self._asr is None:
|
||||
self._asr = WhisperEngine(
|
||||
model_name=self._config.model,
|
||||
device=self._config.device,
|
||||
compute_type=self._config.compute_type,
|
||||
)
|
||||
if self._diarizer is None:
|
||||
self._diarizer = DiarizationEngine(
|
||||
hf_token=self._config.hf_token,
|
||||
device=self._config.device,
|
||||
)
|
||||
|
||||
def _transcribe_chunks(self, chunks: list[ChunkInfo], progress: tqdm) -> list[Segment]:
|
||||
"""Run ASR on each chunk, adjusting timestamps by chunk offset."""
|
||||
all_segments: list[Segment] = []
|
||||
for chunk in chunks:
|
||||
progress.set_description(f"ASR chunk {chunk.start_offset:.0f}s")
|
||||
segments = self._asr.transcribe(
|
||||
audio_path=chunk.path,
|
||||
language=self._config.language,
|
||||
beam_size=self._config.beam_size,
|
||||
vad_filter=self._config.vad,
|
||||
)
|
||||
for seg in segments:
|
||||
seg.start += chunk.start_offset
|
||||
seg.end += chunk.start_offset
|
||||
for w in seg.words:
|
||||
w.start += chunk.start_offset
|
||||
w.end += chunk.start_offset
|
||||
all_segments.extend(segments)
|
||||
progress.update(1)
|
||||
return all_segments
|
||||
|
||||
def _export(self, segments: list[MergedSegment], stem: str) -> list[str]:
|
||||
"""Export segments to requested formats."""
|
||||
output_dir = Path(self._config.output_dir)
|
||||
exported = []
|
||||
for fmt in self._config.formats:
|
||||
exporter = EXPORTERS.get(fmt)
|
||||
if exporter is None:
|
||||
logger.warning("Unknown export format: %s (skipped)", fmt)
|
||||
continue
|
||||
out_path = str(output_dir / f"{stem}.{fmt}")
|
||||
exporter(segments, out_path)
|
||||
exported.append(out_path)
|
||||
logger.info("Exported: %s", out_path)
|
||||
return exported
|
||||
|
||||
def run(self) -> list[str]:
|
||||
"""Execute the full pipeline and return list of exported file paths.
|
||||
|
||||
Returns:
|
||||
List of paths to exported files.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If input file does not exist.
|
||||
RuntimeError: If any pipeline stage fails.
|
||||
"""
|
||||
cfg = self._config
|
||||
stem = Path(cfg.input_path).stem
|
||||
|
||||
total_steps = 7
|
||||
progress = tqdm(total=total_steps, desc="Pipeline", unit="step")
|
||||
|
||||
progress.set_description("Preprocessing")
|
||||
wav_path = preprocess_audio(cfg.input_path, cfg.output_dir)
|
||||
progress.update(1)
|
||||
|
||||
progress.set_description("Chunking")
|
||||
chunks = chunk_audio(wav_path, cfg.chunk_duration)
|
||||
progress.update(1)
|
||||
|
||||
progress.set_description("Loading models")
|
||||
self._init_engines()
|
||||
progress.update(1)
|
||||
|
||||
asr_progress = tqdm(total=len(chunks), desc="ASR", unit="chunk", leave=False)
|
||||
asr_segments = self._transcribe_chunks(chunks, asr_progress)
|
||||
asr_progress.close()
|
||||
progress.update(1)
|
||||
|
||||
progress.set_description("Diarizing")
|
||||
speaker_turns = self._diarizer.diarize(
|
||||
audio_path=wav_path,
|
||||
min_speakers=cfg.min_speakers,
|
||||
max_speakers=cfg.max_speakers,
|
||||
)
|
||||
progress.update(1)
|
||||
|
||||
progress.set_description("Aligning")
|
||||
merged = align_and_merge(
|
||||
asr_segments=asr_segments,
|
||||
speaker_turns=speaker_turns,
|
||||
pause_threshold=cfg.pause_threshold,
|
||||
)
|
||||
progress.update(1)
|
||||
|
||||
progress.set_description("Exporting")
|
||||
exported = self._export(merged, stem)
|
||||
progress.update(1)
|
||||
|
||||
progress.close()
|
||||
logger.info("Pipeline complete. Files: %s", exported)
|
||||
return exported
|
||||
Loading…
Reference in New Issue