from __future__ import annotations
import ctypes
import logging
import os
import sys
import urllib.request
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, cast
from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found]
SYSTEM_PROMPT = (
"You are an amanuensis. Rewrite the user's dictated text into clean, grammatical prose.\n\n"
"Rules:\n"
"- Remove filler words (um/uh/like), false starts, and self-corrections.\n"
"- Keep meaning, facts, and intent.\n"
"- Preserve greetings and salutations (Hey, Hi, Hey there, Hello).\n"
"- Prefer concise sentences.\n"
"- Do not add new info.\n"
"- Output ONLY the cleaned text, no commentary.\n\n"
"Examples:\n"
" - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n"
" - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n"
)
MODEL_NAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
MODEL_URL = (
"https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/"
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
)
MODEL_DIR = Path.home() / ".cache" / "lel" / "models"
MODEL_PATH = MODEL_DIR / MODEL_NAME
class LlamaProcessor:
def __init__(self, verbose=False):
ensure_model()
if not verbose:
os.environ.setdefault("LLAMA_CPP_LOG_LEVEL", "ERROR")
os.environ.setdefault("LLAMA_LOG_LEVEL", "ERROR")
self._log_callback = _llama_log_callback_factory(verbose)
llama_cpp_lib.llama_log_set(cast(Any, self._log_callback), ctypes.c_void_p())
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX", "llama")
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX_SEPARATOR", "::")
self.client = Llama(
model_path=str(MODEL_PATH),
n_ctx=4096,
verbose=verbose,
)
def process(self, text: str, lang: str = "en") -> str:
user_content = f"{text}"
user_content = f"{lang}\n{user_content}"
response = self.client.create_chat_completion(
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
temperature=0.0,
)
return _extract_chat_text(response)
def ensure_model():
if MODEL_PATH.exists():
return MODEL_PATH
MODEL_DIR.mkdir(parents=True, exist_ok=True)
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
logging.info("downloading model: %s", MODEL_NAME)
try:
with urllib.request.urlopen(MODEL_URL) as resp:
total = resp.getheader("Content-Length")
total_size = int(total) if total else None
downloaded = 0
next_log = 0
with open(tmp_path, "wb") as handle:
while True:
chunk = resp.read(1024 * 1024)
if not chunk:
break
handle.write(chunk)
downloaded += len(chunk)
if total_size:
progress = downloaded / total_size
if progress >= next_log:
logging.info("model download %.0f%%", progress * 100)
next_log += 0.1
elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024):
logging.info("model download %d MB", downloaded // (1024 * 1024))
os.replace(tmp_path, MODEL_PATH)
except Exception:
try:
if tmp_path.exists():
tmp_path.unlink()
except Exception:
pass
raise
def _extract_chat_text(payload: Any) -> str:
if "choices" in payload and payload["choices"]:
choice = payload["choices"][0]
msg = choice.get("message") or {}
content = msg.get("content") or choice.get("text")
if content is not None:
return str(content).strip()
raise RuntimeError("unexpected response format")
def _llama_log_callback_factory(verbose: bool) -> Callable:
callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
def raw_callback(_level, text, _user_data):
message = text.decode("utf-8", errors="ignore") if text else ""
if "n_ctx_per_seq" in message:
return
if not verbose:
return
sys.stderr.write(f"llama::{message}")
if message and not message.endswith("\n"):
sys.stderr.write("\n")
return callback_t(raw_callback)