Some more advanced diarization

This commit is contained in:
Maximilian Giller 2024-11-11 00:43:54 +01:00
parent b71c78c5f2
commit d62195a590

View file

@ -1,40 +1,67 @@
# instantiate the pipeline
from pyannote.audio import Pipeline
import torch import torch
from pyannote.audio import Pipeline
import whisper
import librosa
import numpy as np
import matplotlib.pyplot as plt
import librosa.display
from pyannote.core import Segment, Annotation
audio_path = "short_transcript.wav" # Load Whisper model for transcription
whisper_model = whisper.load_model("large")
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token="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",
)
# run the pipeline on an audio file # Transcribe audio using Whisper
diarization = pipeline(audio_path, min_speakers=6, max_speakers=7) def transcribe_audio(audio_path):
result = whisper_model.transcribe(audio_path)
segments = result["segments"]
return [(segment["start"], segment["end"], segment["text"]) for segment in segments]
# dump the diarization output to disk using RTTM format
with open("short_transcript.rttm", "w") as rttm: # Initialize Pyannote Pipeline for diarization
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(torch.device("cuda"))
# Perform diarization
def perform_diarization(audio_path) -> Pipeline:
diarization = pipeline(audio_path, min_speakers=5, max_speakers=7)
# dump the diarization output to disk using RTTM format
with open("diarization.rttm", "w") as rttm:
diarization.write_rttm(rttm) diarization.write_rttm(rttm)
print("Finished diarization")
import matplotlib.pyplot as plt return diarization
import librosa
import librosa.display
# Load the audio file and compute its waveform
# Load audio and perform both transcription and diarization
audio_path = "mid_audio.wav"
transcription_segments = transcribe_audio(audio_path)
diarization: Pipeline = perform_diarization(audio_path)
# Print speaker and corresponding text
print("\nSpeaker and Text Segments:")
for segment in transcription_segments:
start, end, text = segment
for spk_segment, _, speaker_label in diarization.itertracks(yield_label=True):
if spk_segment.start < end and spk_segment.end > start:
print(f"Speaker {speaker_label}: {text}")
break
# Load audio for plotting
audio, sr = librosa.load(audio_path, sr=None) audio, sr = librosa.load(audio_path, sr=None)
# Plot the audio waveform # Plot the audio waveform and speaker segments
plt.figure(figsize=(10, 6)) plt.figure(figsize=(12, 6))
librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray") librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray")
plt.xlabel("Time (s)") plt.xlabel("Time (s)")
plt.ylabel("Amplitude") plt.ylabel("Amplitude")
plt.title("Speaker Diarization Results") plt.title("Speaker Diarization with Transcription")
# Plot speaker segments # Plot speaker segments and add transcription text
for segment, _, label in diarization.itertracks(yield_label=True): for segment, _, label in diarization.itertracks(yield_label=True):
# Get start and end times of each speaker segment
start, end = segment.start, segment.end start, end = segment.start, segment.end
plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}") plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")