Workiiiiiiing
This commit is contained in:
parent
cc320c8376
commit
c54cc70dfe
8 changed files with 239 additions and 173 deletions
12
src/config.py
Normal file
12
src/config.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import socket
|
||||
|
||||
|
||||
HOST = socket.gethostname()
|
||||
PORT = 31452
|
||||
|
||||
TRIGGER_KEY: str = "ctrl+alt+space"
|
||||
TYPE_WHILE_TRIGGERED: bool = True
|
||||
|
||||
CODE_START = "START"
|
||||
CODE_END = "END"
|
||||
|
25
src/intercom/client.py
Normal file
25
src/intercom/client.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import socket
|
||||
|
||||
class TcpClient:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
def connect(self):
|
||||
self.socket.connect((self.host, self.port))
|
||||
|
||||
def send(self, data: bytes):
|
||||
self.socket.sendall(data)
|
||||
|
||||
def receive(self, buffer_size=1024):
|
||||
return self.socket.recv(buffer_size)
|
||||
|
||||
def send_text(self, data: str):
|
||||
self.send(data.encode("utf-8"))
|
||||
|
||||
def receive_text(self, buffer_size=1024):
|
||||
return self.receive(buffer_size).decode("utf-8")
|
||||
|
||||
def close(self):
|
||||
self.socket.close()
|
51
src/intercom/server.py
Normal file
51
src/intercom/server.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import logging
|
||||
import socket
|
||||
|
||||
|
||||
class TcpServer:
|
||||
def __init__(self, host, port, receive_callback=None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.receive_callback = receive_callback
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.client = None
|
||||
|
||||
def start(self):
|
||||
self.socket.bind((self.host, self.port))
|
||||
self.socket.listen(1)
|
||||
print(f"Server listening on {self.host}:{self.port}...")
|
||||
|
||||
while True:
|
||||
client_socket, client_address = self.socket.accept()
|
||||
print(f"Client connected from {client_address[0]}:{client_address[1]}")
|
||||
self.handle_client(client_socket)
|
||||
|
||||
def handle_client(self, client_socket):
|
||||
self.client = client_socket
|
||||
try:
|
||||
while True:
|
||||
data = client_socket.recv(1024)
|
||||
if not data:
|
||||
break
|
||||
if self.receive_callback:
|
||||
self.receive_callback(data)
|
||||
|
||||
client_socket.close()
|
||||
|
||||
except ConnectionResetError:
|
||||
pass
|
||||
finally:
|
||||
logging.debug("Client disconnected")
|
||||
client_socket = None
|
||||
|
||||
def send(self, data: bytes):
|
||||
if not self.client:
|
||||
print("No client connected")
|
||||
return
|
||||
self.client.sendall(data)
|
||||
|
||||
def send_text(self, data: str):
|
||||
self.send(data.encode("utf-8"))
|
||||
|
||||
def has_connection(self) -> bool:
|
||||
return self.client is not None
|
83
src/main.py
83
src/main.py
|
@ -1,40 +1,75 @@
|
|||
import time
|
||||
from config import CODE_END, CODE_START, HOST, PORT, TRIGGER_KEY, TYPE_WHILE_TRIGGERED
|
||||
from intercom.client import TcpClient
|
||||
from typer import insert_text_at_cursor, callback_for_keycombination, wait_for_callbacks
|
||||
from speech import WhisperSTT
|
||||
from mic import AudioRecorder
|
||||
import logging
|
||||
|
||||
LISTENING_KEYCOMBINATION: str = "CTRL"
|
||||
RELEASE_BTN: str = "CTRL"
|
||||
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
is_triggered = False
|
||||
phrases = []
|
||||
|
||||
def start_listening():
|
||||
if recorder.is_recording():
|
||||
def on_trigger():
|
||||
global is_triggered
|
||||
if is_triggered:
|
||||
is_triggered = False
|
||||
|
||||
logging.debug("Trigger released")
|
||||
client.send_text(CODE_END)
|
||||
|
||||
global phrases
|
||||
if not TYPE_WHILE_TRIGGERED:
|
||||
insert_text_at_cursor(" ".join(phrases))
|
||||
phrases = []
|
||||
|
||||
else:
|
||||
is_triggered = True
|
||||
|
||||
logging.debug("Trigger pressed")
|
||||
client.send_text(CODE_START)
|
||||
|
||||
def on_phrase_receive(phrase: str):
|
||||
logging.debug("Phrase received: " + phrase)
|
||||
if not is_triggered:
|
||||
return
|
||||
|
||||
logging.info(f'Listening... press [{RELEASE_BTN}] to stop.')
|
||||
recorder.start()
|
||||
|
||||
def stop_listening():
|
||||
if not recorder.is_recording():
|
||||
return
|
||||
phrase = phrase.strip()
|
||||
phrases.append(phrase)
|
||||
|
||||
audio = recorder.stop()
|
||||
recorder.save("audio.wav")
|
||||
logging.info("Stopped listening.")
|
||||
|
||||
text = whisper.transcribe(audio)
|
||||
insert_text_at_cursor(text)
|
||||
if TYPE_WHILE_TRIGGERED:
|
||||
text = (" " if len(phrases) > 1 else "") + phrase
|
||||
insert_text_at_cursor(text)
|
||||
|
||||
|
||||
recorder: AudioRecorder = AudioRecorder()
|
||||
whisper: WhisperSTT = WhisperSTT()
|
||||
|
||||
# Start the microphone server as normal user
|
||||
import subprocess
|
||||
|
||||
# Define the command to start the new script
|
||||
cmd = ["python3", "recognition_server.py"]
|
||||
|
||||
# Use the os module to run the command as a non-root user
|
||||
subprocess.Popen(["sudo", "-u", "max"] + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
logging.info("Waiting for server to start...")
|
||||
time.sleep(3)
|
||||
|
||||
|
||||
callback_for_keycombination(LISTENING_KEYCOMBINATION, start_listening, on_release=False)
|
||||
callback_for_keycombination(RELEASE_BTN, stop_listening, on_release=True)
|
||||
|
||||
# Run Keyboard client
|
||||
client = TcpClient(HOST, PORT)
|
||||
client.connect()
|
||||
|
||||
logging.info(f'Ready, wait for [{LISTENING_KEYCOMBINATION}]')
|
||||
wait_for_callbacks()
|
||||
callback_for_keycombination(TRIGGER_KEY, on_trigger)
|
||||
logging.info(f"Waiting for trigger key [{TRIGGER_KEY.upper()}] to toggle dictation")
|
||||
|
||||
try:
|
||||
while True:
|
||||
phrase = client.receive_text()
|
||||
on_phrase_receive(phrase)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Closing client...")
|
||||
finally:
|
||||
client.close()
|
||||
|
|
58
src/mic.py
58
src/mic.py
|
@ -1,58 +0,0 @@
|
|||
import numpy as np
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
|
||||
|
||||
class AudioRecorder:
|
||||
def __init__(self, chunk=1024, format=pyaudio.paInt16, channels=1, rate=48000):
|
||||
self.chunk = chunk
|
||||
self.format = format
|
||||
self.channels = channels
|
||||
self.rate = rate
|
||||
self.frames = []
|
||||
self.recording = False
|
||||
self.thread = None
|
||||
|
||||
def start(self):
|
||||
self.recording = True
|
||||
self.thread = threading.Thread(target=self._record)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
self.recording = False
|
||||
self.thread.join()
|
||||
|
||||
return self.frames
|
||||
|
||||
def is_recording(self) -> bool:
|
||||
return self.recording
|
||||
|
||||
def get_last_audio(self) -> np.ndarray:
|
||||
return self.frames
|
||||
|
||||
def _record(self):
|
||||
audio = pyaudio.PyAudio()
|
||||
stream = audio.open(
|
||||
format=self.format,
|
||||
channels=self.channels,
|
||||
rate=self.rate,
|
||||
input=True,
|
||||
frames_per_buffer=self.chunk,
|
||||
)
|
||||
self.frames = []
|
||||
while self.recording:
|
||||
data = stream.read(self.chunk)
|
||||
self.frames.append(np.frombuffer(data, dtype=np.float32))
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
audio.terminate()
|
||||
self.frames = np.concatenate(self.frames, axis=0)
|
||||
|
||||
def save(self, filename: str):
|
||||
waveFile = wave.open(filename, "wb")
|
||||
waveFile.setnchannels(self.channels)
|
||||
waveFile.setsampwidth(pyaudio.get_sample_size(self.format))
|
||||
waveFile.setframerate(self.rate)
|
||||
waveFile.writeframes(b"".join(self.frames))
|
||||
waveFile.close()
|
41
src/recognition_server.py
Normal file
41
src/recognition_server.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import logging
|
||||
from config import CODE_END, CODE_START, HOST, PORT
|
||||
from intercom.server import TcpServer
|
||||
from speech import AudioProcessor
|
||||
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
def phrase_callback(phrase: str):
|
||||
logging.debug("Recognised phrase: " + phrase)
|
||||
|
||||
if not server.has_connection():
|
||||
logging.debug("No client connected to send phrase to, stopping recognition")
|
||||
global audio_processor
|
||||
audio_processor.stop_recording()
|
||||
audio_processor = None
|
||||
return
|
||||
|
||||
logging.debug("Sending phrase to client...")
|
||||
server.send(phrase.encode("utf-8"))
|
||||
|
||||
def receive_callback(data: bytes):
|
||||
msg = data.decode("utf-8")
|
||||
logging.debug("Received: " + msg)
|
||||
|
||||
global audio_processor
|
||||
if msg == CODE_START:
|
||||
audio_processor = AudioProcessor(language="german", model="base", phrase_callback=phrase_callback)
|
||||
audio_processor.start_recording()
|
||||
elif msg == CODE_END:
|
||||
audio_processor.stop_recording()
|
||||
audio_processor = None
|
||||
|
||||
logging.info("Starting server...")
|
||||
|
||||
audio_processor = None
|
||||
|
||||
|
||||
server = TcpServer(host=HOST, port=PORT, receive_callback=receive_callback)
|
||||
server.start()
|
||||
|
||||
logging.info("Closing server...")
|
|
@ -1,33 +1,63 @@
|
|||
from typing import Any
|
||||
from numpy import floating, ndarray
|
||||
import whisper
|
||||
import speech_recognition as sr
|
||||
import logging
|
||||
|
||||
|
||||
class WhisperSTT:
|
||||
def __init__(self, *, model: str = "base") -> None:
|
||||
self.model = whisper.load_model(model)
|
||||
class AudioProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
language: str = "german",
|
||||
model: str = "base",
|
||||
phrase_callback: callable = None,
|
||||
) -> None:
|
||||
self.language: str = language
|
||||
self.model: str = model
|
||||
self.last_text: str = None
|
||||
self.phrase_callback: callable = phrase_callback
|
||||
self.is_listening: bool = False
|
||||
self.stop_listening_caller = None
|
||||
|
||||
self.recognizer: sr.Recognizer = sr.Recognizer()
|
||||
self.listener: sr.Microphone = sr.Microphone()
|
||||
with self.listener as source:
|
||||
self.recognizer.adjust_for_ambient_noise(
|
||||
source
|
||||
) # we only need to calibrate once, before we start listening
|
||||
|
||||
def get_last_text(self) -> str:
|
||||
return self.last_text
|
||||
|
||||
def transcribe(self, audio: ndarray[floating[Any]]) -> str:
|
||||
# load audio and pad/trim it to fit 30 seconds
|
||||
# audio = whisper.load_audio("audio.mp3")
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
def is_listening(self) -> bool:
|
||||
return self.is_listening
|
||||
|
||||
# make log-Mel spectrogram and move to the same device as the model
|
||||
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
|
||||
def start_recording(self) -> None:
|
||||
if self.is_listening:
|
||||
logging.warning("Listener is already open")
|
||||
return
|
||||
|
||||
# detect the spoken language
|
||||
_, probs = self.model.detect_language(mel)
|
||||
max(probs, key=probs.get)
|
||||
logging.debug(f"Detected language: {max(probs, key=probs.get)}")
|
||||
self.last_text = ""
|
||||
self.is_listening = True
|
||||
|
||||
# decode the audio
|
||||
options = whisper.DecodingOptions()
|
||||
result = whisper.decode(self.model, mel, options)
|
||||
self.last_text = result.text
|
||||
self.stop_listening_caller = self.recognizer.listen_in_background(
|
||||
self.listener, self.listening_callback
|
||||
)
|
||||
|
||||
return self.last_text
|
||||
def listening_callback(self, recognizer, audio):
|
||||
new_text = self.recognizer.recognize_whisper(
|
||||
audio, language=self.language, model=self.model
|
||||
)
|
||||
|
||||
if self.last_text is not None and self.last_text != "":
|
||||
self.last_text += " "
|
||||
self.last_text += new_text
|
||||
|
||||
if self.phrase_callback:
|
||||
self.phrase_callback(new_text)
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
if not self.is_listening:
|
||||
logging.warning("Listener is already closed")
|
||||
return
|
||||
|
||||
self.stop_listening_caller(wait_for_stop=False)
|
||||
self.is_listening = False
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
import keyboard
|
||||
from Xlib.display import Display
|
||||
from Xlib import X
|
||||
from Xlib.ext import record
|
||||
from Xlib.protocol import rq
|
||||
|
||||
disp = None
|
||||
key_triggers = []
|
||||
|
||||
def insert_text_at_cursor(text: str):
|
||||
if text is None or text == "":
|
||||
return
|
||||
|
||||
keyboard.write(text)
|
||||
|
||||
def callback_for_keycombination(keycode: int, callback: callable, *, on_release: bool = False):
|
||||
key_triggers.append({
|
||||
"keycode": keycode,
|
||||
"callback": callback,
|
||||
"on_release": on_release
|
||||
})
|
||||
|
||||
def handler(reply):
|
||||
""" This function is called when a xlib event is fired """
|
||||
data = reply.data
|
||||
while len(data):
|
||||
event, data = rq.EventField(None).parse_binary_value(data, disp.display, None, None)
|
||||
|
||||
# KEYCODE IS FOUND USERING event.detail
|
||||
# print(event.detail)
|
||||
|
||||
for trigger in key_triggers:
|
||||
if int(trigger["keycode"]) != event.detail:
|
||||
continue
|
||||
|
||||
if trigger["on_release"] and event.type == X.KeyRelease:
|
||||
trigger["callback"]()
|
||||
elif not trigger["on_release"] and event.type == X.KeyPress:
|
||||
trigger["callback"]()
|
||||
|
||||
|
||||
def wait_for_callbacks():
|
||||
# get current display
|
||||
global disp
|
||||
disp = Display()
|
||||
root = disp.screen().root
|
||||
|
||||
|
||||
# Monitor keypress and button press
|
||||
ctx = disp.record_create_context(
|
||||
0,
|
||||
[record.AllClients],
|
||||
[{
|
||||
'core_requests': (0, 0),
|
||||
'core_replies': (0, 0),
|
||||
'ext_requests': (0, 0, 0, 0),
|
||||
'ext_replies': (0, 0, 0, 0),
|
||||
'delivered_events': (0, 0),
|
||||
'device_events': (X.KeyReleaseMask, X.ButtonReleaseMask),
|
||||
'errors': (0, 0),
|
||||
'client_started': False,
|
||||
'client_died': False,
|
||||
}])
|
||||
disp.record_enable_context(ctx, handler)
|
||||
disp.record_free_context(ctx)
|
||||
|
||||
while True:
|
||||
root.display.next_event()
|
||||
|
||||
|
Loading…
Reference in a new issue