From b71c78c5f2c27cb30416f561be1f7354b637c2a3 Mon Sep 17 00:00:00 2001 From: Maximilian Giller Date: Sun, 10 Nov 2024 23:28:54 +0100 Subject: [PATCH] Some diarization --- tavern_talk/diarization.py | 104 ++++++++++++------------------------- 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/tavern_talk/diarization.py b/tavern_talk/diarization.py index 122e889..6919db9 100644 --- a/tavern_talk/diarization.py +++ b/tavern_talk/diarization.py @@ -1,82 +1,46 @@ -audio_file = "./tavern_talk/short_transcript.wav" - - -import torchaudio +# instantiate the pipeline +from pyannote.audio import Pipeline import torch -from speechbrain.inference.classifiers import EncoderClassifier -from scipy.cluster.vq import kmeans2 -import numpy as np -import matplotlib.pyplot as plt -# Load the speaker encoder model -classifier = EncoderClassifier.from_hparams( - source="speechbrain/spkrec-xvect-voxceleb", savedir="tmp_spkrec" +audio_path = "short_transcript.wav" + +pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token="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", ) -# Load the ASR model from torchaudio -asr_model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model() -# Define the audio file path -signal, fs = torchaudio.load(audio_file) +# run the pipeline on an audio file +diarization = pipeline(audio_path, min_speakers=6, max_speakers=7) -# Segment the audio into 1-second chunks with a 50% overlap for speaker embeddings -window_size = int(fs * 1.0) -overlap = int(fs * 0.5) -segments = [] -embeddings = [] - -for start in range(0, signal.shape[1] - window_size, overlap): - segment = signal[:, start : start + window_size] - segments.append((start / fs, (start + window_size) / fs)) - embedding = classifier.encode_batch(segment) - embeddings.append(embedding.squeeze(0).detach().cpu().numpy()) - -# Convert embeddings to a 2D numpy array (num_segments x embedding_size) -embeddings = np.vstack(embeddings) - -# Perform KMeans clustering on 2D embeddings -centroids, labels = kmeans2(embeddings, k=6) # Adjust 'k' based on number of speakers - -# Output diarization results with speaker labels and timestamps -print("Diarization Results:") -for i, (start, end) in enumerate(segments): - print(f"{start:.2f}s - {end:.2f}s: Speaker {labels[i]}") - -# Perform ASR on the entire audio file and display the result -with torch.inference_mode(): - asr_transcription = asr_model(signal)[0] # Extract only the transcription result - asr_text = asr_transcription.tolist() - -print("\nTranscription Results:") -print(asr_text) +# dump the diarization output to disk using RTTM format +with open("short_transcript.rttm", "w") as rttm: + diarization.write_rttm(rttm) -# Optional: plot audio waveform with speaker probabilities -def plot_diarization_with_audio(signal, fs, segments, labels): - # Plot audio waveform - plt.figure(figsize=(12, 6)) - time = torch.arange(0, signal.shape[1]) / fs - plt.subplot(2, 1, 1) - plt.plot(time, signal.t().numpy()) - plt.title("Audio Waveform") - plt.xlabel("Time (s)") - plt.ylabel("Amplitude") +import matplotlib.pyplot as plt +import librosa +import librosa.display - # Plot speaker diarization - plt.subplot(2, 1, 2) - for i, (start, end) in enumerate(segments): - speaker_label = labels[i] - plt.plot( - [start, end], - [speaker_label, speaker_label], - label=f"Speaker {speaker_label}", - linewidth=4, - ) +# Load the audio file and compute its waveform +audio, sr = librosa.load(audio_path, sr=None) - plt.xlabel("Time (s)") - plt.ylabel("Speaker") - plt.title("Speaker Diarization with Probability") - plt.show() +# Plot the audio waveform +plt.figure(figsize=(10, 6)) +librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray") +plt.xlabel("Time (s)") +plt.ylabel("Amplitude") +plt.title("Speaker Diarization Results") +# Plot speaker segments +for segment, _, label in diarization.itertracks(yield_label=True): + # Get start and end times of each speaker segment + start, end = segment.start, segment.end + plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}") -plot_diarization_with_audio(signal, fs, segments, labels) +# Avoid duplicate labels in legend +handles, labels = plt.gca().get_legend_handles_labels() +by_label = dict(zip(labels, handles)) +plt.legend(by_label.values(), by_label.keys(), loc="upper right") + +plt.show()