tavern-talk/tavern_talk/diarization.py

74 lines
2.3 KiB
Python
Raw Normal View History

2024-11-10 09:27:57 +01:00
import torch
2024-11-11 00:43:54 +01:00
from pyannote.audio import Pipeline
import whisper
import librosa
import numpy as np
import matplotlib.pyplot as plt
import librosa.display
from pyannote.core import Segment, Annotation
2024-11-10 09:27:57 +01:00
2024-11-11 00:43:54 +01:00
# Load Whisper model for transcription
whisper_model = whisper.load_model("large")
2024-11-10 09:27:57 +01:00
2024-11-11 00:43:54 +01:00
# Transcribe audio using Whisper
def transcribe_audio(audio_path):
result = whisper_model.transcribe(audio_path)
segments = result["segments"]
return [(segment["start"], segment["end"], segment["text"]) for segment in segments]
2024-11-10 09:27:57 +01:00
2024-11-11 00:43:54 +01:00
# Initialize Pyannote Pipeline for diarization
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(torch.device("cuda"))
2024-11-10 09:27:57 +01:00
2024-11-11 00:43:54 +01:00
# Perform diarization
def perform_diarization(audio_path) -> Pipeline:
diarization = pipeline(audio_path, min_speakers=5, max_speakers=7)
# dump the diarization output to disk using RTTM format
with open("diarization.rttm", "w") as rttm:
diarization.write_rttm(rttm)
print("Finished diarization")
return diarization
# Load audio and perform both transcription and diarization
audio_path = "mid_audio.wav"
transcription_segments = transcribe_audio(audio_path)
diarization: Pipeline = perform_diarization(audio_path)
# Print speaker and corresponding text
print("\nSpeaker and Text Segments:")
for segment in transcription_segments:
start, end, text = segment
for spk_segment, _, speaker_label in diarization.itertracks(yield_label=True):
if spk_segment.start < end and spk_segment.end > start:
print(f"Speaker {speaker_label}: {text}")
break
2024-11-10 23:28:54 +01:00
2024-11-11 00:43:54 +01:00
# Load audio for plotting
2024-11-10 23:28:54 +01:00
audio, sr = librosa.load(audio_path, sr=None)
2024-11-11 00:43:54 +01:00
# Plot the audio waveform and speaker segments
plt.figure(figsize=(12, 6))
2024-11-10 23:28:54 +01:00
librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
2024-11-11 00:43:54 +01:00
plt.title("Speaker Diarization with Transcription")
2024-11-10 23:28:54 +01:00
2024-11-11 00:43:54 +01:00
# Plot speaker segments and add transcription text
2024-11-10 23:28:54 +01:00
for segment, _, label in diarization.itertracks(yield_label=True):
start, end = segment.start, segment.end
plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")
# Avoid duplicate labels in legend
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), loc="upper right")
plt.show()