aman/internal/aiprocess/aiprocess.go

247 lines
5.6 KiB
Go

package aiprocess
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
_ "embed"
)
type Config struct {
Enabled bool
Provider string
Model string
Temperature float64
SystemPromptFile string
BaseURL string
APIKey string
TimeoutSec int
}
type Processor interface {
Process(ctx context.Context, input string) (string, error)
}
func New(cfg Config) (Processor, error) {
if !cfg.Enabled {
return nil, nil
}
provider := strings.ToLower(strings.TrimSpace(cfg.Provider))
if provider == "" {
return nil, errors.New("ai provider is required when enabled")
}
if strings.TrimSpace(cfg.Model) == "" {
return nil, errors.New("ai model is required when enabled")
}
systemPrompt, err := loadSystemPrompt(cfg.SystemPromptFile)
if err != nil {
return nil, err
}
timeout := time.Duration(cfg.TimeoutSec) * time.Second
if timeout <= 0 {
timeout = 20 * time.Second
}
switch provider {
case "ollama":
base := strings.TrimRight(cfg.BaseURL, "/")
if base == "" {
base = "http://localhost:11434"
}
return &ollamaProcessor{
client: &http.Client{Timeout: timeout},
baseURL: base,
model: cfg.Model,
temperature: cfg.Temperature,
system: systemPrompt,
}, nil
case "openai_compat":
base := strings.TrimRight(cfg.BaseURL, "/")
if base == "" {
return nil, errors.New("ai base_url is required for openai_compat")
}
return &openAICompatProcessor{
client: &http.Client{Timeout: timeout},
baseURL: base,
apiKey: cfg.APIKey,
model: cfg.Model,
temperature: cfg.Temperature,
system: systemPrompt,
}, nil
default:
return nil, fmt.Errorf("unknown ai provider %q", provider)
}
}
func loadSystemPrompt(path string) (string, error) {
if strings.TrimSpace(path) == "" {
return strings.TrimSpace(defaultSystemPrompt), nil
}
data, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("read system prompt file: %w", err)
}
return strings.TrimSpace(string(data)), nil
}
//go:embed system_prompt.txt
var defaultSystemPrompt string
type ollamaProcessor struct {
client *http.Client
baseURL string
model string
temperature float64
system string
}
func (p *ollamaProcessor) Process(ctx context.Context, input string) (string, error) {
reqBody := ollamaRequest{
Model: p.model,
Prompt: input,
Stream: false,
}
if p.system != "" {
reqBody.System = p.system
}
if p.temperature != 0 {
reqBody.Options = &ollamaOptions{Temperature: p.temperature}
}
payload, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/generate", bytes.NewReader(payload))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := p.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("ollama request failed: %s", readErrorBody(resp.Body))
}
var out ollamaResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}
return strings.TrimSpace(out.Response), nil
}
type ollamaRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system,omitempty"`
Stream bool `json:"stream"`
Options *ollamaOptions `json:"options,omitempty"`
}
type ollamaOptions struct {
Temperature float64 `json:"temperature,omitempty"`
}
type ollamaResponse struct {
Response string `json:"response"`
}
type openAICompatProcessor struct {
client *http.Client
baseURL string
apiKey string
model string
temperature float64
system string
}
func (p *openAICompatProcessor) Process(ctx context.Context, input string) (string, error) {
messages := []openAIMessage{
{Role: "user", Content: input},
}
if p.system != "" {
messages = append([]openAIMessage{{Role: "system", Content: p.system}}, messages...)
}
reqBody := openAIRequest{
Model: p.model,
Messages: messages,
Temperature: p.temperature,
}
payload, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/chat/completions", bytes.NewReader(payload))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(p.apiKey) != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("openai_compat request failed: %s", readErrorBody(resp.Body))
}
var out openAIResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}
if len(out.Choices) == 0 {
return "", errors.New("openai_compat response missing choices")
}
return strings.TrimSpace(out.Choices[0].Message.Content), nil
}
type openAIRequest struct {
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
}
type openAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAIResponse struct {
Choices []openAIChoice `json:"choices"`
}
type openAIChoice struct {
Message openAIMessage `json:"message"`
}
func readErrorBody(r io.Reader) string {
data, err := io.ReadAll(io.LimitReader(r, 64*1024))
if err != nil {
return "unknown error"
}
return strings.TrimSpace(string(data))
}