Add AI post-processing prompt
This commit is contained in:
parent
b296491703
commit
cda89923ce
4 changed files with 269 additions and 1 deletions
247
internal/aiprocess/aiprocess.go
Normal file
247
internal/aiprocess/aiprocess.go
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
package aiprocess
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Provider string
|
||||
Model string
|
||||
Temperature float64
|
||||
SystemPromptFile string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
TimeoutSec int
|
||||
}
|
||||
|
||||
type Processor interface {
|
||||
Process(ctx context.Context, input string) (string, error)
|
||||
}
|
||||
|
||||
func New(cfg Config) (Processor, error) {
|
||||
if !cfg.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
provider := strings.ToLower(strings.TrimSpace(cfg.Provider))
|
||||
if provider == "" {
|
||||
return nil, errors.New("ai provider is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(cfg.Model) == "" {
|
||||
return nil, errors.New("ai model is required when enabled")
|
||||
}
|
||||
|
||||
systemPrompt, err := loadSystemPrompt(cfg.SystemPromptFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timeout := time.Duration(cfg.TimeoutSec) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 20 * time.Second
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case "ollama":
|
||||
base := strings.TrimRight(cfg.BaseURL, "/")
|
||||
if base == "" {
|
||||
base = "http://localhost:11434"
|
||||
}
|
||||
return &ollamaProcessor{
|
||||
client: &http.Client{Timeout: timeout},
|
||||
baseURL: base,
|
||||
model: cfg.Model,
|
||||
temperature: cfg.Temperature,
|
||||
system: systemPrompt,
|
||||
}, nil
|
||||
case "openai_compat":
|
||||
base := strings.TrimRight(cfg.BaseURL, "/")
|
||||
if base == "" {
|
||||
return nil, errors.New("ai base_url is required for openai_compat")
|
||||
}
|
||||
return &openAICompatProcessor{
|
||||
client: &http.Client{Timeout: timeout},
|
||||
baseURL: base,
|
||||
apiKey: cfg.APIKey,
|
||||
model: cfg.Model,
|
||||
temperature: cfg.Temperature,
|
||||
system: systemPrompt,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown ai provider %q", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func loadSystemPrompt(path string) (string, error) {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return strings.TrimSpace(defaultSystemPrompt), nil
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read system prompt file: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
//go:embed system_prompt.txt
|
||||
var defaultSystemPrompt string
|
||||
|
||||
type ollamaProcessor struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
model string
|
||||
temperature float64
|
||||
system string
|
||||
}
|
||||
|
||||
func (p *ollamaProcessor) Process(ctx context.Context, input string) (string, error) {
|
||||
reqBody := ollamaRequest{
|
||||
Model: p.model,
|
||||
Prompt: input,
|
||||
Stream: false,
|
||||
}
|
||||
if p.system != "" {
|
||||
reqBody.System = p.system
|
||||
}
|
||||
if p.temperature != 0 {
|
||||
reqBody.Options = &ollamaOptions{Temperature: p.temperature}
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/generate", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("ollama request failed: %s", readErrorBody(resp.Body))
|
||||
}
|
||||
|
||||
var out ollamaResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(out.Response), nil
|
||||
}
|
||||
|
||||
type ollamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Options *ollamaOptions `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaOptions struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
type openAICompatProcessor struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
temperature float64
|
||||
system string
|
||||
}
|
||||
|
||||
func (p *openAICompatProcessor) Process(ctx context.Context, input string) (string, error) {
|
||||
messages := []openAIMessage{
|
||||
{Role: "user", Content: input},
|
||||
}
|
||||
if p.system != "" {
|
||||
messages = append([]openAIMessage{{Role: "system", Content: p.system}}, messages...)
|
||||
}
|
||||
|
||||
reqBody := openAIRequest{
|
||||
Model: p.model,
|
||||
Messages: messages,
|
||||
Temperature: p.temperature,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/chat/completions", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if strings.TrimSpace(p.apiKey) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("openai_compat request failed: %s", readErrorBody(resp.Body))
|
||||
}
|
||||
|
||||
var out openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(out.Choices) == 0 {
|
||||
return "", errors.New("openai_compat response missing choices")
|
||||
}
|
||||
return strings.TrimSpace(out.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
Choices []openAIChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type openAIChoice struct {
|
||||
Message openAIMessage `json:"message"`
|
||||
}
|
||||
|
||||
func readErrorBody(r io.Reader) string {
|
||||
data, err := io.ReadAll(io.LimitReader(r, 64*1024))
|
||||
if err != nil {
|
||||
return "unknown error"
|
||||
}
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue