guest: tests for fingerprint, shellQuote, tar-entries edge cases, nil receivers

Pure-Go additions (no SSH server fixture): AuthorizedPublicKeyFingerprint,
shellQuote escaping, writeTarEntriesArchive error paths (.., ., missing,
duplicates, blank entries) and symlink handling, StreamSession/Client
nil-receiver safety, WaitForSSH context cancellation.

internal/guest coverage 17.8% -> 47.6%. Total 52.1% -> 52.6%. The
remaining uncovered paths need a real in-process SSH server; skip.
This commit is contained in:
Thales Maciel 2026-04-18 17:47:24 -03:00
parent 18bf89eae9
commit a3cc296523
No known key found for this signature in database
GPG key ID: 33112E6833C34679

View file

@ -0,0 +1,293 @@
package guest
import (
"archive/tar"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"net"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
"time"
)
func writeTestKey(t *testing.T) string {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
keyPath := filepath.Join(t.TempDir(), "id_rsa")
if err := os.WriteFile(keyPath, privateKeyPEM, 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
return keyPath
}
func TestAuthorizedPublicKeyFingerprint(t *testing.T) {
t.Parallel()
keyPath := writeTestKey(t)
fp, err := AuthorizedPublicKeyFingerprint(keyPath)
if err != nil {
t.Fatalf("AuthorizedPublicKeyFingerprint: %v", err)
}
if !regexp.MustCompile(`^[0-9a-f]{64}$`).MatchString(fp) {
t.Fatalf("fingerprint = %q, want 64 hex chars", fp)
}
fp2, err := AuthorizedPublicKeyFingerprint(keyPath)
if err != nil {
t.Fatalf("AuthorizedPublicKeyFingerprint (second): %v", err)
}
if fp != fp2 {
t.Fatalf("fingerprint not deterministic: %q vs %q", fp, fp2)
}
}
func TestAuthorizedPublicKeyFingerprintMissingFile(t *testing.T) {
t.Parallel()
_, err := AuthorizedPublicKeyFingerprint(filepath.Join(t.TempDir(), "nope"))
if err == nil {
t.Fatal("expected error for missing key file")
}
}
func TestAuthorizedPublicKeyBadPEM(t *testing.T) {
t.Parallel()
keyPath := filepath.Join(t.TempDir(), "bad")
if err := os.WriteFile(keyPath, []byte("not a private key"), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
if _, err := AuthorizedPublicKey(keyPath); err == nil {
t.Fatal("expected ParsePrivateKey error")
}
}
func TestShellQuote(t *testing.T) {
t.Parallel()
cases := []struct {
in, want string
}{
{"", "''"},
{"simple", "'simple'"},
{"with space", "'with space'"},
{"it's", `'it'"'"'s'`},
{"a'b'c", `'a'"'"'b'"'"'c'`},
{"/path/to/file", "'/path/to/file'"},
}
for _, tc := range cases {
got := shellQuote(tc.in)
if got != tc.want {
t.Errorf("shellQuote(%q) = %q, want %q", tc.in, got, tc.want)
}
}
}
func TestWriteTarEntriesArchiveRejectsEscape(t *testing.T) {
t.Parallel()
dir := t.TempDir()
var buf bytes.Buffer
err := writeTarEntriesArchive(&buf, dir, []string{"../escape"})
if err == nil {
t.Fatal("expected error for escaping entry")
}
if !strings.Contains(err.Error(), "escapes source dir") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestWriteTarEntriesArchiveRejectsDot(t *testing.T) {
t.Parallel()
dir := t.TempDir()
var buf bytes.Buffer
for _, bad := range []string{".", ".."} {
if err := writeTarEntriesArchive(&buf, dir, []string{bad}); err == nil {
t.Errorf("expected error for entry %q", bad)
}
}
}
func TestWriteTarEntriesArchiveDedupsAndSkipsBlank(t *testing.T) {
t.Parallel()
sourceDir := filepath.Join(t.TempDir(), "repo")
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
if err := os.WriteFile(filepath.Join(sourceDir, "a.txt"), []byte("A"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
var buf bytes.Buffer
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"a.txt", "a.txt", "", " "}); err != nil {
t.Fatalf("writeTarEntriesArchive: %v", err)
}
tr := tar.NewReader(&buf)
var names []string
for {
h, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("tar.Next: %v", err)
}
names = append(names, h.Name)
}
if len(names) != 1 || names[0] != "repo/a.txt" {
t.Fatalf("names = %v, want [repo/a.txt]", names)
}
}
func TestWriteTarEntriesArchiveSymlink(t *testing.T) {
t.Parallel()
sourceDir := filepath.Join(t.TempDir(), "repo")
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
if err := os.WriteFile(filepath.Join(sourceDir, "target.txt"), []byte("T"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
linkPath := filepath.Join(sourceDir, "link")
if err := os.Symlink("target.txt", linkPath); err != nil {
t.Skipf("symlink unsupported: %v", err)
}
var buf bytes.Buffer
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"link"}); err != nil {
t.Fatalf("writeTarEntriesArchive: %v", err)
}
tr := tar.NewReader(&buf)
h, err := tr.Next()
if err != nil {
t.Fatalf("tar.Next: %v", err)
}
if h.Typeflag != tar.TypeSymlink {
t.Fatalf("typeflag = %v, want TypeSymlink", h.Typeflag)
}
if h.Linkname != "target.txt" {
t.Fatalf("linkname = %q, want target.txt", h.Linkname)
}
}
func TestWriteTarEntriesArchiveMissingPath(t *testing.T) {
t.Parallel()
sourceDir := t.TempDir()
var buf bytes.Buffer
err := writeTarEntriesArchive(&buf, sourceDir, []string{"missing.txt"})
if err == nil {
t.Fatal("expected error for missing entry")
}
}
func TestStreamSessionNilSafe(t *testing.T) {
t.Parallel()
var s *StreamSession
if s.Stdin() != nil || s.Stdout() != nil || s.Stderr() != nil {
t.Fatal("nil StreamSession getters should return nil")
}
if err := s.Wait(); err != nil {
t.Fatalf("nil Wait error: %v", err)
}
if err := s.Close(); err != nil {
t.Fatalf("nil Close error: %v", err)
}
}
func TestClientNilClose(t *testing.T) {
t.Parallel()
var c *Client
if err := c.Close(); err != nil {
t.Fatalf("nil Close error: %v", err)
}
c2 := &Client{}
if err := c2.Close(); err != nil {
t.Fatalf("empty Close error: %v", err)
}
}
func TestClientRunScriptOutputNotConnected(t *testing.T) {
t.Parallel()
var c *Client
if _, err := c.RunScriptOutput(context.Background(), "true"); err == nil {
t.Fatal("expected not-connected error")
}
c2 := &Client{}
if _, err := c2.RunScriptOutput(context.Background(), "true"); err == nil {
t.Fatal("expected not-connected error")
}
}
func TestClientStartCommandNotConnected(t *testing.T) {
t.Parallel()
var c *Client
if _, err := c.StartCommand(context.Background(), "true"); err == nil {
t.Fatal("expected not-connected error")
}
}
func TestClientRunScriptNotConnected(t *testing.T) {
t.Parallel()
var c *Client
if err := c.RunScript(context.Background(), "true", io.Discard); err == nil {
t.Fatal("expected not-connected error")
}
}
// freeAddr grabs a loopback port by listening briefly, then closing. Next
// Dial to it deterministically fails with "connection refused" — no real
// server on the far end, no flakiness from random ports being taken.
func freeAddr(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen: %v", err)
}
addr := ln.Addr().String()
if err := ln.Close(); err != nil {
t.Fatalf("Close listener: %v", err)
}
return addr
}
func TestWaitForSSHContextCancel(t *testing.T) {
t.Parallel()
keyPath := writeTestKey(t)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
err := WaitForSSH(ctx, freeAddr(t), keyPath, 10*time.Millisecond)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("err = %v, want context.DeadlineExceeded", err)
}
if elapsed := time.Since(start); elapsed > 2*time.Second {
t.Fatalf("took too long: %v", elapsed)
}
}
func TestDialReturnsErrorForBadKey(t *testing.T) {
t.Parallel()
keyPath := filepath.Join(t.TempDir(), "bogus")
if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
_, err := Dial(context.Background(), freeAddr(t), keyPath)
if err == nil {
t.Fatal("expected error for bad key")
}
}