first version
This commit is contained in:
parent
3400a8070a
commit
de07a045ce
|
|
@ -0,0 +1,5 @@
|
||||||
|
faster-whisper
|
||||||
|
pyannote.audio
|
||||||
|
python-dotenv
|
||||||
|
pydub
|
||||||
|
tqdm
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from transcriber.main import main
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WordInfo:
|
||||||
|
"""Single word with timing and confidence."""
|
||||||
|
|
||||||
|
word: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
probability: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Segment:
|
||||||
|
"""ASR segment with text, timing, and optional word-level detail."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
words: list[WordInfo] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperEngine:
|
||||||
|
"""Speech recognition engine based on faster-whisper."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, device: str, compute_type: str):
|
||||||
|
logger.info("Loading Whisper model: %s on %s (%s)", model_name, device, compute_type)
|
||||||
|
self._model = WhisperModel(model_name, device=device, compute_type=compute_type)
|
||||||
|
|
||||||
|
def transcribe(
|
||||||
|
self,
|
||||||
|
audio_path: str,
|
||||||
|
language: str | None = None,
|
||||||
|
beam_size: int = 5,
|
||||||
|
vad_filter: bool = True,
|
||||||
|
) -> list[Segment]:
|
||||||
|
"""Transcribe audio file and return list of segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: Path to WAV file.
|
||||||
|
language: Language code or None for auto-detection.
|
||||||
|
beam_size: Beam search width.
|
||||||
|
vad_filter: Whether to enable VAD filtering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of transcription segments with word-level timestamps.
|
||||||
|
"""
|
||||||
|
logger.info("Transcribing: %s", audio_path)
|
||||||
|
|
||||||
|
segments_gen, info = self._model.transcribe(
|
||||||
|
audio_path,
|
||||||
|
language=language,
|
||||||
|
beam_size=beam_size,
|
||||||
|
word_timestamps=True,
|
||||||
|
vad_filter=vad_filter,
|
||||||
|
vad_parameters={"min_silence_duration_ms": 500},
|
||||||
|
temperature=0.0,
|
||||||
|
condition_on_previous_text=False,
|
||||||
|
no_speech_threshold=0.6,
|
||||||
|
log_prob_threshold=-1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Detected language: %s (%.2f), duration: %.1fs",
|
||||||
|
info.language, info.language_probability, info.duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for seg in segments_gen:
|
||||||
|
words = [
|
||||||
|
WordInfo(
|
||||||
|
word=w.word.strip(),
|
||||||
|
start=w.start,
|
||||||
|
end=w.end,
|
||||||
|
probability=w.probability,
|
||||||
|
)
|
||||||
|
for w in (seg.words or [])
|
||||||
|
]
|
||||||
|
results.append(Segment(
|
||||||
|
start=seg.start,
|
||||||
|
end=seg.end,
|
||||||
|
text=seg.text.strip(),
|
||||||
|
words=words,
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info("Transcription complete: %d segments", len(results))
|
||||||
|
return results
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkInfo:
|
||||||
|
"""Metadata for a single audio chunk."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
start_offset: float
|
||||||
|
duration: float
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_duration(wav_path: str) -> float:
|
||||||
|
"""Get duration of audio file in seconds using ffprobe."""
|
||||||
|
cmd = [
|
||||||
|
"ffprobe", "-v", "quiet",
|
||||||
|
"-show_entries", "format=duration",
|
||||||
|
"-of", "csv=p=0",
|
||||||
|
wav_path,
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"ffprobe failed: {result.stderr[:300]}")
|
||||||
|
return float(result.stdout.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_audio(wav_path: str, max_duration_sec: int = 1800) -> list[ChunkInfo]:
|
||||||
|
"""Split audio into chunks if longer than max_duration_sec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_path: Path to the preprocessed WAV file.
|
||||||
|
max_duration_sec: Maximum chunk duration in seconds (default 30 min).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ChunkInfo with paths and timing metadata.
|
||||||
|
"""
|
||||||
|
total_duration = get_audio_duration(wav_path)
|
||||||
|
logger.info("Audio duration: %.1f sec", total_duration)
|
||||||
|
|
||||||
|
if total_duration <= max_duration_sec:
|
||||||
|
return [ChunkInfo(path=wav_path, start_offset=0.0, duration=total_duration)]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
src = Path(wav_path)
|
||||||
|
chunk_dir = src.parent / "chunks"
|
||||||
|
chunk_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
offset = 0.0
|
||||||
|
idx = 0
|
||||||
|
while offset < total_duration:
|
||||||
|
chunk_path = str(chunk_dir / f"{src.stem}_chunk{idx:03d}.wav")
|
||||||
|
remaining = total_duration - offset
|
||||||
|
duration = min(max_duration_sec, remaining)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-ss", str(offset),
|
||||||
|
"-i", wav_path,
|
||||||
|
"-t", str(duration),
|
||||||
|
"-c", "copy",
|
||||||
|
chunk_path,
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Chunk {idx} failed: {result.stderr[:300]}")
|
||||||
|
|
||||||
|
chunks.append(ChunkInfo(
|
||||||
|
path=chunk_path,
|
||||||
|
start_offset=offset,
|
||||||
|
duration=duration,
|
||||||
|
))
|
||||||
|
logger.info("Chunk %d: %.1fs - %.1fs", idx, offset, offset + duration)
|
||||||
|
|
||||||
|
offset += duration
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SUPPORTED_FORMATS = {".m4a", ".mp3", ".wav", ".aac"}
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audio(input_path: str, output_dir: str) -> str:
|
||||||
|
"""Convert audio to mono 16kHz PCM WAV with normalization and DC offset removal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to the source audio file.
|
||||||
|
output_dir: Directory for the processed file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the processed WAV file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If input file does not exist.
|
||||||
|
ValueError: If file format is not supported.
|
||||||
|
RuntimeError: If ffmpeg processing fails.
|
||||||
|
"""
|
||||||
|
src = Path(input_path)
|
||||||
|
if not src.exists():
|
||||||
|
raise FileNotFoundError(f"Audio file not found: {input_path}")
|
||||||
|
if src.suffix.lower() not in SUPPORTED_FORMATS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported format: {src.suffix}. Supported: {SUPPORTED_FORMATS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = Path(output_dir) / f"{src.stem}_processed.wav"
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y", "-i", str(src),
|
||||||
|
"-ac", "1",
|
||||||
|
"-ar", "16000",
|
||||||
|
"-sample_fmt", "s16",
|
||||||
|
"-af", "highpass=f=10,loudnorm=I=-16:TP=-1.5:LRA=11",
|
||||||
|
str(out),
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Preprocessing: %s -> %s", src.name, out.name)
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd, capture_output=True, text=True, timeout=600
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"ffmpeg failed: {result.stderr[:500]}")
|
||||||
|
|
||||||
|
logger.info("Preprocessing complete: %s", out.name)
|
||||||
|
return str(out)
|
||||||
|
|
@ -0,0 +1,38 @@
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TranscriberConfig:
|
||||||
|
"""Configuration for the transcription pipeline."""
|
||||||
|
|
||||||
|
input_path: str = ""
|
||||||
|
output_dir: str = "./output"
|
||||||
|
model: str = "large-v3"
|
||||||
|
device: str = "cuda"
|
||||||
|
compute_type: str = "float16"
|
||||||
|
language: str = "ru"
|
||||||
|
beam_size: int = 5
|
||||||
|
vad: bool = True
|
||||||
|
max_speakers: int | None = None
|
||||||
|
min_speakers: int | None = None
|
||||||
|
formats: list[str] = field(default_factory=lambda: ["txt", "json"])
|
||||||
|
pause_threshold: float = 1.5
|
||||||
|
chunk_duration: int = 1800
|
||||||
|
hf_token: str = ""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
load_dotenv()
|
||||||
|
if not self.hf_token:
|
||||||
|
self.hf_token = os.getenv("HF_TOKEN", "")
|
||||||
|
if not self.hf_token:
|
||||||
|
raise ValueError(
|
||||||
|
"HF_TOKEN is required for pyannote diarization. "
|
||||||
|
"Set it in .env or pass via --hf-token"
|
||||||
|
)
|
||||||
|
if self.device == "cpu":
|
||||||
|
self.compute_type = "int8"
|
||||||
|
Path(self.output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
from pyannote.audio.pipelines.utils.hook import ProgressHook
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeakerTurn:
|
||||||
|
"""A single speaker turn with timing."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
speaker: str
|
||||||
|
|
||||||
|
|
||||||
|
class DiarizationEngine:
|
||||||
|
"""Speaker diarization engine based on pyannote.audio."""
|
||||||
|
|
||||||
|
def __init__(self, hf_token: str, device: str):
|
||||||
|
logger.info("Loading diarization pipeline on %s", device)
|
||||||
|
self._pipeline = Pipeline.from_pretrained(
|
||||||
|
"pyannote/speaker-diarization-3.1",
|
||||||
|
token=hf_token,
|
||||||
|
)
|
||||||
|
self._device = torch.device(device)
|
||||||
|
self._pipeline.to(self._device)
|
||||||
|
|
||||||
|
def diarize(
|
||||||
|
self,
|
||||||
|
audio_path: str,
|
||||||
|
min_speakers: int | None = None,
|
||||||
|
max_speakers: int | None = None,
|
||||||
|
) -> list[SpeakerTurn]:
|
||||||
|
"""Run speaker diarization on audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: Path to WAV file.
|
||||||
|
min_speakers: Minimum expected number of speakers.
|
||||||
|
max_speakers: Maximum expected number of speakers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of speaker turns sorted by start time.
|
||||||
|
"""
|
||||||
|
logger.info("Diarizing: %s", audio_path)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if min_speakers is not None:
|
||||||
|
kwargs["min_speakers"] = min_speakers
|
||||||
|
if max_speakers is not None:
|
||||||
|
kwargs["max_speakers"] = max_speakers
|
||||||
|
|
||||||
|
with ProgressHook() as hook:
|
||||||
|
diarization = self._pipeline(audio_path, hook=hook, **kwargs)
|
||||||
|
|
||||||
|
turns = []
|
||||||
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
turns.append(SpeakerTurn(
|
||||||
|
start=turn.start,
|
||||||
|
end=turn.end,
|
||||||
|
speaker=speaker,
|
||||||
|
))
|
||||||
|
|
||||||
|
speaker_set = {t.speaker for t in turns}
|
||||||
|
logger.info("Diarization complete: %d turns, %d speakers", len(turns), len(speaker_set))
|
||||||
|
return turns
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transcriber.merge.aligner import MergedSegment
|
||||||
|
|
||||||
|
|
||||||
|
def write_json(segments: list[MergedSegment], output_path: str) -> str:
|
||||||
|
"""Export merged segments as a structured JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments: List of merged speaker segments.
|
||||||
|
output_path: Path to the output .json file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the written file.
|
||||||
|
"""
|
||||||
|
path = Path(output_path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"speaker": seg.speaker,
|
||||||
|
"start": round(seg.start, 2),
|
||||||
|
"end": round(seg.end, 2),
|
||||||
|
"text": seg.text,
|
||||||
|
}
|
||||||
|
for seg in segments
|
||||||
|
]
|
||||||
|
|
||||||
|
path.write_text(
|
||||||
|
json.dumps(data, ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return str(path)
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transcriber.merge.aligner import MergedSegment
|
||||||
|
|
||||||
|
|
||||||
|
def write_txt(segments: list[MergedSegment], output_path: str) -> str:
|
||||||
|
"""Export merged segments as a readable dialogue text file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments: List of merged speaker segments.
|
||||||
|
output_path: Path to the output .txt file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the written file.
|
||||||
|
"""
|
||||||
|
path = Path(output_path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for seg in segments:
|
||||||
|
lines.append(f"[{seg.speaker}]")
|
||||||
|
lines.append(seg.text)
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
path.write_text("\n".join(lines), encoding="utf-8")
|
||||||
|
return str(path)
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from transcriber.config import TranscriberConfig
|
||||||
|
from transcriber.pipeline import TranscriptionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
"""Parse command-line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Transcribe audio with speaker diarization",
|
||||||
|
)
|
||||||
|
parser.add_argument("input", help="Path to audio file (.m4a, .mp3, .wav, .aac)")
|
||||||
|
parser.add_argument("--output", default="./output", help="Output directory (default: ./output)")
|
||||||
|
parser.add_argument("--model", default="large-v3", help="Whisper model name (default: large-v3)")
|
||||||
|
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="Device (default: cuda)")
|
||||||
|
parser.add_argument("--compute-type", default="float16", help="Compute type (default: float16)")
|
||||||
|
parser.add_argument("--language", default="ru", help="Language code (default: ru)")
|
||||||
|
parser.add_argument("--beam-size", type=int, default=5, help="Beam search size (default: 5)")
|
||||||
|
parser.add_argument("--vad", default="on", choices=["on", "off"], help="VAD filter (default: on)")
|
||||||
|
parser.add_argument("--max-speakers", type=int, default=None, help="Maximum number of speakers")
|
||||||
|
parser.add_argument("--min-speakers", type=int, default=None, help="Minimum number of speakers")
|
||||||
|
parser.add_argument("--format", nargs="+", default=["txt", "json"], help="Output formats (default: txt json)")
|
||||||
|
parser.add_argument("--pause-threshold", type=float, default=1.5, help="Max pause for merging (default: 1.5s)")
|
||||||
|
parser.add_argument("--chunk-duration", type=int, default=1800, help="Max chunk duration in sec (default: 1800)")
|
||||||
|
parser.add_argument("--hf-token", default="", help="HuggingFace token (default: from .env)")
|
||||||
|
parser.add_argument("--verbose", "-v", action="store_true", help="Enable debug logging")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Entry point for the transcription CLI."""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = TranscriberConfig(
|
||||||
|
input_path=args.input,
|
||||||
|
output_dir=args.output,
|
||||||
|
model=args.model,
|
||||||
|
device=args.device,
|
||||||
|
compute_type=args.compute_type,
|
||||||
|
language=args.language,
|
||||||
|
beam_size=args.beam_size,
|
||||||
|
vad=args.vad == "on",
|
||||||
|
max_speakers=args.max_speakers,
|
||||||
|
min_speakers=args.min_speakers,
|
||||||
|
formats=args.format,
|
||||||
|
pause_threshold=args.pause_threshold,
|
||||||
|
chunk_duration=args.chunk_duration,
|
||||||
|
hf_token=args.hf_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = TranscriptionPipeline(config)
|
||||||
|
try:
|
||||||
|
exported = pipeline.run()
|
||||||
|
for path in exported:
|
||||||
|
print(f"Saved: {path}")
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Pipeline failed")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from transcriber.asr.whisper_engine import Segment
|
||||||
|
from transcriber.diarization.pyannote_engine import SpeakerTurn
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MergedSegment:
|
||||||
|
"""Final segment with speaker label and merged text."""
|
||||||
|
|
||||||
|
speaker: str
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_overlap(seg_start: float, seg_end: float, turn_start: float, turn_end: float) -> float:
|
||||||
|
"""Compute temporal overlap between two intervals in seconds."""
|
||||||
|
overlap_start = max(seg_start, turn_start)
|
||||||
|
overlap_end = min(seg_end, turn_end)
|
||||||
|
return max(0.0, overlap_end - overlap_start)
|
||||||
|
|
||||||
|
|
||||||
|
def _assign_speaker(segment: Segment, speaker_turns: list[SpeakerTurn]) -> str:
|
||||||
|
"""Assign speaker to ASR segment by maximum overlap."""
|
||||||
|
best_speaker = "Unknown"
|
||||||
|
best_overlap = 0.0
|
||||||
|
|
||||||
|
for turn in speaker_turns:
|
||||||
|
if turn.end < segment.start:
|
||||||
|
continue
|
||||||
|
if turn.start > segment.end:
|
||||||
|
break
|
||||||
|
overlap = _compute_overlap(segment.start, segment.end, turn.start, turn.end)
|
||||||
|
if overlap > best_overlap:
|
||||||
|
best_overlap = overlap
|
||||||
|
best_speaker = turn.speaker
|
||||||
|
|
||||||
|
return best_speaker
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_speaker_labels(segments: list[MergedSegment]) -> list[MergedSegment]:
|
||||||
|
"""Replace raw pyannote labels (SPEAKER_00) with sequential Speaker 1, Speaker 2."""
|
||||||
|
label_map: dict[str, str] = {}
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
for seg in segments:
|
||||||
|
if seg.speaker not in label_map and seg.speaker != "Unknown":
|
||||||
|
label_map[seg.speaker] = f"Speaker {counter}"
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
for seg in segments:
|
||||||
|
seg.speaker = label_map.get(seg.speaker, seg.speaker)
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def align_and_merge(
|
||||||
|
asr_segments: list[Segment],
|
||||||
|
speaker_turns: list[SpeakerTurn],
|
||||||
|
pause_threshold: float = 1.5,
|
||||||
|
) -> list[MergedSegment]:
|
||||||
|
"""Align ASR segments with speaker turns and merge adjacent same-speaker segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asr_segments: Segments from Whisper ASR.
|
||||||
|
speaker_turns: Speaker turns from diarization (sorted by start).
|
||||||
|
pause_threshold: Max pause between segments to merge (seconds).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of merged segments with speaker labels.
|
||||||
|
"""
|
||||||
|
if not asr_segments:
|
||||||
|
return []
|
||||||
|
|
||||||
|
aligned = []
|
||||||
|
for seg in asr_segments:
|
||||||
|
speaker = _assign_speaker(seg, speaker_turns)
|
||||||
|
aligned.append(MergedSegment(
|
||||||
|
speaker=speaker,
|
||||||
|
start=seg.start,
|
||||||
|
end=seg.end,
|
||||||
|
text=seg.text,
|
||||||
|
))
|
||||||
|
|
||||||
|
merged: list[MergedSegment] = [aligned[0]]
|
||||||
|
for seg in aligned[1:]:
|
||||||
|
prev = merged[-1]
|
||||||
|
gap = seg.start - prev.end
|
||||||
|
if seg.speaker == prev.speaker and gap <= pause_threshold:
|
||||||
|
prev.end = seg.end
|
||||||
|
prev.text = f"{prev.text} {seg.text}"
|
||||||
|
else:
|
||||||
|
merged.append(seg)
|
||||||
|
|
||||||
|
merged = _normalize_speaker_labels(merged)
|
||||||
|
logger.info("Alignment complete: %d ASR segments -> %d merged", len(asr_segments), len(merged))
|
||||||
|
return merged
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from transcriber.asr.whisper_engine import Segment, WhisperEngine
|
||||||
|
from transcriber.audio.chunking import ChunkInfo, chunk_audio
|
||||||
|
from transcriber.audio.preprocess import preprocess_audio
|
||||||
|
from transcriber.config import TranscriberConfig
|
||||||
|
from transcriber.diarization.pyannote_engine import DiarizationEngine
|
||||||
|
from transcriber.export.json_writer import write_json
|
||||||
|
from transcriber.export.txt_writer import write_txt
|
||||||
|
from transcriber.merge.aligner import MergedSegment, align_and_merge
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EXPORTERS = {
|
||||||
|
"txt": write_txt,
|
||||||
|
"json": write_json,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionPipeline:
|
||||||
|
"""Orchestrates the full transcription pipeline: preprocess -> ASR -> diarize -> merge -> export."""
|
||||||
|
|
||||||
|
def __init__(self, config: TranscriberConfig):
|
||||||
|
self._config = config
|
||||||
|
self._asr: WhisperEngine | None = None
|
||||||
|
self._diarizer: DiarizationEngine | None = None
|
||||||
|
|
||||||
|
def _init_engines(self) -> None:
|
||||||
|
"""Lazily initialize ASR and diarization engines."""
|
||||||
|
if self._asr is None:
|
||||||
|
self._asr = WhisperEngine(
|
||||||
|
model_name=self._config.model,
|
||||||
|
device=self._config.device,
|
||||||
|
compute_type=self._config.compute_type,
|
||||||
|
)
|
||||||
|
if self._diarizer is None:
|
||||||
|
self._diarizer = DiarizationEngine(
|
||||||
|
hf_token=self._config.hf_token,
|
||||||
|
device=self._config.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transcribe_chunks(self, chunks: list[ChunkInfo], progress: tqdm) -> list[Segment]:
|
||||||
|
"""Run ASR on each chunk, adjusting timestamps by chunk offset."""
|
||||||
|
all_segments: list[Segment] = []
|
||||||
|
for chunk in chunks:
|
||||||
|
progress.set_description(f"ASR chunk {chunk.start_offset:.0f}s")
|
||||||
|
segments = self._asr.transcribe(
|
||||||
|
audio_path=chunk.path,
|
||||||
|
language=self._config.language,
|
||||||
|
beam_size=self._config.beam_size,
|
||||||
|
vad_filter=self._config.vad,
|
||||||
|
)
|
||||||
|
for seg in segments:
|
||||||
|
seg.start += chunk.start_offset
|
||||||
|
seg.end += chunk.start_offset
|
||||||
|
for w in seg.words:
|
||||||
|
w.start += chunk.start_offset
|
||||||
|
w.end += chunk.start_offset
|
||||||
|
all_segments.extend(segments)
|
||||||
|
progress.update(1)
|
||||||
|
return all_segments
|
||||||
|
|
||||||
|
def _export(self, segments: list[MergedSegment], stem: str) -> list[str]:
|
||||||
|
"""Export segments to requested formats."""
|
||||||
|
output_dir = Path(self._config.output_dir)
|
||||||
|
exported = []
|
||||||
|
for fmt in self._config.formats:
|
||||||
|
exporter = EXPORTERS.get(fmt)
|
||||||
|
if exporter is None:
|
||||||
|
logger.warning("Unknown export format: %s (skipped)", fmt)
|
||||||
|
continue
|
||||||
|
out_path = str(output_dir / f"{stem}.{fmt}")
|
||||||
|
exporter(segments, out_path)
|
||||||
|
exported.append(out_path)
|
||||||
|
logger.info("Exported: %s", out_path)
|
||||||
|
return exported
|
||||||
|
|
||||||
|
def run(self) -> list[str]:
|
||||||
|
"""Execute the full pipeline and return list of exported file paths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of paths to exported files.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If input file does not exist.
|
||||||
|
RuntimeError: If any pipeline stage fails.
|
||||||
|
"""
|
||||||
|
cfg = self._config
|
||||||
|
stem = Path(cfg.input_path).stem
|
||||||
|
|
||||||
|
total_steps = 7
|
||||||
|
progress = tqdm(total=total_steps, desc="Pipeline", unit="step")
|
||||||
|
|
||||||
|
progress.set_description("Preprocessing")
|
||||||
|
wav_path = preprocess_audio(cfg.input_path, cfg.output_dir)
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
progress.set_description("Chunking")
|
||||||
|
chunks = chunk_audio(wav_path, cfg.chunk_duration)
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
progress.set_description("Loading models")
|
||||||
|
self._init_engines()
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
asr_progress = tqdm(total=len(chunks), desc="ASR", unit="chunk", leave=False)
|
||||||
|
asr_segments = self._transcribe_chunks(chunks, asr_progress)
|
||||||
|
asr_progress.close()
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
progress.set_description("Diarizing")
|
||||||
|
speaker_turns = self._diarizer.diarize(
|
||||||
|
audio_path=wav_path,
|
||||||
|
min_speakers=cfg.min_speakers,
|
||||||
|
max_speakers=cfg.max_speakers,
|
||||||
|
)
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
progress.set_description("Aligning")
|
||||||
|
merged = align_and_merge(
|
||||||
|
asr_segments=asr_segments,
|
||||||
|
speaker_turns=speaker_turns,
|
||||||
|
pause_threshold=cfg.pause_threshold,
|
||||||
|
)
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
progress.set_description("Exporting")
|
||||||
|
exported = self._export(merged, stem)
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
progress.close()
|
||||||
|
logger.info("Pipeline complete. Files: %s", exported)
|
||||||
|
return exported
|
||||||
Loading…
Reference in New Issue