Some diarization
This commit is contained in:
parent
cc2d6f8210
commit
b71c78c5f2
1 changed files with 34 additions and 70 deletions
|
@ -1,82 +1,46 @@
|
||||||
audio_file = "./tavern_talk/short_transcript.wav"
|
# instantiate the pipeline
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
import torchaudio
|
|
||||||
import torch
|
import torch
|
||||||
from speechbrain.inference.classifiers import EncoderClassifier
|
|
||||||
from scipy.cluster.vq import kmeans2
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
# Load the speaker encoder model
|
audio_path = "short_transcript.wav"
|
||||||
classifier = EncoderClassifier.from_hparams(
|
|
||||||
source="speechbrain/spkrec-xvect-voxceleb", savedir="tmp_spkrec"
|
pipeline = Pipeline.from_pretrained(
|
||||||
|
"pyannote/speaker-diarization-3.1",
|
||||||
|
use_auth_token="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",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load the ASR model from torchaudio
|
|
||||||
asr_model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
|
|
||||||
|
|
||||||
# Define the audio file path
|
# run the pipeline on an audio file
|
||||||
signal, fs = torchaudio.load(audio_file)
|
diarization = pipeline(audio_path, min_speakers=6, max_speakers=7)
|
||||||
|
|
||||||
# Segment the audio into 1-second chunks with a 50% overlap for speaker embeddings
|
# dump the diarization output to disk using RTTM format
|
||||||
window_size = int(fs * 1.0)
|
with open("short_transcript.rttm", "w") as rttm:
|
||||||
overlap = int(fs * 0.5)
|
diarization.write_rttm(rttm)
|
||||||
segments = []
|
|
||||||
embeddings = []
|
|
||||||
|
|
||||||
for start in range(0, signal.shape[1] - window_size, overlap):
|
|
||||||
segment = signal[:, start : start + window_size]
|
|
||||||
segments.append((start / fs, (start + window_size) / fs))
|
|
||||||
embedding = classifier.encode_batch(segment)
|
|
||||||
embeddings.append(embedding.squeeze(0).detach().cpu().numpy())
|
|
||||||
|
|
||||||
# Convert embeddings to a 2D numpy array (num_segments x embedding_size)
|
|
||||||
embeddings = np.vstack(embeddings)
|
|
||||||
|
|
||||||
# Perform KMeans clustering on 2D embeddings
|
|
||||||
centroids, labels = kmeans2(embeddings, k=6) # Adjust 'k' based on number of speakers
|
|
||||||
|
|
||||||
# Output diarization results with speaker labels and timestamps
|
|
||||||
print("Diarization Results:")
|
|
||||||
for i, (start, end) in enumerate(segments):
|
|
||||||
print(f"{start:.2f}s - {end:.2f}s: Speaker {labels[i]}")
|
|
||||||
|
|
||||||
# Perform ASR on the entire audio file and display the result
|
|
||||||
with torch.inference_mode():
|
|
||||||
asr_transcription = asr_model(signal)[0] # Extract only the transcription result
|
|
||||||
asr_text = asr_transcription.tolist()
|
|
||||||
|
|
||||||
print("\nTranscription Results:")
|
|
||||||
print(asr_text)
|
|
||||||
|
|
||||||
|
|
||||||
# Optional: plot audio waveform with speaker probabilities
|
import matplotlib.pyplot as plt
|
||||||
def plot_diarization_with_audio(signal, fs, segments, labels):
|
import librosa
|
||||||
# Plot audio waveform
|
import librosa.display
|
||||||
plt.figure(figsize=(12, 6))
|
|
||||||
time = torch.arange(0, signal.shape[1]) / fs
|
|
||||||
plt.subplot(2, 1, 1)
|
|
||||||
plt.plot(time, signal.t().numpy())
|
|
||||||
plt.title("Audio Waveform")
|
|
||||||
plt.xlabel("Time (s)")
|
|
||||||
plt.ylabel("Amplitude")
|
|
||||||
|
|
||||||
# Plot speaker diarization
|
# Load the audio file and compute its waveform
|
||||||
plt.subplot(2, 1, 2)
|
audio, sr = librosa.load(audio_path, sr=None)
|
||||||
for i, (start, end) in enumerate(segments):
|
|
||||||
speaker_label = labels[i]
|
|
||||||
plt.plot(
|
|
||||||
[start, end],
|
|
||||||
[speaker_label, speaker_label],
|
|
||||||
label=f"Speaker {speaker_label}",
|
|
||||||
linewidth=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.xlabel("Time (s)")
|
# Plot the audio waveform
|
||||||
plt.ylabel("Speaker")
|
plt.figure(figsize=(10, 6))
|
||||||
plt.title("Speaker Diarization with Probability")
|
librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray")
|
||||||
plt.show()
|
plt.xlabel("Time (s)")
|
||||||
|
plt.ylabel("Amplitude")
|
||||||
|
plt.title("Speaker Diarization Results")
|
||||||
|
|
||||||
|
# Plot speaker segments
|
||||||
|
for segment, _, label in diarization.itertracks(yield_label=True):
|
||||||
|
# Get start and end times of each speaker segment
|
||||||
|
start, end = segment.start, segment.end
|
||||||
|
plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")
|
||||||
|
|
||||||
plot_diarization_with_audio(signal, fs, segments, labels)
|
# Avoid duplicate labels in legend
|
||||||
|
handles, labels = plt.gca().get_legend_handles_labels()
|
||||||
|
by_label = dict(zip(labels, handles))
|
||||||
|
plt.legend(by_label.values(), by_label.keys(), loc="upper right")
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
Loading…
Reference in a new issue