package daemon import ( "bufio" "context" "database/sql" "encoding/json" "errors" "fmt" "log/slog" "net" "os" "path/filepath" "strings" "sync" "time" "banger/internal/api" "banger/internal/config" "banger/internal/model" "banger/internal/paths" "banger/internal/rpc" "banger/internal/store" "banger/internal/system" "banger/internal/vmdns" ) type Daemon struct { layout paths.Layout config model.DaemonConfig store *store.Store runner system.CommandRunner logger *slog.Logger mu sync.Mutex closing chan struct{} once sync.Once pid int listener net.Listener vmDNS *vmdns.Server imageBuild func(context.Context, imageBuildSpec) error requestHandler func(context.Context, rpc.Request) rpc.Response } func Open(ctx context.Context) (d *Daemon, err error) { layout, err := paths.Resolve() if err != nil { return nil, err } if err := paths.Ensure(layout); err != nil { return nil, err } cfg, err := config.Load(layout) if err != nil { return nil, err } logger, normalizedLevel, err := newDaemonLogger(os.Stderr, cfg.LogLevel) if err != nil { return nil, err } cfg.LogLevel = normalizedLevel db, err := store.Open(layout.DBPath) if err != nil { return nil, err } d = &Daemon{ layout: layout, config: cfg, store: db, runner: system.NewRunner(), logger: logger, closing: make(chan struct{}), pid: os.Getpid(), } d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "runtime_dir", cfg.RuntimeDir, "log_level", cfg.LogLevel) if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil { d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error()) return nil, err } defer func() { if err != nil { _ = d.stopVMDNS() } }() if err = d.ensureDefaultImage(ctx); err != nil { d.logger.Error("daemon open failed", "stage", "ensure_default_image", "error", err.Error()) return nil, err } if err = d.reconcile(ctx); err != nil { d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error()) return nil, err } return d, nil } func (d *Daemon) Close() error { var err error d.once.Do(func() { if d.logger != nil { d.logger.Info("daemon closing") } close(d.closing) if d.listener != nil { _ = d.listener.Close() } err = errors.Join(d.stopVMDNS(), d.store.Close()) }) return err } func (d *Daemon) Serve(ctx context.Context) error { _ = os.Remove(d.layout.SocketPath) listener, err := net.Listen("unix", d.layout.SocketPath) if err != nil { if d.logger != nil { d.logger.Error("daemon listen failed", "socket", d.layout.SocketPath, "error", err.Error()) } return err } d.listener = listener defer listener.Close() defer os.Remove(d.layout.SocketPath) if err := os.Chmod(d.layout.SocketPath, 0o600); err != nil { return err } if d.logger != nil { d.logger.Info("daemon serving", "socket", d.layout.SocketPath, "pid", d.pid) } go d.backgroundLoop() for { conn, err := listener.Accept() if err != nil { select { case <-ctx.Done(): return nil case <-d.closing: return nil default: } if ne, ok := err.(net.Error); ok && ne.Temporary() { if d.logger != nil { d.logger.Warn("daemon accept temporary failure", "error", err.Error()) } time.Sleep(100 * time.Millisecond) continue } if d.logger != nil { d.logger.Error("daemon accept failed", "error", err.Error()) } return err } go d.handleConn(conn) } } func (d *Daemon) handleConn(conn net.Conn) { defer conn.Close() reader := bufio.NewReader(conn) var req rpc.Request if err := json.NewDecoder(reader).Decode(&req); err != nil { if d.logger != nil { d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error()) } _ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error())) return } reqCtx, cancel := context.WithCancel(context.Background()) defer cancel() stopWatch := d.watchRequestDisconnect(conn, reader, req.Method, cancel) defer stopWatch() resp := d.dispatch(reqCtx, req) if reqCtx.Err() != nil { return } if err := json.NewEncoder(conn).Encode(resp); err != nil && d.logger != nil { d.logger.Warn("daemon response encode failed", "method", req.Method, "remote", conn.RemoteAddr().String(), "error", err.Error()) } } func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, method string, cancel context.CancelFunc) func() { if conn == nil || reader == nil { return func() {} } done := make(chan struct{}) var once sync.Once go func() { go func() { <-done if deadlineSetter, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok { _ = deadlineSetter.SetReadDeadline(time.Now()) } }() var buf [1]byte for { _, err := reader.Read(buf[:]) if err == nil { continue } select { case <-done: return default: } if d.logger != nil { d.logger.Info("daemon request canceled", "method", method, "remote", conn.RemoteAddr().String(), "error", err.Error()) } cancel() return } }() return func() { once.Do(func() { close(done) }) } } func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { if req.Version != rpc.Version { return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version)) } if d.requestHandler != nil { return d.requestHandler(ctx, req) } switch req.Method { case "ping": result, _ := rpc.NewResult(api.PingResult{Status: "ok", PID: d.pid}) return result case "shutdown": go d.Close() result, _ := rpc.NewResult(api.ShutdownResult{Status: "stopping"}) return result case "vm.create": params, err := rpc.DecodeParams[api.VMCreateParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.CreateVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.list": vms, err := d.store.ListVMs(ctx) return marshalResultOrError(api.VMListResult{VMs: vms}, err) case "vm.show": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.FindVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.start": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.StartVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.stop": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.StopVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.kill": params, err := rpc.DecodeParams[api.VMKillParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.KillVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.restart": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.RestartVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.delete": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.DeleteVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.set": params, err := rpc.DecodeParams[api.VMSetParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.SetVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.stats": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, stats, err := d.GetVMStats(ctx, params.IDOrName) return marshalResultOrError(api.VMStatsResult{VM: vm, Stats: stats}, err) case "vm.logs": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.FindVM(ctx, params.IDOrName) if err != nil { return rpc.NewError("not_found", err.Error()) } return marshalResultOrError(api.VMLogsResult{LogPath: vm.Runtime.LogPath}, nil) case "vm.ssh": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.TouchVM(ctx, params.IDOrName) if err != nil { return rpc.NewError("not_found", err.Error()) } if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { return rpc.NewError("not_running", fmt.Sprintf("vm %s is not running", vm.Name)) } return marshalResultOrError(api.VMSSHResult{Name: vm.Name, GuestIP: vm.Runtime.GuestIP}, nil) case "image.list": images, err := d.store.ListImages(ctx) return marshalResultOrError(api.ImageListResult{Images: images}, err) case "image.show": params, err := rpc.DecodeParams[api.ImageRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.FindImage(ctx, params.IDOrName) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.build": params, err := rpc.DecodeParams[api.ImageBuildParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.BuildImage(ctx, params) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.delete": params, err := rpc.DecodeParams[api.ImageRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.DeleteImage(ctx, params.IDOrName) return marshalResultOrError(api.ImageShowResult{Image: image}, err) default: return rpc.NewError("unknown_method", req.Method) } } func (d *Daemon) backgroundLoop() { statsTicker := time.NewTicker(d.config.StatsPollInterval) staleTicker := time.NewTicker(model.DefaultStaleSweepInterval) defer statsTicker.Stop() defer staleTicker.Stop() for { select { case <-d.closing: return case <-statsTicker.C: if err := d.pollStats(context.Background()); err != nil && d.logger != nil { d.logger.Error("background stats poll failed", "error", err.Error()) } case <-staleTicker.C: if err := d.stopStaleVMs(context.Background()); err != nil && d.logger != nil { d.logger.Error("background stale sweep failed", "error", err.Error()) } } } } func (d *Daemon) startVMDNS(addr string) error { server, err := vmdns.New(addr, d.logger) if err != nil { return err } d.vmDNS = server if d.logger != nil { d.logger.Info("vm dns serving", "dns_addr", server.Addr()) } return nil } func (d *Daemon) stopVMDNS() error { if d.vmDNS == nil { return nil } err := d.vmDNS.Close() d.vmDNS = nil return err } func (d *Daemon) ensureDefaultImage(ctx context.Context) error { if d.config.DefaultImageName == "" { return nil } desired, ok := d.desiredDefaultImage() if !ok { if d.logger != nil { d.logger.Debug("default image skipped", "image_name", d.config.DefaultImageName, "rootfs_path", d.config.DefaultRootfs, "kernel_path", d.config.DefaultKernel) } return nil } image, err := d.store.GetImageByName(ctx, d.config.DefaultImageName) switch { case err == nil: if image.Managed { if d.logger != nil { d.logger.Debug("managed default image left untouched", append(imageLogAttrs(image), "managed", image.Managed)...) } return nil } if defaultImageMatches(image, desired) { if d.logger != nil { d.logger.Debug("default image already current", imageLogAttrs(image)...) } return nil } updated := desired updated.ID = image.ID updated.CreatedAt = image.CreatedAt updated.UpdatedAt = model.Now() if err := d.store.UpsertImage(ctx, updated); err != nil { return err } if d.logger != nil { d.logger.Info("default image reconciled", append(imageLogAttrs(updated), "previous_rootfs_path", image.RootfsPath, "previous_kernel_path", image.KernelPath)...) } return nil case errors.Is(err, sql.ErrNoRows): id, err := model.NewID() if err != nil { return err } now := model.Now() desired.ID = id desired.CreatedAt = now desired.UpdatedAt = now if err := d.store.UpsertImage(ctx, desired); err != nil { return err } if d.logger != nil { d.logger.Info("default image registered", append(imageLogAttrs(desired), "managed", desired.Managed)...) } return nil default: return err } } func (d *Daemon) desiredDefaultImage() (model.Image, bool) { rootfs := d.config.DefaultRootfs kernel := d.config.DefaultKernel if !exists(rootfs) || !exists(kernel) { return model.Image{}, false } return model.Image{ Name: d.config.DefaultImageName, Managed: false, ArtifactDir: "", RootfsPath: rootfs, KernelPath: kernel, InitrdPath: d.config.DefaultInitrd, ModulesDir: d.config.DefaultModulesDir, PackagesPath: d.config.DefaultPackagesFile, Docker: strings.Contains(filepath.Base(rootfs), "docker"), }, true } func defaultImageMatches(current, desired model.Image) bool { return current.Name == desired.Name && current.Managed == desired.Managed && current.ArtifactDir == desired.ArtifactDir && current.RootfsPath == desired.RootfsPath && current.KernelPath == desired.KernelPath && current.InitrdPath == desired.InitrdPath && current.ModulesDir == desired.ModulesDir && current.PackagesPath == desired.PackagesPath && current.Docker == desired.Docker } func (d *Daemon) reconcile(ctx context.Context) error { op := d.beginOperation("daemon.reconcile") vms, err := d.store.ListVMs(ctx) if err != nil { return op.fail(err) } for _, vm := range vms { if vm.State != model.VMStateRunning { continue } if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { continue } op.stage("stale_vm", vmLogAttrs(vm)...) _ = d.cleanupRuntime(ctx, vm, true) vm.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped vm.Runtime.PID = 0 vm.Runtime.TapDevice = "" vm.Runtime.APISockPath = "" vm.Runtime.BaseLoop = "" vm.Runtime.COWLoop = "" vm.Runtime.DMName = "" vm.Runtime.DMDev = "" vm.UpdatedAt = model.Now() if err := d.store.UpsertVM(ctx, vm); err != nil { return op.fail(err, vmLogAttrs(vm)...) } } if err := d.rebuildDNS(ctx); err != nil { return op.fail(err) } op.done() return nil } func (d *Daemon) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) { if idOrName == "" { return model.VMRecord{}, errors.New("vm id or name is required") } if vm, err := d.store.GetVM(ctx, idOrName); err == nil { return vm, nil } vms, err := d.store.ListVMs(ctx) if err != nil { return model.VMRecord{}, err } matchCount := 0 var match model.VMRecord for _, vm := range vms { if strings.HasPrefix(vm.ID, idOrName) || strings.HasPrefix(vm.Name, idOrName) { match = vm matchCount++ } } if matchCount == 1 { return match, nil } if matchCount > 1 { return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", idOrName) } return model.VMRecord{}, fmt.Errorf("vm %q not found", idOrName) } func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, error) { if idOrName == "" { return model.Image{}, errors.New("image id or name is required") } if image, err := d.store.GetImageByName(ctx, idOrName); err == nil { return image, nil } if image, err := d.store.GetImageByID(ctx, idOrName); err == nil { return image, nil } images, err := d.store.ListImages(ctx) if err != nil { return model.Image{}, err } matchCount := 0 var match model.Image for _, image := range images { if strings.HasPrefix(image.ID, idOrName) || strings.HasPrefix(image.Name, idOrName) { match = image matchCount++ } } if matchCount == 1 { return match, nil } if matchCount > 1 { return model.Image{}, fmt.Errorf("multiple images match %q", idOrName) } return model.Image{}, fmt.Errorf("image %q not found", idOrName) } func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) { d.mu.Lock() defer d.mu.Unlock() vm, err := d.FindVM(ctx, idOrName) if err != nil { return model.VMRecord{}, err } system.TouchNow(&vm) if err := d.store.UpsertVM(ctx, vm); err != nil { return model.VMRecord{}, err } return vm, nil } func marshalResultOrError(v any, err error) rpc.Response { if err != nil { return rpc.NewError("operation_failed", err.Error()) } resp, marshalErr := rpc.NewResult(v) if marshalErr != nil { return rpc.NewError("marshal_failed", marshalErr.Error()) } return resp } func exists(path string) bool { _, err := os.Stat(path) return err == nil }