102 lines
3.0 KiB
Python
102 lines
3.0 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
|
|
from transcriber.asr.whisper_engine import Segment
|
|
from transcriber.diarization.pyannote_engine import SpeakerTurn
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class MergedSegment:
|
|
"""Final segment with speaker label and merged text."""
|
|
|
|
speaker: str
|
|
start: float
|
|
end: float
|
|
text: str
|
|
|
|
|
|
def _compute_overlap(seg_start: float, seg_end: float, turn_start: float, turn_end: float) -> float:
|
|
"""Compute temporal overlap between two intervals in seconds."""
|
|
overlap_start = max(seg_start, turn_start)
|
|
overlap_end = min(seg_end, turn_end)
|
|
return max(0.0, overlap_end - overlap_start)
|
|
|
|
|
|
def _assign_speaker(segment: Segment, speaker_turns: list[SpeakerTurn]) -> str:
|
|
"""Assign speaker to ASR segment by maximum overlap."""
|
|
best_speaker = "Unknown"
|
|
best_overlap = 0.0
|
|
|
|
for turn in speaker_turns:
|
|
if turn.end < segment.start:
|
|
continue
|
|
if turn.start > segment.end:
|
|
break
|
|
overlap = _compute_overlap(segment.start, segment.end, turn.start, turn.end)
|
|
if overlap > best_overlap:
|
|
best_overlap = overlap
|
|
best_speaker = turn.speaker
|
|
|
|
return best_speaker
|
|
|
|
|
|
def _normalize_speaker_labels(segments: list[MergedSegment]) -> list[MergedSegment]:
|
|
"""Replace raw pyannote labels (SPEAKER_00) with sequential Speaker 1, Speaker 2."""
|
|
label_map: dict[str, str] = {}
|
|
counter = 1
|
|
|
|
for seg in segments:
|
|
if seg.speaker not in label_map and seg.speaker != "Unknown":
|
|
label_map[seg.speaker] = f"Speaker {counter}"
|
|
counter += 1
|
|
|
|
for seg in segments:
|
|
seg.speaker = label_map.get(seg.speaker, seg.speaker)
|
|
|
|
return segments
|
|
|
|
|
|
def align_and_merge(
|
|
asr_segments: list[Segment],
|
|
speaker_turns: list[SpeakerTurn],
|
|
pause_threshold: float = 1.5,
|
|
) -> list[MergedSegment]:
|
|
"""Align ASR segments with speaker turns and merge adjacent same-speaker segments.
|
|
|
|
Args:
|
|
asr_segments: Segments from Whisper ASR.
|
|
speaker_turns: Speaker turns from diarization (sorted by start).
|
|
pause_threshold: Max pause between segments to merge (seconds).
|
|
|
|
Returns:
|
|
List of merged segments with speaker labels.
|
|
"""
|
|
if not asr_segments:
|
|
return []
|
|
|
|
aligned = []
|
|
for seg in asr_segments:
|
|
speaker = _assign_speaker(seg, speaker_turns)
|
|
aligned.append(MergedSegment(
|
|
speaker=speaker,
|
|
start=seg.start,
|
|
end=seg.end,
|
|
text=seg.text,
|
|
))
|
|
|
|
merged: list[MergedSegment] = [aligned[0]]
|
|
for seg in aligned[1:]:
|
|
prev = merged[-1]
|
|
gap = seg.start - prev.end
|
|
if seg.speaker == prev.speaker and gap <= pause_threshold:
|
|
prev.end = seg.end
|
|
prev.text = f"{prev.text} {seg.text}"
|
|
else:
|
|
merged.append(seg)
|
|
|
|
merged = _normalize_speaker_labels(merged)
|
|
logger.info("Alignment complete: %d ASR segments -> %d merged", len(asr_segments), len(merged))
|
|
return merged
|