Add AI processing CLI
This commit is contained in:
parent
9fc74ad27b
commit
2b6a565d85
2 changed files with 101 additions and 17 deletions
63
src/ai_process.py
Normal file
63
src/ai_process.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from aiprocess import AIConfig, build_processor, load_system_prompt
|
||||||
|
from config import load, redacted_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _read_text(arg_text: str) -> str:
|
||||||
|
if arg_text:
|
||||||
|
return arg_text
|
||||||
|
return sys.stdin.read()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", default="", help="path to config.json")
|
||||||
|
parser.add_argument("text", nargs="?", default="", help="text to process (or stdin)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(stream=sys.stderr, level=logging.INFO, format="ai: %(asctime)s %(message)s")
|
||||||
|
cfg = load(args.config)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"config (%s):\n%s",
|
||||||
|
args.config or str(Path.home() / ".config" / "lel" / "config.json"),
|
||||||
|
json.dumps(redacted_dict(cfg), indent=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not cfg.ai_enabled:
|
||||||
|
logging.warning("ai_enabled is false; proceeding anyway")
|
||||||
|
|
||||||
|
prompt = load_system_prompt(cfg.ai_system_prompt_file)
|
||||||
|
logging.info("system prompt:\n%s", prompt)
|
||||||
|
|
||||||
|
processor = build_processor(
|
||||||
|
AIConfig(
|
||||||
|
model=cfg.ai_model,
|
||||||
|
temperature=cfg.ai_temperature,
|
||||||
|
system_prompt_file=cfg.ai_system_prompt_file,
|
||||||
|
base_url=cfg.ai_base_url,
|
||||||
|
api_key=cfg.ai_api_key,
|
||||||
|
timeout_sec=cfg.ai_timeout_sec,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
text = _read_text(args.text).strip()
|
||||||
|
if not text:
|
||||||
|
logging.error("no input text provided")
|
||||||
|
return 2
|
||||||
|
|
||||||
|
output = processor.process(text)
|
||||||
|
sys.stdout.write(output)
|
||||||
|
if not output.endswith("\n"):
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import urllib.request
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import ollama
|
|
||||||
|
|
||||||
|
|
||||||
def load_system_prompt(path: str | None) -> str:
|
def load_system_prompt(path: str | None) -> str:
|
||||||
if path:
|
if path:
|
||||||
|
|
@ -14,7 +14,6 @@ def load_system_prompt(path: str | None) -> str:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AIConfig:
|
class AIConfig:
|
||||||
provider: str
|
|
||||||
model: str
|
model: str
|
||||||
temperature: float
|
temperature: float
|
||||||
system_prompt_file: str
|
system_prompt_file: str
|
||||||
|
|
@ -23,24 +22,46 @@ class AIConfig:
|
||||||
timeout_sec: int
|
timeout_sec: int
|
||||||
|
|
||||||
|
|
||||||
class OllamaProcessor:
|
class GenericAPIProcessor:
|
||||||
def __init__(self, cfg: AIConfig):
|
def __init__(self, cfg: AIConfig):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.system = load_system_prompt(cfg.system_prompt_file)
|
self.system = load_system_prompt(cfg.system_prompt_file)
|
||||||
self.client = ollama.Client(host=cfg.base_url)
|
|
||||||
|
|
||||||
def process(self, text: str) -> str:
|
def process(self, text: str) -> str:
|
||||||
resp = self.client.generate(
|
payload = {
|
||||||
model=self.cfg.model,
|
"model": self.cfg.model,
|
||||||
prompt=text,
|
"messages": [
|
||||||
system=self.system,
|
{"role": "system", "content": self.system},
|
||||||
options={"temperature": self.cfg.temperature},
|
{"role": "user", "content": f"<transcript>{text}</transcript>"},
|
||||||
)
|
],
|
||||||
return (resp.get("response") or "").strip()
|
"temperature": self.cfg.temperature,
|
||||||
|
}
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
req = urllib.request.Request(self.cfg.base_url, data=data, method="POST")
|
||||||
|
req.add_header("Content-Type", "application/json")
|
||||||
|
if self.cfg.api_key:
|
||||||
|
req.add_header("Authorization", f"Bearer {self.cfg.api_key}")
|
||||||
|
|
||||||
|
with urllib.request.urlopen(req, timeout=self.cfg.timeout_sec) as resp:
|
||||||
|
body = resp.read()
|
||||||
|
|
||||||
|
out = json.loads(body.decode("utf-8"))
|
||||||
|
|
||||||
|
if isinstance(out, dict):
|
||||||
|
if "output" in out:
|
||||||
|
return str(out["output"]).strip()
|
||||||
|
if "response" in out:
|
||||||
|
return str(out["response"]).strip()
|
||||||
|
if "choices" in out and out["choices"]:
|
||||||
|
choice = out["choices"][0]
|
||||||
|
msg = choice.get("message") or {}
|
||||||
|
content = msg.get("content") or choice.get("text")
|
||||||
|
if content is not None:
|
||||||
|
return str(content).strip()
|
||||||
|
raise RuntimeError("unexpected response format")
|
||||||
|
|
||||||
|
|
||||||
def build_processor(cfg: AIConfig) -> OllamaProcessor:
|
def build_processor(cfg: AIConfig) -> GenericAPIProcessor:
|
||||||
provider = cfg.provider.strip().lower()
|
if not cfg.base_url:
|
||||||
if provider != "ollama":
|
raise ValueError("ai_base_url is required for generic API")
|
||||||
raise ValueError(f"unsupported ai provider: {cfg.provider}")
|
return GenericAPIProcessor(cfg)
|
||||||
return OllamaProcessor(cfg)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue