import torch from pyannote.audio import Pipeline import whisper import librosa import numpy as np import matplotlib.pyplot as plt import librosa.display from pyannote.core import Segment, Annotation # Load Whisper model for transcription whisper_model = whisper.load_model("large") # Transcribe audio using Whisper def transcribe_audio(audio_path): result = whisper_model.transcribe(audio_path) segments = result["segments"] return [(segment["start"], segment["end"], segment["text"]) for segment in segments] # Initialize Pyannote Pipeline for diarization pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1") pipeline.to(torch.device("cuda")) # Perform diarization def perform_diarization(audio_path) -> Pipeline: diarization = pipeline(audio_path, min_speakers=5, max_speakers=7) # dump the diarization output to disk using RTTM format with open("diarization.rttm", "w") as rttm: diarization.write_rttm(rttm) print("Finished diarization") return diarization # Load audio and perform both transcription and diarization audio_path = "mid_audio.wav" transcription_segments = transcribe_audio(audio_path) diarization: Pipeline = perform_diarization(audio_path) # Print speaker and corresponding text print("\nSpeaker and Text Segments:") for segment in transcription_segments: start, end, text = segment for spk_segment, _, speaker_label in diarization.itertracks(yield_label=True): if spk_segment.start < end and spk_segment.end > start: print(f"Speaker {speaker_label}: {text}") break # Load audio for plotting audio, sr = librosa.load(audio_path, sr=None) # Plot the audio waveform and speaker segments plt.figure(figsize=(12, 6)) librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray") plt.xlabel("Time (s)") plt.ylabel("Amplitude") plt.title("Speaker Diarization with Transcription") # Plot speaker segments and add transcription text for segment, _, label in diarization.itertracks(yield_label=True): start, end = segment.start, segment.end plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}") # Avoid duplicate labels in legend handles, labels = plt.gca().get_legend_handles_labels() by_label = dict(zip(labels, handles)) plt.legend(by_label.values(), by_label.keys(), loc="upper right") plt.show()