guest: tests for fingerprint, shellQuote, tar-entries edge cases, nil receivers
Pure-Go additions (no SSH server fixture): AuthorizedPublicKeyFingerprint, shellQuote escaping, writeTarEntriesArchive error paths (.., ., missing, duplicates, blank entries) and symlink handling, StreamSession/Client nil-receiver safety, WaitForSSH context cancellation. internal/guest coverage 17.8% -> 47.6%. Total 52.1% -> 52.6%. The remaining uncovered paths need a real in-process SSH server; skip.
This commit is contained in:
parent
18bf89eae9
commit
a3cc296523
1 changed files with 293 additions and 0 deletions
293
internal/guest/ssh_more_test.go
Normal file
293
internal/guest/ssh_more_test.go
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
package guest
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func writeTestKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
keyPath := filepath.Join(t.TempDir(), "id_rsa")
|
||||
if err := os.WriteFile(keyPath, privateKeyPEM, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
return keyPath
|
||||
}
|
||||
|
||||
func TestAuthorizedPublicKeyFingerprint(t *testing.T) {
|
||||
t.Parallel()
|
||||
keyPath := writeTestKey(t)
|
||||
|
||||
fp, err := AuthorizedPublicKeyFingerprint(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("AuthorizedPublicKeyFingerprint: %v", err)
|
||||
}
|
||||
if !regexp.MustCompile(`^[0-9a-f]{64}$`).MatchString(fp) {
|
||||
t.Fatalf("fingerprint = %q, want 64 hex chars", fp)
|
||||
}
|
||||
|
||||
fp2, err := AuthorizedPublicKeyFingerprint(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("AuthorizedPublicKeyFingerprint (second): %v", err)
|
||||
}
|
||||
if fp != fp2 {
|
||||
t.Fatalf("fingerprint not deterministic: %q vs %q", fp, fp2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizedPublicKeyFingerprintMissingFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := AuthorizedPublicKeyFingerprint(filepath.Join(t.TempDir(), "nope"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing key file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizedPublicKeyBadPEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
keyPath := filepath.Join(t.TempDir(), "bad")
|
||||
if err := os.WriteFile(keyPath, []byte("not a private key"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
if _, err := AuthorizedPublicKey(keyPath); err == nil {
|
||||
t.Fatal("expected ParsePrivateKey error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellQuote(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"", "''"},
|
||||
{"simple", "'simple'"},
|
||||
{"with space", "'with space'"},
|
||||
{"it's", `'it'"'"'s'`},
|
||||
{"a'b'c", `'a'"'"'b'"'"'c'`},
|
||||
{"/path/to/file", "'/path/to/file'"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := shellQuote(tc.in)
|
||||
if got != tc.want {
|
||||
t.Errorf("shellQuote(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTarEntriesArchiveRejectsEscape(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
var buf bytes.Buffer
|
||||
err := writeTarEntriesArchive(&buf, dir, []string{"../escape"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for escaping entry")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "escapes source dir") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTarEntriesArchiveRejectsDot(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
var buf bytes.Buffer
|
||||
for _, bad := range []string{".", ".."} {
|
||||
if err := writeTarEntriesArchive(&buf, dir, []string{bad}); err == nil {
|
||||
t.Errorf("expected error for entry %q", bad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTarEntriesArchiveDedupsAndSkipsBlank(t *testing.T) {
|
||||
t.Parallel()
|
||||
sourceDir := filepath.Join(t.TempDir(), "repo")
|
||||
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "a.txt"), []byte("A"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"a.txt", "a.txt", "", " "}); err != nil {
|
||||
t.Fatalf("writeTarEntriesArchive: %v", err)
|
||||
}
|
||||
|
||||
tr := tar.NewReader(&buf)
|
||||
var names []string
|
||||
for {
|
||||
h, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("tar.Next: %v", err)
|
||||
}
|
||||
names = append(names, h.Name)
|
||||
}
|
||||
if len(names) != 1 || names[0] != "repo/a.txt" {
|
||||
t.Fatalf("names = %v, want [repo/a.txt]", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTarEntriesArchiveSymlink(t *testing.T) {
|
||||
t.Parallel()
|
||||
sourceDir := filepath.Join(t.TempDir(), "repo")
|
||||
if err := os.MkdirAll(sourceDir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sourceDir, "target.txt"), []byte("T"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
linkPath := filepath.Join(sourceDir, "link")
|
||||
if err := os.Symlink("target.txt", linkPath); err != nil {
|
||||
t.Skipf("symlink unsupported: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := writeTarEntriesArchive(&buf, sourceDir, []string{"link"}); err != nil {
|
||||
t.Fatalf("writeTarEntriesArchive: %v", err)
|
||||
}
|
||||
|
||||
tr := tar.NewReader(&buf)
|
||||
h, err := tr.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("tar.Next: %v", err)
|
||||
}
|
||||
if h.Typeflag != tar.TypeSymlink {
|
||||
t.Fatalf("typeflag = %v, want TypeSymlink", h.Typeflag)
|
||||
}
|
||||
if h.Linkname != "target.txt" {
|
||||
t.Fatalf("linkname = %q, want target.txt", h.Linkname)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteTarEntriesArchiveMissingPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
sourceDir := t.TempDir()
|
||||
var buf bytes.Buffer
|
||||
err := writeTarEntriesArchive(&buf, sourceDir, []string{"missing.txt"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamSessionNilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s *StreamSession
|
||||
if s.Stdin() != nil || s.Stdout() != nil || s.Stderr() != nil {
|
||||
t.Fatal("nil StreamSession getters should return nil")
|
||||
}
|
||||
if err := s.Wait(); err != nil {
|
||||
t.Fatalf("nil Wait error: %v", err)
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("nil Close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientNilClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c *Client
|
||||
if err := c.Close(); err != nil {
|
||||
t.Fatalf("nil Close error: %v", err)
|
||||
}
|
||||
c2 := &Client{}
|
||||
if err := c2.Close(); err != nil {
|
||||
t.Fatalf("empty Close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRunScriptOutputNotConnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c *Client
|
||||
if _, err := c.RunScriptOutput(context.Background(), "true"); err == nil {
|
||||
t.Fatal("expected not-connected error")
|
||||
}
|
||||
c2 := &Client{}
|
||||
if _, err := c2.RunScriptOutput(context.Background(), "true"); err == nil {
|
||||
t.Fatal("expected not-connected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientStartCommandNotConnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c *Client
|
||||
if _, err := c.StartCommand(context.Background(), "true"); err == nil {
|
||||
t.Fatal("expected not-connected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRunScriptNotConnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c *Client
|
||||
if err := c.RunScript(context.Background(), "true", io.Discard); err == nil {
|
||||
t.Fatal("expected not-connected error")
|
||||
}
|
||||
}
|
||||
|
||||
// freeAddr grabs a loopback port by listening briefly, then closing. Next
|
||||
// Dial to it deterministically fails with "connection refused" — no real
|
||||
// server on the far end, no flakiness from random ports being taken.
|
||||
func freeAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen: %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Fatalf("Close listener: %v", err)
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func TestWaitForSSHContextCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
keyPath := writeTestKey(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
err := WaitForSSH(ctx, freeAddr(t), keyPath, 10*time.Millisecond)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("err = %v, want context.DeadlineExceeded", err)
|
||||
}
|
||||
if elapsed := time.Since(start); elapsed > 2*time.Second {
|
||||
t.Fatalf("took too long: %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialReturnsErrorForBadKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
keyPath := filepath.Join(t.TempDir(), "bogus")
|
||||
if err := os.WriteFile(keyPath, []byte("nope"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile: %v", err)
|
||||
}
|
||||
_, err := Dial(context.Background(), freeAddr(t), keyPath)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bad key")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue