package whip import ( "fmt" "net" "strings" "sync" "time" "github.com/pion/interceptor" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" ) // pionProvider implements WebRTCProvider using pion/webrtc v3. // It handles full ICE + DTLS-SRTP negotiation and forwards the decrypted // RTP to the local UDP relay sockets that FFmpeg reads. type pionProvider struct { mu sync.Mutex pcs map[string]*webrtc.PeerConnection } // NewPionProvider returns a WebRTCProvider backed by pion/webrtc v3. func NewPionProvider() WebRTCProvider { return &pionProvider{ pcs: make(map[string]*webrtc.PeerConnection), } } // OpenSession implements WebRTCProvider. func (p *pionProvider) OpenSession(offerSDP string, videoPort, audioPort int) (string, error) { m := &webrtc.MediaEngine{} if err := m.RegisterDefaultCodecs(); err != nil { return "", fmt.Errorf("pion: register codecs: %w", err) } ir := &interceptor.Registry{} if err := webrtc.RegisterDefaultInterceptors(m, ir); err != nil { return "", fmt.Errorf("pion: register interceptors: %w", err) } api := webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithInterceptorRegistry(ir)) pc, err := api.NewPeerConnection(webrtc.Configuration{}) if err != nil { return "", fmt.Errorf("pion: new peer connection: %w", err) } _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionRecvonly}) if err != nil { pc.Close() return "", fmt.Errorf("pion: add video transceiver: %w", err) } _, err = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionRecvonly}) if err != nil { pc.Close() return "", fmt.Errorf("pion: add audio transceiver: %w", err) } if err := pc.SetRemoteDescription(webrtc.SessionDescription{ Type: webrtc.SDPTypeOffer, SDP: offerSDP, }); err != nil { pc.Close() return "", fmt.Errorf("pion: set remote description: %w", err) } answer, err := pc.CreateAnswer(nil) if err != nil { pc.Close() return "", fmt.Errorf("pion: create answer: %w", err) } gatherComplete := webrtc.GatheringCompletePromise(pc) if err := pc.SetLocalDescription(answer); err != nil { pc.Close() return "", fmt.Errorf("pion: set local description: %w", err) } <-gatherComplete finalSDP := pc.LocalDescription().SDP videoDst, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("127.0.0.1:%d", videoPort)) if err != nil { pc.Close() return "", fmt.Errorf("pion: resolve video relay addr: %w", err) } audioDst, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("127.0.0.1:%d", audioPort)) if err != nil { pc.Close() return "", fmt.Errorf("pion: resolve audio relay addr: %w", err) } videoConn, err := net.DialUDP("udp4", nil, videoDst) if err != nil { pc.Close() return "", fmt.Errorf("pion: dial video relay: %w", err) } audioConn, err := net.DialUDP("udp4", nil, audioDst) if err != nil { videoConn.Close() pc.Close() return "", fmt.Errorf("pion: dial audio relay: %w", err) } pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { var dst *net.UDPConn if strings.EqualFold(track.Kind().String(), "video") { dst = videoConn // Request a keyframe immediately so consumers (ffprobe, FFmpeg) can // determine the video resolution without waiting for the next IDR. // Then send PLI every 2 s to keep keyframes flowing. go func() { ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() for { pc.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{ MediaSSRC: uint32(track.SSRC()), }}) <-ticker.C if pc.ConnectionState() != webrtc.PeerConnectionStateConnected { return } } }() } else { dst = audioConn } buf := make([]byte, 1500) for { n, _, err := track.Read(buf) if err != nil { return } dst.Write(buf[:n]) } }) ufrag := extractICEUfrag(offerSDP) p.mu.Lock() p.pcs[ufrag] = pc p.mu.Unlock() pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateClosed || state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateDisconnected { videoConn.Close() audioConn.Close() p.mu.Lock() delete(p.pcs, ufrag) p.mu.Unlock() } }) return finalSDP, nil } func extractICEUfrag(sdp string) string { for _, line := range strings.Split(sdp, "\n") { line = strings.TrimRight(line, "\r") if strings.HasPrefix(line, "a=ice-ufrag:") { return strings.TrimPrefix(line, "a=ice-ufrag:") } } if len(sdp) > 32 { return sdp[:32] } return sdp }