package guest import ( "archive/tar" "bytes" "context" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "io" "net" "os" "path/filepath" "regexp" "strings" "testing" "time" ) func writeTestKey(t *testing.T) string { t.Helper() privateKey, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { t.Fatalf("GenerateKey: %v", err) } privateKeyPEM := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey), }) keyPath := filepath.Join(t.TempDir(), "id_rsa") if err := os.WriteFile(keyPath, privateKeyPEM, 0o600); err != nil { t.Fatalf("WriteFile: %v", err) } return keyPath } func TestAuthorizedPublicKeyFingerprint(t *testing.T) { t.Parallel() keyPath := writeTestKey(t) fp, err := AuthorizedPublicKeyFingerprint(keyPath) if err != nil { t.Fatalf("AuthorizedPublicKeyFingerprint: %v", err) } if !regexp.MustCompile(`^[0-9a-f]{64}$`).MatchString(fp) { t.Fatalf("fingerprint = %q, want 64 hex chars", fp) } fp2, err := AuthorizedPublicKeyFingerprint(keyPath) if err != nil { t.Fatalf("AuthorizedPublicKeyFingerprint (second): %v", err) } if fp != fp2 { t.Fatalf("fingerprint not deterministic: %q vs %q", fp, fp2) } } func TestAuthorizedPublicKeyFingerprintMissingFile(t *testing.T) { t.Parallel() _, err := AuthorizedPublicKeyFingerprint(filepath.Join(t.TempDir(), "nope")) if err == nil { t.Fatal("expected error for missing key file") } } func TestAuthorizedPublicKeyBadPEM(t *testing.T) { t.Parallel() keyPath := filepath.Join(t.TempDir(), "bad") if err := os.WriteFile(keyPath, []byte("not a private key"), 0o600); err != nil { t.Fatalf("WriteFile: %v", err) } if _, err := AuthorizedPublicKey(keyPath); err == nil { t.Fatal("expected ParsePrivateKey error") } } func TestShellQuote(t *testing.T) { t.Parallel() cases := []struct { in, want string }{ {"", "''"}, {"simple", "'simple'"}, {"with space", "'with space'"}, {"it's", `'it'"'"'s'`}, {"a'b'c", `'a'"'"'b'"'"'c'`}, {"/path/to/file", "'/path/to/file'"}, } for _, tc := range cases { got := shellQuote(tc.in) if got != tc.want { t.Errorf("shellQuote(%q) = %q, want %q", tc.in, got, tc.want) } } } func TestWriteTarEntriesArchiveRejectsEscape(t *testing.T) { t.Parallel() dir := t.TempDir() var buf bytes.Buffer err := writeTarEntriesArchive(&buf, dir, []string{"../escape"}) if err == nil { t.Fatal("expected error for escaping entry") } if !strings.Contains(err.Error(), "escapes source dir") { t.Fatalf("unexpected error: %v", err) } } func TestWriteTarEntriesArchiveRejectsDot(t *testing.T) { t.Parallel() dir := t.TempDir() var buf bytes.Buffer for _, bad := range []string{".", ".."} { if err := writeTarEntriesArchive(&buf, dir, []string{bad}); err == nil { t.Errorf("expected error for entry %q", bad) } } } func TestWriteTarEntriesArchiveDedupsAndSkipsBlank(t *testing.T) { t.Parallel() sourceDir := filepath.Join(t.TempDir(), "repo") if err := os.MkdirAll(sourceDir, 0o755); err != nil { t.Fatalf("MkdirAll: %v", err) } if err := os.WriteFile(filepath.Join(sourceDir, "a.txt"), []byte("A"), 0o644); err != nil { t.Fatalf("WriteFile: %v", err) } var buf bytes.Buffer if err := writeTarEntriesArchive(&buf, sourceDir, []string{"a.txt", "a.txt", "", " "}); err != nil { t.Fatalf("writeTarEntriesArchive: %v", err) } tr := tar.NewReader(&buf) var names []string for { h, err := tr.Next() if err == io.EOF { break } if err != nil { t.Fatalf("tar.Next: %v", err) } names = append(names, h.Name) } if len(names) != 1 || names[0] != "repo/a.txt" { t.Fatalf("names = %v, want [repo/a.txt]", names) } } func TestWriteTarEntriesArchiveSymlink(t *testing.T) { t.Parallel() sourceDir := filepath.Join(t.TempDir(), "repo") if err := os.MkdirAll(sourceDir, 0o755); err != nil { t.Fatalf("MkdirAll: %v", err) } if err := os.WriteFile(filepath.Join(sourceDir, "target.txt"), []byte("T"), 0o644); err != nil { t.Fatalf("WriteFile: %v", err) } linkPath := filepath.Join(sourceDir, "link") if err := os.Symlink("target.txt", linkPath); err != nil { t.Skipf("symlink unsupported: %v", err) } var buf bytes.Buffer if err := writeTarEntriesArchive(&buf, sourceDir, []string{"link"}); err != nil { t.Fatalf("writeTarEntriesArchive: %v", err) } tr := tar.NewReader(&buf) h, err := tr.Next() if err != nil { t.Fatalf("tar.Next: %v", err) } if h.Typeflag != tar.TypeSymlink { t.Fatalf("typeflag = %v, want TypeSymlink", h.Typeflag) } if h.Linkname != "target.txt" { t.Fatalf("linkname = %q, want target.txt", h.Linkname) } } func TestWriteTarEntriesArchiveMissingPath(t *testing.T) { t.Parallel() sourceDir := t.TempDir() var buf bytes.Buffer err := writeTarEntriesArchive(&buf, sourceDir, []string{"missing.txt"}) if err == nil { t.Fatal("expected error for missing entry") } } func TestStreamSessionNilSafe(t *testing.T) { t.Parallel() var s *StreamSession if s.Stdin() != nil || s.Stdout() != nil || s.Stderr() != nil { t.Fatal("nil StreamSession getters should return nil") } if err := s.Wait(); err != nil { t.Fatalf("nil Wait error: %v", err) } if err := s.Close(); err != nil { t.Fatalf("nil Close error: %v", err) } } func TestClientNilClose(t *testing.T) { t.Parallel() var c *Client if err := c.Close(); err != nil { t.Fatalf("nil Close error: %v", err) } c2 := &Client{} if err := c2.Close(); err != nil { t.Fatalf("empty Close error: %v", err) } } func TestClientRunScriptOutputNotConnected(t *testing.T) { t.Parallel() var c *Client if _, err := c.RunScriptOutput(context.Background(), "true"); err == nil { t.Fatal("expected not-connected error") } c2 := &Client{} if _, err := c2.RunScriptOutput(context.Background(), "true"); err == nil { t.Fatal("expected not-connected error") } } func TestClientStartCommandNotConnected(t *testing.T) { t.Parallel() var c *Client if _, err := c.StartCommand(context.Background(), "true"); err == nil { t.Fatal("expected not-connected error") } } func TestClientRunScriptNotConnected(t *testing.T) { t.Parallel() var c *Client if err := c.RunScript(context.Background(), "true", io.Discard); err == nil { t.Fatal("expected not-connected error") } } // freeAddr grabs a loopback port by listening briefly, then closing. Next // Dial to it deterministically fails with "connection refused" — no real // server on the far end, no flakiness from random ports being taken. func freeAddr(t *testing.T) string { t.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("net.Listen: %v", err) } addr := ln.Addr().String() if err := ln.Close(); err != nil { t.Fatalf("Close listener: %v", err) } return addr } func TestWaitForSSHContextCancel(t *testing.T) { t.Parallel() keyPath := writeTestKey(t) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() start := time.Now() err := WaitForSSH(ctx, freeAddr(t), keyPath, "", 10*time.Millisecond) if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("err = %v, want context.DeadlineExceeded", err) } if elapsed := time.Since(start); elapsed > 2*time.Second { t.Fatalf("took too long: %v", elapsed) } } func TestDialReturnsErrorForBadKey(t *testing.T) { t.Parallel() keyPath := filepath.Join(t.TempDir(), "bogus") if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil { t.Fatalf("WriteFile: %v", err) } _, err := Dial(context.Background(), freeAddr(t), keyPath, "") if err == nil { t.Fatal("expected error for bad key") } }