transcribe-interview/transcriber/merge/aligner.py

102 lines
3.0 KiB
Python

import logging
from dataclasses import dataclass
from transcriber.asr.whisper_engine import Segment
from transcriber.diarization.pyannote_engine import SpeakerTurn
logger = logging.getLogger(__name__)
@dataclass
class MergedSegment:
"""Final segment with speaker label and merged text."""
speaker: str
start: float
end: float
text: str
def _compute_overlap(seg_start: float, seg_end: float, turn_start: float, turn_end: float) -> float:
"""Compute temporal overlap between two intervals in seconds."""
overlap_start = max(seg_start, turn_start)
overlap_end = min(seg_end, turn_end)
return max(0.0, overlap_end - overlap_start)
def _assign_speaker(segment: Segment, speaker_turns: list[SpeakerTurn]) -> str:
"""Assign speaker to ASR segment by maximum overlap."""
best_speaker = "Unknown"
best_overlap = 0.0
for turn in speaker_turns:
if turn.end < segment.start:
continue
if turn.start > segment.end:
break
overlap = _compute_overlap(segment.start, segment.end, turn.start, turn.end)
if overlap > best_overlap:
best_overlap = overlap
best_speaker = turn.speaker
return best_speaker
def _normalize_speaker_labels(segments: list[MergedSegment]) -> list[MergedSegment]:
"""Replace raw pyannote labels (SPEAKER_00) with sequential Speaker 1, Speaker 2."""
label_map: dict[str, str] = {}
counter = 1
for seg in segments:
if seg.speaker not in label_map and seg.speaker != "Unknown":
label_map[seg.speaker] = f"Speaker {counter}"
counter += 1
for seg in segments:
seg.speaker = label_map.get(seg.speaker, seg.speaker)
return segments
def align_and_merge(
asr_segments: list[Segment],
speaker_turns: list[SpeakerTurn],
pause_threshold: float = 1.5,
) -> list[MergedSegment]:
"""Align ASR segments with speaker turns and merge adjacent same-speaker segments.
Args:
asr_segments: Segments from Whisper ASR.
speaker_turns: Speaker turns from diarization (sorted by start).
pause_threshold: Max pause between segments to merge (seconds).
Returns:
List of merged segments with speaker labels.
"""
if not asr_segments:
return []
aligned = []
for seg in asr_segments:
speaker = _assign_speaker(seg, speaker_turns)
aligned.append(MergedSegment(
speaker=speaker,
start=seg.start,
end=seg.end,
text=seg.text,
))
merged: list[MergedSegment] = [aligned[0]]
for seg in aligned[1:]:
prev = merged[-1]
gap = seg.start - prev.end
if seg.speaker == prev.speaker and gap <= pause_threshold:
prev.end = seg.end
prev.text = f"{prev.text} {seg.text}"
else:
merged.append(seg)
merged = _normalize_speaker_labels(merged)
logger.info("Alignment complete: %d ASR segments -> %d merged", len(asr_segments), len(merged))
return merged