Some more advanced diarization
This commit is contained in:
parent
b71c78c5f2
commit
d62195a590
1 changed files with 53 additions and 26 deletions
|
@ -1,40 +1,67 @@
|
||||||
# instantiate the pipeline
|
|
||||||
from pyannote.audio import Pipeline
|
|
||||||
import torch
|
import torch
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
audio_path = "short_transcript.wav"
|
import whisper
|
||||||
|
|
||||||
pipeline = Pipeline.from_pretrained(
|
|
||||||
"pyannote/speaker-diarization-3.1",
|
|
||||||
use_auth_token="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",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# run the pipeline on an audio file
|
|
||||||
diarization = pipeline(audio_path, min_speakers=6, max_speakers=7)
|
|
||||||
|
|
||||||
# dump the diarization output to disk using RTTM format
|
|
||||||
with open("short_transcript.rttm", "w") as rttm:
|
|
||||||
diarization.write_rttm(rttm)
|
|
||||||
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import librosa
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import librosa.display
|
import librosa.display
|
||||||
|
from pyannote.core import Segment, Annotation
|
||||||
|
|
||||||
# Load the audio file and compute its waveform
|
# Load Whisper model for transcription
|
||||||
|
whisper_model = whisper.load_model("large")
|
||||||
|
|
||||||
|
|
||||||
|
# Transcribe audio using Whisper
|
||||||
|
def transcribe_audio(audio_path):
|
||||||
|
result = whisper_model.transcribe(audio_path)
|
||||||
|
segments = result["segments"]
|
||||||
|
return [(segment["start"], segment["end"], segment["text"]) for segment in segments]
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize Pyannote Pipeline for diarization
|
||||||
|
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
|
||||||
|
pipeline.to(torch.device("cuda"))
|
||||||
|
|
||||||
|
|
||||||
|
# Perform diarization
|
||||||
|
def perform_diarization(audio_path) -> Pipeline:
|
||||||
|
diarization = pipeline(audio_path, min_speakers=5, max_speakers=7)
|
||||||
|
|
||||||
|
# dump the diarization output to disk using RTTM format
|
||||||
|
with open("diarization.rttm", "w") as rttm:
|
||||||
|
diarization.write_rttm(rttm)
|
||||||
|
|
||||||
|
print("Finished diarization")
|
||||||
|
|
||||||
|
return diarization
|
||||||
|
|
||||||
|
|
||||||
|
# Load audio and perform both transcription and diarization
|
||||||
|
audio_path = "mid_audio.wav"
|
||||||
|
transcription_segments = transcribe_audio(audio_path)
|
||||||
|
diarization: Pipeline = perform_diarization(audio_path)
|
||||||
|
|
||||||
|
# Print speaker and corresponding text
|
||||||
|
print("\nSpeaker and Text Segments:")
|
||||||
|
for segment in transcription_segments:
|
||||||
|
start, end, text = segment
|
||||||
|
for spk_segment, _, speaker_label in diarization.itertracks(yield_label=True):
|
||||||
|
if spk_segment.start < end and spk_segment.end > start:
|
||||||
|
print(f"Speaker {speaker_label}: {text}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Load audio for plotting
|
||||||
audio, sr = librosa.load(audio_path, sr=None)
|
audio, sr = librosa.load(audio_path, sr=None)
|
||||||
|
|
||||||
# Plot the audio waveform
|
# Plot the audio waveform and speaker segments
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(12, 6))
|
||||||
librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray")
|
librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray")
|
||||||
plt.xlabel("Time (s)")
|
plt.xlabel("Time (s)")
|
||||||
plt.ylabel("Amplitude")
|
plt.ylabel("Amplitude")
|
||||||
plt.title("Speaker Diarization Results")
|
plt.title("Speaker Diarization with Transcription")
|
||||||
|
|
||||||
# Plot speaker segments
|
# Plot speaker segments and add transcription text
|
||||||
for segment, _, label in diarization.itertracks(yield_label=True):
|
for segment, _, label in diarization.itertracks(yield_label=True):
|
||||||
# Get start and end times of each speaker segment
|
|
||||||
start, end = segment.start, segment.end
|
start, end = segment.start, segment.end
|
||||||
plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")
|
plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue