import logging from dataclasses import dataclass from transcriber.asr.whisper_engine import Segment from transcriber.diarization.pyannote_engine import SpeakerTurn logger = logging.getLogger(__name__) @dataclass class MergedSegment: """Final segment with speaker label and merged text.""" speaker: str start: float end: float text: str def _compute_overlap(seg_start: float, seg_end: float, turn_start: float, turn_end: float) -> float: """Compute temporal overlap between two intervals in seconds.""" overlap_start = max(seg_start, turn_start) overlap_end = min(seg_end, turn_end) return max(0.0, overlap_end - overlap_start) def _assign_speaker(segment: Segment, speaker_turns: list[SpeakerTurn]) -> str: """Assign speaker to ASR segment by maximum overlap.""" best_speaker = "Unknown" best_overlap = 0.0 for turn in speaker_turns: if turn.end < segment.start: continue if turn.start > segment.end: break overlap = _compute_overlap(segment.start, segment.end, turn.start, turn.end) if overlap > best_overlap: best_overlap = overlap best_speaker = turn.speaker return best_speaker def _normalize_speaker_labels(segments: list[MergedSegment]) -> list[MergedSegment]: """Replace raw pyannote labels (SPEAKER_00) with sequential Speaker 1, Speaker 2.""" label_map: dict[str, str] = {} counter = 1 for seg in segments: if seg.speaker not in label_map and seg.speaker != "Unknown": label_map[seg.speaker] = f"Speaker {counter}" counter += 1 for seg in segments: seg.speaker = label_map.get(seg.speaker, seg.speaker) return segments def align_and_merge( asr_segments: list[Segment], speaker_turns: list[SpeakerTurn], pause_threshold: float = 1.5, ) -> list[MergedSegment]: """Align ASR segments with speaker turns and merge adjacent same-speaker segments. Args: asr_segments: Segments from Whisper ASR. speaker_turns: Speaker turns from diarization (sorted by start). pause_threshold: Max pause between segments to merge (seconds). Returns: List of merged segments with speaker labels. """ if not asr_segments: return [] aligned = [] for seg in asr_segments: speaker = _assign_speaker(seg, speaker_turns) aligned.append(MergedSegment( speaker=speaker, start=seg.start, end=seg.end, text=seg.text, )) merged: list[MergedSegment] = [aligned[0]] for seg in aligned[1:]: prev = merged[-1] gap = seg.start - prev.end if seg.speaker == prev.speaker and gap <= pause_threshold: prev.end = seg.end prev.text = f"{prev.text} {seg.text}" else: merged.append(seg) merged = _normalize_speaker_labels(merged) logger.info("Alignment complete: %d ASR segments -> %d merged", len(asr_segments), len(merged)) return merged