Add AI processing CLI
This commit is contained in:
parent
9fc74ad27b
commit
2b6a565d85
2 changed files with 101 additions and 17 deletions
|
|
@ -1,10 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import ollama
|
||||
|
||||
|
||||
def load_system_prompt(path: str | None) -> str:
|
||||
if path:
|
||||
|
|
@ -14,7 +14,6 @@ def load_system_prompt(path: str | None) -> str:
|
|||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
provider: str
|
||||
model: str
|
||||
temperature: float
|
||||
system_prompt_file: str
|
||||
|
|
@ -23,24 +22,46 @@ class AIConfig:
|
|||
timeout_sec: int
|
||||
|
||||
|
||||
class OllamaProcessor:
|
||||
class GenericAPIProcessor:
|
||||
def __init__(self, cfg: AIConfig):
|
||||
self.cfg = cfg
|
||||
self.system = load_system_prompt(cfg.system_prompt_file)
|
||||
self.client = ollama.Client(host=cfg.base_url)
|
||||
|
||||
def process(self, text: str) -> str:
|
||||
resp = self.client.generate(
|
||||
model=self.cfg.model,
|
||||
prompt=text,
|
||||
system=self.system,
|
||||
options={"temperature": self.cfg.temperature},
|
||||
)
|
||||
return (resp.get("response") or "").strip()
|
||||
payload = {
|
||||
"model": self.cfg.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": self.system},
|
||||
{"role": "user", "content": f"<transcript>{text}</transcript>"},
|
||||
],
|
||||
"temperature": self.cfg.temperature,
|
||||
}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = urllib.request.Request(self.cfg.base_url, data=data, method="POST")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
if self.cfg.api_key:
|
||||
req.add_header("Authorization", f"Bearer {self.cfg.api_key}")
|
||||
|
||||
with urllib.request.urlopen(req, timeout=self.cfg.timeout_sec) as resp:
|
||||
body = resp.read()
|
||||
|
||||
out = json.loads(body.decode("utf-8"))
|
||||
|
||||
if isinstance(out, dict):
|
||||
if "output" in out:
|
||||
return str(out["output"]).strip()
|
||||
if "response" in out:
|
||||
return str(out["response"]).strip()
|
||||
if "choices" in out and out["choices"]:
|
||||
choice = out["choices"][0]
|
||||
msg = choice.get("message") or {}
|
||||
content = msg.get("content") or choice.get("text")
|
||||
if content is not None:
|
||||
return str(content).strip()
|
||||
raise RuntimeError("unexpected response format")
|
||||
|
||||
|
||||
def build_processor(cfg: AIConfig) -> OllamaProcessor:
|
||||
provider = cfg.provider.strip().lower()
|
||||
if provider != "ollama":
|
||||
raise ValueError(f"unsupported ai provider: {cfg.provider}")
|
||||
return OllamaProcessor(cfg)
|
||||
def build_processor(cfg: AIConfig) -> GenericAPIProcessor:
|
||||
if not cfg.base_url:
|
||||
raise ValueError("ai_base_url is required for generic API")
|
||||
return GenericAPIProcessor(cfg)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue