import asyncio import struct import wave import io import numpy as np import logging from pathlib import Path from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel from typing import Optional from contextlib import asynccontextmanager from config import Config from engine.stt import STTEngine from engine.llm import LLMEngine from engine.tts import TTSEngine from audio.stream import AudioStreamBuffer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) config = Config() class Message(BaseModel): type: str text: Optional[str] = None class AudioSession: def __init__(self): self.stt = STTEngine() self.llm = LLMEngine() self.tts = TTSEngine() self.audio_buffer = AudioStreamBuffer() self.conversation_history = [] self.is_processing = False self.processing_lock = asyncio.Lock() def initialize_engines(self): logger.info("Loading STT model...") self.stt.initialize() logger.info("Loading LLM model...") self.llm.initialize() logger.info("Loading TTS model...") self.tts.initialize() logger.info("All models loaded!") def add_audio_chunk(self, chunk: bytes): self.audio_buffer.add_chunk(chunk) def recognize_speech(self) -> str: ready_chunk = self.audio_buffer.get_ready_chunk(timeout=2.0) if not ready_chunk: return "" loop = asyncio.get_running_loop() text = loop.run_until_complete( self._transcribe_async(ready_chunk) ) return text async def _transcribe_async(self, audio_bytes: bytes) -> str: """Async wrapper for STT transcription.""" loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self.stt.transcribe, audio_bytes) def get_llm_response(self, user_text: str) -> str: system_prompt = "Ты голосовой ассистент. Отвечай кратко, как в разговорной речи. Не используй списки или форматирование." if self.conversation_history: context = "\n".join(self.conversation_history[-6:]) full_prompt = f"Предыдущий диалог:\n{context}\n\nПользователь: {user_text}\nТы:" else: full_prompt = f"Пользователь: {user_text}\nТы:" loop = asyncio.get_running_loop() response = loop.run_until_complete( self._generate_async(full_prompt, system_prompt) ) loop.close() self.conversation_history.append(f"Пользователь: {user_text}") self.conversation_history.append(f"Ассистент: {response}") return response async def _generate_async(self, prompt: str, system_prompt: str) -> str: """Async wrapper for LLM generation.""" loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self.llm.generate, prompt, system_prompt) def synthesize_speech(self, text: str) -> bytes: loop = asyncio.get_running_loop() audio = loop.run_until_complete( self._synthesize_async(text, 24000) ) loop.close() wav_buffer = io.BytesIO() with wave.open(wav_buffer, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(24000) wf.writeframes(audio.tobytes()) return wav_buffer.getvalue() async def _synthesize_async(self, text: str, sample_rate: int) -> np.ndarray: """Async wrapper for TTS synthesis.""" loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self.tts.synthesize, text, sample_rate) def process_conversation(self, user_text: str): logger.info(f"User said: {user_text}") self.audio_buffer.start() try: response = self.get_llm_response(user_text) logger.info(f"LLM response: {response}") audio = self.synthesize_speech(response) logger.info(f"Generated {len(audio)} bytes of audio") return audio except Exception as e: logger.error(f"Error processing conversation: {e}") return None def reset(self): self.conversation_history = [] self.audio_buffer = AudioStreamBuffer() @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Starting audio chat server...") yield logger.info("Shutting down...") app = FastAPI(title="Audio Chat Server", version="1.0.0", lifespan=lifespan) STATIC_DIR = Path(__file__).parent / "static" @app.get("/") async def serve_index(): return FileResponse(STATIC_DIR / "index.html") app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") async def process_audio_chunk(session: AudioSession, payload: bytes, websocket): """Handle an audio chunk with proper async handling.""" session.add_audio_chunk(payload) ready = session.audio_buffer.get_ready_chunk(timeout=0.1) if ready: try: text = session.stt.transcribe(ready) if text: await websocket.send_text(f"TEXT:{text}") async with session.processing_lock: if not session.is_processing: session.is_processing = True try: audio = await asyncio.to_thread( session.process_conversation, text ) if audio: try: await websocket.send_bytes(b"O" + audio) except Exception as e: logger.error(f"Failed to send audio: {e}") finally: session.is_processing = False except Exception as e: logger.error(f"Error processing audio chunk: {e}") @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() logger.info("Client connected") session = AudioSession() session.initialize_engines() try: while True: data = await asyncio.wait_for(websocket.receive_bytes(), timeout=120.0) msg_type = data[:1] payload = data[1:] if msg_type == b"A": await process_audio_chunk(session, payload, websocket) elif msg_type == b"R": session.reset() await websocket.send_text("RESET") except WebSocketDisconnect: logger.info("Client disconnected") except asyncio.TimeoutError: logger.info("Client timed out") except Exception as e: logger.error(f"WebSocket error: {e}") finally: await websocket.close() @app.get("/health") async def health_check(): return {"status": "ok"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host=config.HOST, port=config.PORT)