package daemon import ( "bufio" "context" "database/sql" "encoding/json" "errors" "fmt" "log/slog" "net" "net/http" "os" "strings" "sync" "time" "banger/internal/api" "banger/internal/buildinfo" "banger/internal/config" "banger/internal/daemon/opstate" "banger/internal/imagecat" "banger/internal/imagepull" "banger/internal/model" "banger/internal/paths" "banger/internal/rpc" "banger/internal/store" "banger/internal/system" "banger/internal/vmdns" ) type Daemon struct { layout paths.Layout config model.DaemonConfig store *store.Store runner system.CommandRunner logger *slog.Logger imageOpsMu sync.Mutex createVMMu sync.Mutex createOps opstate.Registry[*vmCreateOperationState] vmLocks vmLockSet sessions sessionRegistry tapPool tapPool closing chan struct{} once sync.Once pid int listener net.Listener webListener net.Listener webServer *http.Server webURL string vmDNS *vmdns.Server vmCaps []vmCapability pullAndFlatten func(ctx context.Context, ref, cacheDir, destDir string) (imagepull.Metadata, error) finalizePulledRootfs func(ctx context.Context, ext4File string, meta imagepull.Metadata) error bundleFetch func(ctx context.Context, destDir string, entry imagecat.CatEntry) (imagecat.Manifest, error) requestHandler func(context.Context, rpc.Request) rpc.Response guestWaitForSSH func(context.Context, string, string, time.Duration) error guestDial func(context.Context, string, string) (guestSSHClient, error) waitForGuestSessionReady func(context.Context, model.VMRecord, model.GuestSession) (model.GuestSession, error) } func Open(ctx context.Context) (d *Daemon, err error) { layout, err := paths.Resolve() if err != nil { return nil, err } if err := paths.Ensure(layout); err != nil { return nil, err } cfg, err := config.Load(layout) if err != nil { return nil, err } logger, normalizedLevel, err := newDaemonLogger(os.Stderr, cfg.LogLevel) if err != nil { return nil, err } cfg.LogLevel = normalizedLevel db, err := store.Open(layout.DBPath) if err != nil { return nil, err } d = &Daemon{ layout: layout, config: cfg, store: db, runner: system.NewRunner(), logger: logger, closing: make(chan struct{}), pid: os.Getpid(), sessions: newSessionRegistry(), } d.ensureVMSSHClientConfig() d.logger.Info("daemon opened", "socket", layout.SocketPath, "state_dir", layout.StateDir, "log_level", cfg.LogLevel) if err = d.startVMDNS(vmdns.DefaultListenAddr); err != nil { d.logger.Error("daemon open failed", "stage", "start_vm_dns", "error", err.Error()) return nil, err } defer func() { if err != nil { _ = d.stopVMDNS() } }() if err = d.reconcile(ctx); err != nil { d.logger.Error("daemon open failed", "stage", "reconcile", "error", err.Error()) return nil, err } d.ensureVMDNSResolverRouting(ctx) if err = d.initializeTapPool(ctx); err != nil { d.logger.Error("daemon open failed", "stage", "initialize_tap_pool", "error", err.Error()) return nil, err } go d.ensureTapPool(context.Background()) return d, nil } func (d *Daemon) Close() error { var err error d.once.Do(func() { if d.logger != nil { d.logger.Info("daemon closing") } close(d.closing) if d.listener != nil { _ = d.listener.Close() } if d.webServer != nil { _ = d.webServer.Close() } if d.webListener != nil { _ = d.webListener.Close() } err = errors.Join(d.clearVMDNSResolverRouting(context.Background()), d.stopVMDNS(), d.closeGuestSessionControllers(), d.store.Close()) }) return err } func (d *Daemon) Serve(ctx context.Context) error { _ = os.Remove(d.layout.SocketPath) listener, err := net.Listen("unix", d.layout.SocketPath) if err != nil { if d.logger != nil { d.logger.Error("daemon listen failed", "socket", d.layout.SocketPath, "error", err.Error()) } return err } d.listener = listener defer listener.Close() defer os.Remove(d.layout.SocketPath) if err := os.Chmod(d.layout.SocketPath, 0o600); err != nil { return err } if d.logger != nil { d.logger.Info("daemon serving", "socket", d.layout.SocketPath, "pid", d.pid) } if err := d.startWebServer(); err != nil { return err } go d.backgroundLoop() for { conn, err := listener.Accept() if err != nil { select { case <-ctx.Done(): return nil case <-d.closing: return nil default: } if _, ok := err.(net.Error); ok { if d.logger != nil { d.logger.Warn("daemon accept temporary failure", "error", err.Error()) } time.Sleep(100 * time.Millisecond) continue } if d.logger != nil { d.logger.Error("daemon accept failed", "error", err.Error()) } return err } go d.handleConn(conn) } } func (d *Daemon) handleConn(conn net.Conn) { defer conn.Close() reader := bufio.NewReader(conn) var req rpc.Request if err := json.NewDecoder(reader).Decode(&req); err != nil { if d.logger != nil { d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error()) } _ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error())) return } reqCtx, cancel := context.WithCancel(context.Background()) defer cancel() stopWatch := d.watchRequestDisconnect(conn, reader, req.Method, cancel) defer stopWatch() resp := d.dispatch(reqCtx, req) if reqCtx.Err() != nil { return } if err := json.NewEncoder(conn).Encode(resp); err != nil && d.logger != nil { d.logger.Warn("daemon response encode failed", "method", req.Method, "remote", conn.RemoteAddr().String(), "error", err.Error()) } } func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, method string, cancel context.CancelFunc) func() { if conn == nil || reader == nil { return func() {} } done := make(chan struct{}) var once sync.Once go func() { go func() { <-done if deadlineSetter, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok { _ = deadlineSetter.SetReadDeadline(time.Now()) } }() var buf [1]byte for { _, err := reader.Read(buf[:]) if err == nil { continue } select { case <-done: return default: } if d.logger != nil { d.logger.Info("daemon request canceled", "method", method, "remote", conn.RemoteAddr().String(), "error", err.Error()) } cancel() return } }() return func() { once.Do(func() { close(done) }) } } func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { if req.Version != rpc.Version { return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version)) } if d.requestHandler != nil { return d.requestHandler(ctx, req) } switch req.Method { case "ping": info := buildinfo.Current() result, _ := rpc.NewResult(api.PingResult{ Status: "ok", PID: d.pid, WebURL: d.webURL, Version: info.Version, Commit: info.Commit, BuiltAt: info.BuiltAt, }) return result case "shutdown": go d.Close() result, _ := rpc.NewResult(api.ShutdownResult{Status: "stopping"}) return result case "vm.create": params, err := rpc.DecodeParams[api.VMCreateParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.CreateVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.create.begin": params, err := rpc.DecodeParams[api.VMCreateParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } op, err := d.BeginVMCreate(ctx, params) return marshalResultOrError(api.VMCreateBeginResult{Operation: op}, err) case "vm.create.status": params, err := rpc.DecodeParams[api.VMCreateStatusParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } op, err := d.VMCreateStatus(ctx, params.ID) return marshalResultOrError(api.VMCreateStatusResult{Operation: op}, err) case "vm.create.cancel": params, err := rpc.DecodeParams[api.VMCreateStatusParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } err = d.CancelVMCreate(ctx, params.ID) return marshalResultOrError(api.Empty{}, err) case "vm.list": vms, err := d.store.ListVMs(ctx) return marshalResultOrError(api.VMListResult{VMs: vms}, err) case "vm.show": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.FindVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.start": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.StartVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.stop": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.StopVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.kill": params, err := rpc.DecodeParams[api.VMKillParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.KillVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.restart": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.RestartVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.delete": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.DeleteVM(ctx, params.IDOrName) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.set": params, err := rpc.DecodeParams[api.VMSetParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.SetVM(ctx, params) return marshalResultOrError(api.VMShowResult{VM: vm}, err) case "vm.stats": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, stats, err := d.GetVMStats(ctx, params.IDOrName) return marshalResultOrError(api.VMStatsResult{VM: vm, Stats: stats}, err) case "vm.logs": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.FindVM(ctx, params.IDOrName) if err != nil { return rpc.NewError("not_found", err.Error()) } return marshalResultOrError(api.VMLogsResult{LogPath: vm.Runtime.LogPath}, nil) case "vm.ssh": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } vm, err := d.TouchVM(ctx, params.IDOrName) if err != nil { return rpc.NewError("not_found", err.Error()) } if vm.State != model.VMStateRunning || !system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { return rpc.NewError("not_running", fmt.Sprintf("vm %s is not running", vm.Name)) } return marshalResultOrError(api.VMSSHResult{Name: vm.Name, GuestIP: vm.Runtime.GuestIP}, nil) case "vm.health": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.HealthVM(ctx, params.IDOrName) return marshalResultOrError(result, err) case "vm.ping": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.PingVM(ctx, params.IDOrName) return marshalResultOrError(result, err) case "vm.ports": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.PortsVM(ctx, params.IDOrName) return marshalResultOrError(result, err) case "vm.workspace.prepare": params, err := rpc.DecodeParams[api.VMWorkspacePrepareParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } workspace, err := d.PrepareVMWorkspace(ctx, params) return marshalResultOrError(api.VMWorkspacePrepareResult{Workspace: workspace}, err) case "vm.workspace.export": params, err := rpc.DecodeParams[api.WorkspaceExportParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.ExportVMWorkspace(ctx, params) return marshalResultOrError(result, err) case "guest.session.start": params, err := rpc.DecodeParams[api.GuestSessionStartParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } session, err := d.StartGuestSession(ctx, params) return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) case "guest.session.get": params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } session, err := d.GetGuestSession(ctx, params) return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) case "guest.session.list": params, err := rpc.DecodeParams[api.VMRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } sessions, err := d.ListGuestSessions(ctx, params) return marshalResultOrError(api.GuestSessionListResult{Sessions: sessions}, err) case "guest.session.stop": params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } session, err := d.StopGuestSession(ctx, params) return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) case "guest.session.kill": params, err := rpc.DecodeParams[api.GuestSessionRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } session, err := d.KillGuestSession(ctx, params) return marshalResultOrError(api.GuestSessionShowResult{Session: session}, err) case "guest.session.logs": params, err := rpc.DecodeParams[api.GuestSessionLogsParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.GuestSessionLogs(ctx, params) return marshalResultOrError(result, err) case "guest.session.attach.begin": params, err := rpc.DecodeParams[api.GuestSessionAttachBeginParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.BeginGuestSessionAttach(ctx, params) return marshalResultOrError(result, err) case "guest.session.send": params, err := rpc.DecodeParams[api.GuestSessionSendParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } result, err := d.SendToGuestSession(ctx, params) return marshalResultOrError(result, err) case "image.list": images, err := d.store.ListImages(ctx) return marshalResultOrError(api.ImageListResult{Images: images}, err) case "image.show": params, err := rpc.DecodeParams[api.ImageRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.FindImage(ctx, params.IDOrName) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.register": params, err := rpc.DecodeParams[api.ImageRegisterParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.RegisterImage(ctx, params) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.promote": params, err := rpc.DecodeParams[api.ImageRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.PromoteImage(ctx, params.IDOrName) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.delete": params, err := rpc.DecodeParams[api.ImageRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.DeleteImage(ctx, params.IDOrName) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "image.pull": params, err := rpc.DecodeParams[api.ImagePullParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } image, err := d.PullImage(ctx, params) return marshalResultOrError(api.ImageShowResult{Image: image}, err) case "kernel.list": return marshalResultOrError(d.KernelList(ctx)) case "kernel.show": params, err := rpc.DecodeParams[api.KernelRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } entry, err := d.KernelShow(ctx, params.Name) return marshalResultOrError(api.KernelShowResult{Entry: entry}, err) case "kernel.delete": params, err := rpc.DecodeParams[api.KernelRefParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } err = d.KernelDelete(ctx, params.Name) return marshalResultOrError(api.Empty{}, err) case "kernel.import": params, err := rpc.DecodeParams[api.KernelImportParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } entry, err := d.KernelImport(ctx, params) return marshalResultOrError(api.KernelShowResult{Entry: entry}, err) case "kernel.pull": params, err := rpc.DecodeParams[api.KernelPullParams](req) if err != nil { return rpc.NewError("bad_request", err.Error()) } entry, err := d.KernelPull(ctx, params) return marshalResultOrError(api.KernelShowResult{Entry: entry}, err) case "kernel.catalog": return marshalResultOrError(d.KernelCatalog(ctx)) default: return rpc.NewError("unknown_method", req.Method) } } func (d *Daemon) backgroundLoop() { statsTicker := time.NewTicker(d.config.StatsPollInterval) staleTicker := time.NewTicker(model.DefaultStaleSweepInterval) defer statsTicker.Stop() defer staleTicker.Stop() for { select { case <-d.closing: return case <-statsTicker.C: if err := d.pollStats(context.Background()); err != nil && d.logger != nil { d.logger.Error("background stats poll failed", "error", err.Error()) } case <-staleTicker.C: if err := d.stopStaleVMs(context.Background()); err != nil && d.logger != nil { d.logger.Error("background stale sweep failed", "error", err.Error()) } d.pruneVMCreateOperations(time.Now().Add(-10 * time.Minute)) } } } func (d *Daemon) startVMDNS(addr string) error { server, err := vmdns.New(addr, d.logger) if err != nil { return err } d.vmDNS = server if d.logger != nil { d.logger.Info("vm dns serving", "dns_addr", server.Addr()) } return nil } func (d *Daemon) stopVMDNS() error { if d.vmDNS == nil { return nil } err := d.vmDNS.Close() d.vmDNS = nil return err } func (d *Daemon) ensureDefaultImage(ctx context.Context) error { _ = ctx return nil } func (d *Daemon) reconcile(ctx context.Context) error { op := d.beginOperation("daemon.reconcile") vms, err := d.store.ListVMs(ctx) if err != nil { return op.fail(err) } for _, vm := range vms { if err := d.withVMLockByIDErr(ctx, vm.ID, func(vm model.VMRecord) error { if vm.State != model.VMStateRunning { return nil } if system.ProcessRunning(vm.Runtime.PID, vm.Runtime.APISockPath) { return nil } op.stage("stale_vm", vmLogAttrs(vm)...) _ = d.cleanupRuntime(ctx, vm, true) vm.State = model.VMStateStopped vm.Runtime.State = model.VMStateStopped clearRuntimeHandles(&vm) vm.UpdatedAt = model.Now() return d.store.UpsertVM(ctx, vm) }); err != nil { return op.fail(err, "vm_id", vm.ID) } } if err := d.rebuildDNS(ctx); err != nil { return op.fail(err) } op.done() return nil } func (d *Daemon) FindVM(ctx context.Context, idOrName string) (model.VMRecord, error) { if idOrName == "" { return model.VMRecord{}, errors.New("vm id or name is required") } if vm, err := d.store.GetVM(ctx, idOrName); err == nil { return vm, nil } vms, err := d.store.ListVMs(ctx) if err != nil { return model.VMRecord{}, err } matchCount := 0 var match model.VMRecord for _, vm := range vms { if strings.HasPrefix(vm.ID, idOrName) || strings.HasPrefix(vm.Name, idOrName) { match = vm matchCount++ } } if matchCount == 1 { return match, nil } if matchCount > 1 { return model.VMRecord{}, fmt.Errorf("multiple VMs match %q", idOrName) } return model.VMRecord{}, fmt.Errorf("vm %q not found", idOrName) } func (d *Daemon) FindImage(ctx context.Context, idOrName string) (model.Image, error) { if idOrName == "" { return model.Image{}, errors.New("image id or name is required") } if image, err := d.store.GetImageByName(ctx, idOrName); err == nil { return image, nil } if image, err := d.store.GetImageByID(ctx, idOrName); err == nil { return image, nil } images, err := d.store.ListImages(ctx) if err != nil { return model.Image{}, err } matchCount := 0 var match model.Image for _, image := range images { if strings.HasPrefix(image.ID, idOrName) || strings.HasPrefix(image.Name, idOrName) { match = image matchCount++ } } if matchCount == 1 { return match, nil } if matchCount > 1 { return model.Image{}, fmt.Errorf("multiple images match %q", idOrName) } return model.Image{}, fmt.Errorf("image %q not found", idOrName) } func (d *Daemon) TouchVM(ctx context.Context, idOrName string) (model.VMRecord, error) { return d.withVMLockByRef(ctx, idOrName, func(vm model.VMRecord) (model.VMRecord, error) { system.TouchNow(&vm) if err := d.store.UpsertVM(ctx, vm); err != nil { return model.VMRecord{}, err } return vm, nil }) } func (d *Daemon) withVMLockByRef(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) { vm, err := d.FindVM(ctx, idOrName) if err != nil { return model.VMRecord{}, err } return d.withVMLockByID(ctx, vm.ID, fn) } func (d *Daemon) withVMLockByID(ctx context.Context, id string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) { if strings.TrimSpace(id) == "" { return model.VMRecord{}, errors.New("vm id is required") } unlock := d.lockVMID(id) defer unlock() vm, err := d.store.GetVMByID(ctx, id) if err != nil { if errors.Is(err, sql.ErrNoRows) { return model.VMRecord{}, fmt.Errorf("vm %q not found", id) } return model.VMRecord{}, err } return fn(vm) } func (d *Daemon) withVMLockByIDErr(ctx context.Context, id string, fn func(model.VMRecord) error) error { _, err := d.withVMLockByID(ctx, id, func(vm model.VMRecord) (model.VMRecord, error) { if err := fn(vm); err != nil { return model.VMRecord{}, err } return vm, nil }) return err } func (d *Daemon) lockVMID(id string) func() { return d.vmLocks.lock(id) } func marshalResultOrError(v any, err error) rpc.Response { if err != nil { return rpc.NewError("operation_failed", err.Error()) } resp, marshalErr := rpc.NewResult(v) if marshalErr != nil { return rpc.NewError("marshal_failed", marshalErr.Error()) } return resp } func exists(path string) bool { _, err := os.Stat(path) return err == nil }