Some more advanced diarization

This commit is contained in:
Maximilian Giller 2024-11-11 00:43:54 +01:00
parent b71c78c5f2
commit d62195a590

View file

@ -1,40 +1,67 @@
# instantiate the pipeline
from pyannote.audio import Pipeline
import torch
from pyannote.audio import Pipeline
import whisper
import librosa
import numpy as np
import matplotlib.pyplot as plt
import librosa.display
from pyannote.core import Segment, Annotation
audio_path = "short_transcript.wav"
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token="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",
)
# Load Whisper model for transcription
whisper_model = whisper.load_model("large")
# run the pipeline on an audio file
diarization = pipeline(audio_path, min_speakers=6, max_speakers=7)
# Transcribe audio using Whisper
def transcribe_audio(audio_path):
result = whisper_model.transcribe(audio_path)
segments = result["segments"]
return [(segment["start"], segment["end"], segment["text"]) for segment in segments]
# Initialize Pyannote Pipeline for diarization
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(torch.device("cuda"))
# Perform diarization
def perform_diarization(audio_path) -> Pipeline:
diarization = pipeline(audio_path, min_speakers=5, max_speakers=7)
# dump the diarization output to disk using RTTM format
with open("short_transcript.rttm", "w") as rttm:
with open("diarization.rttm", "w") as rttm:
diarization.write_rttm(rttm)
print("Finished diarization")
import matplotlib.pyplot as plt
import librosa
import librosa.display
return diarization
# Load the audio file and compute its waveform
# Load audio and perform both transcription and diarization
audio_path = "mid_audio.wav"
transcription_segments = transcribe_audio(audio_path)
diarization: Pipeline = perform_diarization(audio_path)
# Print speaker and corresponding text
print("\nSpeaker and Text Segments:")
for segment in transcription_segments:
start, end, text = segment
for spk_segment, _, speaker_label in diarization.itertracks(yield_label=True):
if spk_segment.start < end and spk_segment.end > start:
print(f"Speaker {speaker_label}: {text}")
break
# Load audio for plotting
audio, sr = librosa.load(audio_path, sr=None)
# Plot the audio waveform
plt.figure(figsize=(10, 6))
# Plot the audio waveform and speaker segments
plt.figure(figsize=(12, 6))
librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.title("Speaker Diarization Results")
plt.title("Speaker Diarization with Transcription")
# Plot speaker segments
# Plot speaker segments and add transcription text
for segment, _, label in diarization.itertracks(yield_label=True):
# Get start and end times of each speaker segment
start, end = segment.start, segment.end
plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")