137 lines
4.7 KiB
Python
137 lines
4.7 KiB
Python
import logging
|
|
from pathlib import Path
|
|
|
|
from tqdm import tqdm
|
|
|
|
from transcriber.asr.whisper_engine import Segment, WhisperEngine
|
|
from transcriber.audio.chunking import ChunkInfo, chunk_audio
|
|
from transcriber.audio.preprocess import preprocess_audio
|
|
from transcriber.config import TranscriberConfig
|
|
from transcriber.diarization.pyannote_engine import DiarizationEngine
|
|
from transcriber.export.json_writer import write_json
|
|
from transcriber.export.txt_writer import write_txt
|
|
from transcriber.merge.aligner import MergedSegment, align_and_merge
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
EXPORTERS = {
|
|
"txt": write_txt,
|
|
"json": write_json,
|
|
}
|
|
|
|
|
|
class TranscriptionPipeline:
|
|
"""Orchestrates the full transcription pipeline: preprocess -> ASR -> diarize -> merge -> export."""
|
|
|
|
def __init__(self, config: TranscriberConfig):
|
|
self._config = config
|
|
self._asr: WhisperEngine | None = None
|
|
self._diarizer: DiarizationEngine | None = None
|
|
|
|
def _init_engines(self) -> None:
|
|
"""Lazily initialize ASR and diarization engines."""
|
|
if self._asr is None:
|
|
self._asr = WhisperEngine(
|
|
model_name=self._config.model,
|
|
device=self._config.device,
|
|
compute_type=self._config.compute_type,
|
|
)
|
|
if self._diarizer is None:
|
|
self._diarizer = DiarizationEngine(
|
|
hf_token=self._config.hf_token,
|
|
device=self._config.device,
|
|
)
|
|
|
|
def _transcribe_chunks(self, chunks: list[ChunkInfo], progress: tqdm) -> list[Segment]:
|
|
"""Run ASR on each chunk, adjusting timestamps by chunk offset."""
|
|
all_segments: list[Segment] = []
|
|
for chunk in chunks:
|
|
progress.set_description(f"ASR chunk {chunk.start_offset:.0f}s")
|
|
segments = self._asr.transcribe(
|
|
audio_path=chunk.path,
|
|
language=self._config.language,
|
|
beam_size=self._config.beam_size,
|
|
vad_filter=self._config.vad,
|
|
)
|
|
for seg in segments:
|
|
seg.start += chunk.start_offset
|
|
seg.end += chunk.start_offset
|
|
for w in seg.words:
|
|
w.start += chunk.start_offset
|
|
w.end += chunk.start_offset
|
|
all_segments.extend(segments)
|
|
progress.update(1)
|
|
return all_segments
|
|
|
|
def _export(self, segments: list[MergedSegment], stem: str) -> list[str]:
|
|
"""Export segments to requested formats."""
|
|
output_dir = Path(self._config.output_dir)
|
|
exported = []
|
|
for fmt in self._config.formats:
|
|
exporter = EXPORTERS.get(fmt)
|
|
if exporter is None:
|
|
logger.warning("Unknown export format: %s (skipped)", fmt)
|
|
continue
|
|
out_path = str(output_dir / f"{stem}.{fmt}")
|
|
exporter(segments, out_path)
|
|
exported.append(out_path)
|
|
logger.info("Exported: %s", out_path)
|
|
return exported
|
|
|
|
def run(self) -> list[str]:
|
|
"""Execute the full pipeline and return list of exported file paths.
|
|
|
|
Returns:
|
|
List of paths to exported files.
|
|
|
|
Raises:
|
|
FileNotFoundError: If input file does not exist.
|
|
RuntimeError: If any pipeline stage fails.
|
|
"""
|
|
cfg = self._config
|
|
stem = Path(cfg.input_path).stem
|
|
|
|
total_steps = 7
|
|
progress = tqdm(total=total_steps, desc="Pipeline", unit="step")
|
|
|
|
progress.set_description("Preprocessing")
|
|
wav_path = preprocess_audio(cfg.input_path, cfg.output_dir)
|
|
progress.update(1)
|
|
|
|
progress.set_description("Chunking")
|
|
chunks = chunk_audio(wav_path, cfg.chunk_duration)
|
|
progress.update(1)
|
|
|
|
progress.set_description("Loading models")
|
|
self._init_engines()
|
|
progress.update(1)
|
|
|
|
asr_progress = tqdm(total=len(chunks), desc="ASR", unit="chunk", leave=False)
|
|
asr_segments = self._transcribe_chunks(chunks, asr_progress)
|
|
asr_progress.close()
|
|
progress.update(1)
|
|
|
|
progress.set_description("Diarizing")
|
|
speaker_turns = self._diarizer.diarize(
|
|
audio_path=wav_path,
|
|
min_speakers=cfg.min_speakers,
|
|
max_speakers=cfg.max_speakers,
|
|
)
|
|
progress.update(1)
|
|
|
|
progress.set_description("Aligning")
|
|
merged = align_and_merge(
|
|
asr_segments=asr_segments,
|
|
speaker_turns=speaker_turns,
|
|
pause_threshold=cfg.pause_threshold,
|
|
)
|
|
progress.update(1)
|
|
|
|
progress.set_description("Exporting")
|
|
exported = self._export(merged, stem)
|
|
progress.update(1)
|
|
|
|
progress.close()
|
|
logger.info("Pipeline complete. Files: %s", exported)
|
|
return exported
|