Initial commit: audio-chat with fixes
- Created AGENTS.md with architecture documentation - Fixed race conditions and async patterns - Added conversation history to LLM prompts - Fixed TTS audio shape handling - Added buffer limits and graceful shutdown - Fixed client.py with file sending support - Removed duplicate requirements - Added .gitignore
This commit is contained in:
61
engine/llm.py
Normal file
61
engine/llm.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
||||
from config import Config
|
||||
import torch
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.config = Config()
|
||||
|
||||
def initialize(self):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.float16 if device == "cuda" else torch.float32
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.config.LLM_MODEL,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.config.LLM_MODEL,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
def generate(self, user_text: str, system_prompt: str = None) -> str:
|
||||
if not self.model:
|
||||
self.initialize()
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = "Ты полезный ассистент. Отвечай на русском языке кратко и по делу."
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_text},
|
||||
]
|
||||
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.config.LLM_MAX_TOKENS,
|
||||
temperature=self.config.LLM_TEMPERATURE,
|
||||
do_sample=True,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.1,
|
||||
)
|
||||
|
||||
generated = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
response = self.tokenizer.decode(generated, skip_special_tokens=True)
|
||||
|
||||
return response.strip()
|
||||
Reference in New Issue
Block a user