73 lines
2.3 KiB
73 lines
2.3 KiB
import torch
from pyannote.audio import Pipeline
import whisper
import librosa
import numpy as np
import matplotlib.pyplot as plt
import librosa.display
from pyannote.core import Segment, Annotation
# Load Whisper model for transcription
whisper_model = whisper.load_model("large")
# Transcribe audio using Whisper
def transcribe_audio(audio_path):
result = whisper_model.transcribe(audio_path)
segments = result["segments"]
return [(segment["start"], segment["end"], segment["text"]) for segment in segments]
# Initialize Pyannote Pipeline for diarization
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
# Perform diarization
def perform_diarization(audio_path) -> Pipeline:
diarization = pipeline(audio_path, min_speakers=5, max_speakers=7)
# dump the diarization output to disk using RTTM format
with open("diarization.rttm", "w") as rttm:
print("Finished diarization")
return diarization
# Load audio and perform both transcription and diarization
audio_path = "mid_audio.wav"
transcription_segments = transcribe_audio(audio_path)
diarization: Pipeline = perform_diarization(audio_path)
# Print speaker and corresponding text
print("\nSpeaker and Text Segments:")
for segment in transcription_segments:
start, end, text = segment
for spk_segment, _, speaker_label in diarization.itertracks(yield_label=True):
if spk_segment.start < end and spk_segment.end > start:
print(f"Speaker {speaker_label}: {text}")
# Load audio for plotting
audio, sr = librosa.load(audio_path, sr=None)
# Plot the audio waveform and speaker segments
plt.figure(figsize=(12, 6))
librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray")
plt.xlabel("Time (s)")
plt.title("Speaker Diarization with Transcription")
# Plot speaker segments and add transcription text
for segment, _, label in diarization.itertracks(yield_label=True):
start, end = segment.start, segment.end
plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")
# Avoid duplicate labels in legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), loc="upper right")