diff --git a/src/ai_process.py b/src/ai_process.py new file mode 100644 index 0000000..c2e9225 --- /dev/null +++ b/src/ai_process.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +import argparse +import json +import logging +import sys +from pathlib import Path + +from aiprocess import AIConfig, build_processor, load_system_prompt +from config import load, redacted_dict + + +def _read_text(arg_text: str) -> str: + if arg_text: + return arg_text + return sys.stdin.read() + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--config", default="", help="path to config.json") + parser.add_argument("text", nargs="?", default="", help="text to process (or stdin)") + args = parser.parse_args() + + logging.basicConfig(stream=sys.stderr, level=logging.INFO, format="ai: %(asctime)s %(message)s") + cfg = load(args.config) + + logging.info( + "config (%s):\n%s", + args.config or str(Path.home() / ".config" / "lel" / "config.json"), + json.dumps(redacted_dict(cfg), indent=2), + ) + + if not cfg.ai_enabled: + logging.warning("ai_enabled is false; proceeding anyway") + + prompt = load_system_prompt(cfg.ai_system_prompt_file) + logging.info("system prompt:\n%s", prompt) + + processor = build_processor( + AIConfig( + model=cfg.ai_model, + temperature=cfg.ai_temperature, + system_prompt_file=cfg.ai_system_prompt_file, + base_url=cfg.ai_base_url, + api_key=cfg.ai_api_key, + timeout_sec=cfg.ai_timeout_sec, + ) + ) + + text = _read_text(args.text).strip() + if not text: + logging.error("no input text provided") + return 2 + + output = processor.process(text) + sys.stdout.write(output) + if not output.endswith("\n"): + sys.stdout.write("\n") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/aiprocess.py b/src/aiprocess.py index b98fd11..e4af413 100644 --- a/src/aiprocess.py +++ b/src/aiprocess.py @@ -1,10 +1,10 @@ from __future__ import annotations +import json +import urllib.request from dataclasses import dataclass from pathlib import Path -import ollama - def load_system_prompt(path: str | None) -> str: if path: @@ -14,7 +14,6 @@ def load_system_prompt(path: str | None) -> str: @dataclass class AIConfig: - provider: str model: str temperature: float system_prompt_file: str @@ -23,24 +22,46 @@ class AIConfig: timeout_sec: int -class OllamaProcessor: +class GenericAPIProcessor: def __init__(self, cfg: AIConfig): self.cfg = cfg self.system = load_system_prompt(cfg.system_prompt_file) - self.client = ollama.Client(host=cfg.base_url) def process(self, text: str) -> str: - resp = self.client.generate( - model=self.cfg.model, - prompt=text, - system=self.system, - options={"temperature": self.cfg.temperature}, - ) - return (resp.get("response") or "").strip() + payload = { + "model": self.cfg.model, + "messages": [ + {"role": "system", "content": self.system}, + {"role": "user", "content": f"{text}"}, + ], + "temperature": self.cfg.temperature, + } + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request(self.cfg.base_url, data=data, method="POST") + req.add_header("Content-Type", "application/json") + if self.cfg.api_key: + req.add_header("Authorization", f"Bearer {self.cfg.api_key}") + + with urllib.request.urlopen(req, timeout=self.cfg.timeout_sec) as resp: + body = resp.read() + + out = json.loads(body.decode("utf-8")) + + if isinstance(out, dict): + if "output" in out: + return str(out["output"]).strip() + if "response" in out: + return str(out["response"]).strip() + if "choices" in out and out["choices"]: + choice = out["choices"][0] + msg = choice.get("message") or {} + content = msg.get("content") or choice.get("text") + if content is not None: + return str(content).strip() + raise RuntimeError("unexpected response format") -def build_processor(cfg: AIConfig) -> OllamaProcessor: - provider = cfg.provider.strip().lower() - if provider != "ollama": - raise ValueError(f"unsupported ai provider: {cfg.provider}") - return OllamaProcessor(cfg) +def build_processor(cfg: AIConfig) -> GenericAPIProcessor: + if not cfg.base_url: + raise ValueError("ai_base_url is required for generic API") + return GenericAPIProcessor(cfg)