diff --git a/tavern_talk/diarization.py b/tavern_talk/diarization.py index 6919db9..6456168 100644 --- a/tavern_talk/diarization.py +++ b/tavern_talk/diarization.py @@ -1,40 +1,67 @@ -# instantiate the pipeline -from pyannote.audio import Pipeline import torch - -audio_path = "short_transcript.wav" - -pipeline = Pipeline.from_pretrained( - "pyannote/speaker-diarization-3.1", - use_auth_token="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", -) - - -# run the pipeline on an audio file -diarization = pipeline(audio_path, min_speakers=6, max_speakers=7) - -# dump the diarization output to disk using RTTM format -with open("short_transcript.rttm", "w") as rttm: - diarization.write_rttm(rttm) - - -import matplotlib.pyplot as plt +from pyannote.audio import Pipeline +import whisper import librosa +import numpy as np +import matplotlib.pyplot as plt import librosa.display +from pyannote.core import Segment, Annotation -# Load the audio file and compute its waveform +# Load Whisper model for transcription +whisper_model = whisper.load_model("large") + + +# Transcribe audio using Whisper +def transcribe_audio(audio_path): + result = whisper_model.transcribe(audio_path) + segments = result["segments"] + return [(segment["start"], segment["end"], segment["text"]) for segment in segments] + + +# Initialize Pyannote Pipeline for diarization +pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1") +pipeline.to(torch.device("cuda")) + + +# Perform diarization +def perform_diarization(audio_path) -> Pipeline: + diarization = pipeline(audio_path, min_speakers=5, max_speakers=7) + + # dump the diarization output to disk using RTTM format + with open("diarization.rttm", "w") as rttm: + diarization.write_rttm(rttm) + + print("Finished diarization") + + return diarization + + +# Load audio and perform both transcription and diarization +audio_path = "mid_audio.wav" +transcription_segments = transcribe_audio(audio_path) +diarization: Pipeline = perform_diarization(audio_path) + +# Print speaker and corresponding text +print("\nSpeaker and Text Segments:") +for segment in transcription_segments: + start, end, text = segment + for spk_segment, _, speaker_label in diarization.itertracks(yield_label=True): + if spk_segment.start < end and spk_segment.end > start: + print(f"Speaker {speaker_label}: {text}") + break + +# Load audio for plotting audio, sr = librosa.load(audio_path, sr=None) -# Plot the audio waveform -plt.figure(figsize=(10, 6)) +# Plot the audio waveform and speaker segments +plt.figure(figsize=(12, 6)) librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray") plt.xlabel("Time (s)") plt.ylabel("Amplitude") -plt.title("Speaker Diarization Results") +plt.title("Speaker Diarization with Transcription") -# Plot speaker segments +# Plot speaker segments and add transcription text for segment, _, label in diarization.itertracks(yield_label=True): - # Get start and end times of each speaker segment start, end = segment.start, segment.end plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")