package hostnat import ( "context" "errors" "fmt" "strings" "banger/internal/system" ) type Rule struct { Table string Chain string Args []string } func Ensure(ctx context.Context, runner system.CommandRunner, guestIP, tapDevice string, enable bool) error { uplink, err := DefaultUplink(ctx, runner) if err != nil { return err } rules, err := Rules(guestIP, tapDevice, uplink) if err != nil { return err } if enable { if _, err := runner.RunSudo(ctx, "sysctl", "-w", "net.ipv4.ip_forward=1"); err != nil { return err } for _, rule := range rules { if err := addRule(ctx, runner, rule); err != nil { return err } } return nil } for _, rule := range rules { if err := removeRule(ctx, runner, rule); err != nil { return err } } return nil } func DefaultUplink(ctx context.Context, runner system.CommandRunner) (string, error) { out, err := runner.Run(ctx, "ip", "route", "show", "default") if err != nil { return "", err } return ParseDefaultUplink(string(out)) } func ParseDefaultUplink(output string) (string, error) { for _, line := range strings.Split(output, "\n") { fields := strings.Fields(line) if len(fields) == 0 || fields[0] != "default" { continue } for i := 0; i < len(fields)-1; i++ { if fields[i] == "dev" && fields[i+1] != "" { return fields[i+1], nil } } } return "", errors.New("failed to detect uplink interface") } func Rules(guestIP, tapDevice, uplink string) ([]Rule, error) { guestIP = strings.TrimSpace(guestIP) if guestIP == "" { return nil, errors.New("nat requires a guest IP") } tapDevice = strings.TrimSpace(tapDevice) if tapDevice == "" { return nil, errors.New("nat requires a tap device") } uplink = strings.TrimSpace(uplink) if uplink == "" { return nil, errors.New("nat requires an uplink interface") } guestCIDR := guestIP + "/32" return []Rule{ { Table: "nat", Chain: "POSTROUTING", Args: []string{"-s", guestCIDR, "-o", uplink, "-j", "MASQUERADE"}, }, { Chain: "FORWARD", Args: []string{"-i", tapDevice, "-o", uplink, "-j", "ACCEPT"}, }, { Chain: "FORWARD", Args: []string{"-i", uplink, "-o", tapDevice, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"}, }, }, nil } func RuleArgs(action string, rule Rule) []string { args := make([]string, 0, len(rule.Args)+4) if rule.Table != "" { args = append(args, "-t", rule.Table) } args = append(args, action, rule.Chain) args = append(args, rule.Args...) return args } func AddPlan(rules []Rule) [][]string { plan := make([][]string, 0, len(rules)+1) plan = append(plan, []string{"sysctl", "-w", "net.ipv4.ip_forward=1"}) for _, rule := range rules { plan = append(plan, RuleArgs("-A", rule)) } return plan } func RemovePlan(rules []Rule) [][]string { plan := make([][]string, 0, len(rules)) for _, rule := range rules { plan = append(plan, RuleArgs("-D", rule)) } return plan } func RuleKey(rule Rule) string { return fmt.Sprintf("%s:%s:%s", rule.Table, rule.Chain, strings.Join(rule.Args, " ")) } func addRule(ctx context.Context, runner system.CommandRunner, rule Rule) error { if _, err := runner.RunSudo(ctx, append([]string{"iptables"}, RuleArgs("-C", rule)...)...); err == nil { return nil } _, err := runner.RunSudo(ctx, append([]string{"iptables"}, RuleArgs("-A", rule)...)...) return err } func removeRule(ctx context.Context, runner system.CommandRunner, rule Rule) error { if _, err := runner.RunSudo(ctx, append([]string{"iptables"}, RuleArgs("-C", rule)...)...); err != nil { return nil } _, err := runner.RunSudo(ctx, append([]string{"iptables"}, RuleArgs("-D", rule)...)...) return err }