first version

This commit is contained in:
itqop 2026-02-19 03:22:04 +03:00
parent 3400a8070a
commit de07a045ce
18 changed files with 712 additions and 0 deletions

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
faster-whisper
pyannote.audio
python-dotenv
pydub
tqdm

4
transcriber.py Normal file
View File

@ -0,0 +1,4 @@
from transcriber.main import main
if __name__ == "__main__":
main()

0
transcriber/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,93 @@
import logging
from dataclasses import dataclass, field
from faster_whisper import WhisperModel
logger = logging.getLogger(__name__)
@dataclass
class WordInfo:
"""Single word with timing and confidence."""
word: str
start: float
end: float
probability: float
@dataclass
class Segment:
"""ASR segment with text, timing, and optional word-level detail."""
start: float
end: float
text: str
words: list[WordInfo] = field(default_factory=list)
class WhisperEngine:
"""Speech recognition engine based on faster-whisper."""
def __init__(self, model_name: str, device: str, compute_type: str):
logger.info("Loading Whisper model: %s on %s (%s)", model_name, device, compute_type)
self._model = WhisperModel(model_name, device=device, compute_type=compute_type)
def transcribe(
self,
audio_path: str,
language: str | None = None,
beam_size: int = 5,
vad_filter: bool = True,
) -> list[Segment]:
"""Transcribe audio file and return list of segments.
Args:
audio_path: Path to WAV file.
language: Language code or None for auto-detection.
beam_size: Beam search width.
vad_filter: Whether to enable VAD filtering.
Returns:
List of transcription segments with word-level timestamps.
"""
logger.info("Transcribing: %s", audio_path)
segments_gen, info = self._model.transcribe(
audio_path,
language=language,
beam_size=beam_size,
word_timestamps=True,
vad_filter=vad_filter,
vad_parameters={"min_silence_duration_ms": 500},
temperature=0.0,
condition_on_previous_text=False,
no_speech_threshold=0.6,
log_prob_threshold=-1.0,
)
logger.info(
"Detected language: %s (%.2f), duration: %.1fs",
info.language, info.language_probability, info.duration,
)
results = []
for seg in segments_gen:
words = [
WordInfo(
word=w.word.strip(),
start=w.start,
end=w.end,
probability=w.probability,
)
for w in (seg.words or [])
]
results.append(Segment(
start=seg.start,
end=seg.end,
text=seg.text.strip(),
words=words,
))
logger.info("Transcription complete: %d segments", len(results))
return results

View File

View File

@ -0,0 +1,82 @@
import logging
import subprocess
from dataclasses import dataclass
from pathlib import Path
logger = logging.getLogger(__name__)
@dataclass
class ChunkInfo:
"""Metadata for a single audio chunk."""
path: str
start_offset: float
duration: float
def get_audio_duration(wav_path: str) -> float:
"""Get duration of audio file in seconds using ffprobe."""
cmd = [
"ffprobe", "-v", "quiet",
"-show_entries", "format=duration",
"-of", "csv=p=0",
wav_path,
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
raise RuntimeError(f"ffprobe failed: {result.stderr[:300]}")
return float(result.stdout.strip())
def chunk_audio(wav_path: str, max_duration_sec: int = 1800) -> list[ChunkInfo]:
"""Split audio into chunks if longer than max_duration_sec.
Args:
wav_path: Path to the preprocessed WAV file.
max_duration_sec: Maximum chunk duration in seconds (default 30 min).
Returns:
List of ChunkInfo with paths and timing metadata.
"""
total_duration = get_audio_duration(wav_path)
logger.info("Audio duration: %.1f sec", total_duration)
if total_duration <= max_duration_sec:
return [ChunkInfo(path=wav_path, start_offset=0.0, duration=total_duration)]
chunks = []
src = Path(wav_path)
chunk_dir = src.parent / "chunks"
chunk_dir.mkdir(exist_ok=True)
offset = 0.0
idx = 0
while offset < total_duration:
chunk_path = str(chunk_dir / f"{src.stem}_chunk{idx:03d}.wav")
remaining = total_duration - offset
duration = min(max_duration_sec, remaining)
cmd = [
"ffmpeg", "-y",
"-ss", str(offset),
"-i", wav_path,
"-t", str(duration),
"-c", "copy",
chunk_path,
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
raise RuntimeError(f"Chunk {idx} failed: {result.stderr[:300]}")
chunks.append(ChunkInfo(
path=chunk_path,
start_offset=offset,
duration=duration,
))
logger.info("Chunk %d: %.1fs - %.1fs", idx, offset, offset + duration)
offset += duration
idx += 1
return chunks

View File

@ -0,0 +1,54 @@
import logging
import subprocess
from pathlib import Path
logger = logging.getLogger(__name__)
SUPPORTED_FORMATS = {".m4a", ".mp3", ".wav", ".aac"}
def preprocess_audio(input_path: str, output_dir: str) -> str:
"""Convert audio to mono 16kHz PCM WAV with normalization and DC offset removal.
Args:
input_path: Path to the source audio file.
output_dir: Directory for the processed file.
Returns:
Path to the processed WAV file.
Raises:
FileNotFoundError: If input file does not exist.
ValueError: If file format is not supported.
RuntimeError: If ffmpeg processing fails.
"""
src = Path(input_path)
if not src.exists():
raise FileNotFoundError(f"Audio file not found: {input_path}")
if src.suffix.lower() not in SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported format: {src.suffix}. Supported: {SUPPORTED_FORMATS}"
)
out = Path(output_dir) / f"{src.stem}_processed.wav"
out.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"ffmpeg", "-y", "-i", str(src),
"-ac", "1",
"-ar", "16000",
"-sample_fmt", "s16",
"-af", "highpass=f=10,loudnorm=I=-16:TP=-1.5:LRA=11",
str(out),
]
logger.info("Preprocessing: %s -> %s", src.name, out.name)
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=600
)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg failed: {result.stderr[:500]}")
logger.info("Preprocessing complete: %s", out.name)
return str(out)

38
transcriber/config.py Normal file
View File

@ -0,0 +1,38 @@
import os
from dataclasses import dataclass, field
from pathlib import Path
from dotenv import load_dotenv
@dataclass
class TranscriberConfig:
"""Configuration for the transcription pipeline."""
input_path: str = ""
output_dir: str = "./output"
model: str = "large-v3"
device: str = "cuda"
compute_type: str = "float16"
language: str = "ru"
beam_size: int = 5
vad: bool = True
max_speakers: int | None = None
min_speakers: int | None = None
formats: list[str] = field(default_factory=lambda: ["txt", "json"])
pause_threshold: float = 1.5
chunk_duration: int = 1800
hf_token: str = ""
def __post_init__(self):
load_dotenv()
if not self.hf_token:
self.hf_token = os.getenv("HF_TOKEN", "")
if not self.hf_token:
raise ValueError(
"HF_TOKEN is required for pyannote diarization. "
"Set it in .env or pass via --hf-token"
)
if self.device == "cpu":
self.compute_type = "int8"
Path(self.output_dir).mkdir(parents=True, exist_ok=True)

View File

View File

@ -0,0 +1,69 @@
import logging
from dataclasses import dataclass
import torch
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
logger = logging.getLogger(__name__)
@dataclass
class SpeakerTurn:
"""A single speaker turn with timing."""
start: float
end: float
speaker: str
class DiarizationEngine:
"""Speaker diarization engine based on pyannote.audio."""
def __init__(self, hf_token: str, device: str):
logger.info("Loading diarization pipeline on %s", device)
self._pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
token=hf_token,
)
self._device = torch.device(device)
self._pipeline.to(self._device)
def diarize(
self,
audio_path: str,
min_speakers: int | None = None,
max_speakers: int | None = None,
) -> list[SpeakerTurn]:
"""Run speaker diarization on audio file.
Args:
audio_path: Path to WAV file.
min_speakers: Minimum expected number of speakers.
max_speakers: Maximum expected number of speakers.
Returns:
List of speaker turns sorted by start time.
"""
logger.info("Diarizing: %s", audio_path)
kwargs = {}
if min_speakers is not None:
kwargs["min_speakers"] = min_speakers
if max_speakers is not None:
kwargs["max_speakers"] = max_speakers
with ProgressHook() as hook:
diarization = self._pipeline(audio_path, hook=hook, **kwargs)
turns = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
turns.append(SpeakerTurn(
start=turn.start,
end=turn.end,
speaker=speaker,
))
speaker_set = {t.speaker for t in turns}
logger.info("Diarization complete: %d turns, %d speakers", len(turns), len(speaker_set))
return turns

View File

View File

@ -0,0 +1,34 @@
import json
from pathlib import Path
from transcriber.merge.aligner import MergedSegment
def write_json(segments: list[MergedSegment], output_path: str) -> str:
"""Export merged segments as a structured JSON file.
Args:
segments: List of merged speaker segments.
output_path: Path to the output .json file.
Returns:
Path to the written file.
"""
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
data = [
{
"speaker": seg.speaker,
"start": round(seg.start, 2),
"end": round(seg.end, 2),
"text": seg.text,
}
for seg in segments
]
path.write_text(
json.dumps(data, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return str(path)

View File

@ -0,0 +1,26 @@
from pathlib import Path
from transcriber.merge.aligner import MergedSegment
def write_txt(segments: list[MergedSegment], output_path: str) -> str:
"""Export merged segments as a readable dialogue text file.
Args:
segments: List of merged speaker segments.
output_path: Path to the output .txt file.
Returns:
Path to the written file.
"""
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
lines = []
for seg in segments:
lines.append(f"[{seg.speaker}]")
lines.append(seg.text)
lines.append("")
path.write_text("\n".join(lines), encoding="utf-8")
return str(path)

70
transcriber/main.py Normal file
View File

@ -0,0 +1,70 @@
import argparse
import logging
import sys
from transcriber.config import TranscriberConfig
from transcriber.pipeline import TranscriptionPipeline
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Transcribe audio with speaker diarization",
)
parser.add_argument("input", help="Path to audio file (.m4a, .mp3, .wav, .aac)")
parser.add_argument("--output", default="./output", help="Output directory (default: ./output)")
parser.add_argument("--model", default="large-v3", help="Whisper model name (default: large-v3)")
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device (default: cuda)")
parser.add_argument("--compute-type", default="float16", help="Compute type (default: float16)")
parser.add_argument("--language", default="ru", help="Language code (default: ru)")
parser.add_argument("--beam-size", type=int, default=5, help="Beam search size (default: 5)")
parser.add_argument("--vad", default="on", choices=["on", "off"], help="VAD filter (default: on)")
parser.add_argument("--max-speakers", type=int, default=None, help="Maximum number of speakers")
parser.add_argument("--min-speakers", type=int, default=None, help="Minimum number of speakers")
parser.add_argument("--format", nargs="+", default=["txt", "json"], help="Output formats (default: txt json)")
parser.add_argument("--pause-threshold", type=float, default=1.5, help="Max pause for merging (default: 1.5s)")
parser.add_argument("--chunk-duration", type=int, default=1800, help="Max chunk duration in sec (default: 1800)")
parser.add_argument("--hf-token", default="", help="HuggingFace token (default: from .env)")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable debug logging")
return parser.parse_args()
def main() -> None:
"""Entry point for the transcription CLI."""
args = parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
config = TranscriberConfig(
input_path=args.input,
output_dir=args.output,
model=args.model,
device=args.device,
compute_type=args.compute_type,
language=args.language,
beam_size=args.beam_size,
vad=args.vad == "on",
max_speakers=args.max_speakers,
min_speakers=args.min_speakers,
formats=args.format,
pause_threshold=args.pause_threshold,
chunk_duration=args.chunk_duration,
hf_token=args.hf_token,
)
pipeline = TranscriptionPipeline(config)
try:
exported = pipeline.run()
for path in exported:
print(f"Saved: {path}")
except Exception:
logging.exception("Pipeline failed")
sys.exit(1)
if __name__ == "__main__":
main()

View File

View File

@ -0,0 +1,101 @@
import logging
from dataclasses import dataclass
from transcriber.asr.whisper_engine import Segment
from transcriber.diarization.pyannote_engine import SpeakerTurn
logger = logging.getLogger(__name__)
@dataclass
class MergedSegment:
"""Final segment with speaker label and merged text."""
speaker: str
start: float
end: float
text: str
def _compute_overlap(seg_start: float, seg_end: float, turn_start: float, turn_end: float) -> float:
"""Compute temporal overlap between two intervals in seconds."""
overlap_start = max(seg_start, turn_start)
overlap_end = min(seg_end, turn_end)
return max(0.0, overlap_end - overlap_start)
def _assign_speaker(segment: Segment, speaker_turns: list[SpeakerTurn]) -> str:
"""Assign speaker to ASR segment by maximum overlap."""
best_speaker = "Unknown"
best_overlap = 0.0
for turn in speaker_turns:
if turn.end < segment.start:
continue
if turn.start > segment.end:
break
overlap = _compute_overlap(segment.start, segment.end, turn.start, turn.end)
if overlap > best_overlap:
best_overlap = overlap
best_speaker = turn.speaker
return best_speaker
def _normalize_speaker_labels(segments: list[MergedSegment]) -> list[MergedSegment]:
"""Replace raw pyannote labels (SPEAKER_00) with sequential Speaker 1, Speaker 2."""
label_map: dict[str, str] = {}
counter = 1
for seg in segments:
if seg.speaker not in label_map and seg.speaker != "Unknown":
label_map[seg.speaker] = f"Speaker {counter}"
counter += 1
for seg in segments:
seg.speaker = label_map.get(seg.speaker, seg.speaker)
return segments
def align_and_merge(
asr_segments: list[Segment],
speaker_turns: list[SpeakerTurn],
pause_threshold: float = 1.5,
) -> list[MergedSegment]:
"""Align ASR segments with speaker turns and merge adjacent same-speaker segments.
Args:
asr_segments: Segments from Whisper ASR.
speaker_turns: Speaker turns from diarization (sorted by start).
pause_threshold: Max pause between segments to merge (seconds).
Returns:
List of merged segments with speaker labels.
"""
if not asr_segments:
return []
aligned = []
for seg in asr_segments:
speaker = _assign_speaker(seg, speaker_turns)
aligned.append(MergedSegment(
speaker=speaker,
start=seg.start,
end=seg.end,
text=seg.text,
))
merged: list[MergedSegment] = [aligned[0]]
for seg in aligned[1:]:
prev = merged[-1]
gap = seg.start - prev.end
if seg.speaker == prev.speaker and gap <= pause_threshold:
prev.end = seg.end
prev.text = f"{prev.text} {seg.text}"
else:
merged.append(seg)
merged = _normalize_speaker_labels(merged)
logger.info("Alignment complete: %d ASR segments -> %d merged", len(asr_segments), len(merged))
return merged

136
transcriber/pipeline.py Normal file
View File

@ -0,0 +1,136 @@
import logging
from pathlib import Path
from tqdm import tqdm
from transcriber.asr.whisper_engine import Segment, WhisperEngine
from transcriber.audio.chunking import ChunkInfo, chunk_audio
from transcriber.audio.preprocess import preprocess_audio
from transcriber.config import TranscriberConfig
from transcriber.diarization.pyannote_engine import DiarizationEngine
from transcriber.export.json_writer import write_json
from transcriber.export.txt_writer import write_txt
from transcriber.merge.aligner import MergedSegment, align_and_merge
logger = logging.getLogger(__name__)
EXPORTERS = {
"txt": write_txt,
"json": write_json,
}
class TranscriptionPipeline:
"""Orchestrates the full transcription pipeline: preprocess -> ASR -> diarize -> merge -> export."""
def __init__(self, config: TranscriberConfig):
self._config = config
self._asr: WhisperEngine | None = None
self._diarizer: DiarizationEngine | None = None
def _init_engines(self) -> None:
"""Lazily initialize ASR and diarization engines."""
if self._asr is None:
self._asr = WhisperEngine(
model_name=self._config.model,
device=self._config.device,
compute_type=self._config.compute_type,
)
if self._diarizer is None:
self._diarizer = DiarizationEngine(
hf_token=self._config.hf_token,
device=self._config.device,
)
def _transcribe_chunks(self, chunks: list[ChunkInfo], progress: tqdm) -> list[Segment]:
"""Run ASR on each chunk, adjusting timestamps by chunk offset."""
all_segments: list[Segment] = []
for chunk in chunks:
progress.set_description(f"ASR chunk {chunk.start_offset:.0f}s")
segments = self._asr.transcribe(
audio_path=chunk.path,
language=self._config.language,
beam_size=self._config.beam_size,
vad_filter=self._config.vad,
)
for seg in segments:
seg.start += chunk.start_offset
seg.end += chunk.start_offset
for w in seg.words:
w.start += chunk.start_offset
w.end += chunk.start_offset
all_segments.extend(segments)
progress.update(1)
return all_segments
def _export(self, segments: list[MergedSegment], stem: str) -> list[str]:
"""Export segments to requested formats."""
output_dir = Path(self._config.output_dir)
exported = []
for fmt in self._config.formats:
exporter = EXPORTERS.get(fmt)
if exporter is None:
logger.warning("Unknown export format: %s (skipped)", fmt)
continue
out_path = str(output_dir / f"{stem}.{fmt}")
exporter(segments, out_path)
exported.append(out_path)
logger.info("Exported: %s", out_path)
return exported
def run(self) -> list[str]:
"""Execute the full pipeline and return list of exported file paths.
Returns:
List of paths to exported files.
Raises:
FileNotFoundError: If input file does not exist.
RuntimeError: If any pipeline stage fails.
"""
cfg = self._config
stem = Path(cfg.input_path).stem
total_steps = 7
progress = tqdm(total=total_steps, desc="Pipeline", unit="step")
progress.set_description("Preprocessing")
wav_path = preprocess_audio(cfg.input_path, cfg.output_dir)
progress.update(1)
progress.set_description("Chunking")
chunks = chunk_audio(wav_path, cfg.chunk_duration)
progress.update(1)
progress.set_description("Loading models")
self._init_engines()
progress.update(1)
asr_progress = tqdm(total=len(chunks), desc="ASR", unit="chunk", leave=False)
asr_segments = self._transcribe_chunks(chunks, asr_progress)
asr_progress.close()
progress.update(1)
progress.set_description("Diarizing")
speaker_turns = self._diarizer.diarize(
audio_path=wav_path,
min_speakers=cfg.min_speakers,
max_speakers=cfg.max_speakers,
)
progress.update(1)
progress.set_description("Aligning")
merged = align_and_merge(
asr_segments=asr_segments,
speaker_turns=speaker_turns,
pause_threshold=cfg.pause_threshold,
)
progress.update(1)
progress.set_description("Exporting")
exported = self._export(merged, stem)
progress.update(1)
progress.close()
logger.info("Pipeline complete. Files: %s", exported)
return exported