core/whip/rtsp_relay.go
Cesar Mendivil c9304b7b63 feat: add RTSP relay support to WHIP server
- Introduced a new RTSP relay server that allows WHIP streams to be served over RTSP.
- Added `RTSPRelayAddr` configuration option to enable the RTSP relay.
- Implemented methods for updating codec parameters and forwarding RTP packets to RTSP clients.
- Enhanced the WHIP server to handle RTSP connections and manage multiple clients.
- Added comprehensive tests for RTSP relay functionality, including OPTIONS, DESCRIBE, SETUP, PLAY, and TEARDOWN requests.
2026-03-17 20:28:06 -07:00

502 lines
14 KiB
Go

// Package whip provides a minimal RFC 2326 RTSP relay server.
//
// The RTSP relay accepts TCP connections from FFmpeg (or any RTSP consumer)
// and forwards RTP packets from active WHIP publishers using RTP/AVP/TCP
// interleaved mode (RFC 2326 §10.12).
//
// Each stream is served at rtsp://<host>:<port>/live/<name>.
// Multiple consumers may connect to the same stream simultaneously.
// Only RTP/AVP/TCP transport (interleaved binary) is supported.
package whip
import (
"bufio"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/datarhei/core/v16/log"
)
// RTSPCodecParams holds codec parameters for a single WHIP stream.
// These are used to build the RTSP DESCRIBE response (SDP).
type RTSPCodecParams struct {
VideoPayloadType int
VideoCodec string // e.g. "H264"
VideoClockRate int
VideoFmtp string // e.g. "sprop-parameter-sets=...;packetization-mode=1"
AudioPayloadType int
AudioCodec string // e.g. "opus"
AudioClockRate int
AudioChannels int
}
// rtspRelay is a minimal TCP RTSP server that relays WHIP streams to multiple
// FFmpeg consumers using RTP/AVP/TCP interleaved mode (RFC 2326 §10.12).
type rtspRelay struct {
addr string
ln net.Listener
streams sync.Map // name -> *rtspStreamEntry
logger log.Logger
done chan struct{}
wg sync.WaitGroup
}
// rtspStreamEntry tracks per-stream codec params and connected consumers.
type rtspStreamEntry struct {
name string
mu sync.RWMutex
codecs RTSPCodecParams
clients map[string]*rtspRelayClient
}
// rtspRelayClient represents one connected RTSP consumer (e.g. an FFmpeg process).
type rtspRelayClient struct {
id string
conn net.Conn
videoTrack int // interleaved channel index for video RTP
audioTrack int // interleaved channel index for audio RTP
mu sync.Mutex
closed bool
}
// newRTSPRelay creates a new relay server bound to addr. Call ListenAndServe to start.
func newRTSPRelay(addr string, logger log.Logger) *rtspRelay {
if logger == nil {
logger = log.New("")
}
return &rtspRelay{
addr: addr,
logger: logger,
done: make(chan struct{}),
}
}
// ListenAndServe starts the RTSP relay and blocks until Close is called.
func (r *rtspRelay) ListenAndServe() error {
ln, err := net.Listen("tcp", r.addr)
if err != nil {
return fmt.Errorf("rtsp relay: listen %q: %w", r.addr, err)
}
r.ln = ln
r.logger.Info().WithField("addr", r.addr).Log("RTSP relay started")
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-r.done:
return nil
default:
r.logger.Error().WithField("error", err).Log("RTSP relay: accept error")
continue
}
}
r.wg.Add(1)
go func() {
defer r.wg.Done()
r.handleConn(conn)
}()
}
}
// Close shuts down the relay and waits for all connections to finish.
func (r *rtspRelay) Close() {
select {
case <-r.done:
default:
close(r.done)
}
if r.ln != nil {
r.ln.Close()
}
r.wg.Wait()
}
// UpdateCodecs updates the codec parameters used in DESCRIBE responses for name.
func (r *rtspRelay) UpdateCodecs(name string, params RTSPCodecParams) {
e := r.getOrCreateStream(name)
e.mu.Lock()
e.codecs = params
e.mu.Unlock()
}
// WriteVideo forwards a video RTP packet to all consumers watching name.
func (r *rtspRelay) WriteVideo(name string, pkt []byte) {
v, ok := r.streams.Load(name)
if !ok {
return
}
e := v.(*rtspStreamEntry)
e.mu.RLock()
clients := make([]*rtspRelayClient, 0, len(e.clients))
for _, c := range e.clients {
clients = append(clients, c)
}
e.mu.RUnlock()
for _, c := range clients {
c.writeInterleaved(c.videoTrack, pkt)
}
}
// WriteAudio forwards an audio RTP packet to all consumers watching name.
func (r *rtspRelay) WriteAudio(name string, pkt []byte) {
v, ok := r.streams.Load(name)
if !ok {
return
}
e := v.(*rtspStreamEntry)
e.mu.RLock()
clients := make([]*rtspRelayClient, 0, len(e.clients))
for _, c := range e.clients {
clients = append(clients, c)
}
e.mu.RUnlock()
for _, c := range clients {
c.writeInterleaved(c.audioTrack, pkt)
}
}
// getOrCreateStream returns the named stream entry, creating it with sensible defaults if needed.
func (r *rtspRelay) getOrCreateStream(name string) *rtspStreamEntry {
e := &rtspStreamEntry{
name: name,
clients: make(map[string]*rtspRelayClient),
codecs: RTSPCodecParams{
VideoPayloadType: 96,
VideoCodec: "H264",
VideoClockRate: 90000,
AudioPayloadType: 111,
AudioCodec: "opus",
AudioClockRate: 48000,
AudioChannels: 2,
},
}
actual, _ := r.streams.LoadOrStore(name, e)
return actual.(*rtspStreamEntry)
}
// writeInterleaved sends one RFC 2326 §10.12 interleaved binary packet over the TCP connection.
// If the write times out (slow consumer), the connection is closed to avoid head-of-line blocking.
func (c *rtspRelayClient) writeInterleaved(channel int, data []byte) {
if len(data) > 65535 {
return
}
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
hdr := [4]byte{'$', byte(channel), byte(len(data) >> 8), byte(len(data))}
c.conn.SetWriteDeadline(time.Now().Add(150 * time.Millisecond))
if _, err := c.conn.Write(hdr[:]); err != nil {
c.closed = true
c.conn.Close()
return
}
if _, err := c.conn.Write(data); err != nil {
c.closed = true
c.conn.Close()
return
}
c.conn.SetWriteDeadline(time.Time{})
}
// handleConn processes one RTSP TCP connection through OPTIONS → DESCRIBE → SETUP → PLAY.
func (r *rtspRelay) handleConn(conn net.Conn) {
defer conn.Close()
rdr := bufio.NewReader(conn)
sessionID := fmt.Sprintf("%016x", time.Now().UnixNano())
var entry *rtspStreamEntry
var client *rtspRelayClient
videoInterleaved := 0
audioInterleaved := 2
for {
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
method, reqURL, headers, err := rtspReadRequest(rdr)
conn.SetReadDeadline(time.Time{})
if err != nil {
return
}
cseq := headers["cseq"]
switch method {
case "OPTIONS":
rtspWriteResponse(conn, 200, "OK", cseq, map[string]string{
"Public": "OPTIONS, DESCRIBE, SETUP, PLAY, TEARDOWN",
}, "")
case "DESCRIBE":
name := rtspExtractName(reqURL)
if name == "" {
rtspWriteResponse(conn, 404, "Not Found", cseq, nil, "")
return
}
entry = r.getOrCreateStream(name)
sdp := entry.buildSDP()
rtspWriteResponse(conn, 200, "OK", cseq, map[string]string{
"Content-Type": "application/sdp",
"Content-Base": reqURL + "/",
"Content-Length": strconv.Itoa(len(sdp)),
}, sdp)
case "SETUP":
transport := headers["transport"]
if !strings.Contains(strings.ToLower(transport), "tcp") {
// Only TCP interleaved supported; decline UDP politely.
rtspWriteResponse(conn, 461, "Unsupported Transport", cseq, nil, "")
return
}
isAudio := strings.HasSuffix(strings.TrimRight(reqURL, "/"), "track1")
ch := rtspParseInterleaved(transport)
if isAudio {
if ch >= 0 {
audioInterleaved = ch
}
rtspWriteResponse(conn, 200, "OK", cseq, map[string]string{
"Session": sessionID + ";timeout=60",
"Transport": fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d;mode=play", audioInterleaved, audioInterleaved+1),
}, "")
} else {
if ch >= 0 {
videoInterleaved = ch
}
rtspWriteResponse(conn, 200, "OK", cseq, map[string]string{
"Session": sessionID + ";timeout=60",
"Transport": fmt.Sprintf("RTP/AVP/TCP;unicast;interleaved=%d-%d;mode=play", videoInterleaved, videoInterleaved+1),
}, "")
}
case "PLAY":
if entry == nil {
rtspWriteResponse(conn, 455, "Method Not Valid in This State", cseq, nil, "")
return
}
client = &rtspRelayClient{
id: sessionID,
conn: conn,
videoTrack: videoInterleaved,
audioTrack: audioInterleaved,
}
entry.mu.Lock()
entry.clients[sessionID] = client
entry.mu.Unlock()
rtspWriteResponse(conn, 200, "OK", cseq, map[string]string{
"Session": sessionID,
}, "")
r.logger.Info().
WithField("stream", entry.name).
WithField("session", sessionID).
Log("RTSP client connected")
// Stay connected: discard interleaved data, handle TEARDOWN.
for {
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
b, readErr := rdr.ReadByte()
conn.SetReadDeadline(time.Time{})
if readErr != nil {
break
}
if b == '$' {
// Discard interleaved binary data sent by the client (e.g. RTCP RR).
var meta [3]byte
if _, err := io.ReadFull(rdr, meta[:]); err != nil {
break
}
dataLen := int(meta[1])<<8 | int(meta[2])
if dataLen > 0 {
if _, err := io.CopyN(io.Discard, rdr, int64(dataLen)); err != nil {
break
}
}
continue
}
rdr.UnreadByte()
m, _, innerHdrs, err := rtspReadRequest(rdr)
if err != nil || m == "TEARDOWN" {
if m == "TEARDOWN" {
rtspWriteResponse(conn, 200, "OK", innerHdrs["cseq"], map[string]string{
"Session": sessionID,
}, "")
}
break
}
}
entry.mu.Lock()
delete(entry.clients, sessionID)
entry.mu.Unlock()
r.logger.Info().
WithField("stream", entry.name).
WithField("session", sessionID).
Log("RTSP client disconnected")
return
case "TEARDOWN":
if entry != nil {
entry.mu.Lock()
delete(entry.clients, sessionID)
entry.mu.Unlock()
}
rtspWriteResponse(conn, 200, "OK", cseq, map[string]string{
"Session": sessionID,
}, "")
return
default:
rtspWriteResponse(conn, 405, "Method Not Allowed", cseq, nil, "")
}
}
}
// buildSDP returns an SDP description for this stream suitable for RTSP DESCRIBE.
func (e *rtspStreamEntry) buildSDP() string {
e.mu.RLock()
c := e.codecs
e.mu.RUnlock()
var sb strings.Builder
sb.WriteString("v=0\r\n")
sb.WriteString("o=- 1 1 IN IP4 127.0.0.1\r\n")
sb.WriteString("s=WHIP Live Stream\r\n")
sb.WriteString("c=IN IP4 127.0.0.1\r\n")
sb.WriteString("t=0 0\r\n")
sb.WriteString(fmt.Sprintf("m=video 0 RTP/AVP %d\r\n", c.VideoPayloadType))
sb.WriteString(fmt.Sprintf("a=rtpmap:%d %s/%d\r\n", c.VideoPayloadType, c.VideoCodec, c.VideoClockRate))
if c.VideoFmtp != "" {
sb.WriteString(fmt.Sprintf("a=fmtp:%d %s\r\n", c.VideoPayloadType, c.VideoFmtp))
} else if strings.EqualFold(c.VideoCodec, "H264") {
sb.WriteString(fmt.Sprintf("a=fmtp:%d packetization-mode=1\r\n", c.VideoPayloadType))
}
sb.WriteString("a=control:track0\r\n")
sb.WriteString(fmt.Sprintf("m=audio 0 RTP/AVP %d\r\n", c.AudioPayloadType))
if c.AudioChannels > 1 {
sb.WriteString(fmt.Sprintf("a=rtpmap:%d %s/%d/%d\r\n", c.AudioPayloadType, c.AudioCodec, c.AudioClockRate, c.AudioChannels))
} else {
sb.WriteString(fmt.Sprintf("a=rtpmap:%d %s/%d\r\n", c.AudioPayloadType, c.AudioCodec, c.AudioClockRate))
}
sb.WriteString("a=control:track1\r\n")
return sb.String()
}
// rtspReadRequest reads one RTSP request (request line + headers) from rdr.
// Returns method, URL, lowercase-keyed headers map, and any error.
func rtspReadRequest(rdr *bufio.Reader) (method, reqURL string, headers map[string]string, err error) {
headers = make(map[string]string)
// First line: "METHOD URL RTSP/1.0"
line, err := rdr.ReadString('\n')
if err != nil {
return
}
line = strings.TrimRight(line, "\r\n")
parts := strings.SplitN(line, " ", 3)
if len(parts) < 2 {
err = fmt.Errorf("invalid RTSP request line: %q", line)
return
}
method = strings.ToUpper(parts[0])
reqURL = parts[1]
// Headers until blank line.
for {
line, err = rdr.ReadString('\n')
if err != nil {
return
}
line = strings.TrimRight(line, "\r\n")
if line == "" {
break
}
idx := strings.IndexByte(line, ':')
if idx < 0 {
continue
}
key := strings.ToLower(strings.TrimSpace(line[:idx]))
val := strings.TrimSpace(line[idx+1:])
headers[key] = val
}
return
}
// rtspWriteResponse writes an RTSP/1.0 response to conn.
func rtspWriteResponse(conn net.Conn, code int, reason, cseq string, extra map[string]string, body string) {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("RTSP/1.0 %d %s\r\n", code, reason))
if cseq != "" {
sb.WriteString("CSeq: " + cseq + "\r\n")
}
sb.WriteString("Server: datarhei-core-whip/1.0\r\n")
sb.WriteString(fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT")))
for k, v := range extra {
sb.WriteString(k + ": " + v + "\r\n")
}
sb.WriteString("\r\n")
if body != "" {
sb.WriteString(body)
}
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
conn.Write([]byte(sb.String())) //nolint:errcheck
conn.SetWriteDeadline(time.Time{})
}
// rtspExtractName extracts the stream name from an RTSP URL like
// rtsp://127.0.0.1:8554/live/mystream or rtsp://host/live/mystream/track0.
func rtspExtractName(rtspURL string) string {
u := rtspURL
// Strip scheme + host.
if idx := strings.Index(u, "://"); idx >= 0 {
u = u[idx+3:]
}
if slash := strings.Index(u, "/"); slash >= 0 {
u = u[slash+1:]
}
// Expect "live/<name>[/trackN]"
if strings.HasPrefix(u, "live/") {
u = u[5:]
}
// Strip any trailing "/trackN" suffix.
if idx := strings.LastIndex(u, "/"); idx >= 0 {
tail := u[idx+1:]
if strings.HasPrefix(tail, "track") {
u = u[:idx]
}
}
// Strip query string.
if idx := strings.IndexByte(u, '?'); idx >= 0 {
u = u[:idx]
}
return u
}
// rtspParseInterleaved parses the "interleaved=X-Y" field from an RTSP Transport
// header and returns X (the RTP channel index), or -1 if not present.
func rtspParseInterleaved(transport string) int {
lower := strings.ToLower(transport)
idx := strings.Index(lower, "interleaved=")
if idx < 0 {
return -1
}
rest := transport[idx+12:]
parts := strings.SplitN(rest, "-", 2)
n, err := strconv.Atoi(strings.TrimSpace(parts[0]))
if err != nil {
return -1
}
return n
}