package hostnat import ( "context" "errors" "fmt" "reflect" "strings" "testing" ) type call struct { sudo bool name string args []string } type fakeRunner struct { calls []call // runResp maps "name arg0 arg1 ..." (Run, no sudo) to a scripted // (stdout, err) pair. Missing entries return error. runResp map[string]callResp // sudoMatcher decides whether a RunSudo call succeeds. If nil, all // RunSudo calls succeed with empty stdout. sudoMatcher func(args []string) ([]byte, error) } type callResp struct { out []byte err error } func (r *fakeRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) { c := call{name: name, args: append([]string(nil), args...)} r.calls = append(r.calls, c) key := name + " " + strings.Join(args, " ") if resp, ok := r.runResp[key]; ok { return resp.out, resp.err } return nil, fmt.Errorf("unexpected Run: %s", key) } func (r *fakeRunner) RunSudo(ctx context.Context, args ...string) ([]byte, error) { c := call{sudo: true, args: append([]string(nil), args...)} r.calls = append(r.calls, c) if r.sudoMatcher != nil { return r.sudoMatcher(args) } return nil, nil } func TestDefaultUplink(t *testing.T) { t.Parallel() r := &fakeRunner{ runResp: map[string]callResp{ "ip route show default": {out: []byte("default via 10.0.0.1 dev wlan0 proto dhcp\n")}, }, } got, err := DefaultUplink(context.Background(), r) if err != nil { t.Fatalf("DefaultUplink: %v", err) } if got != "wlan0" { t.Fatalf("got %q, want wlan0", got) } } func TestDefaultUplinkPropagatesRunError(t *testing.T) { t.Parallel() r := &fakeRunner{} _, err := DefaultUplink(context.Background(), r) if err == nil { t.Fatal("expected error from DefaultUplink when Run fails") } } func TestRuleKey(t *testing.T) { rule := Rule{Table: "nat", Chain: "POSTROUTING", Args: []string{"-s", "172.16.0.5/32"}} key := RuleKey(rule) if !strings.Contains(key, "nat") || !strings.Contains(key, "POSTROUTING") || !strings.Contains(key, "172.16.0.5/32") { t.Fatalf("key missing expected parts: %q", key) } // Different args → different key. other := Rule{Table: "nat", Chain: "POSTROUTING", Args: []string{"-s", "10.0.0.5/32"}} if RuleKey(rule) == RuleKey(other) { t.Fatal("RuleKey should differ for different args") } } func TestEnsureEnableInstallsRules(t *testing.T) { t.Parallel() r := &fakeRunner{ runResp: map[string]callResp{ "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, }, sudoMatcher: func(args []string) ([]byte, error) { // The first sudo call is sysctl; every subsequent call is // `iptables -C ...` (probe) followed by `iptables -A ...` // because the probe should report the rule is NOT present. if args[0] == "sysctl" { return nil, nil } if args[0] != "iptables" { return nil, fmt.Errorf("unexpected sudo prefix: %v", args) } // Fail -C (rule absent) so Ensure issues -A. for _, a := range args { if a == "-C" { return nil, errors.New("rule absent") } } return nil, nil }, } if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true); err != nil { t.Fatalf("Ensure: %v", err) } // Expect at least: 1 ip route, 1 sysctl, and for 3 rules: -C + -A = 6 iptables calls. if len(r.calls) < 8 { t.Fatalf("call count = %d, want >= 8; calls=%+v", len(r.calls), r.calls) } // First call is ip route; second is sysctl. if r.calls[0].name != "ip" { t.Errorf("calls[0] = %+v, want ip route", r.calls[0]) } if !r.calls[1].sudo || r.calls[1].args[0] != "sysctl" { t.Errorf("calls[1] = %+v, want sudo sysctl", r.calls[1]) } // Somewhere we must have an iptables -A POSTROUTING call. var sawAppend bool for _, c := range r.calls { if c.sudo && len(c.args) >= 3 && c.args[0] == "iptables" && contains(c.args, "-A") && contains(c.args, "POSTROUTING") { sawAppend = true break } } if !sawAppend { t.Fatal("no iptables -A POSTROUTING call observed") } } func TestEnsureEnableSkipsAppendWhenRulePresent(t *testing.T) { t.Parallel() r := &fakeRunner{ runResp: map[string]callResp{ "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, }, sudoMatcher: func(args []string) ([]byte, error) { // Probe succeeds → Ensure should NOT follow up with -A. return nil, nil }, } if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true); err != nil { t.Fatalf("Ensure: %v", err) } // No -A iptables calls should have been issued. for _, c := range r.calls { if c.sudo && contains(c.args, "iptables") && contains(c.args, "-A") { t.Fatalf("unexpected -A call with probe success: %+v", c) } } } func TestEnsureDisableRemovesRulesWhenPresent(t *testing.T) { t.Parallel() r := &fakeRunner{ runResp: map[string]callResp{ "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, }, sudoMatcher: func(args []string) ([]byte, error) { // Every probe succeeds → rule is present → -D is issued. return nil, nil }, } if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", false); err != nil { t.Fatalf("Ensure(disable): %v", err) } var sawDelete bool for _, c := range r.calls { if c.sudo && contains(c.args, "iptables") && contains(c.args, "-D") { sawDelete = true break } } if !sawDelete { t.Fatal("expected at least one iptables -D call") } // No sysctl on disable path. for _, c := range r.calls { if c.sudo && len(c.args) > 0 && c.args[0] == "sysctl" { t.Fatal("sysctl should not run on disable path") } } } func TestEnsureDisableSkipsRemovalWhenAbsent(t *testing.T) { t.Parallel() r := &fakeRunner{ runResp: map[string]callResp{ "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, }, sudoMatcher: func(args []string) ([]byte, error) { return nil, errors.New("rule not present") }, } if err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", false); err != nil { t.Fatalf("Ensure(disable, absent): %v", err) } for _, c := range r.calls { if c.sudo && contains(c.args, "iptables") && contains(c.args, "-D") { t.Fatalf("unexpected -D with absent rule: %+v", c) } } } func TestEnsurePropagatesUplinkError(t *testing.T) { t.Parallel() r := &fakeRunner{} // no runResp → ip route fails err := Ensure(context.Background(), r, "172.16.0.5", "tap-x", true) if err == nil { t.Fatal("expected uplink error to propagate") } } func TestEnsureValidatesInputs(t *testing.T) { t.Parallel() r := &fakeRunner{ runResp: map[string]callResp{ "ip route show default": {out: []byte("default via 10.0.0.1 dev eth0\n")}, }, } if err := Ensure(context.Background(), r, "", "tap-x", true); err == nil { t.Fatal("expected error for empty guestIP") } } func TestRuleArgsWithoutTable(t *testing.T) { // Sanity: RuleArgs should only prepend -t when Table is set. bare := Rule{Chain: "FORWARD", Args: []string{"-i", "eth0"}} got := RuleArgs("-A", bare) want := []string{"-A", "FORWARD", "-i", "eth0"} if !reflect.DeepEqual(got, want) { t.Fatalf("got %v, want %v", got, want) } } func contains(xs []string, target string) bool { for _, x := range xs { if x == target { return true } } return false }