package sessionstream import ( "bytes" "errors" "io" "testing" ) func TestWriteReadFrameRoundtrip(t *testing.T) { cases := []struct { name string channel byte payload []byte }{ {"stdout_bytes", ChannelStdout, []byte("hello world")}, {"stderr_bytes", ChannelStderr, []byte{0x00, 0xff, 0x7f}}, {"empty_payload", ChannelStdin, nil}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { var buf bytes.Buffer if err := WriteFrame(&buf, tc.channel, tc.payload); err != nil { t.Fatalf("WriteFrame: %v", err) } ch, got, err := ReadFrame(&buf) if err != nil { t.Fatalf("ReadFrame: %v", err) } if ch != tc.channel { t.Fatalf("channel = %d, want %d", ch, tc.channel) } if !bytes.Equal(got, tc.payload) && !(len(got) == 0 && len(tc.payload) == 0) { t.Fatalf("payload = %q, want %q", got, tc.payload) } }) } } type shortWriter struct { failAfter int written int } func (s *shortWriter) Write(p []byte) (int, error) { s.written += len(p) if s.written > s.failAfter { return 0, io.ErrShortWrite } return len(p), nil } func TestWriteFrameWriterError(t *testing.T) { w := &shortWriter{failAfter: 2} err := WriteFrame(w, ChannelStdout, []byte("payload")) if err == nil { t.Fatal("expected error from short writer") } } func TestReadFrameTruncated(t *testing.T) { _, _, err := ReadFrame(bytes.NewReader([]byte{0x02, 0x00})) if !errors.Is(err, io.ErrUnexpectedEOF) && err == nil { t.Fatalf("expected EOF-ish error, got %v", err) } // Header OK, but payload truncated. var buf bytes.Buffer buf.Write([]byte{ChannelStdout, 0x00, 0x00, 0x00, 0x05}) buf.Write([]byte("ab")) if _, _, err := ReadFrame(&buf); err == nil { t.Fatal("expected truncated payload error") } } func TestControlRoundtrip(t *testing.T) { code := 42 msg := ControlMessage{Type: "exit", ExitCode: &code} var buf bytes.Buffer if err := WriteControl(&buf, msg); err != nil { t.Fatalf("WriteControl: %v", err) } got, err := ReadNextControl(&buf) if err != nil { t.Fatalf("ReadNextControl: %v", err) } if got.Type != "exit" { t.Fatalf("type = %q, want exit", got.Type) } if got.ExitCode == nil || *got.ExitCode != 42 { t.Fatalf("exit_code = %v, want 42", got.ExitCode) } } func TestReadControlBadJSON(t *testing.T) { if _, err := ReadControl([]byte("{not json")); err == nil { t.Fatal("expected JSON error") } } func TestReadNextControlWrongChannel(t *testing.T) { var buf bytes.Buffer if err := WriteFrame(&buf, ChannelStdout, []byte("not a control frame")); err != nil { t.Fatalf("WriteFrame: %v", err) } if _, err := ReadNextControl(&buf); err == nil { t.Fatal("expected error for non-control channel") } } func TestFormatConstant(t *testing.T) { if FormatV1 != "stdio_mux_v1" { t.Fatalf("FormatV1 = %q, want stdio_mux_v1", FormatV1) } }