package daemon import ( "bufio" "context" "database/sql" "encoding/json" "errors" "fmt" "log/slog" "net" "net/http" "os" "strings" "sync" "time" "banger/internal/api" "banger/internal/buildinfo" "banger/internal/config" "banger/internal/daemon/opstate" "banger/internal/model" "banger/internal/paths" "banger/internal/rpc" "banger/internal/store" "banger/internal/system" "banger/internal/vmdns" ) type Daemon struct { layout paths.Layout config model.DaemonConfig store *store.Store runner system.CommandRunner logger *slog.Logger imageOpsMu sync.Mutex createVMMu sync.Mutex createOps opstate.Registry[*vmCreateOperationState] imageBuildOps opstate.Registry[*imageBuildOperationState] vmLocks vmLockSet sessions sessionRegistry tapPool tapPool closing chan struct{} once sync.Once pid int listener net.Listener webListener net.Listener webServer *http.Server webURL string vmDNS *vmdns.Server vmCaps []vmCapability imageBuild func(context.Context, imageBuildSpec) error requestHandler func(context.Context, rpc.Request) rpc.Response guestWaitForSSH func(context.Context, string, string, time.Duration) error guestDial func(context.Context, string, string) (guestSSHClient, error) waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error) } func Open(ctx context.Context) (d *Daemon, err error) { layout, err := paths.Resolve() if err != nil { return nil, err } if err := paths.Ensure(layout); err != nil { return nil, err } cfg, err := config.Load(layout) if err != nil { return nil, err } logger, normalizedLevel, err := newDaemonLogger(os.Stderr, cfg.LogLevel) if err != nil { return nil, err } cfg.LogLevel = normalizedLevel db, err := store.Open(layout.DBPath) if err != nil { return nil, err } d = &Daemon{ layout: layout, config: cfg, store: db, runner: system.NewRunner(), logger: logger, closing: make(chan struct{}), pid: os.Getpid(), sessions: newSessionRegistry(), } d.ensureVMSSHClientConfig() d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel) if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil { d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error()) return nil, err } defer func() { if err != nil { _ = d.stopVMDNS() } }() if err = d.reconcile(ctx); err != nil { d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error()) return nil, err } d.ensureVMDNSResolverRouting(ctx) if err = d.initializeTapPool(ctx); err != nil { d.logger.Error("daemon open failed", "stage", "initialize_tap_pool", "error", err.Error()) return nil, err } go d.ensureTapPool(context.Background()) return d, nil } func (d *Daemon) Close() error { var err error d.once.Do(func() { if d.logger != nil { d.logger.Info("daemon closing") } close(d.closing) if d.listener != nil { _ = d.listener.Close() } if d.webServer != nil { _ = d.webServer.Close() } if d.webListener != nil { _ = d.webListener.Close() } err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close()) }) return err } func (d *Daemon) Serve(ctx context.Context) error { _ = os.Remove(d.layout.SocketPath) listener, err := net.Listen("unix", d.layout.SocketPath) if err != nil { if d.logger != nil { d.logger.Error("daemon listen failed", "socket", d.layout.SocketPath, "error", err.Error()) } return err } d.listener = listener defer listener.Close() defer os.Remove(d.layout.SocketPath) if err := os.Chmod(d.layout.SocketPath, 0o600); err != nil { return err } if d.logger != nil { d.logger.Info("daemon serving", "socket", d.layout.SocketPath, "pid", d.pid) } if err := d.startWebServer(); err != nil { return err } go d.backgroundLoop() for { conn, err := listener.Accept() if err != nil { select { case <-ctx.Done(): return nil case <-d.closing: return nil default: } if _, ok := err.(net.Error); ok { if d.logger != nil { d.logger.Warn("daemon accept temporary failure", "error", err.Error()) } time.Sleep(100 * time.Millisecond) continue } if d.logger != nil { d.logger.Error("daemon accept failed", "error", err.Error()) } return err } go d.handleConn(conn) } } func (d *Daemon) handleConn(conn net.Conn) { defer conn.Close() reader := bufio.NewReader(conn) var req rpc.Request if err := json.NewDecoder(reader).Decode(&req); err != nil { if d.logger != nil { d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error()) } _ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error())) return } reqCtx, cancel := context.WithCancel(context.Background()) defer cancel() stopWatch := d.watchRequestDisconnect(conn, reader, req.Method, cancel) defer stopWatch() resp := d.dispatch(reqCtx, req) if reqCtx.Err() != nil { return } if err := json.NewEncoder(conn).Encode(resp); err != nil && d.logger != nil { d.logger.Warn("daemon response encode failed", "method", req.Method, "remote", conn.RemoteAddr().String(), "error", err.Error()) } } func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, method string, cancel context.CancelFunc) func() { if conn == nil || reader == nil { return func() {} } done := make(chan struct{}) var once sync.Once go func() { go func() { <-done if deadlineSetter, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok { _ = deadlineSetter.SetReadDeadline(time.Now()) } }() var buf [1]byte for { _, err := reader.Read(buf[:]) if err == nil { continue } select { case <-done: return default: } if d.logger != nil { d.logger.Info("daemon request canceled", "method", method, "remote", conn.RemoteAddr().String(), "error", err.Error()) } cancel() return } }() return func() { once.Do(func() { close(done) }) } } func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { if req.Version != rpc.Version { return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version)) } if d.requestHandler != nil { return d.requestHandler(ctx, req) } switch req.Method { case "ping": info := buildinfo.Current() result, _ := rpc.NewResult(api.PingResult{ Status: "ok", PID: d.pid, WebURL: d.webURL, Version: info.Version, Commit: info.Commit, BuiltAt: info.BuiltAt, }) return result case "shutdown": go d.Close() result, _ := rpc.NewResult(api.ShutdownResult{Status: "stopping"}) return result case "vm.create": params, err := rpc.DecodeParams[api.VMCreateParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.CreateVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.create.begin": params, err := rpc.DecodeParams[api.VMCreateParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } op, err := d.BeginVMCreate(ctx, params) return marshalResultOrError(api.VMCreateBeginResult{Operation: op}, err) case "vm.create.status": params, err := rpc.DecodeParams[api.VMCreateStatusParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } op, err := d.VMCreateStatus(ctx, params.ID) return marshalResultOrError(api.VMCreateStatusResult{Operation: op}, err) case "vm.create.cancel": params, err := rpc.DecodeParams[api.VMCreateStatusParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } err = d.CancelVMCreate(ctx, params.ID) return marshalResultOrError(api.Empty{}, err) case "vm.list": vms, err := d.store.ListVMs(ctx) return marshalResultOrError(api.VMListResult{VMs: vms}, err) case "vm.show": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.FindVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.start": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.StartVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.stop": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.StopVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.kill": params, err := rpc.DecodeParams[api.VMKillParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.KillVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.restart": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.RestartVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.delete": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.DeleteVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.set": params, err := rpc.DecodeParams[api.VMSetParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.SetVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.stats": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, stats, err := d.GetVMStats(ctx, params.IDOrName) return marshalResultOrError(api.VMStatsResult{VM: vm, Stats: stats}, err) case "vm.logs": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.FindVM(ctx, params.IDOrName) if err != nil { return rpc.NewError("not_found", err.Error()) } return marshalResultOrError(api.VMLogsResult{LogPath: vm.Runtime.LogPath}, nil) case "vm.ssh": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.TouchVM(ctx, params.IDOrName) if err != nil { return rpc.NewError("not_found", err.Error()) } if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { return rpc.NewError("not_running", fmt.Sprintf("vm %s is not running", vm.Name)) } return marshalResultOrError(api.VMSSHResult{Name: vm.Name, GuestIP: vm.Runtime.GuestIP}, nil) case "vm.health": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.HealthVM(ctx, params.IDOrName) return marshalResultOrError(result, err) case "vm.ping": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.PingVM(ctx, params.IDOrName) return marshalResultOrError(result, err) case "vm.ports": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.PortsVM(ctx, params.IDOrName) return marshalResultOrError(result, err) case "vm.workspace.prepare": params, err := rpc.DecodeParams[api.VMWorkspacePrepareParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } workspace, err := d.PrepareVMWorkspace(ctx, params) return marshalResultOrError(api.VMWorkspacePrepareResult{Workspace: workspace}, err) case "vm.workspace.export": params, err := rpc.DecodeParams[api.WorkspaceExportParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.ExportVMWorkspace(ctx, params) return marshalResultOrError(result, err) case "guest.session.start": params, err := rpc.DecodeParams[api.GuestSessionStartParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } session, err := d.StartGuestSession(ctx, params) return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) case "guest.session.get": params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } session, err := d.GetGuestSession(ctx, params) return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) case "guest.session.list": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } sessions, err := d.ListGuestSessions(ctx, params) return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err) case "guest.session.stop": params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } session, err := d.StopGuestSession(ctx, params) return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) case "guest.session.kill": params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } session, err := d.KillGuestSession(ctx, params) return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) case "guest.session.logs": params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.GuestSessionLogs(ctx, params) return marshalResultOrError(result, err) case "guest.session.attach.begin": params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.BeginGuestSessionAttach(ctx, params) return marshalResultOrError(result, err) case "guest.session.send": params, err := rpc.DecodeParams[api.GuestSessionSendParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.SendToGuestSession(ctx, params) return marshalResultOrError(result, err) case "image.list": images, err := d.store.ListImages(ctx) return marshalResultOrError(api.ImageListResult{Images: images}, err) case "image.show": params, err := rpc.DecodeParams[api.ImageRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.FindImage(ctx, params.IDOrName) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.build": params, err := rpc.DecodeParams[api.ImageBuildParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.BuildImage(ctx, params) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.build.begin": params, err := rpc.DecodeParams[api.ImageBuildParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } op, err := d.BeginImageBuild(ctx, params) return marshalResultOrError(api.ImageBuildBeginResult{Operation: op}, err) case "image.build.status": params, err := rpc.DecodeParams[api.ImageBuildStatusParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } op, err := d.ImageBuildStatus(ctx, params.ID) return marshalResultOrError(api.ImageBuildStatusResult{Operation: op}, err) case "image.build.cancel": params, err := rpc.DecodeParams[api.ImageBuildStatusParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } err = d.CancelImageBuild(ctx, params.ID) return marshalResultOrError(api.Empty{}, err) case "image.register": params, err := rpc.DecodeParams[api.ImageRegisterParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.RegisterImage(ctx, params) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.promote": params, err := rpc.DecodeParams[api.ImageRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.PromoteImage(ctx, params.IDOrName) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.delete": params, err := rpc.DecodeParams[api.ImageRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.DeleteImage(ctx, params.IDOrName) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "kernel.list": return marshalResultOrError(d.KernelList(ctx)) case "kernel.show": params, err := rpc.DecodeParams[api.KernelRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } entry, err := d.KernelShow(ctx, params.Name) return marshalResultOrError(api.KernelShowResult{Entry: entry}, err) case "kernel.delete": params, err := rpc.DecodeParams[api.KernelRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } err = d.KernelDelete(ctx, params.Name) return marshalResultOrError(api.Empty{}, err) case "kernel.import": params, err := rpc.DecodeParams[api.KernelImportParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } entry, err := d.KernelImport(ctx, params) return marshalResultOrError(api.KernelShowResult{Entry: entry}, err) case "kernel.pull": params, err := rpc.DecodeParams[api.KernelPullParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } entry, err := d.KernelPull(ctx, params) return marshalResultOrError(api.KernelShowResult{Entry: entry}, err) case "kernel.catalog": return marshalResultOrError(d.KernelCatalog(ctx)) default: return rpc.NewError("unknown_method", req.Method) } } func (d *Daemon) backgroundLoop() { statsTicker := time.NewTicker(d.config.StatsPollInterval) staleTicker := time.NewTicker(model.DefaultStaleSweepInterval) defer statsTicker.Stop() defer staleTicker.Stop() for { select { case <-d.closing: return case <-statsTicker.C: if err := d.pollStats(context.Background()); err != nil && d.logger != nil { d.logger.Error("background stats poll failed", "error", err.Error()) } case <-staleTicker.C: if err := d.stopStaleVMs(context.Background()); err != nil && d.logger != nil { d.logger.Error("background stale sweep failed", "error", err.Error()) } d.pruneVMCreateOperations(time.Now().Add(-10 * time.Minute)) d.pruneImageBuildOperations(time.Now().Add(-10 * time.Minute)) } } } func (d *Daemon) startVMDNS(addr string) error { server, err := vmdns.New(addr, d.logger) if err != nil { return err } d.vmDNS = server if d.logger != nil { d.logger.Info("vm dns serving", "dns_addr", server.Addr()) } return nil } func (d *Daemon) stopVMDNS() error { if d.vmDNS == nil { return nil } err := d.vmDNS.Close() d.vmDNS = nil return err } func (d *Daemon) ensureDefaultImage(ctx context.Context) error { _ = ctx return nil } func (d *Daemon) reconcile(ctx context.Context) error { op := d.beginOperation("daemon.reconcile") vms, err := d.store.ListVMs(ctx) if err != nil { return op.fail(err) } for _, vm := range vms { if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { if vm.State != model.VMStateRunning { return nil } if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { return nil } op.stage("stale_vm", vmLogAttrs(vm)...) _ = d.cleanupRuntime(ctx, vm, true) vm.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped clearRuntimeHandles(&vm) vm.UpdatedAt = model.Now() return d.store.UpsertVM(ctx, vm) }); err != nil { return op.fail(err, "vm_id", vm.ID) } } if err := d.rebuildDNS(ctx); err != nil { return op.fail(err) } op.done() return nil } func (d *Daemon) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) { if idOrName == "" { return model.VMRecord{}, errors.New("vm id or name is required") } if vm, err := d.store.GetVM(ctx, idOrName); err == nil { return vm, nil } vms, err := d.store.ListVMs(ctx) if err != nil { return model.VMRecord{}, err } matchCount := 0 var match model.VMRecord for _, vm := range vms { if strings.HasPrefix(vm.ID, idOrName) || strings.HasPrefix(vm.Name, idOrName) { match = vm matchCount++ } } if matchCount == 1 { return match, nil } if matchCount > 1 { return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", idOrName) } return model.VMRecord{}, fmt.Errorf("vm %q not found", idOrName) } func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, error) { if idOrName == "" { return model.Image{}, errors.New("image id or name is required") } if image, err := d.store.GetImageByName(ctx, idOrName); err == nil { return image, nil } if image, err := d.store.GetImageByID(ctx, idOrName); err == nil { return image, nil } images, err := d.store.ListImages(ctx) if err != nil { return model.Image{}, err } matchCount := 0 var match model.Image for _, image := range images { if strings.HasPrefix(image.ID, idOrName) || strings.HasPrefix(image.Name, idOrName) { match = image matchCount++ } } if matchCount == 1 { return match, nil } if matchCount > 1 { return model.Image{}, fmt.Errorf("multiple images match %q", idOrName) } return model.Image{}, fmt.Errorf("image %q not found", idOrName) } func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) { return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { system.TouchNow(&vm) if err := d.store.UpsertVM(ctx, vm); err != nil { return model.VMRecord{}, err } return vm, nil }) } func (d *Daemon) withVMLockByRef(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) { vm, err := d.FindVM(ctx, idOrName) if err != nil { return model.VMRecord{}, err } return d.withVMLockByID(ctx, vm.ID, fn) } func (d *Daemon) withVMLockByID(ctx context.Context, id string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) { if strings.TrimSpace(id) == "" { return model.VMRecord{}, errors.New("vm id is required") } unlock := d.lockVMID(id) defer unlock() vm, err := d.store.GetVMByID(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return model.VMRecord{}, fmt.Errorf("vm %q not found", id) } return model.VMRecord{}, err } return fn(vm) } func (d *Daemon) withVMLockByIDErr(ctx context.Context, id string, fn func(model.VMRecord) error) error { _, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) { if err := fn(vm); err != nil { return model.VMRecord{}, err } return vm, nil }) return err } func (d *Daemon) lockVMID(id string) func() { return d.vmLocks.lock(id) } func marshalResultOrError(v any, err error) rpc.Response { if err != nil { return rpc.NewError("operation_failed", err.Error()) } resp, marshalErr := rpc.NewResult(v) if marshalErr != nil { return rpc.NewError("marshal_failed", marshalErr.Error()) } return resp } func exists(path string) bool { _, err := os.Stat(path) return err == nil }