Use in-process Llama cleanup
This commit is contained in:
parent
548be49112
commit
a83a843e1a
7 changed files with 235 additions and 116 deletions
146
src/aiprocess.py
146
src/aiprocess.py
|
|
@ -1,12 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from llama_cpp import Llama # type: ignore[import-not-found]
|
||||
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
|
|
@ -22,77 +22,95 @@ SYSTEM_PROMPT = (
|
|||
" - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n"
|
||||
)
|
||||
|
||||
MODEL_NAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
||||
MODEL_URL = (
|
||||
"https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/"
|
||||
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
||||
)
|
||||
MODEL_DIR = Path.home() / ".cache" / "lel" / "models"
|
||||
MODEL_PATH = MODEL_DIR / MODEL_NAME
|
||||
LLM_LANGUAGE = "en"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
model: str
|
||||
base_url: str
|
||||
api_key: str
|
||||
timeout_sec: int
|
||||
language_hint: str | None = None
|
||||
wrap_transcript: bool = True
|
||||
class LLMConfig:
|
||||
model_path: Path
|
||||
n_ctx: int = 4096
|
||||
verbose: bool = False
|
||||
|
||||
|
||||
class GenericAPIProcessor:
|
||||
def __init__(self, cfg: AIConfig):
|
||||
class LlamaProcessor:
|
||||
def __init__(self, cfg: LLMConfig):
|
||||
self.cfg = cfg
|
||||
self.system = SYSTEM_PROMPT
|
||||
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX", "llama")
|
||||
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX_SEPARATOR", "::")
|
||||
self.client = Llama(
|
||||
model_path=str(cfg.model_path),
|
||||
n_ctx=cfg.n_ctx,
|
||||
verbose=cfg.verbose,
|
||||
)
|
||||
|
||||
def process(self, text: str) -> str:
|
||||
language = self.cfg.language_hint or ""
|
||||
if self.cfg.wrap_transcript:
|
||||
user_content = f"<transcript>{text}</transcript>"
|
||||
else:
|
||||
user_content = text
|
||||
if language:
|
||||
user_content = f"<language>{language}</language>\n{user_content}"
|
||||
payload = {
|
||||
"model": self.cfg.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": self.system},
|
||||
user_content = f"<transcript>{text}</transcript>"
|
||||
if LLM_LANGUAGE:
|
||||
user_content = f"<language>{LLM_LANGUAGE}</language>\n{user_content}"
|
||||
response = self.client.create_chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
"temperature": 0.0,
|
||||
}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
url = _chat_completions_url(self.cfg.base_url)
|
||||
req = urllib.request.Request(url, data=data, method="POST")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
if self.cfg.api_key:
|
||||
req.add_header("Authorization", f"Bearer {self.cfg.api_key}")
|
||||
|
||||
with urllib.request.urlopen(req, timeout=self.cfg.timeout_sec) as resp:
|
||||
body = resp.read()
|
||||
|
||||
out = json.loads(body.decode("utf-8"))
|
||||
|
||||
if isinstance(out, dict):
|
||||
if "output" in out:
|
||||
return str(out["output"]).strip()
|
||||
if "response" in out:
|
||||
return str(out["response"]).strip()
|
||||
if "choices" in out and out["choices"]:
|
||||
choice = out["choices"][0]
|
||||
msg = choice.get("message") or {}
|
||||
content = msg.get("content") or choice.get("text")
|
||||
if content is not None:
|
||||
return str(content).strip()
|
||||
raise RuntimeError("unexpected response format")
|
||||
temperature=0.0,
|
||||
)
|
||||
return _extract_chat_text(response)
|
||||
|
||||
|
||||
def build_processor(cfg: AIConfig) -> GenericAPIProcessor:
|
||||
if not cfg.base_url:
|
||||
raise ValueError("ai_base_url is required for generic API")
|
||||
return GenericAPIProcessor(cfg)
|
||||
def build_processor(verbose: bool = False) -> LlamaProcessor:
|
||||
model_path = ensure_model()
|
||||
return LlamaProcessor(LLMConfig(model_path=model_path, verbose=verbose))
|
||||
|
||||
|
||||
def _chat_completions_url(base_url: str) -> str:
|
||||
if not base_url:
|
||||
return ""
|
||||
trimmed = base_url.rstrip("/")
|
||||
if "/v1/" in trimmed:
|
||||
return trimmed
|
||||
if trimmed.endswith("/v1"):
|
||||
return trimmed + "/chat/completions"
|
||||
return trimmed + "/v1/chat/completions"
|
||||
def ensure_model() -> Path:
|
||||
if MODEL_PATH.exists():
|
||||
return MODEL_PATH
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
|
||||
logging.info("downloading model: %s", MODEL_NAME)
|
||||
try:
|
||||
with urllib.request.urlopen(MODEL_URL) as resp:
|
||||
total = resp.getheader("Content-Length")
|
||||
total_size = int(total) if total else None
|
||||
downloaded = 0
|
||||
next_log = 0
|
||||
with open(tmp_path, "wb") as handle:
|
||||
while True:
|
||||
chunk = resp.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
handle.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total_size:
|
||||
progress = downloaded / total_size
|
||||
if progress >= next_log:
|
||||
logging.info("model download %.0f%%", progress * 100)
|
||||
next_log += 0.1
|
||||
elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024):
|
||||
logging.info("model download %d MB", downloaded // (1024 * 1024))
|
||||
os.replace(tmp_path, MODEL_PATH)
|
||||
except Exception:
|
||||
try:
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return MODEL_PATH
|
||||
|
||||
|
||||
def _extract_chat_text(payload: dict) -> str:
|
||||
if "choices" in payload and payload["choices"]:
|
||||
choice = payload["choices"][0]
|
||||
msg = choice.get("message") or {}
|
||||
content = msg.get("content") or choice.get("text")
|
||||
if content is not None:
|
||||
return str(content).strip()
|
||||
raise RuntimeError("unexpected response format")
|
||||
|
|
|
|||
|
|
@ -9,13 +9,6 @@ class Config:
|
|||
recording: dict = field(default_factory=lambda: {"input": ""})
|
||||
stt: dict = field(default_factory=lambda: {"model": "base", "device": "cpu"})
|
||||
injection: dict = field(default_factory=lambda: {"backend": "clipboard"})
|
||||
ai_cleanup: dict = field(
|
||||
default_factory=lambda: {
|
||||
"model": "llama3.2:3b",
|
||||
"base_url": "http://localhost:11434",
|
||||
"api_key": "",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def default_path() -> Path:
|
||||
|
|
@ -27,7 +20,7 @@ def load(path: str | None) -> Config:
|
|||
p = Path(path) if path else default_path()
|
||||
if p.exists():
|
||||
data = json.loads(p.read_text(encoding="utf-8"))
|
||||
if any(k in data for k in ("daemon", "recording", "stt", "injection", "ai_cleanup")):
|
||||
if any(k in data for k in ("daemon", "recording", "stt", "injection")):
|
||||
for k, v in data.items():
|
||||
if hasattr(cfg, k):
|
||||
setattr(cfg, k, v)
|
||||
|
|
@ -37,9 +30,6 @@ def load(path: str | None) -> Config:
|
|||
cfg.stt["model"] = data.get("whisper_model", cfg.stt["model"])
|
||||
cfg.stt["device"] = data.get("whisper_device", cfg.stt["device"])
|
||||
cfg.injection["backend"] = data.get("injection_backend", cfg.injection["backend"])
|
||||
cfg.ai_cleanup["model"] = data.get("ai_model", cfg.ai_cleanup["model"])
|
||||
cfg.ai_cleanup["base_url"] = data.get("ai_base_url", cfg.ai_cleanup["base_url"])
|
||||
cfg.ai_cleanup["api_key"] = data.get("ai_api_key", cfg.ai_cleanup["api_key"])
|
||||
|
||||
if not isinstance(cfg.daemon, dict):
|
||||
cfg.daemon = {"hotkey": "Cmd+m"}
|
||||
|
|
@ -49,22 +39,12 @@ def load(path: str | None) -> Config:
|
|||
cfg.stt = {"model": "base", "device": "cpu"}
|
||||
if not isinstance(cfg.injection, dict):
|
||||
cfg.injection = {"backend": "clipboard"}
|
||||
if not isinstance(cfg.ai_cleanup, dict):
|
||||
cfg.ai_cleanup = {
|
||||
"model": "llama3.2:3b",
|
||||
"base_url": "http://localhost:11434",
|
||||
"api_key": "",
|
||||
}
|
||||
validate(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def redacted_dict(cfg: Config) -> dict:
|
||||
d = cfg.__dict__.copy()
|
||||
if isinstance(d.get("ai_cleanup"), dict):
|
||||
d["ai_cleanup"] = d["ai_cleanup"].copy()
|
||||
d["ai_cleanup"]["api_key"] = ""
|
||||
return d
|
||||
return cfg.__dict__.copy()
|
||||
|
||||
|
||||
def validate(cfg: Config) -> None:
|
||||
|
|
|
|||
40
src/leld.py
40
src/leld.py
|
|
@ -15,7 +15,7 @@ from faster_whisper import WhisperModel
|
|||
|
||||
from config import Config, load, redacted_dict
|
||||
from recorder import start_recording, stop_recording
|
||||
from aiprocess import AIConfig, build_processor
|
||||
from aiprocess import build_processor
|
||||
from inject import inject
|
||||
from x11_hotkey import listen
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ def _compute_type(device: str) -> str:
|
|||
|
||||
|
||||
class Daemon:
|
||||
def __init__(self, cfg: Config):
|
||||
def __init__(self, cfg: Config, *, llama_verbose: bool = False):
|
||||
self.cfg = cfg
|
||||
self.lock = threading.Lock()
|
||||
self.state = State.IDLE
|
||||
|
|
@ -63,6 +63,7 @@ class Daemon:
|
|||
device=cfg.stt.get("device", "cpu"),
|
||||
compute_type=_compute_type(cfg.stt.get("device", "cpu")),
|
||||
)
|
||||
self.ai_processor = build_processor(verbose=llama_verbose)
|
||||
self.indicator = None
|
||||
self.status_icon = None
|
||||
if AppIndicator3 is not None:
|
||||
|
|
@ -183,25 +184,13 @@ class Daemon:
|
|||
|
||||
logging.info("stt: %s", text)
|
||||
|
||||
ai_model = (self.cfg.ai_cleanup.get("model") or "").strip()
|
||||
ai_base_url = (self.cfg.ai_cleanup.get("base_url") or "").strip()
|
||||
if ai_model and ai_base_url:
|
||||
self.set_state(State.PROCESSING)
|
||||
logging.info("ai processing started")
|
||||
try:
|
||||
processor = build_processor(
|
||||
AIConfig(
|
||||
model=ai_model,
|
||||
base_url=ai_base_url,
|
||||
api_key=self.cfg.ai_cleanup.get("api_key", ""),
|
||||
timeout_sec=25,
|
||||
language_hint="en",
|
||||
)
|
||||
)
|
||||
ai_input = text
|
||||
text = processor.process(ai_input) or text
|
||||
except Exception as exc:
|
||||
logging.error("ai process failed: %s", exc)
|
||||
self.set_state(State.PROCESSING)
|
||||
logging.info("ai processing started")
|
||||
try:
|
||||
ai_input = text
|
||||
text = self.ai_processor.process(ai_input) or text
|
||||
except Exception as exc:
|
||||
logging.error("ai process failed: %s", exc)
|
||||
|
||||
logging.info("processed: %s", text)
|
||||
|
||||
|
|
@ -286,6 +275,7 @@ def main():
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", default="", help="path to config.json")
|
||||
parser.add_argument("--dry-run", action="store_true", help="log hotkey only")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="enable verbose logs")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
|
|
@ -299,7 +289,13 @@ def main():
|
|||
logging.info("ready (hotkey: %s)", cfg.daemon.get("hotkey", ""))
|
||||
logging.info("config (%s):\n%s", args.config or str(Path.home() / ".config" / "lel" / "config.json"), json.dumps(redacted_dict(cfg), indent=2))
|
||||
|
||||
daemon = Daemon(cfg)
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
try:
|
||||
daemon = Daemon(cfg, llama_verbose=args.verbose)
|
||||
except Exception as exc:
|
||||
logging.error("startup failed: %s", exc)
|
||||
raise SystemExit(1)
|
||||
|
||||
def handle_signal(_sig, _frame):
|
||||
logging.info("signal received, shutting down")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue