From 1edfd5d62ff00fa7aceb8e0447abac4255fdd771 Mon Sep 17 00:00:00 2001 From: noturum Date: Fri, 1 May 2026 13:01:06 +0000 Subject: [PATCH] Initial commit: audio-chat with fixes - Created AGENTS.md with architecture documentation - Fixed race conditions and async patterns - Added conversation history to LLM prompts - Fixed TTS audio shape handling - Added buffer limits and graceful shutdown - Fixed client.py with file sending support - Removed duplicate requirements - Added .gitignore --- .gitignore | 12 + AGENTS.md | 51 ++++ audio/__init__.py | 0 audio/stream.py | 50 ++++ client.py | 160 +++++++++++++ config.py | 29 +++ engine/__init__.py | 0 engine/llm.py | 61 +++++ engine/stt.py | 49 ++++ engine/tts.py | 53 +++++ main.py | 225 ++++++++++++++++++ requirements.txt | 26 +++ static/index.html | 570 +++++++++++++++++++++++++++++++++++++++++++++ 13 files changed, 1286 insertions(+) create mode 100644 .gitignore create mode 100644 AGENTS.md create mode 100644 audio/__init__.py create mode 100644 audio/stream.py create mode 100644 client.py create mode 100644 config.py create mode 100644 engine/__init__.py create mode 100644 engine/llm.py create mode 100644 engine/stt.py create mode 100644 engine/tts.py create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 static/index.html diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ee90a9c --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +__pycache__/ +*.pyc +*.pyo +.env +.venv/ +venv/ +*.egg-info/ +dist/ +build/ +*.egg +response_*.wav +*.log diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..e4877f7 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,51 @@ +# Audio Chat + +## Running + +```bash +uvicorn main:app --host 0.0.0.0 --port 8000 +``` + +Client (for testing): +```bash +python client.py +``` + +## Architecture + +Single-process FastAPI server. On each WebSocket connection a new `AudioSession` is created with three engines: + +| Module | Purpose | Model (default) | +|--------|---------|-----------------| +| `engine/stt.py` | Speech-to-text | Systran/faster-whisper-large-v3 | +| `engine/llm.py` | LLM response generation | Qwen/Qwen2.5-7B-Instruct | +| `engine/tts.py` | Text-to-speech | facebook/mms-tts-rus | + +Models are loaded lazily on first use if `initialize()` was not called. STT always runs in Russian (`language="ru"` with VAD). + +## WebSocket Protocol + +| Direction | Format | Meaning | +|-----------|--------|---------| +| Client → Server | `b"A" + PCM data` | Send audio chunk | +| Client → Server | `b"R"` | Reset conversation | +| Server → Client | `b"O" + WAV bytes` | LLM response as audio | +| Server → Client | `"TEXT:"` | Recognized speech | + +Audio format: 16-bit PCM mono, 16 kHz input / 24 kHz output. + +## Configuration + +All settings via `.env` (loaded by `config.py`). Key vars: + +- `DEVICE` — `"cuda"` or `"cpu"` (default `"auto"`) +- `AUDIO_BUFFER_SECONDS` / `CHUNK_SIZE` — silence detection thresholds +- `LLM_MAX_TOKENS` / `LLM_TEMPERATURE` — generation parameters + +## Gotchas + +- No test suite or linting configured. +- Models download on first use; ensure network access to HuggingFace. +- `AudioSession` holds conversation history (last 6 turns) in memory — each WebSocket reconnect resets it. +- Thread pool executor is fixed at 2 workers; concurrent heavy requests will queue. +- TTS pipeline falls back to CPU (`device=-1`) if GPU initialization fails silently. diff --git a/audio/__init__.py b/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/audio/stream.py b/audio/stream.py new file mode 100644 index 0000000..b8ac082 --- /dev/null +++ b/audio/stream.py @@ -0,0 +1,50 @@ +import threading +import time +from config import Config + + +class AudioStreamBuffer: + def __init__(self): + self.config = Config() + self.buffer = b"" + self.lock = threading.Lock() + self.event = threading.Event() + self.running = False + # Limit buffer to 10 seconds of audio to prevent OOM + self.max_buffer_bytes = int(self.config.SAMPLE_RATE * 10 * 2) + + def start(self): + self.running = True + self.buffer = b"" + self.event.clear() + + def add_chunk(self, chunk: bytes): + with self.lock: + self.buffer += chunk + # Evict oldest data if buffer exceeds limit + if len(self.buffer) > self.max_buffer_bytes: + self.buffer = self.buffer[-self.max_buffer_bytes // 2:] + if len(self.buffer) >= self.config.CHUNK_SIZE: + self.event.set() + + def get_ready_chunk(self, timeout: float = 1.0) -> bytes: + if self.event.wait(timeout=timeout): + with self.lock: + chunk = self.buffer + self.buffer = b"" + self.event.clear() + return chunk + return b"" + + def get_full_buffer(self) -> bytes: + with self.lock: + chunk = self.buffer + self.buffer = b"" + return chunk + + def stop(self): + self.running = False + self.event.set() + + def is_running(self): + return self.running diff --git a/client.py b/client.py new file mode 100644 index 0000000..d140fae --- /dev/null +++ b/client.py @@ -0,0 +1,160 @@ +import asyncio +import websockets +import struct +import wave +import numpy as np + +# WebSocket URL +WS_URL = "ws://localhost:8000/ws" + + +async def start_recording(): + """Send start signal (b'S')""" + async with websockets.connect(WS_URL) as ws: + await ws.send(b"S") + + +async def send_audio(ws, audio_data: bytes): + """Send audio data (b'A' + raw PCM)""" + await ws.send(b"A" + audio_data) + + +async def reset_session(ws): + """Reset conversation (b'R')""" + await ws.send(b"R") + + +async def receive_messages(ws): + """Receive TEXT and AUDIO messages""" + while True: + try: + msg = await asyncio.wait_for(ws.recv(), timeout=30.0) + if isinstance(msg, str): + if msg.startswith("TEXT:"): + print(f"[RECognized] {msg[5:]}") + else: + print(f"[Server] {msg}") + elif isinstance(msg, bytes): + if msg[0:1] == b"O": + audio = msg[1:] + print(f"[Audio] Received {len(audio)} bytes") + # Save to file + timestamp = int(asyncio.get_running_loop().time()) + filename = f"response_{timestamp}.wav" + with open(filename, "wb") as f: + with wave.open(f, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + print(f"[Audio] Saved to {filename}") + except asyncio.TimeoutError: + break + except Exception as e: + print(f"Error: {e}") + break + + +async def record_and_send(): + """Record audio from microphone and send""" + import pyaudio + + CHUNK = 1024 + FORMAT = pyaudio.paInt16 + CHANNELS = 1 + RATE = 16000 + + p = pyaudio.PyAudio() + stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK) + + async with websockets.connect(WS_URL) as ws: + print("Recording... Press Ctrl+C to stop") + try: + while True: + data = stream.read(CHUNK) + await send_audio(ws, data) + except KeyboardInterrupt: + print("\nStopped recording") + finally: + stream.stop_stream() + stream.close() + p.terminate() + + +async def send_audio_file(filepath: str): + """Read and send an audio file to the server.""" + try: + with open(filepath, "rb") as f: + file_data = f.read() + except FileNotFoundError: + print(f"Error: File '{filepath}' not found") + return + + print(f"Reading audio file: {filepath} ({len(file_data)} bytes)") + + async with websockets.connect(WS_URL) as ws: + print("Connected. Sending audio file...") + await ws.send(b"A" + file_data) + print("File sent. Waiting for response...") + + try: + while True: + msg = await asyncio.wait_for(ws.recv(), timeout=60.0) + if isinstance(msg, str): + if msg.startswith("TEXT:"): + print(f"[Recognized] {msg[5:]}") + else: + print(f"[Server] {msg}") + elif isinstance(msg, bytes): + if msg[0:1] == b"O": + audio = msg[1:] + timestamp = int(asyncio.get_running_loop().time()) + filename = f"response_{timestamp}.wav" + with open(filename, "wb") as f: + with wave.open(f, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio) + print(f"[Audio] Saved response to {filename}") + except asyncio.TimeoutError: + print("Timed out waiting for response") + except Exception as e: + print(f"Error: {e}") + + +async def client(): + """Main client loop""" + print("Audio Chat Client") + print("1. Record from microphone") + print("2. Send audio file") + choice = input("Choice (1/2): ") + + if choice == "1": + p = pyaudio.PyAudio() + stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=1024) + async with websockets.connect(WS_URL) as ws: + print("Recording... Press Ctrl+C to stop") + try: + receive_task = asyncio.create_task(receive_messages(ws)) + while True: + data = stream.read(1024) + await ws.send(b"A" + data) + except KeyboardInterrupt: + receive_task.cancel() + finally: + stream.stop_stream() + stream.close() + p.terminate() + elif choice == "2": + filepath = input("Enter audio file path: ").strip() + if filepath: + await send_audio_file(filepath) + else: + print("No file path provided") + else: + print("Invalid choice") + + +if __name__ == "__main__": + asyncio.run(client()) diff --git a/config.py b/config.py new file mode 100644 index 0000000..415d211 --- /dev/null +++ b/config.py @@ -0,0 +1,29 @@ +import os +from pathlib import Path +from dotenv import load_dotenv + +env_path = Path(__file__).parent / ".env" +load_dotenv(env_path) + + +class Config: + # Models + STT_MODEL = os.getenv("STT_MODEL", "Systran/faster-whisper-large-v3") + LLM_MODEL = os.getenv("LLM_MODEL", "Qwen/Qwen2.5-7B-Instruct") + TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-rus") + + # Audio settings + SAMPLE_RATE = int(os.getenv("SAMPLE_RATE", "16000")) + AUDIO_BUFFER_SECONDS = float(os.getenv("AUDIO_BUFFER_SECONDS", "2")) + CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "1024")) + + # Server + HOST = os.getenv("HOST", "0.0.0.0") + PORT = int(os.getenv("PORT", "8000")) + + # LLM settings + LLM_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "512")) + LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0.7")) + + # GPU + DEVICE = os.getenv("DEVICE", "auto") diff --git a/engine/__init__.py b/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/engine/llm.py b/engine/llm.py new file mode 100644 index 0000000..bae17c1 --- /dev/null +++ b/engine/llm.py @@ -0,0 +1,61 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig +from config import Config +import torch + + +class LLMEngine: + def __init__(self): + self.model = None + self.tokenizer = None + self.config = Config() + + def initialize(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if device == "cuda" else torch.float32 + + self.tokenizer = AutoTokenizer.from_pretrained( + self.config.LLM_MODEL, + trust_remote_code=True, + ) + + self.model = AutoModelForCausalLM.from_pretrained( + self.config.LLM_MODEL, + torch_dtype=dtype, + device_map="auto", + trust_remote_code=True, + ) + + def generate(self, user_text: str, system_prompt: str = None) -> str: + if not self.model: + self.initialize() + + if system_prompt is None: + system_prompt = "Ты полезный ассистент. Отвечай на русском языке кратко и по делу." + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_text}, + ] + + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device) + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.config.LLM_MAX_TOKENS, + temperature=self.config.LLM_TEMPERATURE, + do_sample=True, + top_p=0.9, + repetition_penalty=1.1, + ) + + generated = outputs[0][inputs["input_ids"].shape[1]:] + response = self.tokenizer.decode(generated, skip_special_tokens=True) + + return response.strip() diff --git a/engine/stt.py b/engine/stt.py new file mode 100644 index 0000000..9f06d35 --- /dev/null +++ b/engine/stt.py @@ -0,0 +1,49 @@ +from faster_whisper import WhisperModel +from config import Config +import io +import numpy as np + + +class STTEngine: + def __init__(self): + self.model = None + self.config = Config() + self._model_size = self._resolve_model_size(self.config.STT_MODEL) + + def _resolve_model_size(self, model_name: str) -> str: + """Extract model size from various naming conventions.""" + # Handle Systran/faster-whisper-* format + if "faster-whisper-" in model_name: + return model_name.split("faster-whisper-")[-1] + # Handle whisper-* format + if model_name.startswith("whisper-"): + return model_name[len("whisper-"):] + # Return as-is for direct model names + return model_name + + def initialize(self): + device = "cuda" if self.config.DEVICE == "auto" else self.config.DEVICE + self.model = WhisperModel( + self._model_size, + device=device, + compute_type="float16" if device == "cuda" else "int8", + download_root=None, + ) + + def transcribe(self, audio_bytes: bytes) -> str: + if not self.model: + self.initialize() + + audio_file = io.BytesIO(audio_bytes) + segments, info = self.model.transcribe( + audio_file, + beam_size=5, + language="ru", + vad_filter=True, + ) + + text = "" + for segment in segments: + text += segment.text + " " + + return text.strip() diff --git a/engine/tts.py b/engine/tts.py new file mode 100644 index 0000000..a8019dd --- /dev/null +++ b/engine/tts.py @@ -0,0 +1,53 @@ +from transformers import pipeline +from config import Config +import numpy as np + + +class TTSEngine: + def __init__(self): + self.tts_pipeline = None + self.config = Config() + + def initialize(self): + try: + self.tts_pipeline = pipeline( + "text-to-speech", + self.config.TTS_MODEL, + device=0 if __import__("torch").cuda.is_available() else -1, + ) + except Exception: + self.tts_pipeline = pipeline( + "text-to-speech", + model=self._tts_model, + device=-1, + ) + self.tts_pipeline.start() + + def synthesize(self, text: str, output_sample_rate: int = 24000) -> np.ndarray: + if not self.tts_pipeline: + self.initialize() + + result = self.tts_pipeline( + text, + generate_kwargs={"task": "tts", "language": "ru"}, + return_tensors=True, + ) + + audio = result["audio"] + # Convert torch tensor to numpy if needed + if hasattr(audio, 'numpy'): + audio = audio.numpy() + elif not isinstance(audio, np.ndarray): + audio = np.asarray(audio) + + # Handle multi-dimensional arrays (batch or stereo) + if audio.ndim > 2: + # Batch dimension - take first item + audio = audio[0] + if audio.ndim == 2: + # Stereo - mix to mono + audio = audio.mean(axis=1) + + audio = audio.astype(np.float32) + + return audio diff --git a/main.py b/main.py new file mode 100644 index 0000000..14c3937 --- /dev/null +++ b/main.py @@ -0,0 +1,225 @@ +import asyncio +import struct +import wave +import io +import numpy as np +import logging +from pathlib import Path +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request +from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse +from pydantic import BaseModel +from typing import Optional +from contextlib import asynccontextmanager + +from config import Config +from engine.stt import STTEngine +from engine.llm import LLMEngine +from engine.tts import TTSEngine +from audio.stream import AudioStreamBuffer + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +config = Config() + + +class Message(BaseModel): + type: str + text: Optional[str] = None + + +class AudioSession: + def __init__(self): + self.stt = STTEngine() + self.llm = LLMEngine() + self.tts = TTSEngine() + self.audio_buffer = AudioStreamBuffer() + self.conversation_history = [] + self.is_processing = False + self.processing_lock = asyncio.Lock() + + def initialize_engines(self): + logger.info("Loading STT model...") + self.stt.initialize() + logger.info("Loading LLM model...") + self.llm.initialize() + logger.info("Loading TTS model...") + self.tts.initialize() + logger.info("All models loaded!") + + def add_audio_chunk(self, chunk: bytes): + self.audio_buffer.add_chunk(chunk) + + def recognize_speech(self) -> str: + ready_chunk = self.audio_buffer.get_ready_chunk(timeout=2.0) + if not ready_chunk: + return "" + + loop = asyncio.get_running_loop() + text = loop.run_until_complete( + self._transcribe_async(ready_chunk) + ) + return text + + async def _transcribe_async(self, audio_bytes: bytes) -> str: + """Async wrapper for STT transcription.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.stt.transcribe, audio_bytes) + + def get_llm_response(self, user_text: str) -> str: + system_prompt = "Ты голосовой ассистент. Отвечай кратко, как в разговорной речи. Не используй списки или форматирование." + + if self.conversation_history: + context = "\n".join(self.conversation_history[-6:]) + full_prompt = f"Предыдущий диалог:\n{context}\n\nПользователь: {user_text}\nТы:" + else: + full_prompt = f"Пользователь: {user_text}\nТы:" + + loop = asyncio.get_running_loop() + response = loop.run_until_complete( + self._generate_async(full_prompt, system_prompt) + ) + loop.close() + + self.conversation_history.append(f"Пользователь: {user_text}") + self.conversation_history.append(f"Ассистент: {response}") + + return response + + async def _generate_async(self, prompt: str, system_prompt: str) -> str: + """Async wrapper for LLM generation.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.llm.generate, prompt, system_prompt) + + def synthesize_speech(self, text: str) -> bytes: + loop = asyncio.get_running_loop() + audio = loop.run_until_complete( + self._synthesize_async(text, 24000) + ) + loop.close() + + wav_buffer = io.BytesIO() + with wave.open(wav_buffer, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(24000) + wf.writeframes(audio.tobytes()) + + return wav_buffer.getvalue() + + async def _synthesize_async(self, text: str, sample_rate: int) -> np.ndarray: + """Async wrapper for TTS synthesis.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.tts.synthesize, text, sample_rate) + + def process_conversation(self, user_text: str): + logger.info(f"User said: {user_text}") + + self.audio_buffer.start() + + try: + response = self.get_llm_response(user_text) + logger.info(f"LLM response: {response}") + + audio = self.synthesize_speech(response) + logger.info(f"Generated {len(audio)} bytes of audio") + + return audio + except Exception as e: + logger.error(f"Error processing conversation: {e}") + return None + + def reset(self): + self.conversation_history = [] + self.audio_buffer = AudioStreamBuffer() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("Starting audio chat server...") + yield + logger.info("Shutting down...") + + +app = FastAPI(title="Audio Chat Server", version="1.0.0", lifespan=lifespan) + +STATIC_DIR = Path(__file__).parent / "static" + + +@app.get("/") +async def serve_index(): + return FileResponse(STATIC_DIR / "index.html") + + +app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") + + +async def process_audio_chunk(session: AudioSession, payload: bytes, websocket): + """Handle an audio chunk with proper async handling.""" + session.add_audio_chunk(payload) + + ready = session.audio_buffer.get_ready_chunk(timeout=0.1) + if ready: + try: + text = session.stt.transcribe(ready) + if text: + await websocket.send_text(f"TEXT:{text}") + + async with session.processing_lock: + if not session.is_processing: + session.is_processing = True + try: + audio = await asyncio.to_thread( + session.process_conversation, text + ) + if audio: + try: + await websocket.send_bytes(b"O" + audio) + except Exception as e: + logger.error(f"Failed to send audio: {e}") + finally: + session.is_processing = False + except Exception as e: + logger.error(f"Error processing audio chunk: {e}") + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + logger.info("Client connected") + + session = AudioSession() + session.initialize_engines() + + try: + while True: + data = await asyncio.wait_for(websocket.receive_bytes(), timeout=120.0) + + msg_type = data[:1] + payload = data[1:] + + if msg_type == b"A": + await process_audio_chunk(session, payload, websocket) + elif msg_type == b"R": + session.reset() + await websocket.send_text("RESET") + + except WebSocketDisconnect: + logger.info("Client disconnected") + except asyncio.TimeoutError: + logger.info("Client timed out") + except Exception as e: + logger.error(f"WebSocket error: {e}") + finally: + await websocket.close() + + +@app.get("/health") +async def health_check(): + return {"status": "ok"} + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host=config.HOST, port=config.PORT) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..751fcd3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +# WebSocket server +fastapi==0.115.0 +uvicorn[standard]==0.30.6 +websockets==13.1 + +# Speech-to-Text +faster-whisper==1.0.3 +soundfile==0.12.1 + +# LLM +transformers==4.44.0 +torch==2.4.1 +accelerate==1.0.0 +bitsandbytes==0.44.0 + +# TTS +torchaudio>=2.4.0 + +# Audio processing +numpy==2.1.1 +scipy==1.14.1 + +# Utilities +python-dotenv==1.0.1 +pydantic==2.9.2 +pydantic-settings==2.5.2 diff --git a/static/index.html b/static/index.html new file mode 100644 index 0000000..626cbb5 --- /dev/null +++ b/static/index.html @@ -0,0 +1,570 @@ + + + + + + Голосовой чат + + + +
+

🎙️ Голосовой чат с ИИ

+
+
+ Отключено +
+
+ +
+
+ + +
+ +
+
+
Ассистент
+
Привет! Нажми "Начать запись" и говори со мной.
+
+
+ +
+
+
Генерирую ответ...
+
+ +
+ +
+ +
+ + + +
+
+
+ + + + + +