Manage NAT directly from VM records
Fix the Go control plane NAT path now that runtime state lives in the daemon instead of the old repo-local vm.json files. Add a daemon-native NAT helper that derives uplink, guest IP, and TAP rules directly from VMRecord, applies the existing iptables/sysctl behavior idempotently, and removes the broken nat.sh handoff from vm.go. Cover uplink parsing and rule generation with unit tests. Validated with go test ./... and make build; a live verify.sh --nat run installed host rules but stopped on the same guest SSH-readiness issue seen in the plain smoke test on this host.
This commit is contained in:
parent
2539800f5c
commit
171009b30b
3 changed files with 278 additions and 18 deletions
149
internal/daemon/nat.go
Normal file
149
internal/daemon/nat.go
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"banger/internal/model"
|
||||
"banger/internal/system"
|
||||
)
|
||||
|
||||
type natRule struct {
|
||||
table string
|
||||
chain string
|
||||
args []string
|
||||
}
|
||||
|
||||
func (d *Daemon) ensureNAT(ctx context.Context, vm model.VMRecord, enable bool) error {
|
||||
if err := system.RequireCommands(ctx, "iptables", "sysctl"); err != nil {
|
||||
return err
|
||||
}
|
||||
uplink, err := d.defaultUplink(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rules, err := natRulesForVM(vm, uplink)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if enable {
|
||||
if _, err := d.runner.RunSudo(ctx, "sysctl", "-w", "net.ipv4.ip_forward=1"); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := d.addNATRule(ctx, rule); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := d.removeNATRule(ctx, rule); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) defaultUplink(ctx context.Context) (string, error) {
|
||||
out, err := d.runner.Run(ctx, "ip", "route", "show", "default")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return parseDefaultUplink(string(out))
|
||||
}
|
||||
|
||||
func parseDefaultUplink(output string) (string, error) {
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) == 0 || fields[0] != "default" {
|
||||
continue
|
||||
}
|
||||
for i := 0; i < len(fields)-1; i++ {
|
||||
if fields[i] == "dev" && fields[i+1] != "" {
|
||||
return fields[i+1], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", errors.New("failed to detect uplink interface")
|
||||
}
|
||||
|
||||
func natRulesForVM(vm model.VMRecord, uplink string) ([]natRule, error) {
|
||||
guestIP := strings.TrimSpace(vm.Runtime.GuestIP)
|
||||
if guestIP == "" {
|
||||
return nil, errors.New("nat requires a guest IP")
|
||||
}
|
||||
tap := strings.TrimSpace(vm.Runtime.TapDevice)
|
||||
if tap == "" {
|
||||
return nil, errors.New("nat requires a tap device")
|
||||
}
|
||||
uplink = strings.TrimSpace(uplink)
|
||||
if uplink == "" {
|
||||
return nil, errors.New("nat requires an uplink interface")
|
||||
}
|
||||
guestCIDR := guestIP + "/32"
|
||||
return []natRule{
|
||||
{
|
||||
table: "nat",
|
||||
chain: "POSTROUTING",
|
||||
args: []string{"-s", guestCIDR, "-o", uplink, "-j", "MASQUERADE"},
|
||||
},
|
||||
{
|
||||
chain: "FORWARD",
|
||||
args: []string{"-i", tap, "-o", uplink, "-j", "ACCEPT"},
|
||||
},
|
||||
{
|
||||
chain: "FORWARD",
|
||||
args: []string{"-i", uplink, "-o", tap, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func natRuleArgs(action string, rule natRule) []string {
|
||||
args := make([]string, 0, len(rule.args)+4)
|
||||
if rule.table != "" {
|
||||
args = append(args, "-t", rule.table)
|
||||
}
|
||||
args = append(args, action, rule.chain)
|
||||
args = append(args, rule.args...)
|
||||
return args
|
||||
}
|
||||
|
||||
func natAddPlan(rules []natRule) [][]string {
|
||||
plan := make([][]string, 0, len(rules)+1)
|
||||
plan = append(plan, []string{"sysctl", "-w", "net.ipv4.ip_forward=1"})
|
||||
for _, rule := range rules {
|
||||
plan = append(plan, natRuleArgs("-A", rule))
|
||||
}
|
||||
return plan
|
||||
}
|
||||
|
||||
func natRemovePlan(rules []natRule) [][]string {
|
||||
plan := make([][]string, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
plan = append(plan, natRuleArgs("-D", rule))
|
||||
}
|
||||
return plan
|
||||
}
|
||||
|
||||
func (d *Daemon) addNATRule(ctx context.Context, rule natRule) error {
|
||||
if _, err := d.runner.RunSudo(ctx, append([]string{"iptables"}, natRuleArgs("-C", rule)...)...); err == nil {
|
||||
return nil
|
||||
}
|
||||
_, err := d.runner.RunSudo(ctx, append([]string{"iptables"}, natRuleArgs("-A", rule)...)...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Daemon) removeNATRule(ctx context.Context, rule natRule) error {
|
||||
if _, err := d.runner.RunSudo(ctx, append([]string{"iptables"}, natRuleArgs("-C", rule)...)...); err != nil {
|
||||
return nil
|
||||
}
|
||||
_, err := d.runner.RunSudo(ctx, append([]string{"iptables"}, natRuleArgs("-D", rule)...)...)
|
||||
return err
|
||||
}
|
||||
|
||||
func natRuleKey(rule natRule) string {
|
||||
return fmt.Sprintf("%s:%s:%s", rule.table, rule.chain, strings.Join(rule.args, " "))
|
||||
}
|
||||
129
internal/daemon/nat_test.go
Normal file
129
internal/daemon/nat_test.go
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
package daemon
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"banger/internal/model"
|
||||
)
|
||||
|
||||
func TestParseDefaultUplink(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
output := "default via 192.168.1.1 dev enp5s0 proto dhcp src 192.168.1.40 metric 100\n"
|
||||
uplink, err := parseDefaultUplink(output)
|
||||
if err != nil {
|
||||
t.Fatalf("parseDefaultUplink returned error: %v", err)
|
||||
}
|
||||
if uplink != "enp5s0" {
|
||||
t.Fatalf("uplink = %q, want enp5s0", uplink)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDefaultUplinkFailsWithoutRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if _, err := parseDefaultUplink("10.0.0.0/24 dev br-fc proto kernel scope link src 10.0.0.1\n"); err == nil {
|
||||
t.Fatal("expected parseDefaultUplink to fail without a default route")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNATRulesForVM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
vm := model.VMRecord{
|
||||
Runtime: model.VMRuntime{
|
||||
GuestIP: "172.16.0.8",
|
||||
TapDevice: "tap-fc-abcd1234",
|
||||
},
|
||||
}
|
||||
rules, err := natRulesForVM(vm, "wlan0")
|
||||
if err != nil {
|
||||
t.Fatalf("natRulesForVM returned error: %v", err)
|
||||
}
|
||||
if len(rules) != 3 {
|
||||
t.Fatalf("rule count = %d, want 3", len(rules))
|
||||
}
|
||||
if got, want := natRuleArgs("-A", rules[0]), []string{"-t", "nat", "-A", "POSTROUTING", "-s", "172.16.0.8/32", "-o", "wlan0", "-j", "MASQUERADE"}; !slices.Equal(got, want) {
|
||||
t.Fatalf("postrouting args = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := natRuleArgs("-A", rules[1]), []string{"-A", "FORWARD", "-i", "tap-fc-abcd1234", "-o", "wlan0", "-j", "ACCEPT"}; !slices.Equal(got, want) {
|
||||
t.Fatalf("forward-out args = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := natRuleArgs("-A", rules[2]), []string{"-A", "FORWARD", "-i", "wlan0", "-o", "tap-fc-abcd1234", "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"}; !slices.Equal(got, want) {
|
||||
t.Fatalf("forward-in args = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNATRulesForVMRequiresRuntimeData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
vm model.VMRecord
|
||||
uplink string
|
||||
}{
|
||||
{
|
||||
name: "guest ip",
|
||||
vm: model.VMRecord{
|
||||
Runtime: model.VMRuntime{TapDevice: "tap-fc-abcd1234"},
|
||||
},
|
||||
uplink: "eth0",
|
||||
},
|
||||
{
|
||||
name: "tap",
|
||||
vm: model.VMRecord{
|
||||
Runtime: model.VMRuntime{GuestIP: "172.16.0.8"},
|
||||
},
|
||||
uplink: "eth0",
|
||||
},
|
||||
{
|
||||
name: "uplink",
|
||||
vm: model.VMRecord{
|
||||
Runtime: model.VMRuntime{
|
||||
GuestIP: "172.16.0.8",
|
||||
TapDevice: "tap-fc-abcd1234",
|
||||
},
|
||||
},
|
||||
uplink: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, err := natRulesForVM(tt.vm, tt.uplink); err == nil {
|
||||
t.Fatalf("expected natRulesForVM to fail for missing %s", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNATPlans(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rules := []natRule{
|
||||
{table: "nat", chain: "POSTROUTING", args: []string{"-s", "172.16.0.8/32", "-o", "eth0", "-j", "MASQUERADE"}},
|
||||
{chain: "FORWARD", args: []string{"-i", "tap-fc-abcd1234", "-o", "eth0", "-j", "ACCEPT"}},
|
||||
}
|
||||
|
||||
addPlan := natAddPlan(rules)
|
||||
if len(addPlan) != 3 {
|
||||
t.Fatalf("addPlan count = %d, want 3", len(addPlan))
|
||||
}
|
||||
if got, want := addPlan[0], []string{"sysctl", "-w", "net.ipv4.ip_forward=1"}; !slices.Equal(got, want) {
|
||||
t.Fatalf("sysctl command = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := addPlan[1], []string{"-t", "nat", "-A", "POSTROUTING", "-s", "172.16.0.8/32", "-o", "eth0", "-j", "MASQUERADE"}; !slices.Equal(got, want) {
|
||||
t.Fatalf("add NAT command = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
removePlan := natRemovePlan(rules)
|
||||
if len(removePlan) != 2 {
|
||||
t.Fatalf("removePlan count = %d, want 2", len(removePlan))
|
||||
}
|
||||
if got, want := removePlan[0], []string{"-t", "nat", "-D", "POSTROUTING", "-s", "172.16.0.8/32", "-o", "eth0", "-j", "MASQUERADE"}; !slices.Equal(got, want) {
|
||||
t.Fatalf("remove NAT command = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -686,23 +685,6 @@ func (d *Daemon) removeDNS(ctx context.Context, dnsName string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (d *Daemon) ensureNAT(ctx context.Context, vm model.VMRecord, enable bool) error {
|
||||
if d.config.RepoRoot == "" {
|
||||
return errors.New("repo root not detected")
|
||||
}
|
||||
script := filepath.Join(d.config.RepoRoot, "nat.sh")
|
||||
action := "down"
|
||||
if enable {
|
||||
action = "up"
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, "bash", script, action, vm.ID)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Dir = d.config.RepoRoot
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func (d *Daemon) killVMProcess(ctx context.Context, pid int) error {
|
||||
_, err := d.runner.RunSudo(ctx, "kill", "-KILL", strconv.Itoa(pid))
|
||||
return err
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue