Add AI post-processing prompt
This commit is contained in:
parent
b296491703
commit
cda89923ce
4 changed files with 269 additions and 1 deletions
247
internal/aiprocess/aiprocess.go
Normal file
247
internal/aiprocess/aiprocess.go
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
package aiprocess
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Provider string
|
||||
Model string
|
||||
Temperature float64
|
||||
SystemPromptFile string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
TimeoutSec int
|
||||
}
|
||||
|
||||
type Processor interface {
|
||||
Process(ctx context.Context, input string) (string, error)
|
||||
}
|
||||
|
||||
func New(cfg Config) (Processor, error) {
|
||||
if !cfg.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
provider := strings.ToLower(strings.TrimSpace(cfg.Provider))
|
||||
if provider == "" {
|
||||
return nil, errors.New("ai provider is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(cfg.Model) == "" {
|
||||
return nil, errors.New("ai model is required when enabled")
|
||||
}
|
||||
|
||||
systemPrompt, err := loadSystemPrompt(cfg.SystemPromptFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timeout := time.Duration(cfg.TimeoutSec) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 20 * time.Second
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case "ollama":
|
||||
base := strings.TrimRight(cfg.BaseURL, "/")
|
||||
if base == "" {
|
||||
base = "http://localhost:11434"
|
||||
}
|
||||
return &ollamaProcessor{
|
||||
client: &http.Client{Timeout: timeout},
|
||||
baseURL: base,
|
||||
model: cfg.Model,
|
||||
temperature: cfg.Temperature,
|
||||
system: systemPrompt,
|
||||
}, nil
|
||||
case "openai_compat":
|
||||
base := strings.TrimRight(cfg.BaseURL, "/")
|
||||
if base == "" {
|
||||
return nil, errors.New("ai base_url is required for openai_compat")
|
||||
}
|
||||
return &openAICompatProcessor{
|
||||
client: &http.Client{Timeout: timeout},
|
||||
baseURL: base,
|
||||
apiKey: cfg.APIKey,
|
||||
model: cfg.Model,
|
||||
temperature: cfg.Temperature,
|
||||
system: systemPrompt,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown ai provider %q", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func loadSystemPrompt(path string) (string, error) {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return strings.TrimSpace(defaultSystemPrompt), nil
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read system prompt file: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
//go:embed system_prompt.txt
|
||||
var defaultSystemPrompt string
|
||||
|
||||
type ollamaProcessor struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
model string
|
||||
temperature float64
|
||||
system string
|
||||
}
|
||||
|
||||
func (p *ollamaProcessor) Process(ctx context.Context, input string) (string, error) {
|
||||
reqBody := ollamaRequest{
|
||||
Model: p.model,
|
||||
Prompt: input,
|
||||
Stream: false,
|
||||
}
|
||||
if p.system != "" {
|
||||
reqBody.System = p.system
|
||||
}
|
||||
if p.temperature != 0 {
|
||||
reqBody.Options = &ollamaOptions{Temperature: p.temperature}
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/generate", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("ollama request failed: %s", readErrorBody(resp.Body))
|
||||
}
|
||||
|
||||
var out ollamaResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(out.Response), nil
|
||||
}
|
||||
|
||||
type ollamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Options *ollamaOptions `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaOptions struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
type openAICompatProcessor struct {
|
||||
client *http.Client
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
temperature float64
|
||||
system string
|
||||
}
|
||||
|
||||
func (p *openAICompatProcessor) Process(ctx context.Context, input string) (string, error) {
|
||||
messages := []openAIMessage{
|
||||
{Role: "user", Content: input},
|
||||
}
|
||||
if p.system != "" {
|
||||
messages = append([]openAIMessage{{Role: "system", Content: p.system}}, messages...)
|
||||
}
|
||||
|
||||
reqBody := openAIRequest{
|
||||
Model: p.model,
|
||||
Messages: messages,
|
||||
Temperature: p.temperature,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/v1/chat/completions", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if strings.TrimSpace(p.apiKey) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("openai_compat request failed: %s", readErrorBody(resp.Body))
|
||||
}
|
||||
|
||||
var out openAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(out.Choices) == 0 {
|
||||
return "", errors.New("openai_compat response missing choices")
|
||||
}
|
||||
return strings.TrimSpace(out.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAIResponse struct {
|
||||
Choices []openAIChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type openAIChoice struct {
|
||||
Message openAIMessage `json:"message"`
|
||||
}
|
||||
|
||||
func readErrorBody(r io.Reader) string {
|
||||
data, err := io.ReadAll(io.LimitReader(r, 64*1024))
|
||||
if err != nil {
|
||||
return "unknown error"
|
||||
}
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
16
internal/aiprocess/system_prompt.txt
Normal file
16
internal/aiprocess/system_prompt.txt
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
You are a deterministic text transcription cleaning engine.
|
||||
You transform speech transcripts into clean written text while keeping its meaning.
|
||||
|
||||
Follow these rules strictly:
|
||||
1. Remove filler words (um, uh, like, okay so).
|
||||
2. Resolve self-corrections by keeping ONLY the final version.
|
||||
Examples:
|
||||
- "schedule that for 5 PM, I mean 4 PM" -> "schedule that for 4 PM"
|
||||
- "let's ask Bob, I mean Janice, let's ask Janice" -> "let's ask Janice"
|
||||
3. Fix grammar, capitalization, and punctuation.
|
||||
4. Do NOT add new content.
|
||||
5. Do NOT remove real content.
|
||||
6. Do NOT rewrite stylistically.
|
||||
7. Preserve meaning exactly.
|
||||
|
||||
Return ONLY the cleaned text. No explanations.
|
||||
|
|
@ -218,7 +218,8 @@ func (d *Daemon) stopAndProcess(reason string) {
|
|||
}
|
||||
d.log.Printf("transcript: %s", text)
|
||||
|
||||
if d.ai != nil && d.cfg.AIEnabled {
|
||||
if d.cfg.AIEnabled && d.ai != nil {
|
||||
d.log.Printf("ai enabled")
|
||||
d.setState(StateProcessing)
|
||||
aiCtx, cancel := context.WithTimeout(context.Background(), time.Duration(d.cfg.AITimeoutSec)*time.Second)
|
||||
cleaned, err := d.ai.Process(aiCtx, text)
|
||||
|
|
|
|||
|
|
@ -22,3 +22,7 @@ func IconRecording() []byte {
|
|||
func IconTranscribing() []byte {
|
||||
return iconTranscribing
|
||||
}
|
||||
|
||||
func IconProcessing() []byte {
|
||||
return iconProcessing
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue