core/whip/rtsp_relay_test.go
Cesar Mendivil c9304b7b63 feat: add RTSP relay support to WHIP server
- Introduced a new RTSP relay server that allows WHIP streams to be served over RTSP.
- Added `RTSPRelayAddr` configuration option to enable the RTSP relay.
- Implemented methods for updating codec parameters and forwarding RTP packets to RTSP clients.
- Enhanced the WHIP server to handle RTSP connections and manage multiple clients.
- Added comprehensive tests for RTSP relay functionality, including OPTIONS, DESCRIBE, SETUP, PLAY, and TEARDOWN requests.
2026-03-17 20:28:06 -07:00

402 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package whip
import (
"bufio"
"fmt"
"net"
"strings"
"testing"
"time"
)
// dialRTSP dials the relay and returns a bufio.Reader + the raw conn.
func dialRTSP(t *testing.T, addr string) (net.Conn, *bufio.Reader) {
t.Helper()
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
t.Fatalf("RTSP dial %s: %v", addr, err)
}
return conn, bufio.NewReader(conn)
}
// sendRTSP writes an RTSP request and returns the parsed response status line + headers.
func sendRTSP(t *testing.T, conn net.Conn, rdr *bufio.Reader, request string) (status string, headers map[string]string) {
t.Helper()
conn.SetWriteDeadline(time.Now().Add(3 * time.Second))
if _, err := conn.Write([]byte(request)); err != nil {
t.Fatalf("write RTSP request: %v", err)
}
conn.SetWriteDeadline(time.Time{})
headers = make(map[string]string)
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
line, err := rdr.ReadString('\n')
conn.SetReadDeadline(time.Time{})
if err != nil {
t.Fatalf("read RTSP status: %v", err)
}
status = strings.TrimRight(line, "\r\n")
for {
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
hdr, err := rdr.ReadString('\n')
conn.SetReadDeadline(time.Time{})
if err != nil {
t.Fatalf("read RTSP header: %v", err)
}
hdr = strings.TrimRight(hdr, "\r\n")
if hdr == "" {
break
}
idx := strings.IndexByte(hdr, ':')
if idx >= 0 {
k := strings.ToLower(strings.TrimSpace(hdr[:idx]))
v := strings.TrimSpace(hdr[idx+1:])
headers[k] = v
}
}
// If there's a body (Content-Length), drain it.
if cl, ok := headers["content-length"]; ok {
n := 0
fmt.Sscanf(cl, "%d", &n)
buf := make([]byte, n)
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
for read := 0; read < n; {
nn, err := rdr.Read(buf[read:])
if err != nil {
t.Fatalf("drain body: %v", err)
}
read += nn
}
conn.SetReadDeadline(time.Time{})
headers["__body__"] = string(buf)
}
return status, headers
}
// readInterleaved reads one RFC 2326 §10.12 interleaved binary frame.
// Returns channel index and payload.
func readInterleaved(t *testing.T, conn net.Conn, rdr *bufio.Reader) (channel int, payload []byte) {
t.Helper()
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
defer conn.SetReadDeadline(time.Time{})
// '$' marker
b, err := rdr.ReadByte()
if err != nil {
t.Fatalf("readInterleaved: read '$': %v", err)
}
if b != '$' {
t.Fatalf("readInterleaved: expected '$', got %02x", b)
}
ch, err := rdr.ReadByte()
if err != nil {
t.Fatalf("readInterleaved: read channel: %v", err)
}
hi, err := rdr.ReadByte()
if err != nil {
t.Fatalf("readInterleaved: read len hi: %v", err)
}
lo, err := rdr.ReadByte()
if err != nil {
t.Fatalf("readInterleaved: read len lo: %v", err)
}
length := int(hi)<<8 | int(lo)
payload = make([]byte, length)
for read := 0; read < length; {
nn, err := rdr.Read(payload[read:])
if err != nil {
t.Fatalf("readInterleaved: read payload: %v", err)
}
read += nn
}
return int(ch), payload
}
// freePort returns an unused TCP port on localhost.
func freePort(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("freePort: %v", err)
}
addr := ln.Addr().String()
ln.Close()
return addr
}
// TestRTSPRelayOptions verifies OPTIONS returns the correct Public header.
func TestRTSPRelayOptions(t *testing.T) {
addr := freePort(t)
relay := newRTSPRelay(addr, nil)
go relay.ListenAndServe() //nolint:errcheck
defer relay.Close()
time.Sleep(30 * time.Millisecond)
conn, rdr := dialRTSP(t, addr)
defer conn.Close()
status, hdrs := sendRTSP(t, conn, rdr,
"OPTIONS rtsp://"+addr+"/live/test RTSP/1.0\r\nCSeq: 1\r\n\r\n")
if !strings.Contains(status, "200") {
t.Fatalf("OPTIONS: expected 200, got %q", status)
}
pub := hdrs["public"]
for _, m := range []string{"OPTIONS", "DESCRIBE", "SETUP", "PLAY", "TEARDOWN"} {
if !strings.Contains(pub, m) {
t.Errorf("OPTIONS Public missing %q: %q", m, pub)
}
}
}
// TestRTSPRelayDescribeDefaultCodecs verifies DESCRIBE returns valid SDP with default codecs.
func TestRTSPRelayDescribeDefaultCodecs(t *testing.T) {
addr := freePort(t)
relay := newRTSPRelay(addr, nil)
go relay.ListenAndServe() //nolint:errcheck
defer relay.Close()
time.Sleep(30 * time.Millisecond)
conn, rdr := dialRTSP(t, addr)
defer conn.Close()
sendRTSP(t, conn, rdr,
"OPTIONS rtsp://"+addr+"/live/mystream RTSP/1.0\r\nCSeq: 1\r\n\r\n")
_, hdrs := sendRTSP(t, conn, rdr,
"DESCRIBE rtsp://"+addr+"/live/mystream RTSP/1.0\r\nCSeq: 2\r\nAccept: application/sdp\r\n\r\n")
if hdrs["content-type"] != "application/sdp" {
t.Fatalf("DESCRIBE: expected content-type application/sdp, got %q", hdrs["content-type"])
}
body := hdrs["__body__"]
for _, want := range []string{"H264", "opus", "a=control:track0", "a=control:track1"} {
if !strings.Contains(body, want) {
t.Errorf("DESCRIBE SDP missing %q:\n%s", want, body)
}
}
}
// TestRTSPRelayDescribeUpdatedCodecs verifies that UpdateCodecs changes the DESCRIBE SDP.
func TestRTSPRelayDescribeUpdatedCodecs(t *testing.T) {
addr := freePort(t)
relay := newRTSPRelay(addr, nil)
go relay.ListenAndServe() //nolint:errcheck
defer relay.Close()
time.Sleep(30 * time.Millisecond)
relay.UpdateCodecs("cam1", RTSPCodecParams{
VideoPayloadType: 96,
VideoCodec: "H264",
VideoClockRate: 90000,
VideoFmtp: "sprop-parameter-sets=abc;packetization-mode=1",
AudioPayloadType: 111,
AudioCodec: "opus",
AudioClockRate: 48000,
AudioChannels: 2,
})
conn, rdr := dialRTSP(t, addr)
defer conn.Close()
sendRTSP(t, conn, rdr,
"OPTIONS rtsp://"+addr+"/live/cam1 RTSP/1.0\r\nCSeq: 1\r\n\r\n")
_, hdrs := sendRTSP(t, conn, rdr,
"DESCRIBE rtsp://"+addr+"/live/cam1 RTSP/1.0\r\nCSeq: 2\r\nAccept: application/sdp\r\n\r\n")
body := hdrs["__body__"]
if !strings.Contains(body, "sprop-parameter-sets=abc") {
t.Errorf("DESCRIBE SDP should contain updated fmtp:\n%s", body)
}
}
// TestRTSPRelayDataDelivery is the full E2E test:
// 1. relay starts
// 2. RTSP client connects (OPTIONS → DESCRIBE → SETUP video → SETUP audio → PLAY)
// 3. WriteVideo + WriteAudio inject fake RTP packets
// 4. client reads the interleaved frames and verifies channel/payload match
func TestRTSPRelayDataDelivery(t *testing.T) {
addr := freePort(t)
relay := newRTSPRelay(addr, nil)
go relay.ListenAndServe() //nolint:errcheck
defer relay.Close()
time.Sleep(30 * time.Millisecond)
streamName := "obs-test"
relay.UpdateCodecs(streamName, RTSPCodecParams{
VideoPayloadType: 96,
VideoCodec: "H264",
VideoClockRate: 90000,
AudioPayloadType: 111,
AudioCodec: "opus",
AudioClockRate: 48000,
AudioChannels: 2,
})
base := "rtsp://" + addr + "/live/" + streamName
conn, rdr := dialRTSP(t, addr)
defer conn.Close()
// OPTIONS
status, _ := sendRTSP(t, conn, rdr,
fmt.Sprintf("OPTIONS %s RTSP/1.0\r\nCSeq: 1\r\n\r\n", base))
if !strings.Contains(status, "200") {
t.Fatalf("OPTIONS: %s", status)
}
// DESCRIBE
status, _ = sendRTSP(t, conn, rdr,
fmt.Sprintf("DESCRIBE %s RTSP/1.0\r\nCSeq: 2\r\nAccept: application/sdp\r\n\r\n", base))
if !strings.Contains(status, "200") {
t.Fatalf("DESCRIBE: %s", status)
}
// SETUP video interleaved channels 0-1
status, _ = sendRTSP(t, conn, rdr,
fmt.Sprintf("SETUP %s/track0 RTSP/1.0\r\nCSeq: 3\r\nTransport: RTP/AVP/TCP;unicast;interleaved=0-1;mode=play\r\n\r\n", base))
if !strings.Contains(status, "200") {
t.Fatalf("SETUP video: %s", status)
}
// SETUP audio interleaved channels 2-3
status, _ = sendRTSP(t, conn, rdr,
fmt.Sprintf("SETUP %s/track1 RTSP/1.0\r\nCSeq: 4\r\nTransport: RTP/AVP/TCP;unicast;interleaved=2-3;mode=play\r\n\r\n", base))
if !strings.Contains(status, "200") {
t.Fatalf("SETUP audio: %s", status)
}
// PLAY
status, _ = sendRTSP(t, conn, rdr,
fmt.Sprintf("PLAY %s RTSP/1.0\r\nCSeq: 5\r\nSession: ignored\r\n\r\n", base))
if !strings.Contains(status, "200") {
t.Fatalf("PLAY: %s", status)
}
// Inject a video RTP packet (fake but valid minimal structure).
videoRTP := buildFakeRTP(96, 1000, 90000)
relay.WriteVideo(streamName, videoRTP)
ch, payload := readInterleaved(t, conn, rdr)
if ch != 0 {
t.Errorf("expected video on interleaved channel 0, got %d", ch)
}
if len(payload) != len(videoRTP) {
t.Errorf("video payload length: want %d got %d", len(videoRTP), len(payload))
}
// Inject an audio RTP packet.
audioRTP := buildFakeRTP(111, 2000, 48000)
relay.WriteAudio(streamName, audioRTP)
ch, payload = readInterleaved(t, conn, rdr)
if ch != 2 {
t.Errorf("expected audio on interleaved channel 2, got %d", ch)
}
if len(payload) != len(audioRTP) {
t.Errorf("audio payload length: want %d got %d", len(audioRTP), len(payload))
}
_ = payload
}
// TestRTSPRelayMultipleSubscribers verifies that two simultaneous clients both
// receive the same RTP packets.
func TestRTSPRelayMultipleSubscribers(t *testing.T) {
addr := freePort(t)
relay := newRTSPRelay(addr, nil)
go relay.ListenAndServe() //nolint:errcheck
defer relay.Close()
time.Sleep(30 * time.Millisecond)
streamName := "multi"
base := "rtsp://" + addr + "/live/" + streamName
connectAndPlay := func(t *testing.T) (net.Conn, *bufio.Reader) {
t.Helper()
c, r := dialRTSP(t, addr)
for _, req := range []string{
fmt.Sprintf("OPTIONS %s RTSP/1.0\r\nCSeq: 1\r\n\r\n", base),
fmt.Sprintf("DESCRIBE %s RTSP/1.0\r\nCSeq: 2\r\nAccept: application/sdp\r\n\r\n", base),
fmt.Sprintf("SETUP %s/track0 RTSP/1.0\r\nCSeq: 3\r\nTransport: RTP/AVP/TCP;unicast;interleaved=0-1;mode=play\r\n\r\n", base),
fmt.Sprintf("PLAY %s RTSP/1.0\r\nCSeq: 4\r\nSession: ignored\r\n\r\n", base),
} {
status, _ := sendRTSP(t, c, r, req)
if !strings.Contains(status, "200") {
t.Fatalf("setup failed: %s", status)
}
}
return c, r
}
conn1, rdr1 := connectAndPlay(t)
defer conn1.Close()
conn2, rdr2 := connectAndPlay(t)
defer conn2.Close()
pkt := buildFakeRTP(96, 5000, 90000)
relay.WriteVideo(streamName, pkt)
ch1, p1 := readInterleaved(t, conn1, rdr1)
ch2, p2 := readInterleaved(t, conn2, rdr2)
if ch1 != 0 || ch2 != 0 {
t.Errorf("expected channel 0 for both clients, got %d and %d", ch1, ch2)
}
if len(p1) != len(pkt) || len(p2) != len(pkt) {
t.Errorf("packet length mismatch: sent %d, got %d and %d", len(pkt), len(p1), len(p2))
}
}
// TestRTSPRelayTEARDOWN verifies that TEARDOWN is handled gracefully.
func TestRTSPRelayTEARDOWN(t *testing.T) {
addr := freePort(t)
relay := newRTSPRelay(addr, nil)
go relay.ListenAndServe() //nolint:errcheck
defer relay.Close()
time.Sleep(30 * time.Millisecond)
base := "rtsp://" + addr + "/live/tdtest"
conn, rdr := dialRTSP(t, addr)
defer conn.Close()
for cseq, req := range []string{
fmt.Sprintf("OPTIONS %s RTSP/1.0\r\nCSeq: 1\r\n\r\n", base),
fmt.Sprintf("DESCRIBE %s RTSP/1.0\r\nCSeq: 2\r\nAccept: application/sdp\r\n\r\n", base),
fmt.Sprintf("SETUP %s/track0 RTSP/1.0\r\nCSeq: 3\r\nTransport: RTP/AVP/TCP;unicast;interleaved=0-1;mode=play\r\n\r\n", base),
fmt.Sprintf("PLAY %s RTSP/1.0\r\nCSeq: 4\r\nSession: ignored\r\n\r\n", base),
} {
status, _ := sendRTSP(t, conn, rdr, req)
if !strings.Contains(status, "200") {
t.Fatalf("step %d: %s", cseq, status)
}
}
// Send TEARDOWN in-band (as an RTSP request mixed with interleaved).
conn.SetWriteDeadline(time.Now().Add(3 * time.Second))
conn.Write([]byte(fmt.Sprintf("TEARDOWN %s RTSP/1.0\r\nCSeq: 5\r\n\r\n", base))) //nolint:errcheck
conn.SetWriteDeadline(time.Time{})
// Server should reply 200.
status, _ := sendRTSP(t, conn, rdr, "") // already written above, just drain response
_ = status // Server closes after TEARDOWN; a 200 or EOF is both acceptable
}
// buildFakeRTP constructs a minimal 12-byte RTP header + 4-byte payload.
func buildFakeRTP(pt uint8, seq uint16, ts uint32) []byte {
pkt := make([]byte, 16)
pkt[0] = 0x80 // V=2, P=0, X=0, CC=0
pkt[1] = pt & 0x7f // M=0, PT
pkt[2] = byte(seq >> 8) // Seq hi
pkt[3] = byte(seq) // Seq lo
pkt[4] = byte(ts >> 24) // TS
pkt[5] = byte(ts >> 16)
pkt[6] = byte(ts >> 8)
pkt[7] = byte(ts)
pkt[8], pkt[9], pkt[10], pkt[11] = 0, 0, 0, 1 // SSRC = 1
pkt[12] = 0xde // fake payload
pkt[13] = 0xad
pkt[14] = 0xbe
pkt[15] = 0xef
return pkt
}