SpeechBrain experiments
This commit is contained in:
parent
70d01eb444
commit
cc2d6f8210
3 changed files with 1666 additions and 1 deletions
1581
poetry.lock
generated
1581
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -9,6 +9,10 @@ package-mode = false
|
|||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
pydub = "^0.25.1"
|
||||
speechbrain = "^1.0.2"
|
||||
matplotlib = "^3.9.2"
|
||||
torch = "^2.5.1"
|
||||
torchaudio = "^2.5.1"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
|
82
tavern_talk/diarization.py
Normal file
82
tavern_talk/diarization.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
audio_file = "./tavern_talk/short_transcript.wav"
|
||||
|
||||
|
||||
import torchaudio
|
||||
import torch
|
||||
from speechbrain.inference.classifiers import EncoderClassifier
|
||||
from scipy.cluster.vq import kmeans2
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Load the speaker encoder model
|
||||
classifier = EncoderClassifier.from_hparams(
|
||||
source="speechbrain/spkrec-xvect-voxceleb", savedir="tmp_spkrec"
|
||||
)
|
||||
|
||||
# Load the ASR model from torchaudio
|
||||
asr_model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
|
||||
|
||||
# Define the audio file path
|
||||
signal, fs = torchaudio.load(audio_file)
|
||||
|
||||
# Segment the audio into 1-second chunks with a 50% overlap for speaker embeddings
|
||||
window_size = int(fs * 1.0)
|
||||
overlap = int(fs * 0.5)
|
||||
segments = []
|
||||
embeddings = []
|
||||
|
||||
for start in range(0, signal.shape[1] - window_size, overlap):
|
||||
segment = signal[:, start : start + window_size]
|
||||
segments.append((start / fs, (start + window_size) / fs))
|
||||
embedding = classifier.encode_batch(segment)
|
||||
embeddings.append(embedding.squeeze(0).detach().cpu().numpy())
|
||||
|
||||
# Convert embeddings to a 2D numpy array (num_segments x embedding_size)
|
||||
embeddings = np.vstack(embeddings)
|
||||
|
||||
# Perform KMeans clustering on 2D embeddings
|
||||
centroids, labels = kmeans2(embeddings, k=6) # Adjust 'k' based on number of speakers
|
||||
|
||||
# Output diarization results with speaker labels and timestamps
|
||||
print("Diarization Results:")
|
||||
for i, (start, end) in enumerate(segments):
|
||||
print(f"{start:.2f}s - {end:.2f}s: Speaker {labels[i]}")
|
||||
|
||||
# Perform ASR on the entire audio file and display the result
|
||||
with torch.inference_mode():
|
||||
asr_transcription = asr_model(signal)[0] # Extract only the transcription result
|
||||
asr_text = asr_transcription.tolist()
|
||||
|
||||
print("\nTranscription Results:")
|
||||
print(asr_text)
|
||||
|
||||
|
||||
# Optional: plot audio waveform with speaker probabilities
|
||||
def plot_diarization_with_audio(signal, fs, segments, labels):
|
||||
# Plot audio waveform
|
||||
plt.figure(figsize=(12, 6))
|
||||
time = torch.arange(0, signal.shape[1]) / fs
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(time, signal.t().numpy())
|
||||
plt.title("Audio Waveform")
|
||||
plt.xlabel("Time (s)")
|
||||
plt.ylabel("Amplitude")
|
||||
|
||||
# Plot speaker diarization
|
||||
plt.subplot(2, 1, 2)
|
||||
for i, (start, end) in enumerate(segments):
|
||||
speaker_label = labels[i]
|
||||
plt.plot(
|
||||
[start, end],
|
||||
[speaker_label, speaker_label],
|
||||
label=f"Speaker {speaker_label}",
|
||||
linewidth=4,
|
||||
)
|
||||
|
||||
plt.xlabel("Time (s)")
|
||||
plt.ylabel("Speaker")
|
||||
plt.title("Speaker Diarization with Probability")
|
||||
plt.show()
|
||||
|
||||
|
||||
plot_diarization_with_audio(signal, fs, segments, labels)
|
Loading…
Reference in a new issue