package roothelper import ( "os" "path/filepath" "testing" "banger/internal/daemon/dmsnap" "banger/internal/firecracker" "banger/internal/paths" ) func TestValidateDMDevicePath(t *testing.T) { t.Parallel() for _, tc := range []struct { name string path string ok bool }{ {name: "valid", path: "/dev/mapper/fc-rootfs-test", ok: true}, {name: "wrong_prefix", path: "/dev/mapper/not-banger", ok: false}, {name: "wrong_dir", path: "/tmp/fc-rootfs-test", ok: false}, {name: "relative", path: "fc-rootfs-test", ok: false}, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := validateDMDevicePath(tc.path) if tc.ok && err != nil { t.Fatalf("validateDMDevicePath(%q) = %v, want nil", tc.path, err) } if !tc.ok && err == nil { t.Fatalf("validateDMDevicePath(%q) succeeded, want error", tc.path) } }) } } func TestValidateFirecrackerPID(t *testing.T) { t.Parallel() if err := validateFirecrackerPID(0); err == nil { t.Fatal("validateFirecrackerPID(0) succeeded, want error") } if err := validateFirecrackerPID(-1); err == nil { t.Fatal("validateFirecrackerPID(-1) succeeded, want error") } // Self pid points at the go test binary, whose cmdline does not // contain "firecracker" — rejection proves the helper would refuse // to kill arbitrary host processes. if err := validateFirecrackerPID(os.Getpid()); err == nil { t.Fatal("validateFirecrackerPID(test pid) succeeded, want error") } // PID 1 is init/systemd on Linux — a juicy target for a compromised // daemon, and definitely not firecracker. Make sure we'd refuse. if err := validateFirecrackerPID(1); err == nil { t.Fatal("validateFirecrackerPID(1) succeeded, want error") } } // TestValidateRootExecutableRejectsSymlink pins the O_NOFOLLOW // guarantee: even if the path string passes a textual check, a symlink // at the leaf is refused before we ever stat the target. func TestValidateRootExecutableRejectsSymlink(t *testing.T) { t.Parallel() dir := t.TempDir() regular := filepath.Join(dir, "real") if err := os.WriteFile(regular, []byte{}, 0o755); err != nil { t.Fatalf("write regular: %v", err) } link := filepath.Join(dir, "link") if err := os.Symlink(regular, link); err != nil { t.Fatalf("symlink: %v", err) } if err := validateRootExecutable(link); err == nil { t.Fatal("validateRootExecutable(symlink) succeeded, want error") } } // TestValidateRootExecutableRejectsNonRootOwned exercises the Fstat // uid check on a file the test user just created: it can't possibly // be uid 0, so the validator must refuse it. This is the regression // guard against the previous os.Stat code path drifting back in. func TestValidateRootExecutableRejectsNonRootOwned(t *testing.T) { t.Parallel() if os.Getuid() == 0 { t.Skip("test runs as root; cannot construct a non-root-owned file in a tempdir we can write") } path := filepath.Join(t.TempDir(), "binary") if err := os.WriteFile(path, []byte{}, 0o755); err != nil { t.Fatalf("write: %v", err) } err := validateRootExecutable(path) if err == nil { t.Fatal("validateRootExecutable(user-owned) succeeded, want error") } if !contains(err.Error(), "root-owned") { t.Fatalf("err = %v, want root-owned rejection", err) } } func TestValidateRootExecutableRejectsGroupWritable(t *testing.T) { t.Parallel() if os.Getuid() == 0 { t.Skip("test runs as root; can't construct a non-root-owned file") } path := filepath.Join(t.TempDir(), "binary") if err := os.WriteFile(path, []byte{}, 0o775); err != nil { t.Fatalf("write: %v", err) } err := validateRootExecutable(path) if err == nil { t.Fatal("validateRootExecutable(group-writable) succeeded, want error") } } // contains is a local substring helper that mirrors strings.Contains // without pulling in the package — kept tiny so the test file's // dependency surface stays close to the thing being tested. func contains(s, sub string) bool { for i := 0; i+len(sub) <= len(s); i++ { if s[i:i+len(sub)] == sub { return true } } return false } func TestValidateLoopDevicePath(t *testing.T) { t.Parallel() for _, tc := range []struct { name string arg string ok bool }{ {name: "loop0", arg: "/dev/loop0", ok: true}, {name: "loop12", arg: "/dev/loop12", ok: true}, {name: "no_index", arg: "/dev/loop", ok: false}, {name: "non_numeric", arg: "/dev/loop-x", ok: false}, {name: "wrong_prefix", arg: "/dev/sda1", ok: false}, {name: "empty", arg: "", ok: false}, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := validateLoopDevicePath(tc.arg) if tc.ok && err != nil { t.Fatalf("validateLoopDevicePath(%q) = %v, want nil", tc.arg, err) } if !tc.ok && err == nil { t.Fatalf("validateLoopDevicePath(%q) succeeded, want error", tc.arg) } }) } } func TestValidateDMRemoveTarget(t *testing.T) { t.Parallel() for _, tc := range []struct { name string arg string ok bool }{ {name: "dm_name", arg: "fc-rootfs-abc", ok: true}, {name: "dm_device_path", arg: "/dev/mapper/fc-rootfs-abc", ok: true}, {name: "wrong_prefix", arg: "not-banger", ok: false}, {name: "device_wrong_prefix", arg: "/dev/mapper/not-banger", ok: false}, {name: "empty", arg: "", ok: false}, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := validateDMRemoveTarget(tc.arg) if tc.ok && err != nil { t.Fatalf("validateDMRemoveTarget(%q) = %v, want nil", tc.arg, err) } if !tc.ok && err == nil { t.Fatalf("validateDMRemoveTarget(%q) succeeded, want error", tc.arg) } }) } } func TestValidateDMSnapshotHandles(t *testing.T) { t.Parallel() // Empty handles are tolerated — the dmsnap layer treats every // missing field as a no-op for that step. if err := validateDMSnapshotHandles(dmsnap.Handles{}); err != nil { t.Fatalf("validateDMSnapshotHandles(empty) = %v, want nil", err) } good := dmsnap.Handles{ BaseLoop: "/dev/loop0", COWLoop: "/dev/loop1", DMName: "fc-rootfs-abc", DMDev: "/dev/mapper/fc-rootfs-abc", } if err := validateDMSnapshotHandles(good); err != nil { t.Fatalf("validateDMSnapshotHandles(good) = %v, want nil", err) } for _, tc := range []struct { name string mutate func(dmsnap.Handles) dmsnap.Handles wantErr bool }{ {name: "bad_dm_name", mutate: func(h dmsnap.Handles) dmsnap.Handles { h.DMName = "rogue" return h }, wantErr: true}, {name: "bad_dm_device", mutate: func(h dmsnap.Handles) dmsnap.Handles { h.DMDev = "/dev/sda1" return h }, wantErr: true}, {name: "bad_base_loop", mutate: func(h dmsnap.Handles) dmsnap.Handles { h.BaseLoop = "/dev/sda1" return h }, wantErr: true}, {name: "bad_cow_loop", mutate: func(h dmsnap.Handles) dmsnap.Handles { h.COWLoop = "/etc/shadow" return h }, wantErr: true}, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := validateDMSnapshotHandles(tc.mutate(good)) if tc.wantErr && err == nil { t.Fatalf("validateDMSnapshotHandles(%s) succeeded, want error", tc.name) } if !tc.wantErr && err != nil { t.Fatalf("validateDMSnapshotHandles(%s) = %v, want nil", tc.name, err) } }) } } func TestValidateLinuxIfaceName(t *testing.T) { t.Parallel() for _, tc := range []struct { name string arg string ok bool }{ {name: "typical_bridge", arg: "br-banger", ok: true}, {name: "uplink", arg: "enp5s0", ok: true}, {name: "max_len", arg: "a234567890abcde", ok: true}, // 15 chars {name: "empty", arg: "", ok: false}, {name: "too_long", arg: "a234567890abcdef", ok: false}, {name: "with_slash", arg: "br/0", ok: false}, {name: "with_space", arg: "br 0", ok: false}, {name: "with_colon", arg: "br:0", ok: false}, {name: "dot", arg: ".", ok: false}, {name: "dotdot", arg: "..", ok: false}, {name: "control_char", arg: "br\x01", ok: false}, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := validateLinuxIfaceName(tc.arg) if tc.ok && err != nil { t.Fatalf("validateLinuxIfaceName(%q) = %v, want nil", tc.arg, err) } if !tc.ok && err == nil { t.Fatalf("validateLinuxIfaceName(%q) succeeded, want error", tc.arg) } }) } } func TestValidateIPv4(t *testing.T) { t.Parallel() for _, tc := range []struct { name string arg string ok bool }{ {name: "valid", arg: "172.16.0.2", ok: true}, {name: "with_whitespace", arg: " 10.0.0.1 ", ok: true}, {name: "empty", arg: "", ok: false}, {name: "ipv6", arg: "::1", ok: false}, {name: "garbage", arg: "not-an-ip", ok: false}, {name: "with_cidr", arg: "10.0.0.1/24", ok: false}, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := validateIPv4(tc.arg) if tc.ok && err != nil { t.Fatalf("validateIPv4(%q) = %v, want nil", tc.arg, err) } if !tc.ok && err == nil { t.Fatalf("validateIPv4(%q) succeeded, want error", tc.arg) } }) } } func TestValidateResolverAddr(t *testing.T) { t.Parallel() for _, tc := range []struct { name string arg string ok bool }{ {name: "ipv4", arg: "192.168.1.1", ok: true}, {name: "ipv6", arg: "fe80::1", ok: true}, {name: "empty", arg: "", ok: false}, {name: "garbage", arg: "resolver.example", ok: false}, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := validateResolverAddr(tc.arg) if tc.ok && err != nil { t.Fatalf("validateResolverAddr(%q) = %v, want nil", tc.arg, err) } if !tc.ok && err == nil { t.Fatalf("validateResolverAddr(%q) succeeded, want error", tc.arg) } }) } } func TestValidateExt4ImagePath(t *testing.T) { t.Parallel() srv := &Server{} stateDir := paths.ResolveSystem().StateDir for _, tc := range []struct { name string arg string ok bool }{ {name: "managed_image", arg: filepath.Join(stateDir, "vms", "abc", "rootfs.ext4"), ok: true}, {name: "managed_dm_device", arg: "/dev/mapper/fc-rootfs-test", ok: true}, {name: "outside_state", arg: "/etc/shadow", ok: false}, {name: "wrong_dm", arg: "/dev/mapper/not-banger", ok: false}, {name: "relative", arg: "rootfs.ext4", ok: false}, {name: "empty", arg: "", ok: false}, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := srv.validateExt4ImagePath(tc.arg) if tc.ok && err != nil { t.Fatalf("validateExt4ImagePath(%q) = %v, want nil", tc.arg, err) } if !tc.ok && err == nil { t.Fatalf("validateExt4ImagePath(%q) succeeded, want error", tc.arg) } }) } } func TestValidateNotSymlink(t *testing.T) { t.Parallel() dir := t.TempDir() regular := filepath.Join(dir, "real") if err := os.WriteFile(regular, []byte("ok"), 0o600); err != nil { t.Fatalf("write regular: %v", err) } link := filepath.Join(dir, "link") if err := os.Symlink(regular, link); err != nil { t.Fatalf("symlink: %v", err) } if err := validateNotSymlink(regular); err != nil { t.Fatalf("validateNotSymlink(real) = %v, want nil", err) } if err := validateNotSymlink(link); err == nil { t.Fatal("validateNotSymlink(symlink) succeeded, want error") } if err := validateNotSymlink(filepath.Join(dir, "missing")); err == nil { t.Fatal("validateNotSymlink(missing) succeeded, want error") } // Symlink pointing into the system tree is the threat we care about. // A daemon-uid attacker plants this kind of link and hopes the helper // follows it; this test pins the rejection. hostileLink := filepath.Join(dir, "hostile") if err := os.Symlink("/etc/shadow", hostileLink); err != nil { t.Fatalf("symlink: %v", err) } if err := validateNotSymlink(hostileLink); err == nil { t.Fatal("validateNotSymlink(symlink-to-/etc/shadow) succeeded, want error") } } func TestValidateLaunchDrivePathAllowsManagedRootDMDevice(t *testing.T) { t.Parallel() srv := &Server{} if err := srv.validateLaunchDrivePath(firecracker.DriveConfig{ ID: "rootfs", Path: "/dev/mapper/fc-rootfs-test", IsRoot: true, }, "/var/lib/banger"); err != nil { t.Fatalf("validateLaunchDrivePath(root dm) = %v, want nil", err) } if err := srv.validateLaunchDrivePath(firecracker.DriveConfig{ ID: "work", Path: "/dev/mapper/fc-rootfs-test", IsRoot: false, }, "/var/lib/banger"); err == nil { t.Fatal("validateLaunchDrivePath(non-root dm) succeeded, want error") } }