247 lines
5.6 KiB
Go
247 lines
5.6 KiB
Go
package aiprocess
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "embed"
|
|
)
|
|
|
|
type Config struct {
|
|
Enabled bool
|
|
Provider string
|
|
Model string
|
|
Temperature float64
|
|
SystemPromptFile string
|
|
BaseURL string
|
|
APIKey string
|
|
TimeoutSec int
|
|
}
|
|
|
|
type Processor interface {
|
|
Process(ctx context.Context, input string) (string, error)
|
|
}
|
|
|
|
func New(cfg Config) (Processor, error) {
|
|
if !cfg.Enabled {
|
|
return nil, nil
|
|
}
|
|
|
|
provider := strings.ToLower(strings.TrimSpace(cfg.Provider))
|
|
if provider == "" {
|
|
return nil, errors.New("ai provider is required when enabled")
|
|
}
|
|
if strings.TrimSpace(cfg.Model) == "" {
|
|
return nil, errors.New("ai model is required when enabled")
|
|
}
|
|
|
|
systemPrompt, err := loadSystemPrompt(cfg.SystemPromptFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
timeout := time.Duration(cfg.TimeoutSec) * time.Second
|
|
if timeout <= 0 {
|
|
timeout = 20 * time.Second
|
|
}
|
|
|
|
switch provider {
|
|
case "ollama":
|
|
base := strings.TrimRight(cfg.BaseURL, "/")
|
|
if base == "" {
|
|
base = "http://localhost:11434"
|
|
}
|
|
return &ollamaProcessor{
|
|
client: &http.Client{Timeout: timeout},
|
|
baseURL: base,
|
|
model: cfg.Model,
|
|
temperature: cfg.Temperature,
|
|
system: systemPrompt,
|
|
}, nil
|
|
case "openai_compat":
|
|
base := strings.TrimRight(cfg.BaseURL, "/")
|
|
if base == "" {
|
|
return nil, errors.New("ai base_url is required for openai_compat")
|
|
}
|
|
return &openAICompatProcessor{
|
|
client: &http.Client{Timeout: timeout},
|
|
baseURL: base,
|
|
apiKey: cfg.APIKey,
|
|
model: cfg.Model,
|
|
temperature: cfg.Temperature,
|
|
system: systemPrompt,
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unknown ai provider %q", provider)
|
|
}
|
|
}
|
|
|
|
func loadSystemPrompt(path string) (string, error) {
|
|
if strings.TrimSpace(path) == "" {
|
|
return strings.TrimSpace(defaultSystemPrompt), nil
|
|
}
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read system prompt file: %w", err)
|
|
}
|
|
return strings.TrimSpace(string(data)), nil
|
|
}
|
|
|
|
//go:embed system_prompt.txt
|
|
var defaultSystemPrompt string
|
|
|
|
type ollamaProcessor struct {
|
|
client *http.Client
|
|
baseURL string
|
|
model string
|
|
temperature float64
|
|
system string
|
|
}
|
|
|
|
func (p *ollamaProcessor) Process(ctx context.Context, input string) (string, error) {
|
|
reqBody := ollamaRequest{
|
|
Model: p.model,
|
|
Prompt: input,
|
|
Stream: false,
|
|
}
|
|
if p.system != "" {
|
|
reqBody.System = p.system
|
|
}
|
|
if p.temperature != 0 {
|
|
reqBody.Options = &ollamaOptions{Temperature: p.temperature}
|
|
}
|
|
|
|
payload, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/generate", bytes.NewReader(payload))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := p.client.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return "", fmt.Errorf("ollama request failed: %s", readErrorBody(resp.Body))
|
|
}
|
|
|
|
var out ollamaResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return "", err
|
|
}
|
|
return strings.TrimSpace(out.Response), nil
|
|
}
|
|
|
|
type ollamaRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
System string `json:"system,omitempty"`
|
|
Stream bool `json:"stream"`
|
|
Options *ollamaOptions `json:"options,omitempty"`
|
|
}
|
|
|
|
type ollamaOptions struct {
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
}
|
|
|
|
type ollamaResponse struct {
|
|
Response string `json:"response"`
|
|
}
|
|
|
|
type openAICompatProcessor struct {
|
|
client *http.Client
|
|
baseURL string
|
|
apiKey string
|
|
model string
|
|
temperature float64
|
|
system string
|
|
}
|
|
|
|
func (p *openAICompatProcessor) Process(ctx context.Context, input string) (string, error) {
|
|
messages := []openAIMessage{
|
|
{Role: "user", Content: input},
|
|
}
|
|
if p.system != "" {
|
|
messages = append([]openAIMessage{{Role: "system", Content: p.system}}, messages...)
|
|
}
|
|
|
|
reqBody := openAIRequest{
|
|
Model: p.model,
|
|
Messages: messages,
|
|
Temperature: p.temperature,
|
|
}
|
|
|
|
payload, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/chat/completions", bytes.NewReader(payload))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if strings.TrimSpace(p.apiKey) != "" {
|
|
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
|
}
|
|
|
|
resp, err := p.client.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return "", fmt.Errorf("openai_compat request failed: %s", readErrorBody(resp.Body))
|
|
}
|
|
|
|
var out openAIResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
|
return "", err
|
|
}
|
|
if len(out.Choices) == 0 {
|
|
return "", errors.New("openai_compat response missing choices")
|
|
}
|
|
return strings.TrimSpace(out.Choices[0].Message.Content), nil
|
|
}
|
|
|
|
type openAIRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []openAIMessage `json:"messages"`
|
|
Temperature float64 `json:"temperature,omitempty"`
|
|
}
|
|
|
|
type openAIMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type openAIResponse struct {
|
|
Choices []openAIChoice `json:"choices"`
|
|
}
|
|
|
|
type openAIChoice struct {
|
|
Message openAIMessage `json:"message"`
|
|
}
|
|
|
|
func readErrorBody(r io.Reader) string {
|
|
data, err := io.ReadAll(io.LimitReader(r, 64*1024))
|
|
if err != nil {
|
|
return "unknown error"
|
|
}
|
|
return strings.TrimSpace(string(data))
|
|
}
|