Use in-process Llama cleanup

This commit is contained in:
Thales Maciel 2026-02-24 12:46:11 -03:00
parent 548be49112
commit a83a843e1a
No known key found for this signature in database
GPG key ID: 33112E6833C34679
7 changed files with 235 additions and 116 deletions

View file

@ -1,12 +1,12 @@
from __future__ import annotations
import argparse
import json
import logging
import sys
import os
import urllib.request
from pathlib import Path
from dataclasses import dataclass
from pathlib import Path
from llama_cpp import Llama # type: ignore[import-not-found]
SYSTEM_PROMPT = (
@ -22,77 +22,95 @@ SYSTEM_PROMPT = (
" - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n"
)
MODEL_NAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
MODEL_URL = (
"https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/"
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
)
MODEL_DIR = Path.home() / ".cache" / "lel" / "models"
MODEL_PATH = MODEL_DIR / MODEL_NAME
LLM_LANGUAGE = "en"
@dataclass
class AIConfig:
model: str
base_url: str
api_key: str
timeout_sec: int
language_hint: str | None = None
wrap_transcript: bool = True
class LLMConfig:
model_path: Path
n_ctx: int = 4096
verbose: bool = False
class GenericAPIProcessor:
def __init__(self, cfg: AIConfig):
class LlamaProcessor:
def __init__(self, cfg: LLMConfig):
self.cfg = cfg
self.system = SYSTEM_PROMPT
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX", "llama")
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX_SEPARATOR", "::")
self.client = Llama(
model_path=str(cfg.model_path),
n_ctx=cfg.n_ctx,
verbose=cfg.verbose,
)
def process(self, text: str) -> str:
language = self.cfg.language_hint or ""
if self.cfg.wrap_transcript:
user_content = f"<transcript>{text}</transcript>"
else:
user_content = text
if language:
user_content = f"<language>{language}</language>\n{user_content}"
payload = {
"model": self.cfg.model,
"messages": [
{"role": "system", "content": self.system},
user_content = f"<transcript>{text}</transcript>"
if LLM_LANGUAGE:
user_content = f"<language>{LLM_LANGUAGE}</language>\n{user_content}"
response = self.client.create_chat_completion(
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
"temperature": 0.0,
}
data = json.dumps(payload).encode("utf-8")
url = _chat_completions_url(self.cfg.base_url)
req = urllib.request.Request(url, data=data, method="POST")
req.add_header("Content-Type", "application/json")
if self.cfg.api_key:
req.add_header("Authorization", f"Bearer {self.cfg.api_key}")
with urllib.request.urlopen(req, timeout=self.cfg.timeout_sec) as resp:
body = resp.read()
out = json.loads(body.decode("utf-8"))
if isinstance(out, dict):
if "output" in out:
return str(out["output"]).strip()
if "response" in out:
return str(out["response"]).strip()
if "choices" in out and out["choices"]:
choice = out["choices"][0]
msg = choice.get("message") or {}
content = msg.get("content") or choice.get("text")
if content is not None:
return str(content).strip()
raise RuntimeError("unexpected response format")
temperature=0.0,
)
return _extract_chat_text(response)
def build_processor(cfg: AIConfig) -> GenericAPIProcessor:
if not cfg.base_url:
raise ValueError("ai_base_url is required for generic API")
return GenericAPIProcessor(cfg)
def build_processor(verbose: bool = False) -> LlamaProcessor:
model_path = ensure_model()
return LlamaProcessor(LLMConfig(model_path=model_path, verbose=verbose))
def _chat_completions_url(base_url: str) -> str:
if not base_url:
return ""
trimmed = base_url.rstrip("/")
if "/v1/" in trimmed:
return trimmed
if trimmed.endswith("/v1"):
return trimmed + "/chat/completions"
return trimmed + "/v1/chat/completions"
def ensure_model() -> Path:
if MODEL_PATH.exists():
return MODEL_PATH
MODEL_DIR.mkdir(parents=True, exist_ok=True)
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
logging.info("downloading model: %s", MODEL_NAME)
try:
with urllib.request.urlopen(MODEL_URL) as resp:
total = resp.getheader("Content-Length")
total_size = int(total) if total else None
downloaded = 0
next_log = 0
with open(tmp_path, "wb") as handle:
while True:
chunk = resp.read(1024 * 1024)
if not chunk:
break
handle.write(chunk)
downloaded += len(chunk)
if total_size:
progress = downloaded / total_size
if progress >= next_log:
logging.info("model download %.0f%%", progress * 100)
next_log += 0.1
elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024):
logging.info("model download %d MB", downloaded // (1024 * 1024))
os.replace(tmp_path, MODEL_PATH)
except Exception:
try:
if tmp_path.exists():
tmp_path.unlink()
except Exception:
pass
raise
return MODEL_PATH
def _extract_chat_text(payload: dict) -> str:
if "choices" in payload and payload["choices"]:
choice = payload["choices"][0]
msg = choice.get("message") or {}
content = msg.get("content") or choice.get("text")
if content is not None:
return str(content).strip()
raise RuntimeError("unexpected response format")