import logging from dataclasses import dataclass import soundfile as sf import torch from pyannote.audio import Pipeline from pyannote.audio.pipelines.utils.hook import ProgressHook logger = logging.getLogger(__name__) @dataclass class SpeakerTurn: """A single speaker turn with timing.""" start: float end: float speaker: str class DiarizationEngine: """Speaker diarization engine based on pyannote.audio.""" def __init__(self, hf_token: str, device: str): logger.info("Loading diarization pipeline on %s", device) self._pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-community-1", token=hf_token, ) self._device = torch.device(device) self._pipeline.to(self._device) def diarize( self, audio_path: str, min_speakers: int | None = None, max_speakers: int | None = None, ) -> list[SpeakerTurn]: """Run speaker diarization on audio file. Args: audio_path: Path to WAV file. min_speakers: Minimum expected number of speakers. max_speakers: Maximum expected number of speakers. Returns: List of speaker turns sorted by start time. """ logger.info("Diarizing: %s", audio_path) data, sample_rate = sf.read(audio_path, dtype="float32") waveform = torch.from_numpy(data).unsqueeze(0) audio_input = {"waveform": waveform, "sample_rate": sample_rate} kwargs = {} if min_speakers is not None: kwargs["min_speakers"] = min_speakers if max_speakers is not None: kwargs["max_speakers"] = max_speakers with ProgressHook() as hook: diarization = self._pipeline(audio_input, hook=hook, **kwargs) turns = [] for turn, speaker in diarization.exclusive_speaker_diarization: turns.append(SpeakerTurn( start=turn.start, end=turn.end, speaker=speaker, )) speaker_set = {t.speaker for t in turns} logger.info("Diarization complete: %d turns, %d speakers", len(turns), len(speaker_set)) return turns