transcribe-interview/transcriber/pipeline.py

137 lines
4.7 KiB
Python

import logging
from pathlib import Path
from tqdm import tqdm
from transcriber.asr.whisper_engine import Segment, WhisperEngine
from transcriber.audio.chunking import ChunkInfo, chunk_audio
from transcriber.audio.preprocess import preprocess_audio
from transcriber.config import TranscriberConfig
from transcriber.diarization.pyannote_engine import DiarizationEngine
from transcriber.export.json_writer import write_json
from transcriber.export.txt_writer import write_txt
from transcriber.merge.aligner import MergedSegment, align_and_merge
logger = logging.getLogger(__name__)
EXPORTERS = {
"txt": write_txt,
"json": write_json,
}
class TranscriptionPipeline:
"""Orchestrates the full transcription pipeline: preprocess -> ASR -> diarize -> merge -> export."""
def __init__(self, config: TranscriberConfig):
self._config = config
self._asr: WhisperEngine | None = None
self._diarizer: DiarizationEngine | None = None
def _init_engines(self) -> None:
"""Lazily initialize ASR and diarization engines."""
if self._asr is None:
self._asr = WhisperEngine(
model_name=self._config.model,
device=self._config.device,
compute_type=self._config.compute_type,
)
if self._diarizer is None:
self._diarizer = DiarizationEngine(
hf_token=self._config.hf_token,
device=self._config.device,
)
def _transcribe_chunks(self, chunks: list[ChunkInfo], progress: tqdm) -> list[Segment]:
"""Run ASR on each chunk, adjusting timestamps by chunk offset."""
all_segments: list[Segment] = []
for chunk in chunks:
progress.set_description(f"ASR chunk {chunk.start_offset:.0f}s")
segments = self._asr.transcribe(
audio_path=chunk.path,
language=self._config.language,
beam_size=self._config.beam_size,
vad_filter=self._config.vad,
)
for seg in segments:
seg.start += chunk.start_offset
seg.end += chunk.start_offset
for w in seg.words:
w.start += chunk.start_offset
w.end += chunk.start_offset
all_segments.extend(segments)
progress.update(1)
return all_segments
def _export(self, segments: list[MergedSegment], stem: str) -> list[str]:
"""Export segments to requested formats."""
output_dir = Path(self._config.output_dir)
exported = []
for fmt in self._config.formats:
exporter = EXPORTERS.get(fmt)
if exporter is None:
logger.warning("Unknown export format: %s (skipped)", fmt)
continue
out_path = str(output_dir / f"{stem}.{fmt}")
exporter(segments, out_path)
exported.append(out_path)
logger.info("Exported: %s", out_path)
return exported
def run(self) -> list[str]:
"""Execute the full pipeline and return list of exported file paths.
Returns:
List of paths to exported files.
Raises:
FileNotFoundError: If input file does not exist.
RuntimeError: If any pipeline stage fails.
"""
cfg = self._config
stem = Path(cfg.input_path).stem
total_steps = 7
progress = tqdm(total=total_steps, desc="Pipeline", unit="step")
progress.set_description("Preprocessing")
wav_path = preprocess_audio(cfg.input_path, cfg.output_dir)
progress.update(1)
progress.set_description("Chunking")
chunks = chunk_audio(wav_path, cfg.chunk_duration)
progress.update(1)
progress.set_description("Loading models")
self._init_engines()
progress.update(1)
asr_progress = tqdm(total=len(chunks), desc="ASR", unit="chunk", leave=False)
asr_segments = self._transcribe_chunks(chunks, asr_progress)
asr_progress.close()
progress.update(1)
progress.set_description("Diarizing")
speaker_turns = self._diarizer.diarize(
audio_path=wav_path,
min_speakers=cfg.min_speakers,
max_speakers=cfg.max_speakers,
)
progress.update(1)
progress.set_description("Aligning")
merged = align_and_merge(
asr_segments=asr_segments,
speaker_turns=speaker_turns,
pause_threshold=cfg.pause_threshold,
)
progress.update(1)
progress.set_description("Exporting")
exported = self._export(merged, stem)
progress.update(1)
progress.close()
logger.info("Pipeline complete. Files: %s", exported)
return exported