transcribe-interview/transcriber/diarization/pyannote_engine.py

70 lines
2.0 KiB
Python

import logging
from dataclasses import dataclass
import torch
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
logger = logging.getLogger(__name__)
@dataclass
class SpeakerTurn:
"""A single speaker turn with timing."""
start: float
end: float
speaker: str
class DiarizationEngine:
"""Speaker diarization engine based on pyannote.audio."""
def __init__(self, hf_token: str, device: str):
logger.info("Loading diarization pipeline on %s", device)
self._pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
token=hf_token,
)
self._device = torch.device(device)
self._pipeline.to(self._device)
def diarize(
self,
audio_path: str,
min_speakers: int | None = None,
max_speakers: int | None = None,
) -> list[SpeakerTurn]:
"""Run speaker diarization on audio file.
Args:
audio_path: Path to WAV file.
min_speakers: Minimum expected number of speakers.
max_speakers: Maximum expected number of speakers.
Returns:
List of speaker turns sorted by start time.
"""
logger.info("Diarizing: %s", audio_path)
kwargs = {}
if min_speakers is not None:
kwargs["min_speakers"] = min_speakers
if max_speakers is not None:
kwargs["max_speakers"] = max_speakers
with ProgressHook() as hook:
diarization = self._pipeline(audio_path, hook=hook, **kwargs)
turns = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
turns.append(SpeakerTurn(
start=turn.start,
end=turn.end,
speaker=speaker,
))
speaker_set = {t.speaker for t in turns}
logger.info("Diarization complete: %d turns, %d speakers", len(turns), len(speaker_set))
return turns