tavern-talk/tavern_talk/diarization.py

47 lines
1.3 KiB
Python
Raw Normal View History

2024-11-10 23:28:54 +01:00
# instantiate the pipeline
from pyannote.audio import Pipeline
2024-11-10 09:27:57 +01:00
import torch
2024-11-10 23:28:54 +01:00
audio_path = "short_transcript.wav"
2024-11-10 09:27:57 +01:00
2024-11-10 23:28:54 +01:00
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token="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",
)
2024-11-10 09:27:57 +01:00
2024-11-10 23:28:54 +01:00
# run the pipeline on an audio file
diarization = pipeline(audio_path, min_speakers=6, max_speakers=7)
2024-11-10 09:27:57 +01:00
2024-11-10 23:28:54 +01:00
# dump the diarization output to disk using RTTM format
with open("short_transcript.rttm", "w") as rttm:
diarization.write_rttm(rttm)
2024-11-10 09:27:57 +01:00
2024-11-10 23:28:54 +01:00
import matplotlib.pyplot as plt
import librosa
import librosa.display
# Load the audio file and compute its waveform
audio, sr = librosa.load(audio_path, sr=None)
# Plot the audio waveform
plt.figure(figsize=(10, 6))
librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.title("Speaker Diarization Results")
# Plot speaker segments
for segment, _, label in diarization.itertracks(yield_label=True):
# Get start and end times of each speaker segment
start, end = segment.start, segment.end
plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")
# Avoid duplicate labels in legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), loc="upper right")
plt.show()