Files
audio-chat/main.py
noturum 1edfd5d62f Initial commit: audio-chat with fixes
- Created AGENTS.md with architecture documentation
- Fixed race conditions and async patterns
- Added conversation history to LLM prompts
- Fixed TTS audio shape handling
- Added buffer limits and graceful shutdown
- Fixed client.py with file sending support
- Removed duplicate requirements
- Added .gitignore
2026-05-01 13:01:06 +00:00

226 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import struct
import wave
import io
import numpy as np
import logging
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional
from contextlib import asynccontextmanager
from config import Config
from engine.stt import STTEngine
from engine.llm import LLMEngine
from engine.tts import TTSEngine
from audio.stream import AudioStreamBuffer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
config = Config()
class Message(BaseModel):
type: str
text: Optional[str] = None
class AudioSession:
def __init__(self):
self.stt = STTEngine()
self.llm = LLMEngine()
self.tts = TTSEngine()
self.audio_buffer = AudioStreamBuffer()
self.conversation_history = []
self.is_processing = False
self.processing_lock = asyncio.Lock()
def initialize_engines(self):
logger.info("Loading STT model...")
self.stt.initialize()
logger.info("Loading LLM model...")
self.llm.initialize()
logger.info("Loading TTS model...")
self.tts.initialize()
logger.info("All models loaded!")
def add_audio_chunk(self, chunk: bytes):
self.audio_buffer.add_chunk(chunk)
def recognize_speech(self) -> str:
ready_chunk = self.audio_buffer.get_ready_chunk(timeout=2.0)
if not ready_chunk:
return ""
loop = asyncio.get_running_loop()
text = loop.run_until_complete(
self._transcribe_async(ready_chunk)
)
return text
async def _transcribe_async(self, audio_bytes: bytes) -> str:
"""Async wrapper for STT transcription."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self.stt.transcribe, audio_bytes)
def get_llm_response(self, user_text: str) -> str:
system_prompt = "Ты голосовой ассистент. Отвечай кратко, как в разговорной речи. Не используй списки или форматирование."
if self.conversation_history:
context = "\n".join(self.conversation_history[-6:])
full_prompt = f"Предыдущий диалог:\n{context}\n\nПользователь: {user_text}\nТы:"
else:
full_prompt = f"Пользователь: {user_text}\nТы:"
loop = asyncio.get_running_loop()
response = loop.run_until_complete(
self._generate_async(full_prompt, system_prompt)
)
loop.close()
self.conversation_history.append(f"Пользователь: {user_text}")
self.conversation_history.append(f"Ассистент: {response}")
return response
async def _generate_async(self, prompt: str, system_prompt: str) -> str:
"""Async wrapper for LLM generation."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self.llm.generate, prompt, system_prompt)
def synthesize_speech(self, text: str) -> bytes:
loop = asyncio.get_running_loop()
audio = loop.run_until_complete(
self._synthesize_async(text, 24000)
)
loop.close()
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(24000)
wf.writeframes(audio.tobytes())
return wav_buffer.getvalue()
async def _synthesize_async(self, text: str, sample_rate: int) -> np.ndarray:
"""Async wrapper for TTS synthesis."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self.tts.synthesize, text, sample_rate)
def process_conversation(self, user_text: str):
logger.info(f"User said: {user_text}")
self.audio_buffer.start()
try:
response = self.get_llm_response(user_text)
logger.info(f"LLM response: {response}")
audio = self.synthesize_speech(response)
logger.info(f"Generated {len(audio)} bytes of audio")
return audio
except Exception as e:
logger.error(f"Error processing conversation: {e}")
return None
def reset(self):
self.conversation_history = []
self.audio_buffer = AudioStreamBuffer()
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting audio chat server...")
yield
logger.info("Shutting down...")
app = FastAPI(title="Audio Chat Server", version="1.0.0", lifespan=lifespan)
STATIC_DIR = Path(__file__).parent / "static"
@app.get("/")
async def serve_index():
return FileResponse(STATIC_DIR / "index.html")
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
async def process_audio_chunk(session: AudioSession, payload: bytes, websocket):
"""Handle an audio chunk with proper async handling."""
session.add_audio_chunk(payload)
ready = session.audio_buffer.get_ready_chunk(timeout=0.1)
if ready:
try:
text = session.stt.transcribe(ready)
if text:
await websocket.send_text(f"TEXT:{text}")
async with session.processing_lock:
if not session.is_processing:
session.is_processing = True
try:
audio = await asyncio.to_thread(
session.process_conversation, text
)
if audio:
try:
await websocket.send_bytes(b"O" + audio)
except Exception as e:
logger.error(f"Failed to send audio: {e}")
finally:
session.is_processing = False
except Exception as e:
logger.error(f"Error processing audio chunk: {e}")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
logger.info("Client connected")
session = AudioSession()
session.initialize_engines()
try:
while True:
data = await asyncio.wait_for(websocket.receive_bytes(), timeout=120.0)
msg_type = data[:1]
payload = data[1:]
if msg_type == b"A":
await process_audio_chunk(session, payload, websocket)
elif msg_type == b"R":
session.reset()
await websocket.send_text("RESET")
except WebSocketDisconnect:
logger.info("Client disconnected")
except asyncio.TimeoutError:
logger.info("Client timed out")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
await websocket.close()
@app.get("/health")
async def health_check():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=config.HOST, port=config.PORT)