diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ef65724 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +faster-whisper +pyannote.audio +python-dotenv +pydub +tqdm diff --git a/transcriber.py b/transcriber.py new file mode 100644 index 0000000..9ede4a0 --- /dev/null +++ b/transcriber.py @@ -0,0 +1,4 @@ +from transcriber.main import main + +if __name__ == "__main__": + main() diff --git a/transcriber/__init__.py b/transcriber/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/transcriber/asr/__init__.py b/transcriber/asr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/transcriber/asr/whisper_engine.py b/transcriber/asr/whisper_engine.py new file mode 100644 index 0000000..aa4e56c --- /dev/null +++ b/transcriber/asr/whisper_engine.py @@ -0,0 +1,93 @@ +import logging +from dataclasses import dataclass, field + +from faster_whisper import WhisperModel + +logger = logging.getLogger(__name__) + + +@dataclass +class WordInfo: + """Single word with timing and confidence.""" + + word: str + start: float + end: float + probability: float + + +@dataclass +class Segment: + """ASR segment with text, timing, and optional word-level detail.""" + + start: float + end: float + text: str + words: list[WordInfo] = field(default_factory=list) + + +class WhisperEngine: + """Speech recognition engine based on faster-whisper.""" + + def __init__(self, model_name: str, device: str, compute_type: str): + logger.info("Loading Whisper model: %s on %s (%s)", model_name, device, compute_type) + self._model = WhisperModel(model_name, device=device, compute_type=compute_type) + + def transcribe( + self, + audio_path: str, + language: str | None = None, + beam_size: int = 5, + vad_filter: bool = True, + ) -> list[Segment]: + """Transcribe audio file and return list of segments. + + Args: + audio_path: Path to WAV file. + language: Language code or None for auto-detection. + beam_size: Beam search width. + vad_filter: Whether to enable VAD filtering. + + Returns: + List of transcription segments with word-level timestamps. + """ + logger.info("Transcribing: %s", audio_path) + + segments_gen, info = self._model.transcribe( + audio_path, + language=language, + beam_size=beam_size, + word_timestamps=True, + vad_filter=vad_filter, + vad_parameters={"min_silence_duration_ms": 500}, + temperature=0.0, + condition_on_previous_text=False, + no_speech_threshold=0.6, + log_prob_threshold=-1.0, + ) + + logger.info( + "Detected language: %s (%.2f), duration: %.1fs", + info.language, info.language_probability, info.duration, + ) + + results = [] + for seg in segments_gen: + words = [ + WordInfo( + word=w.word.strip(), + start=w.start, + end=w.end, + probability=w.probability, + ) + for w in (seg.words or []) + ] + results.append(Segment( + start=seg.start, + end=seg.end, + text=seg.text.strip(), + words=words, + )) + + logger.info("Transcription complete: %d segments", len(results)) + return results diff --git a/transcriber/audio/__init__.py b/transcriber/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/transcriber/audio/chunking.py b/transcriber/audio/chunking.py new file mode 100644 index 0000000..22f4a23 --- /dev/null +++ b/transcriber/audio/chunking.py @@ -0,0 +1,82 @@ +import logging +import subprocess +from dataclasses import dataclass +from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass +class ChunkInfo: + """Metadata for a single audio chunk.""" + + path: str + start_offset: float + duration: float + + +def get_audio_duration(wav_path: str) -> float: + """Get duration of audio file in seconds using ffprobe.""" + cmd = [ + "ffprobe", "-v", "quiet", + "-show_entries", "format=duration", + "-of", "csv=p=0", + wav_path, + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode != 0: + raise RuntimeError(f"ffprobe failed: {result.stderr[:300]}") + return float(result.stdout.strip()) + + +def chunk_audio(wav_path: str, max_duration_sec: int = 1800) -> list[ChunkInfo]: + """Split audio into chunks if longer than max_duration_sec. + + Args: + wav_path: Path to the preprocessed WAV file. + max_duration_sec: Maximum chunk duration in seconds (default 30 min). + + Returns: + List of ChunkInfo with paths and timing metadata. + """ + total_duration = get_audio_duration(wav_path) + logger.info("Audio duration: %.1f sec", total_duration) + + if total_duration <= max_duration_sec: + return [ChunkInfo(path=wav_path, start_offset=0.0, duration=total_duration)] + + chunks = [] + src = Path(wav_path) + chunk_dir = src.parent / "chunks" + chunk_dir.mkdir(exist_ok=True) + + offset = 0.0 + idx = 0 + while offset < total_duration: + chunk_path = str(chunk_dir / f"{src.stem}_chunk{idx:03d}.wav") + remaining = total_duration - offset + duration = min(max_duration_sec, remaining) + + cmd = [ + "ffmpeg", "-y", + "-ss", str(offset), + "-i", wav_path, + "-t", str(duration), + "-c", "copy", + chunk_path, + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode != 0: + raise RuntimeError(f"Chunk {idx} failed: {result.stderr[:300]}") + + chunks.append(ChunkInfo( + path=chunk_path, + start_offset=offset, + duration=duration, + )) + logger.info("Chunk %d: %.1fs - %.1fs", idx, offset, offset + duration) + + offset += duration + idx += 1 + + return chunks diff --git a/transcriber/audio/preprocess.py b/transcriber/audio/preprocess.py new file mode 100644 index 0000000..4a5c48e --- /dev/null +++ b/transcriber/audio/preprocess.py @@ -0,0 +1,54 @@ +import logging +import subprocess +from pathlib import Path + +logger = logging.getLogger(__name__) + +SUPPORTED_FORMATS = {".m4a", ".mp3", ".wav", ".aac"} + + +def preprocess_audio(input_path: str, output_dir: str) -> str: + """Convert audio to mono 16kHz PCM WAV with normalization and DC offset removal. + + Args: + input_path: Path to the source audio file. + output_dir: Directory for the processed file. + + Returns: + Path to the processed WAV file. + + Raises: + FileNotFoundError: If input file does not exist. + ValueError: If file format is not supported. + RuntimeError: If ffmpeg processing fails. + """ + src = Path(input_path) + if not src.exists(): + raise FileNotFoundError(f"Audio file not found: {input_path}") + if src.suffix.lower() not in SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported format: {src.suffix}. Supported: {SUPPORTED_FORMATS}" + ) + + out = Path(output_dir) / f"{src.stem}_processed.wav" + out.parent.mkdir(parents=True, exist_ok=True) + + cmd = [ + "ffmpeg", "-y", "-i", str(src), + "-ac", "1", + "-ar", "16000", + "-sample_fmt", "s16", + "-af", "highpass=f=10,loudnorm=I=-16:TP=-1.5:LRA=11", + str(out), + ] + + logger.info("Preprocessing: %s -> %s", src.name, out.name) + + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=600 + ) + if result.returncode != 0: + raise RuntimeError(f"ffmpeg failed: {result.stderr[:500]}") + + logger.info("Preprocessing complete: %s", out.name) + return str(out) diff --git a/transcriber/config.py b/transcriber/config.py new file mode 100644 index 0000000..d3ea825 --- /dev/null +++ b/transcriber/config.py @@ -0,0 +1,38 @@ +import os +from dataclasses import dataclass, field +from pathlib import Path + +from dotenv import load_dotenv + + +@dataclass +class TranscriberConfig: + """Configuration for the transcription pipeline.""" + + input_path: str = "" + output_dir: str = "./output" + model: str = "large-v3" + device: str = "cuda" + compute_type: str = "float16" + language: str = "ru" + beam_size: int = 5 + vad: bool = True + max_speakers: int | None = None + min_speakers: int | None = None + formats: list[str] = field(default_factory=lambda: ["txt", "json"]) + pause_threshold: float = 1.5 + chunk_duration: int = 1800 + hf_token: str = "" + + def __post_init__(self): + load_dotenv() + if not self.hf_token: + self.hf_token = os.getenv("HF_TOKEN", "") + if not self.hf_token: + raise ValueError( + "HF_TOKEN is required for pyannote diarization. " + "Set it in .env or pass via --hf-token" + ) + if self.device == "cpu": + self.compute_type = "int8" + Path(self.output_dir).mkdir(parents=True, exist_ok=True) diff --git a/transcriber/diarization/__init__.py b/transcriber/diarization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/transcriber/diarization/pyannote_engine.py b/transcriber/diarization/pyannote_engine.py new file mode 100644 index 0000000..7ba07c1 --- /dev/null +++ b/transcriber/diarization/pyannote_engine.py @@ -0,0 +1,69 @@ +import logging +from dataclasses import dataclass + +import torch +from pyannote.audio import Pipeline +from pyannote.audio.pipelines.utils.hook import ProgressHook + +logger = logging.getLogger(__name__) + + +@dataclass +class SpeakerTurn: + """A single speaker turn with timing.""" + + start: float + end: float + speaker: str + + +class DiarizationEngine: + """Speaker diarization engine based on pyannote.audio.""" + + def __init__(self, hf_token: str, device: str): + logger.info("Loading diarization pipeline on %s", device) + self._pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + token=hf_token, + ) + self._device = torch.device(device) + self._pipeline.to(self._device) + + def diarize( + self, + audio_path: str, + min_speakers: int | None = None, + max_speakers: int | None = None, + ) -> list[SpeakerTurn]: + """Run speaker diarization on audio file. + + Args: + audio_path: Path to WAV file. + min_speakers: Minimum expected number of speakers. + max_speakers: Maximum expected number of speakers. + + Returns: + List of speaker turns sorted by start time. + """ + logger.info("Diarizing: %s", audio_path) + + kwargs = {} + if min_speakers is not None: + kwargs["min_speakers"] = min_speakers + if max_speakers is not None: + kwargs["max_speakers"] = max_speakers + + with ProgressHook() as hook: + diarization = self._pipeline(audio_path, hook=hook, **kwargs) + + turns = [] + for turn, _, speaker in diarization.itertracks(yield_label=True): + turns.append(SpeakerTurn( + start=turn.start, + end=turn.end, + speaker=speaker, + )) + + speaker_set = {t.speaker for t in turns} + logger.info("Diarization complete: %d turns, %d speakers", len(turns), len(speaker_set)) + return turns diff --git a/transcriber/export/__init__.py b/transcriber/export/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/transcriber/export/json_writer.py b/transcriber/export/json_writer.py new file mode 100644 index 0000000..ff8f05c --- /dev/null +++ b/transcriber/export/json_writer.py @@ -0,0 +1,34 @@ +import json +from pathlib import Path + +from transcriber.merge.aligner import MergedSegment + + +def write_json(segments: list[MergedSegment], output_path: str) -> str: + """Export merged segments as a structured JSON file. + + Args: + segments: List of merged speaker segments. + output_path: Path to the output .json file. + + Returns: + Path to the written file. + """ + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + + data = [ + { + "speaker": seg.speaker, + "start": round(seg.start, 2), + "end": round(seg.end, 2), + "text": seg.text, + } + for seg in segments + ] + + path.write_text( + json.dumps(data, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + return str(path) diff --git a/transcriber/export/txt_writer.py b/transcriber/export/txt_writer.py new file mode 100644 index 0000000..b97d24f --- /dev/null +++ b/transcriber/export/txt_writer.py @@ -0,0 +1,26 @@ +from pathlib import Path + +from transcriber.merge.aligner import MergedSegment + + +def write_txt(segments: list[MergedSegment], output_path: str) -> str: + """Export merged segments as a readable dialogue text file. + + Args: + segments: List of merged speaker segments. + output_path: Path to the output .txt file. + + Returns: + Path to the written file. + """ + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + + lines = [] + for seg in segments: + lines.append(f"[{seg.speaker}]") + lines.append(seg.text) + lines.append("") + + path.write_text("\n".join(lines), encoding="utf-8") + return str(path) diff --git a/transcriber/main.py b/transcriber/main.py new file mode 100644 index 0000000..7b31054 --- /dev/null +++ b/transcriber/main.py @@ -0,0 +1,70 @@ +import argparse +import logging +import sys + +from transcriber.config import TranscriberConfig +from transcriber.pipeline import TranscriptionPipeline + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Transcribe audio with speaker diarization", + ) + parser.add_argument("input", help="Path to audio file (.m4a, .mp3, .wav, .aac)") + parser.add_argument("--output", default="./output", help="Output directory (default: ./output)") + parser.add_argument("--model", default="large-v3", help="Whisper model name (default: large-v3)") + parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device (default: cuda)") + parser.add_argument("--compute-type", default="float16", help="Compute type (default: float16)") + parser.add_argument("--language", default="ru", help="Language code (default: ru)") + parser.add_argument("--beam-size", type=int, default=5, help="Beam search size (default: 5)") + parser.add_argument("--vad", default="on", choices=["on", "off"], help="VAD filter (default: on)") + parser.add_argument("--max-speakers", type=int, default=None, help="Maximum number of speakers") + parser.add_argument("--min-speakers", type=int, default=None, help="Minimum number of speakers") + parser.add_argument("--format", nargs="+", default=["txt", "json"], help="Output formats (default: txt json)") + parser.add_argument("--pause-threshold", type=float, default=1.5, help="Max pause for merging (default: 1.5s)") + parser.add_argument("--chunk-duration", type=int, default=1800, help="Max chunk duration in sec (default: 1800)") + parser.add_argument("--hf-token", default="", help="HuggingFace token (default: from .env)") + parser.add_argument("--verbose", "-v", action="store_true", help="Enable debug logging") + return parser.parse_args() + + +def main() -> None: + """Entry point for the transcription CLI.""" + args = parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", + ) + + config = TranscriberConfig( + input_path=args.input, + output_dir=args.output, + model=args.model, + device=args.device, + compute_type=args.compute_type, + language=args.language, + beam_size=args.beam_size, + vad=args.vad == "on", + max_speakers=args.max_speakers, + min_speakers=args.min_speakers, + formats=args.format, + pause_threshold=args.pause_threshold, + chunk_duration=args.chunk_duration, + hf_token=args.hf_token, + ) + + pipeline = TranscriptionPipeline(config) + try: + exported = pipeline.run() + for path in exported: + print(f"Saved: {path}") + except Exception: + logging.exception("Pipeline failed") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/transcriber/merge/__init__.py b/transcriber/merge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/transcriber/merge/aligner.py b/transcriber/merge/aligner.py new file mode 100644 index 0000000..bbd9c9f --- /dev/null +++ b/transcriber/merge/aligner.py @@ -0,0 +1,101 @@ +import logging +from dataclasses import dataclass + +from transcriber.asr.whisper_engine import Segment +from transcriber.diarization.pyannote_engine import SpeakerTurn + +logger = logging.getLogger(__name__) + + +@dataclass +class MergedSegment: + """Final segment with speaker label and merged text.""" + + speaker: str + start: float + end: float + text: str + + +def _compute_overlap(seg_start: float, seg_end: float, turn_start: float, turn_end: float) -> float: + """Compute temporal overlap between two intervals in seconds.""" + overlap_start = max(seg_start, turn_start) + overlap_end = min(seg_end, turn_end) + return max(0.0, overlap_end - overlap_start) + + +def _assign_speaker(segment: Segment, speaker_turns: list[SpeakerTurn]) -> str: + """Assign speaker to ASR segment by maximum overlap.""" + best_speaker = "Unknown" + best_overlap = 0.0 + + for turn in speaker_turns: + if turn.end < segment.start: + continue + if turn.start > segment.end: + break + overlap = _compute_overlap(segment.start, segment.end, turn.start, turn.end) + if overlap > best_overlap: + best_overlap = overlap + best_speaker = turn.speaker + + return best_speaker + + +def _normalize_speaker_labels(segments: list[MergedSegment]) -> list[MergedSegment]: + """Replace raw pyannote labels (SPEAKER_00) with sequential Speaker 1, Speaker 2.""" + label_map: dict[str, str] = {} + counter = 1 + + for seg in segments: + if seg.speaker not in label_map and seg.speaker != "Unknown": + label_map[seg.speaker] = f"Speaker {counter}" + counter += 1 + + for seg in segments: + seg.speaker = label_map.get(seg.speaker, seg.speaker) + + return segments + + +def align_and_merge( + asr_segments: list[Segment], + speaker_turns: list[SpeakerTurn], + pause_threshold: float = 1.5, +) -> list[MergedSegment]: + """Align ASR segments with speaker turns and merge adjacent same-speaker segments. + + Args: + asr_segments: Segments from Whisper ASR. + speaker_turns: Speaker turns from diarization (sorted by start). + pause_threshold: Max pause between segments to merge (seconds). + + Returns: + List of merged segments with speaker labels. + """ + if not asr_segments: + return [] + + aligned = [] + for seg in asr_segments: + speaker = _assign_speaker(seg, speaker_turns) + aligned.append(MergedSegment( + speaker=speaker, + start=seg.start, + end=seg.end, + text=seg.text, + )) + + merged: list[MergedSegment] = [aligned[0]] + for seg in aligned[1:]: + prev = merged[-1] + gap = seg.start - prev.end + if seg.speaker == prev.speaker and gap <= pause_threshold: + prev.end = seg.end + prev.text = f"{prev.text} {seg.text}" + else: + merged.append(seg) + + merged = _normalize_speaker_labels(merged) + logger.info("Alignment complete: %d ASR segments -> %d merged", len(asr_segments), len(merged)) + return merged diff --git a/transcriber/pipeline.py b/transcriber/pipeline.py new file mode 100644 index 0000000..2366b26 --- /dev/null +++ b/transcriber/pipeline.py @@ -0,0 +1,136 @@ +import logging +from pathlib import Path + +from tqdm import tqdm + +from transcriber.asr.whisper_engine import Segment, WhisperEngine +from transcriber.audio.chunking import ChunkInfo, chunk_audio +from transcriber.audio.preprocess import preprocess_audio +from transcriber.config import TranscriberConfig +from transcriber.diarization.pyannote_engine import DiarizationEngine +from transcriber.export.json_writer import write_json +from transcriber.export.txt_writer import write_txt +from transcriber.merge.aligner import MergedSegment, align_and_merge + +logger = logging.getLogger(__name__) + +EXPORTERS = { + "txt": write_txt, + "json": write_json, +} + + +class TranscriptionPipeline: + """Orchestrates the full transcription pipeline: preprocess -> ASR -> diarize -> merge -> export.""" + + def __init__(self, config: TranscriberConfig): + self._config = config + self._asr: WhisperEngine | None = None + self._diarizer: DiarizationEngine | None = None + + def _init_engines(self) -> None: + """Lazily initialize ASR and diarization engines.""" + if self._asr is None: + self._asr = WhisperEngine( + model_name=self._config.model, + device=self._config.device, + compute_type=self._config.compute_type, + ) + if self._diarizer is None: + self._diarizer = DiarizationEngine( + hf_token=self._config.hf_token, + device=self._config.device, + ) + + def _transcribe_chunks(self, chunks: list[ChunkInfo], progress: tqdm) -> list[Segment]: + """Run ASR on each chunk, adjusting timestamps by chunk offset.""" + all_segments: list[Segment] = [] + for chunk in chunks: + progress.set_description(f"ASR chunk {chunk.start_offset:.0f}s") + segments = self._asr.transcribe( + audio_path=chunk.path, + language=self._config.language, + beam_size=self._config.beam_size, + vad_filter=self._config.vad, + ) + for seg in segments: + seg.start += chunk.start_offset + seg.end += chunk.start_offset + for w in seg.words: + w.start += chunk.start_offset + w.end += chunk.start_offset + all_segments.extend(segments) + progress.update(1) + return all_segments + + def _export(self, segments: list[MergedSegment], stem: str) -> list[str]: + """Export segments to requested formats.""" + output_dir = Path(self._config.output_dir) + exported = [] + for fmt in self._config.formats: + exporter = EXPORTERS.get(fmt) + if exporter is None: + logger.warning("Unknown export format: %s (skipped)", fmt) + continue + out_path = str(output_dir / f"{stem}.{fmt}") + exporter(segments, out_path) + exported.append(out_path) + logger.info("Exported: %s", out_path) + return exported + + def run(self) -> list[str]: + """Execute the full pipeline and return list of exported file paths. + + Returns: + List of paths to exported files. + + Raises: + FileNotFoundError: If input file does not exist. + RuntimeError: If any pipeline stage fails. + """ + cfg = self._config + stem = Path(cfg.input_path).stem + + total_steps = 7 + progress = tqdm(total=total_steps, desc="Pipeline", unit="step") + + progress.set_description("Preprocessing") + wav_path = preprocess_audio(cfg.input_path, cfg.output_dir) + progress.update(1) + + progress.set_description("Chunking") + chunks = chunk_audio(wav_path, cfg.chunk_duration) + progress.update(1) + + progress.set_description("Loading models") + self._init_engines() + progress.update(1) + + asr_progress = tqdm(total=len(chunks), desc="ASR", unit="chunk", leave=False) + asr_segments = self._transcribe_chunks(chunks, asr_progress) + asr_progress.close() + progress.update(1) + + progress.set_description("Diarizing") + speaker_turns = self._diarizer.diarize( + audio_path=wav_path, + min_speakers=cfg.min_speakers, + max_speakers=cfg.max_speakers, + ) + progress.update(1) + + progress.set_description("Aligning") + merged = align_and_merge( + asr_segments=asr_segments, + speaker_turns=speaker_turns, + pause_threshold=cfg.pause_threshold, + ) + progress.update(1) + + progress.set_description("Exporting") + exported = self._export(merged, stem) + progress.update(1) + + progress.close() + logger.info("Pipeline complete. Files: %s", exported) + return exported