diff --git a/internal/aiprocess/aiprocess.go b/internal/aiprocess/aiprocess.go new file mode 100644 index 0000000..7725d61 --- /dev/null +++ b/internal/aiprocess/aiprocess.go @@ -0,0 +1,247 @@ +package aiprocess + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + _ "embed" +) + +type Config struct { + Enabled bool + Provider string + Model string + Temperature float64 + SystemPromptFile string + BaseURL string + APIKey string + TimeoutSec int +} + +type Processor interface { + Process(ctx context.Context, input string) (string, error) +} + +func New(cfg Config) (Processor, error) { + if !cfg.Enabled { + return nil, nil + } + + provider := strings.ToLower(strings.TrimSpace(cfg.Provider)) + if provider == "" { + return nil, errors.New("ai provider is required when enabled") + } + if strings.TrimSpace(cfg.Model) == "" { + return nil, errors.New("ai model is required when enabled") + } + + systemPrompt, err := loadSystemPrompt(cfg.SystemPromptFile) + if err != nil { + return nil, err + } + + timeout := time.Duration(cfg.TimeoutSec) * time.Second + if timeout <= 0 { + timeout = 20 * time.Second + } + + switch provider { + case "ollama": + base := strings.TrimRight(cfg.BaseURL, "/") + if base == "" { + base = "http://localhost:11434" + } + return &ollamaProcessor{ + client: &http.Client{Timeout: timeout}, + baseURL: base, + model: cfg.Model, + temperature: cfg.Temperature, + system: systemPrompt, + }, nil + case "openai_compat": + base := strings.TrimRight(cfg.BaseURL, "/") + if base == "" { + return nil, errors.New("ai base_url is required for openai_compat") + } + return &openAICompatProcessor{ + client: &http.Client{Timeout: timeout}, + baseURL: base, + apiKey: cfg.APIKey, + model: cfg.Model, + temperature: cfg.Temperature, + system: systemPrompt, + }, nil + default: + return nil, fmt.Errorf("unknown ai provider %q", provider) + } +} + +func loadSystemPrompt(path string) (string, error) { + if strings.TrimSpace(path) == "" { + return strings.TrimSpace(defaultSystemPrompt), nil + } + data, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("read system prompt file: %w", err) + } + return strings.TrimSpace(string(data)), nil +} + +//go:embed system_prompt.txt +var defaultSystemPrompt string + +type ollamaProcessor struct { + client *http.Client + baseURL string + model string + temperature float64 + system string +} + +func (p *ollamaProcessor) Process(ctx context.Context, input string) (string, error) { + reqBody := ollamaRequest{ + Model: p.model, + Prompt: input, + Stream: false, + } + if p.system != "" { + reqBody.System = p.system + } + if p.temperature != 0 { + reqBody.Options = &ollamaOptions{Temperature: p.temperature} + } + + payload, err := json.Marshal(reqBody) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/generate", bytes.NewReader(payload)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("ollama request failed: %s", readErrorBody(resp.Body)) + } + + var out ollamaResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return "", err + } + return strings.TrimSpace(out.Response), nil +} + +type ollamaRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + System string `json:"system,omitempty"` + Stream bool `json:"stream"` + Options *ollamaOptions `json:"options,omitempty"` +} + +type ollamaOptions struct { + Temperature float64 `json:"temperature,omitempty"` +} + +type ollamaResponse struct { + Response string `json:"response"` +} + +type openAICompatProcessor struct { + client *http.Client + baseURL string + apiKey string + model string + temperature float64 + system string +} + +func (p *openAICompatProcessor) Process(ctx context.Context, input string) (string, error) { + messages := []openAIMessage{ + {Role: "user", Content: input}, + } + if p.system != "" { + messages = append([]openAIMessage{{Role: "system", Content: p.system}}, messages...) + } + + reqBody := openAIRequest{ + Model: p.model, + Messages: messages, + Temperature: p.temperature, + } + + payload, err := json.Marshal(reqBody) + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/chat/completions", bytes.NewReader(payload)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(p.apiKey) != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("openai_compat request failed: %s", readErrorBody(resp.Body)) + } + + var out openAIResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return "", err + } + if len(out.Choices) == 0 { + return "", errors.New("openai_compat response missing choices") + } + return strings.TrimSpace(out.Choices[0].Message.Content), nil +} + +type openAIRequest struct { + Model string `json:"model"` + Messages []openAIMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type openAIResponse struct { + Choices []openAIChoice `json:"choices"` +} + +type openAIChoice struct { + Message openAIMessage `json:"message"` +} + +func readErrorBody(r io.Reader) string { + data, err := io.ReadAll(io.LimitReader(r, 64*1024)) + if err != nil { + return "unknown error" + } + return strings.TrimSpace(string(data)) +} diff --git a/internal/aiprocess/system_prompt.txt b/internal/aiprocess/system_prompt.txt new file mode 100644 index 0000000..5f3722a --- /dev/null +++ b/internal/aiprocess/system_prompt.txt @@ -0,0 +1,16 @@ +You are a deterministic text transcription cleaning engine. +You transform speech transcripts into clean written text while keeping its meaning. + +Follow these rules strictly: +1. Remove filler words (um, uh, like, okay so). +2. Resolve self-corrections by keeping ONLY the final version. + Examples: + - "schedule that for 5 PM, I mean 4 PM" -> "schedule that for 4 PM" + - "let's ask Bob, I mean Janice, let's ask Janice" -> "let's ask Janice" +3. Fix grammar, capitalization, and punctuation. +4. Do NOT add new content. +5. Do NOT remove real content. +6. Do NOT rewrite stylistically. +7. Preserve meaning exactly. + +Return ONLY the cleaned text. No explanations. diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index e934d01..4da3421 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -218,7 +218,8 @@ func (d *Daemon) stopAndProcess(reason string) { } d.log.Printf("transcript: %s", text) - if d.ai != nil && d.cfg.AIEnabled { + if d.cfg.AIEnabled && d.ai != nil { + d.log.Printf("ai enabled") d.setState(StateProcessing) aiCtx, cancel := context.WithTimeout(context.Background(), time.Duration(d.cfg.AITimeoutSec)*time.Second) cleaned, err := d.ai.Process(aiCtx, text) diff --git a/internal/ui/icons.go b/internal/ui/icons.go index ee24b68..aec61c0 100644 --- a/internal/ui/icons.go +++ b/internal/ui/icons.go @@ -22,3 +22,7 @@ func IconRecording() []byte { func IconTranscribing() []byte { return iconTranscribing } + +func IconProcessing() []byte { + return iconProcessing +}