package aiprocess import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "os" "strings" "time" _ "embed" ) type Config struct { Enabled bool Provider string Model string Temperature float64 SystemPromptFile string BaseURL string APIKey string TimeoutSec int } type Processor interface { Process(ctx context.Context, input string) (string, error) } func New(cfg Config) (Processor, error) { if !cfg.Enabled { return nil, nil } provider := strings.ToLower(strings.TrimSpace(cfg.Provider)) if provider == "" { return nil, errors.New("ai provider is required when enabled") } if strings.TrimSpace(cfg.Model) == "" { return nil, errors.New("ai model is required when enabled") } systemPrompt, err := loadSystemPrompt(cfg.SystemPromptFile) if err != nil { return nil, err } timeout := time.Duration(cfg.TimeoutSec) * time.Second if timeout <= 0 { timeout = 20 * time.Second } switch provider { case "ollama": base := strings.TrimRight(cfg.BaseURL, "/") if base == "" { base = "http://localhost:11434" } return &ollamaProcessor{ client: &http.Client{Timeout: timeout}, baseURL: base, model: cfg.Model, temperature: cfg.Temperature, system: systemPrompt, }, nil case "openai_compat": base := strings.TrimRight(cfg.BaseURL, "/") if base == "" { return nil, errors.New("ai base_url is required for openai_compat") } return &openAICompatProcessor{ client: &http.Client{Timeout: timeout}, baseURL: base, apiKey: cfg.APIKey, model: cfg.Model, temperature: cfg.Temperature, system: systemPrompt, }, nil default: return nil, fmt.Errorf("unknown ai provider %q", provider) } } func loadSystemPrompt(path string) (string, error) { if strings.TrimSpace(path) == "" { return strings.TrimSpace(defaultSystemPrompt), nil } data, err := os.ReadFile(path) if err != nil { return "", fmt.Errorf("read system prompt file: %w", err) } return strings.TrimSpace(string(data)), nil } //go:embed system_prompt.txt var defaultSystemPrompt string type ollamaProcessor struct { client *http.Client baseURL string model string temperature float64 system string } func (p *ollamaProcessor) Process(ctx context.Context, input string) (string, error) { reqBody := ollamaRequest{ Model: p.model, Prompt: input, Stream: false, } if p.system != "" { reqBody.System = p.system } if p.temperature != 0 { reqBody.Options = &ollamaOptions{Temperature: p.temperature} } payload, err := json.Marshal(reqBody) if err != nil { return "", err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/generate", bytes.NewReader(payload)) if err != nil { return "", err } req.Header.Set("Content-Type", "application/json") resp, err := p.client.Do(req) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return "", fmt.Errorf("ollama request failed: %s", readErrorBody(resp.Body)) } var out ollamaResponse if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return "", err } return strings.TrimSpace(out.Response), nil } type ollamaRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` System string `json:"system,omitempty"` Stream bool `json:"stream"` Options *ollamaOptions `json:"options,omitempty"` } type ollamaOptions struct { Temperature float64 `json:"temperature,omitempty"` } type ollamaResponse struct { Response string `json:"response"` } type openAICompatProcessor struct { client *http.Client baseURL string apiKey string model string temperature float64 system string } func (p *openAICompatProcessor) Process(ctx context.Context, input string) (string, error) { messages := []openAIMessage{ {Role: "user", Content: input}, } if p.system != "" { messages = append([]openAIMessage{{Role: "system", Content: p.system}}, messages...) } reqBody := openAIRequest{ Model: p.model, Messages: messages, Temperature: p.temperature, } payload, err := json.Marshal(reqBody) if err != nil { return "", err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/chat/completions", bytes.NewReader(payload)) if err != nil { return "", err } req.Header.Set("Content-Type", "application/json") if strings.TrimSpace(p.apiKey) != "" { req.Header.Set("Authorization", "Bearer "+p.apiKey) } resp, err := p.client.Do(req) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return "", fmt.Errorf("openai_compat request failed: %s", readErrorBody(resp.Body)) } var out openAIResponse if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return "", err } if len(out.Choices) == 0 { return "", errors.New("openai_compat response missing choices") } return strings.TrimSpace(out.Choices[0].Message.Content), nil } type openAIRequest struct { Model string `json:"model"` Messages []openAIMessage `json:"messages"` Temperature float64 `json:"temperature,omitempty"` } type openAIMessage struct { Role string `json:"role"` Content string `json:"content"` } type openAIResponse struct { Choices []openAIChoice `json:"choices"` } type openAIChoice struct { Message openAIMessage `json:"message"` } func readErrorBody(r io.Reader) string { data, err := io.ReadAll(io.LimitReader(r, 64*1024)) if err != nil { return "unknown error" } return strings.TrimSpace(string(data)) }