import logging from pathlib import Path from tqdm import tqdm from transcriber.asr.whisper_engine import Segment, WhisperEngine from transcriber.audio.chunking import ChunkInfo, chunk_audio from transcriber.audio.preprocess import preprocess_audio from transcriber.config import TranscriberConfig from transcriber.diarization.pyannote_engine import DiarizationEngine from transcriber.export.json_writer import write_json from transcriber.export.txt_writer import write_txt from transcriber.merge.aligner import MergedSegment, align_and_merge logger = logging.getLogger(__name__) EXPORTERS = { "txt": write_txt, "json": write_json, } class TranscriptionPipeline: """Orchestrates the full transcription pipeline: preprocess -> ASR -> diarize -> merge -> export.""" def __init__(self, config: TranscriberConfig): self._config = config self._asr: WhisperEngine | None = None self._diarizer: DiarizationEngine | None = None def _init_engines(self) -> None: """Lazily initialize ASR and diarization engines.""" if self._asr is None: self._asr = WhisperEngine( model_name=self._config.model, device=self._config.device, compute_type=self._config.compute_type, ) if self._diarizer is None: self._diarizer = DiarizationEngine( hf_token=self._config.hf_token, device=self._config.device, ) def _transcribe_chunks(self, chunks: list[ChunkInfo], progress: tqdm) -> list[Segment]: """Run ASR on each chunk, adjusting timestamps by chunk offset.""" all_segments: list[Segment] = [] for chunk in chunks: progress.set_description(f"ASR chunk {chunk.start_offset:.0f}s") segments = self._asr.transcribe( audio_path=chunk.path, language=self._config.language, beam_size=self._config.beam_size, vad_filter=self._config.vad, ) for seg in segments: seg.start += chunk.start_offset seg.end += chunk.start_offset for w in seg.words: w.start += chunk.start_offset w.end += chunk.start_offset all_segments.extend(segments) progress.update(1) return all_segments def _export(self, segments: list[MergedSegment], stem: str) -> list[str]: """Export segments to requested formats.""" output_dir = Path(self._config.output_dir) exported = [] for fmt in self._config.formats: exporter = EXPORTERS.get(fmt) if exporter is None: logger.warning("Unknown export format: %s (skipped)", fmt) continue out_path = str(output_dir / f"{stem}.{fmt}") exporter(segments, out_path) exported.append(out_path) logger.info("Exported: %s", out_path) return exported def run(self) -> list[str]: """Execute the full pipeline and return list of exported file paths. Returns: List of paths to exported files. Raises: FileNotFoundError: If input file does not exist. RuntimeError: If any pipeline stage fails. """ cfg = self._config stem = Path(cfg.input_path).stem total_steps = 7 progress = tqdm(total=total_steps, desc="Pipeline", unit="step") progress.set_description("Preprocessing") wav_path = preprocess_audio(cfg.input_path, cfg.output_dir) progress.update(1) progress.set_description("Chunking") chunks = chunk_audio(wav_path, cfg.chunk_duration) progress.update(1) progress.set_description("Loading models") self._init_engines() progress.update(1) asr_progress = tqdm(total=len(chunks), desc="ASR", unit="chunk", leave=False) asr_segments = self._transcribe_chunks(chunks, asr_progress) asr_progress.close() progress.update(1) progress.set_description("Diarizing") speaker_turns = self._diarizer.diarize( audio_path=wav_path, min_speakers=cfg.min_speakers, max_speakers=cfg.max_speakers, ) progress.update(1) progress.set_description("Aligning") merged = align_and_merge( asr_segments=asr_segments, speaker_turns=speaker_turns, pause_threshold=cfg.pause_threshold, ) progress.update(1) progress.set_description("Exporting") exported = self._export(merged, stem) progress.update(1) progress.close() logger.info("Pipeline complete. Files: %s", exported) return exported