70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from pyannote.audio import Pipeline
|
|
from pyannote.audio.pipelines.utils.hook import ProgressHook
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SpeakerTurn:
|
|
"""A single speaker turn with timing."""
|
|
|
|
start: float
|
|
end: float
|
|
speaker: str
|
|
|
|
|
|
class DiarizationEngine:
|
|
"""Speaker diarization engine based on pyannote.audio."""
|
|
|
|
def __init__(self, hf_token: str, device: str):
|
|
logger.info("Loading diarization pipeline on %s", device)
|
|
self._pipeline = Pipeline.from_pretrained(
|
|
"pyannote/speaker-diarization-community-1",
|
|
token=hf_token,
|
|
)
|
|
self._device = torch.device(device)
|
|
self._pipeline.to(self._device)
|
|
|
|
def diarize(
|
|
self,
|
|
audio_path: str,
|
|
min_speakers: int | None = None,
|
|
max_speakers: int | None = None,
|
|
) -> list[SpeakerTurn]:
|
|
"""Run speaker diarization on audio file.
|
|
|
|
Args:
|
|
audio_path: Path to WAV file.
|
|
min_speakers: Minimum expected number of speakers.
|
|
max_speakers: Maximum expected number of speakers.
|
|
|
|
Returns:
|
|
List of speaker turns sorted by start time.
|
|
"""
|
|
logger.info("Diarizing: %s", audio_path)
|
|
|
|
kwargs = {}
|
|
if min_speakers is not None:
|
|
kwargs["min_speakers"] = min_speakers
|
|
if max_speakers is not None:
|
|
kwargs["max_speakers"] = max_speakers
|
|
|
|
with ProgressHook() as hook:
|
|
diarization = self._pipeline(audio_path, hook=hook, **kwargs)
|
|
|
|
turns = []
|
|
for turn, speaker in diarization.exclusive_speaker_diarization:
|
|
turns.append(SpeakerTurn(
|
|
start=turn.start,
|
|
end=turn.end,
|
|
speaker=speaker,
|
|
))
|
|
|
|
speaker_set = {t.speaker for t in turns}
|
|
logger.info("Diarization complete: %d turns, %d speakers", len(turns), len(speaker_set))
|
|
return turns
|