diff --git a/audio.wav b/audio.wav new file mode 100644 index 0000000..7c42862 Binary files /dev/null and b/audio.wav differ diff --git a/src/main.py b/src/main.py index 93c7d5c..853ae7c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,43 +1,40 @@ from typer import insert_text_at_cursor, callback_for_keycombination, wait_for_callbacks -from speech import AudioProcessor +from speech import WhisperSTT +from mic import AudioRecorder import logging - -LISTENING_KEYCOMBINATION: str = "ctrl" -RELEASE_BTN: str = "ctrl" +LISTENING_KEYCOMBINATION: str = "CTRL" +RELEASE_BTN: str = "CTRL" logging.getLogger().setLevel(logging.DEBUG) -def phrase_callback(phrase: str) -> None: - if audio_processor.is_listening: - return - - insert_text_at_cursor(phrase) - def start_listening(): - if audio_processor.is_listening: + if recorder.is_recording(): return logging.info(f'Listening... press [{RELEASE_BTN}] to stop.') - audio_processor.start_recording() + recorder.start() def stop_listening(): - if not audio_processor.is_listening: + if not recorder.is_recording(): return + audio = recorder.stop() + recorder.save("audio.wav") logging.info("Stopped listening.") - audio_processor.stop_recording() - insert_text_at_cursor(audio_processor.get_last_text()) + text = whisper.transcribe(audio) + insert_text_at_cursor(text) -audio_processor: AudioProcessor = AudioProcessor(model="medium", phrase_callback=phrase_callback) +recorder: AudioRecorder = AudioRecorder() +whisper: WhisperSTT = WhisperSTT() callback_for_keycombination(LISTENING_KEYCOMBINATION, start_listening, on_release=False) callback_for_keycombination(RELEASE_BTN, stop_listening, on_release=True) -logging.info(f'Ready, wait for [{LISTENING_KEYCOMBINATION.upper()}]') +logging.info(f'Ready, wait for [{LISTENING_KEYCOMBINATION}]') wait_for_callbacks() diff --git a/src/mic.py b/src/mic.py new file mode 100644 index 0000000..33fa57c --- /dev/null +++ b/src/mic.py @@ -0,0 +1,58 @@ +import numpy as np +import pyaudio +import wave +import threading + + +class AudioRecorder: + def __init__(self, chunk=1024, format=pyaudio.paInt16, channels=1, rate=48000): + self.chunk = chunk + self.format = format + self.channels = channels + self.rate = rate + self.frames = [] + self.recording = False + self.thread = None + + def start(self): + self.recording = True + self.thread = threading.Thread(target=self._record) + self.thread.start() + + def stop(self): + self.recording = False + self.thread.join() + + return self.frames + + def is_recording(self) -> bool: + return self.recording + + def get_last_audio(self) -> np.ndarray: + return self.frames + + def _record(self): + audio = pyaudio.PyAudio() + stream = audio.open( + format=self.format, + channels=self.channels, + rate=self.rate, + input=True, + frames_per_buffer=self.chunk, + ) + self.frames = [] + while self.recording: + data = stream.read(self.chunk) + self.frames.append(np.frombuffer(data, dtype=np.float32)) + stream.stop_stream() + stream.close() + audio.terminate() + self.frames = np.concatenate(self.frames, axis=0) + + def save(self, filename: str): + waveFile = wave.open(filename, "wb") + waveFile.setnchannels(self.channels) + waveFile.setsampwidth(pyaudio.get_sample_size(self.format)) + waveFile.setframerate(self.rate) + waveFile.writeframes(b"".join(self.frames)) + waveFile.close() diff --git a/src/speech.py b/src/speech.py index fea7701..4b1759d 100644 --- a/src/speech.py +++ b/src/speech.py @@ -1,57 +1,33 @@ -import speech_recognition as sr +from typing import Any +from numpy import floating, ndarray +import whisper import logging -class AudioProcessor: - def __init__(self, *, language: str = "german", model: str = "base", phrase_callback: callable = None) -> None: - self.language: str = language - self.model: str = model +class WhisperSTT: + def __init__(self, *, model: str = "base") -> None: + self.model = whisper.load_model(model) self.last_text: str = None - self.phrase_callback: callable = phrase_callback - self.is_listening: bool = False - self.stop_listening_caller = None - - logging.debug("Found the following microphones:") - for index, name in sr.Microphone.list_working_microphones().items(): - logging.debug(f"Microphone with index {index} and name `{name}` found") - - self.recognizer: sr.Recognizer = sr.Recognizer() - self.listener: sr.Microphone = sr.Microphone() - with self.listener as source: - self.recognizer.adjust_for_ambient_noise(source) # we only need to calibrate once, before we start listening def get_last_text(self) -> str: return self.last_text - def is_listening(self) -> bool: - return self.is_listening + def transcribe(self, audio: ndarray[floating[Any]]) -> str: + # load audio and pad/trim it to fit 30 seconds + # audio = whisper.load_audio("audio.mp3") + audio = whisper.pad_or_trim(audio) - def start_recording(self) -> None: - if self.is_listening: - logging.warning("Listener is already open") - return + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio).to(self.model.device) - self.last_text = "" - self.is_listening = True - - self.stop_listening_caller = self.recognizer.listen_in_background(self.listener, self.listening_callback) - - def listening_callback(self, recognizer, audio): - new_text = self.recognizer.recognize_whisper( - audio, language=self.language, model=self.model - ) - - if self.last_text is not None and self.last_text != "": - self.last_text += " " - self.last_text += new_text - - if self.phrase_callback: - self.phrase_callback(new_text) + # detect the spoken language + _, probs = self.model.detect_language(mel) + max(probs, key=probs.get) + logging.debug(f"Detected language: {max(probs, key=probs.get)}") - def stop_recording(self) -> None: - if not self.is_listening: - logging.warning("Listener is already closed") - return - - self.stop_listening_caller(wait_for_stop=False) - self.is_listening = False + # decode the audio + options = whisper.DecodingOptions() + result = whisper.decode(self.model, mel, options) + self.last_text = result.text + + return self.last_text diff --git a/src/typer_with_xlib.py b/src/typer_with_xlib.py new file mode 100644 index 0000000..f308001 --- /dev/null +++ b/src/typer_with_xlib.py @@ -0,0 +1,70 @@ +import keyboard +from Xlib.display import Display +from Xlib import X +from Xlib.ext import record +from Xlib.protocol import rq + +disp = None +key_triggers = [] + +def insert_text_at_cursor(text: str): + if text is None or text == "": + return + + keyboard.write(text) + +def callback_for_keycombination(keycode: int, callback: callable, *, on_release: bool = False): + key_triggers.append({ + "keycode": keycode, + "callback": callback, + "on_release": on_release + }) + +def handler(reply): + """ This function is called when a xlib event is fired """ + data = reply.data + while len(data): + event, data = rq.EventField(None).parse_binary_value(data, disp.display, None, None) + + # KEYCODE IS FOUND USERING event.detail + # print(event.detail) + + for trigger in key_triggers: + if int(trigger["keycode"]) != event.detail: + continue + + if trigger["on_release"] and event.type == X.KeyRelease: + trigger["callback"]() + elif not trigger["on_release"] and event.type == X.KeyPress: + trigger["callback"]() + + +def wait_for_callbacks(): + # get current display + global disp + disp = Display() + root = disp.screen().root + + + # Monitor keypress and button press + ctx = disp.record_create_context( + 0, + [record.AllClients], + [{ + 'core_requests': (0, 0), + 'core_replies': (0, 0), + 'ext_requests': (0, 0, 0, 0), + 'ext_replies': (0, 0, 0, 0), + 'delivered_events': (0, 0), + 'device_events': (X.KeyReleaseMask, X.ButtonReleaseMask), + 'errors': (0, 0), + 'client_started': False, + 'client_died': False, + }]) + disp.record_enable_context(ctx, handler) + disp.record_free_context(ctx) + + while True: + root.display.next_event() + +