from __future__ import annotations import ctypes import logging import os import sys import urllib.request from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, cast from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found] SYSTEM_PROMPT = ( "You are an amanuensis. Rewrite the user's dictated text into clean, grammatical prose.\n\n" "Rules:\n" "- Remove filler words (um/uh/like), false starts, and self-corrections.\n" "- Keep meaning, facts, and intent.\n" "- Preserve greetings and salutations (Hey, Hi, Hey there, Hello).\n" "- Prefer concise sentences.\n" "- Do not add new info.\n" "- Output ONLY the cleaned text, no commentary.\n\n" "Examples:\n" " - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n" " - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n" ) MODEL_NAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf" MODEL_URL = ( "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/" "Llama-3.2-3B-Instruct-Q4_K_M.gguf" ) MODEL_DIR = Path.home() / ".cache" / "lel" / "models" MODEL_PATH = MODEL_DIR / MODEL_NAME class LlamaProcessor: def __init__(self, verbose=False): ensure_model() if not verbose: os.environ.setdefault("LLAMA_CPP_LOG_LEVEL", "ERROR") os.environ.setdefault("LLAMA_LOG_LEVEL", "ERROR") self._log_callback = _llama_log_callback_factory(verbose) llama_cpp_lib.llama_log_set(cast(Any, self._log_callback), ctypes.c_void_p()) os.environ.setdefault("LLAMA_CPP_LOG_PREFIX", "llama") os.environ.setdefault("LLAMA_CPP_LOG_PREFIX_SEPARATOR", "::") self.client = Llama( model_path=str(MODEL_PATH), n_ctx=4096, verbose=verbose, ) def process(self, text: str, lang: str = "en") -> str: user_content = f"{text}" user_content = f"{lang}\n{user_content}" response = self.client.create_chat_completion( messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}, ], temperature=0.0, ) return _extract_chat_text(response) def ensure_model(): if MODEL_PATH.exists(): return MODEL_PATH MODEL_DIR.mkdir(parents=True, exist_ok=True) tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp") logging.info("downloading model: %s", MODEL_NAME) try: with urllib.request.urlopen(MODEL_URL) as resp: total = resp.getheader("Content-Length") total_size = int(total) if total else None downloaded = 0 next_log = 0 with open(tmp_path, "wb") as handle: while True: chunk = resp.read(1024 * 1024) if not chunk: break handle.write(chunk) downloaded += len(chunk) if total_size: progress = downloaded / total_size if progress >= next_log: logging.info("model download %.0f%%", progress * 100) next_log += 0.1 elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024): logging.info("model download %d MB", downloaded // (1024 * 1024)) os.replace(tmp_path, MODEL_PATH) except Exception: try: if tmp_path.exists(): tmp_path.unlink() except Exception: pass raise def _extract_chat_text(payload: Any) -> str: if "choices" in payload and payload["choices"]: choice = payload["choices"][0] msg = choice.get("message") or {} content = msg.get("content") or choice.get("text") if content is not None: return str(content).strip() raise RuntimeError("unexpected response format") def _llama_log_callback_factory(verbose: bool) -> Callable: callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p) def raw_callback(_level, text, _user_data): message = text.decode("utf-8", errors="ignore") if text else "" if "n_ctx_per_seq" in message: return if not verbose: return sys.stderr.write(f"llama::{message}") if message and not message.endswith("\n"): sys.stderr.write("\n") return callback_t(raw_callback)