Filter llama context warning
This commit is contained in:
parent
42cf10cce3
commit
3ba696fb7d
1 changed files with 26 additions and 2 deletions
|
|
@ -1,12 +1,15 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ctypes
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, cast
|
||||||
|
|
||||||
from llama_cpp import Llama # type: ignore[import-not-found]
|
from llama_cpp import Llama, llama_cpp as llama_cpp_lib # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = (
|
SYSTEM_PROMPT = (
|
||||||
|
|
@ -42,6 +45,11 @@ class LLMConfig:
|
||||||
class LlamaProcessor:
|
class LlamaProcessor:
|
||||||
def __init__(self, cfg: LLMConfig):
|
def __init__(self, cfg: LLMConfig):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
if not cfg.verbose:
|
||||||
|
os.environ.setdefault("LLAMA_CPP_LOG_LEVEL", "ERROR")
|
||||||
|
os.environ.setdefault("LLAMA_LOG_LEVEL", "ERROR")
|
||||||
|
self._log_callback = _llama_log_callback_factory(cfg.verbose)
|
||||||
|
llama_cpp_lib.llama_log_set(cast(Any, self._log_callback), ctypes.c_void_p())
|
||||||
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX", "llama")
|
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX", "llama")
|
||||||
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX_SEPARATOR", "::")
|
os.environ.setdefault("LLAMA_CPP_LOG_PREFIX_SEPARATOR", "::")
|
||||||
self.client = Llama(
|
self.client = Llama(
|
||||||
|
|
@ -106,7 +114,7 @@ def ensure_model() -> Path:
|
||||||
return MODEL_PATH
|
return MODEL_PATH
|
||||||
|
|
||||||
|
|
||||||
def _extract_chat_text(payload: dict) -> str:
|
def _extract_chat_text(payload: Any) -> str:
|
||||||
if "choices" in payload and payload["choices"]:
|
if "choices" in payload and payload["choices"]:
|
||||||
choice = payload["choices"][0]
|
choice = payload["choices"][0]
|
||||||
msg = choice.get("message") or {}
|
msg = choice.get("message") or {}
|
||||||
|
|
@ -114,3 +122,19 @@ def _extract_chat_text(payload: dict) -> str:
|
||||||
if content is not None:
|
if content is not None:
|
||||||
return str(content).strip()
|
return str(content).strip()
|
||||||
raise RuntimeError("unexpected response format")
|
raise RuntimeError("unexpected response format")
|
||||||
|
|
||||||
|
|
||||||
|
def _llama_log_callback_factory(verbose: bool) -> Callable:
|
||||||
|
callback_t = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p)
|
||||||
|
|
||||||
|
def raw_callback(_level, text, _user_data):
|
||||||
|
message = text.decode("utf-8", errors="ignore") if text else ""
|
||||||
|
if "n_ctx_per_seq" in message:
|
||||||
|
return
|
||||||
|
if not verbose:
|
||||||
|
return
|
||||||
|
sys.stderr.write(f"llama::{message}")
|
||||||
|
if message and not message.endswith("\n"):
|
||||||
|
sys.stderr.write("\n")
|
||||||
|
|
||||||
|
return callback_t(raw_callback)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue