- Created AGENTS.md with architecture documentation - Fixed race conditions and async patterns - Added conversation history to LLM prompts - Fixed TTS audio shape handling - Added buffer limits and graceful shutdown - Fixed client.py with file sending support - Removed duplicate requirements - Added .gitignore
226 lines
7.1 KiB
Python
226 lines
7.1 KiB
Python
import asyncio
|
||
import struct
|
||
import wave
|
||
import io
|
||
import numpy as np
|
||
import logging
|
||
from pathlib import Path
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
from contextlib import asynccontextmanager
|
||
|
||
from config import Config
|
||
from engine.stt import STTEngine
|
||
from engine.llm import LLMEngine
|
||
from engine.tts import TTSEngine
|
||
from audio.stream import AudioStreamBuffer
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
config = Config()
|
||
|
||
|
||
class Message(BaseModel):
|
||
type: str
|
||
text: Optional[str] = None
|
||
|
||
|
||
class AudioSession:
|
||
def __init__(self):
|
||
self.stt = STTEngine()
|
||
self.llm = LLMEngine()
|
||
self.tts = TTSEngine()
|
||
self.audio_buffer = AudioStreamBuffer()
|
||
self.conversation_history = []
|
||
self.is_processing = False
|
||
self.processing_lock = asyncio.Lock()
|
||
|
||
def initialize_engines(self):
|
||
logger.info("Loading STT model...")
|
||
self.stt.initialize()
|
||
logger.info("Loading LLM model...")
|
||
self.llm.initialize()
|
||
logger.info("Loading TTS model...")
|
||
self.tts.initialize()
|
||
logger.info("All models loaded!")
|
||
|
||
def add_audio_chunk(self, chunk: bytes):
|
||
self.audio_buffer.add_chunk(chunk)
|
||
|
||
def recognize_speech(self) -> str:
|
||
ready_chunk = self.audio_buffer.get_ready_chunk(timeout=2.0)
|
||
if not ready_chunk:
|
||
return ""
|
||
|
||
loop = asyncio.get_running_loop()
|
||
text = loop.run_until_complete(
|
||
self._transcribe_async(ready_chunk)
|
||
)
|
||
return text
|
||
|
||
async def _transcribe_async(self, audio_bytes: bytes) -> str:
|
||
"""Async wrapper for STT transcription."""
|
||
loop = asyncio.get_running_loop()
|
||
return await loop.run_in_executor(None, self.stt.transcribe, audio_bytes)
|
||
|
||
def get_llm_response(self, user_text: str) -> str:
|
||
system_prompt = "Ты голосовой ассистент. Отвечай кратко, как в разговорной речи. Не используй списки или форматирование."
|
||
|
||
if self.conversation_history:
|
||
context = "\n".join(self.conversation_history[-6:])
|
||
full_prompt = f"Предыдущий диалог:\n{context}\n\nПользователь: {user_text}\nТы:"
|
||
else:
|
||
full_prompt = f"Пользователь: {user_text}\nТы:"
|
||
|
||
loop = asyncio.get_running_loop()
|
||
response = loop.run_until_complete(
|
||
self._generate_async(full_prompt, system_prompt)
|
||
)
|
||
loop.close()
|
||
|
||
self.conversation_history.append(f"Пользователь: {user_text}")
|
||
self.conversation_history.append(f"Ассистент: {response}")
|
||
|
||
return response
|
||
|
||
async def _generate_async(self, prompt: str, system_prompt: str) -> str:
|
||
"""Async wrapper for LLM generation."""
|
||
loop = asyncio.get_running_loop()
|
||
return await loop.run_in_executor(None, self.llm.generate, prompt, system_prompt)
|
||
|
||
def synthesize_speech(self, text: str) -> bytes:
|
||
loop = asyncio.get_running_loop()
|
||
audio = loop.run_until_complete(
|
||
self._synthesize_async(text, 24000)
|
||
)
|
||
loop.close()
|
||
|
||
wav_buffer = io.BytesIO()
|
||
with wave.open(wav_buffer, "wb") as wf:
|
||
wf.setnchannels(1)
|
||
wf.setsampwidth(2)
|
||
wf.setframerate(24000)
|
||
wf.writeframes(audio.tobytes())
|
||
|
||
return wav_buffer.getvalue()
|
||
|
||
async def _synthesize_async(self, text: str, sample_rate: int) -> np.ndarray:
|
||
"""Async wrapper for TTS synthesis."""
|
||
loop = asyncio.get_running_loop()
|
||
return await loop.run_in_executor(None, self.tts.synthesize, text, sample_rate)
|
||
|
||
def process_conversation(self, user_text: str):
|
||
logger.info(f"User said: {user_text}")
|
||
|
||
self.audio_buffer.start()
|
||
|
||
try:
|
||
response = self.get_llm_response(user_text)
|
||
logger.info(f"LLM response: {response}")
|
||
|
||
audio = self.synthesize_speech(response)
|
||
logger.info(f"Generated {len(audio)} bytes of audio")
|
||
|
||
return audio
|
||
except Exception as e:
|
||
logger.error(f"Error processing conversation: {e}")
|
||
return None
|
||
|
||
def reset(self):
|
||
self.conversation_history = []
|
||
self.audio_buffer = AudioStreamBuffer()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
logger.info("Starting audio chat server...")
|
||
yield
|
||
logger.info("Shutting down...")
|
||
|
||
|
||
app = FastAPI(title="Audio Chat Server", version="1.0.0", lifespan=lifespan)
|
||
|
||
STATIC_DIR = Path(__file__).parent / "static"
|
||
|
||
|
||
@app.get("/")
|
||
async def serve_index():
|
||
return FileResponse(STATIC_DIR / "index.html")
|
||
|
||
|
||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||
|
||
|
||
async def process_audio_chunk(session: AudioSession, payload: bytes, websocket):
|
||
"""Handle an audio chunk with proper async handling."""
|
||
session.add_audio_chunk(payload)
|
||
|
||
ready = session.audio_buffer.get_ready_chunk(timeout=0.1)
|
||
if ready:
|
||
try:
|
||
text = session.stt.transcribe(ready)
|
||
if text:
|
||
await websocket.send_text(f"TEXT:{text}")
|
||
|
||
async with session.processing_lock:
|
||
if not session.is_processing:
|
||
session.is_processing = True
|
||
try:
|
||
audio = await asyncio.to_thread(
|
||
session.process_conversation, text
|
||
)
|
||
if audio:
|
||
try:
|
||
await websocket.send_bytes(b"O" + audio)
|
||
except Exception as e:
|
||
logger.error(f"Failed to send audio: {e}")
|
||
finally:
|
||
session.is_processing = False
|
||
except Exception as e:
|
||
logger.error(f"Error processing audio chunk: {e}")
|
||
|
||
|
||
@app.websocket("/ws")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
await websocket.accept()
|
||
logger.info("Client connected")
|
||
|
||
session = AudioSession()
|
||
session.initialize_engines()
|
||
|
||
try:
|
||
while True:
|
||
data = await asyncio.wait_for(websocket.receive_bytes(), timeout=120.0)
|
||
|
||
msg_type = data[:1]
|
||
payload = data[1:]
|
||
|
||
if msg_type == b"A":
|
||
await process_audio_chunk(session, payload, websocket)
|
||
elif msg_type == b"R":
|
||
session.reset()
|
||
await websocket.send_text("RESET")
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info("Client disconnected")
|
||
except asyncio.TimeoutError:
|
||
logger.info("Client timed out")
|
||
except Exception as e:
|
||
logger.error(f"WebSocket error: {e}")
|
||
finally:
|
||
await websocket.close()
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
return {"status": "ok"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host=config.HOST, port=config.PORT)
|