diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..aea7c80 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.vscode/ +venv/ +tmp_spkrec/ \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 1f62a77..d52990a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -575,6 +575,36 @@ files = [ {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, ] +[[package]] +name = "llvmlite" +version = "0.43.0" +description = "lightweight wrapper around basic LLVM functionality" +optional = false +python-versions = ">=3.9" +files = [ + {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, + {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, + {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"}, + {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"}, + {file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"}, + {file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"}, + {file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"}, + {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"}, + {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"}, + {file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"}, + {file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"}, + {file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"}, + {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"}, + {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"}, + {file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"}, + {file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"}, + {file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"}, + {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"}, + {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"}, + {file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"}, + {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"}, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -1593,4 +1623,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1c96d328a333d968529d4a32218e66005636f5d976cb6adf7f38fb5008b3f3f3" +content-hash = "ba38328158022b27b43b2e6dab25968b4f07ef73483218f9b439ae5e0e1d4142" diff --git a/pyproject.toml b/pyproject.toml index d98691d..d9f9861 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,8 @@ speechbrain = "^1.0.2" matplotlib = "^3.9.2" torch = "^2.5.1" torchaudio = "^2.5.1" +llvmlite = "^0.43.0" +setuptools = "^75.3.0" [build-system] diff --git a/tavern_talk/diarization.py b/tavern_talk/diarization.py index 6456168..be9f48d 100644 --- a/tavern_talk/diarization.py +++ b/tavern_talk/diarization.py @@ -1,11 +1,13 @@ import torch from pyannote.audio import Pipeline import whisper -import librosa -import numpy as np -import matplotlib.pyplot as plt -import librosa.display -from pyannote.core import Segment, Annotation + +AUDIO_FILE = "2024-07-29_audio.wav" +filename = (AUDIO_FILE[::-1].split(".")[1].split("/")[0].split("\\")[0])[::-1] + + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True # Load Whisper model for transcription whisper_model = whisper.load_model("large") @@ -28,7 +30,7 @@ def perform_diarization(audio_path) -> Pipeline: diarization = pipeline(audio_path, min_speakers=5, max_speakers=7) # dump the diarization output to disk using RTTM format - with open("diarization.rttm", "w") as rttm: + with open(f"diarization_{filename}.rttm", "w") as rttm: diarization.write_rttm(rttm) print("Finished diarization") @@ -37,37 +39,23 @@ def perform_diarization(audio_path) -> Pipeline: # Load audio and perform both transcription and diarization -audio_path = "mid_audio.wav" -transcription_segments = transcribe_audio(audio_path) -diarization: Pipeline = perform_diarization(audio_path) +transcription_segments = transcribe_audio(AUDIO_FILE) +diarization: Pipeline = perform_diarization(AUDIO_FILE) # Print speaker and corresponding text print("\nSpeaker and Text Segments:") +diarization_with_text = [] for segment in transcription_segments: start, end, text = segment for spk_segment, _, speaker_label in diarization.itertracks(yield_label=True): if spk_segment.start < end and spk_segment.end > start: - print(f"Speaker {speaker_label}: {text}") + diarization_with_text.append(f"Speaker {speaker_label}: {text}") break -# Load audio for plotting -audio, sr = librosa.load(audio_path, sr=None) -# Plot the audio waveform and speaker segments -plt.figure(figsize=(12, 6)) -librosa.display.waveshow(audio, sr=sr, alpha=0.5, color="gray") -plt.xlabel("Time (s)") -plt.ylabel("Amplitude") -plt.title("Speaker Diarization with Transcription") +with open(f"transcript-diarization_{filename}.txt", "w") as fp: + fp.writelines(f"{l}\n" for l in diarization_with_text) -# Plot speaker segments and add transcription text -for segment, _, label in diarization.itertracks(yield_label=True): - start, end = segment.start, segment.end - plt.plot([start, end], [0.9, 0.9], label=f"Speaker {label}") +print("\n".join(diarization_with_text)) -# Avoid duplicate labels in legend -handles, labels = plt.gca().get_legend_handles_labels() -by_label = dict(zip(labels, handles)) -plt.legend(by_label.values(), by_label.keys(), loc="upper right") - -plt.show() +print("Saved diarization")