126 lines
4.6 KiB
Python
126 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
import ctypes
|
|
import logging
|
|
import os
|
|
import sys
|
|
import urllib.request
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Callable, cast
|
|
|
|
from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found]
|
|
|
|
|
|
SYSTEM_PROMPT = (
|
|
"You are an amanuensis. Rewrite the user's dictated text into clean, grammatical prose.\n\n"
|
|
"Rules:\n"
|
|
"- Remove filler words (um/uh/like), false starts, and self-corrections.\n"
|
|
"- Keep meaning, facts, and intent.\n"
|
|
"- Preserve greetings and salutations (Hey, Hi, Hey there, Hello).\n"
|
|
"- Prefer concise sentences.\n"
|
|
"- Do not add new info.\n"
|
|
"- Output ONLY the cleaned text, no commentary.\n\n"
|
|
"Examples:\n"
|
|
" - \"Hey, schedule that for 5 PM, I mean 4 PM\" -> \"Hey, schedule that for 4 PM\"\n"
|
|
" - \"let's ask Bob, I mean Janice, let's ask Janice\" -> \"let's ask Janice\"\n"
|
|
)
|
|
|
|
MODEL_NAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
|
MODEL_URL = (
|
|
"https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/"
|
|
"Llama-3.2-3B-Instruct-Q4_K_M.gguf"
|
|
)
|
|
MODEL_DIR = Path.home() / ".cache" / "lel" / "models"
|
|
MODEL_PATH = MODEL_DIR / MODEL_NAME
|
|
|
|
|
|
class LlamaProcessor:
|
|
def __init__(self, verbose=False):
|
|
ensure_model()
|
|
if not verbose:
|
|
os.environ.setdefault("LLAMA_CPP_LOG_LEVEL", "ERROR")
|
|
os.environ.setdefault("LLAMA_LOG_LEVEL", "ERROR")
|
|
self._log_callback = _llama_log_callback_factory(verbose)
|
|
llama_cpp_lib.llama_log_set(cast(Any, self._log_callback), ctypes.c_void_p())
|
|
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX", "llama")
|
|
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX_SEPARATOR", "::")
|
|
self.client = Llama(
|
|
model_path=str(MODEL_PATH),
|
|
n_ctx=4096,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def process(self, text: str, lang: str = "en") -> str:
|
|
user_content = f"<transcript>{text}</transcript>"
|
|
user_content = f"<language>{lang}</language>\n{user_content}"
|
|
response = self.client.create_chat_completion(
|
|
messages=[
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": user_content},
|
|
],
|
|
temperature=0.0,
|
|
)
|
|
return _extract_chat_text(response)
|
|
|
|
|
|
def ensure_model():
|
|
if MODEL_PATH.exists():
|
|
return MODEL_PATH
|
|
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
|
tmp_path = MODEL_PATH.with_suffix(MODEL_PATH.suffix + ".tmp")
|
|
logging.info("downloading model: %s", MODEL_NAME)
|
|
try:
|
|
with urllib.request.urlopen(MODEL_URL) as resp:
|
|
total = resp.getheader("Content-Length")
|
|
total_size = int(total) if total else None
|
|
downloaded = 0
|
|
next_log = 0
|
|
with open(tmp_path, "wb") as handle:
|
|
while True:
|
|
chunk = resp.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
handle.write(chunk)
|
|
downloaded += len(chunk)
|
|
if total_size:
|
|
progress = downloaded / total_size
|
|
if progress >= next_log:
|
|
logging.info("model download %.0f%%", progress * 100)
|
|
next_log += 0.1
|
|
elif downloaded // (50 * 1024 * 1024) > (downloaded - len(chunk)) // (50 * 1024 * 1024):
|
|
logging.info("model download %d MB", downloaded // (1024 * 1024))
|
|
os.replace(tmp_path, MODEL_PATH)
|
|
except Exception:
|
|
try:
|
|
if tmp_path.exists():
|
|
tmp_path.unlink()
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
|
|
def _extract_chat_text(payload: Any) -> str:
|
|
if "choices" in payload and payload["choices"]:
|
|
choice = payload["choices"][0]
|
|
msg = choice.get("message") or {}
|
|
content = msg.get("content") or choice.get("text")
|
|
if content is not None:
|
|
return str(content).strip()
|
|
raise RuntimeError("unexpected response format")
|
|
|
|
|
|
def _llama_log_callback_factory(verbose: bool) -> Callable:
|
|
callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
|
|
|
|
def raw_callback(_level, text, _user_data):
|
|
message = text.decode("utf-8", errors="ignore") if text else ""
|
|
if "n_ctx_per_seq" in message:
|
|
return
|
|
if not verbose:
|
|
return
|
|
sys.stderr.write(f"llama::{message}")
|
|
if message and not message.endswith("\n"):
|
|
sys.stderr.write("\n")
|
|
|
|
return callback_t(raw_callback)
|