From 5a78ec42303ca046d808742e911a156636a2432b Mon Sep 17 00:00:00 2001 From: RustDesk <71636191+rustdesk@users.noreply.github.com> Date: Mon, 18 May 2026 18:59:36 +0800 Subject: [PATCH] feat: support WebRTC rendezvous signaling --- protos/rendezvous.proto | 14 ++++ src/webrtc.rs | 146 ++++++++++++++++++++++++++++++---------- 2 files changed, 124 insertions(+), 36 deletions(-) diff --git a/protos/rendezvous.proto b/protos/rendezvous.proto index 5c4dc3d1e..49f39944f 100644 --- a/protos/rendezvous.proto +++ b/protos/rendezvous.proto @@ -28,6 +28,7 @@ message PunchHoleRequest { bool force_relay = 8; int32 upnp_port = 9; bytes socket_addr_v6 = 10; + string webrtc_sdp_offer = 11; } message ControlPermissions { @@ -58,6 +59,8 @@ message PunchHole { int32 upnp_port = 6; bytes socket_addr_v6 = 7; ControlPermissions control_permissions = 8; + string webrtc_sdp_offer = 9; + reserved 10; } message TestNatRequest { @@ -84,6 +87,7 @@ message PunchHoleSent { string version = 5; int32 upnp_port = 6; bytes socket_addr_v6 = 7; + string webrtc_sdp_answer = 8; } message RegisterPk { @@ -129,6 +133,7 @@ message PunchHoleResponse { bool is_udp = 9; int32 upnp_port = 10; bytes socket_addr_v6 = 11; + string webrtc_sdp_answer = 12; } message ConfigUpdate { @@ -161,6 +166,7 @@ message RelayResponse { int32 feedback = 9; bytes socket_addr_v6 = 10; int32 upnp_port = 11; + string webrtc_sdp_answer = 12; } message SoftwareUpdate { string url = 1; } @@ -231,6 +237,13 @@ message HttpProxyResponse { string error = 4; } +message IceCandidate { + string id = 1; + bytes socket_addr = 2; + string session_key = 3; + string candidate = 4; +} + message RendezvousMessage { oneof union { RegisterPeer register_peer = 6; @@ -256,5 +269,6 @@ message RendezvousMessage { HealthCheck hc = 26; HttpProxyRequest http_proxy_request = 27; HttpProxyResponse http_proxy_response = 28; + IceCandidate ice_candidate = 29; } } diff --git a/src/webrtc.rs b/src/webrtc.rs index 8f3c410cc..e6056c401 100644 --- a/src/webrtc.rs +++ b/src/webrtc.rs @@ -1,13 +1,14 @@ use std::collections::HashMap; use std::io::{Error, ErrorKind}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::sync::Arc; +use std::sync::{Arc, Mutex as StdMutex}; use std::time::Duration; use webrtc::api::setting_engine::SettingEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::RTCDataChannel; use webrtc::ice::mdns::MulticastDnsMode; +use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; @@ -18,8 +19,7 @@ use webrtc::peer_connection::RTCPeerConnection; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use base64::Engine; use bytes::{Bytes, BytesMut}; -use tokio::sync::watch; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, watch, Mutex}; use tokio::time::timeout; use url::Url; @@ -28,10 +28,19 @@ use crate::protobuf::Message; use crate::sodiumoxide::crypto::secretbox::Key; use crate::ResultType; +#[derive(Clone, Debug, PartialEq, Eq)] +enum WebRTCConnectionState { + Pending, + Open, + Closed(String), +} + pub struct WebRTCStream { pc: Arc, stream: Arc>>, - state_notify: watch::Receiver, + state_notify: watch::Receiver, + local_ice_rx: Arc>>>, + session_key: String, send_timeout: u64, } @@ -59,6 +68,8 @@ impl Clone for WebRTCStream { pc: self.pc.clone(), stream: self.stream.clone(), state_notify: self.state_notify.clone(), + local_ice_rx: self.local_ice_rx.clone(), + session_key: self.session_key.clone(), send_timeout: self.send_timeout, } } @@ -122,7 +133,7 @@ impl WebRTCStream { if sdp_json.is_empty() { return Ok("".to_string()); } - let sdp = serde_json::from_str::(&sdp_json)?; + let sdp = serde_json::from_str::(sdp_json)?; Self::get_key_for_sdp(&sdp) } @@ -243,16 +254,40 @@ impl WebRTCStream { ..Default::default() }; - let (notify_tx, notify_rx) = watch::channel(false); + let (notify_tx, notify_rx) = watch::channel(WebRTCConnectionState::Pending); + let (ice_tx, ice_rx) = mpsc::unbounded_channel::(); // Create a new RTCPeerConnection let pc = Arc::new(api.new_peer_connection(config).await?); + let local_ice_tx = ice_tx.clone(); + pc.on_ice_candidate(Box::new(move |candidate| { + let local_ice_tx = local_ice_tx.clone(); + Box::pin(async move { + let Some(candidate) = candidate else { + return; + }; + match candidate.to_json() { + Ok(candidate) => match serde_json::to_string(&candidate) { + Ok(candidate_json) => { + let _ = local_ice_tx.send(candidate_json); + } + Err(err) => { + log::warn!("failed to serialize local ICE candidate: {}", err); + } + }, + Err(err) => { + log::warn!("failed to convert local ICE candidate to JSON: {}", err); + } + } + }) + })); + let bootstrap_dc = if start_local_offer { let dc_open_notify = notify_tx.clone(); // Create a data channel with label "bootstrap" let dc = pc.create_data_channel("bootstrap", None).await?; dc.on_open(Box::new(move || { log::debug!("Local data channel bootstrap open."); - let _ = dc_open_notify.send(true); + let _ = dc_open_notify.send(WebRTCConnectionState::Open); Box::pin(async {}) })); dc @@ -277,7 +312,7 @@ impl WebRTCStream { *stream_lock = dc.clone(); drop(stream_lock); dc.on_open(Box::new(move || { - let _ = dc_open_notify2.send(true); + let _ = dc_open_notify2.send(WebRTCConnectionState::Open); Box::pin(async {}) })); }) @@ -297,7 +332,9 @@ impl WebRTCStream { RTCPeerConnectionState::Disconnected | RTCPeerConnectionState::Failed | RTCPeerConnectionState::Closed => { - let _ = on_connection_notify.send(true); + let _ = on_connection_notify.send(WebRTCConnectionState::Closed( + s.to_string(), + )); log::debug!("WebRTC session closing due to disconnected"); let _ = stream_for_close2.lock().await.close().await; log::debug!("WebRTC session stream closed"); @@ -339,9 +376,7 @@ impl WebRTCStream { // process offer/answer if start_local_offer { let sdp = pc.create_offer(None).await?; - let mut gather_complete = pc.gathering_complete_promise().await; pc.set_local_description(sdp.clone()).await?; - let _ = gather_complete.recv().await; log::debug!("local offer:\n{}", sdp.sdp); // get local sdp key @@ -351,9 +386,7 @@ impl WebRTCStream { let sdp = serde_json::from_str::(&remote_offer)?; pc.set_remote_description(sdp.clone()).await?; let answer = pc.create_answer(None).await?; - let mut gather_complete = pc.gathering_complete_promise().await; pc.set_local_description(answer).await?; - let _ = gather_complete.recv().await; log::debug!("remote offer:\n{}", sdp.sdp); // get remote sdp key @@ -371,6 +404,8 @@ impl WebRTCStream { pc, stream, state_notify: notify_rx, + local_ice_rx: Arc::new(StdMutex::new(Some(ice_rx))), + session_key: key.clone(), send_timeout: ms_timeout, }; final_lock.insert(key, webrtc_stream.clone()); @@ -397,6 +432,38 @@ impl WebRTCStream { Ok(()) } + #[inline] + pub fn take_local_ice_rx(&self) -> Option> { + self.local_ice_rx.lock().ok().and_then(|mut rx| rx.take()) + } + + #[inline] + pub async fn add_remote_ice_candidate(&self, candidate_json: &str) -> ResultType<()> { + if candidate_json.is_empty() { + return Ok(()); + } + let candidate = serde_json::from_str::(candidate_json)?; + self.pc.add_ice_candidate(candidate).await?; + Ok(()) + } + + #[inline] + pub fn session_key(&self) -> &str { + &self.session_key + } + + pub async fn wait_connected(&mut self, ms: u64) -> ResultType<()> { + if ms > 0 { + match timeout(Duration::from_millis(ms), self.wait_for_connect_result()).await { + Ok(result) => result?, + Err(_) => return Err(anyhow::anyhow!("WebRTC wait_connected timeout")), + } + } else { + self.wait_for_connect_result().await?; + } + Ok(()) + } + #[inline] pub fn set_raw(&mut self) { // not-supported @@ -435,33 +502,30 @@ impl WebRTCStream { } #[inline] - async fn wait_for_connect_result(&mut self) { - if *self.state_notify.borrow() { - return; + async fn wait_for_connect_result(&mut self) -> ResultType<()> { + loop { + match self.state_notify.borrow().clone() { + WebRTCConnectionState::Open => return Ok(()), + WebRTCConnectionState::Closed(reason) => { + return Err(anyhow::anyhow!("WebRTC connection closed: {}", reason)); + } + WebRTCConnectionState::Pending => {} + } + self.state_notify.changed().await?; } - let _ = self.state_notify.changed().await; } pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { - if self.send_timeout > 0 { - match timeout( - Duration::from_millis(self.send_timeout), - self.wait_for_connect_result(), - ) - .await + if let Err(err) = self.wait_connected(self.send_timeout).await { + self.pc.close().await.ok(); + let kind = if err.to_string().contains("deadline") + || err.to_string().contains("timeout") { - Ok(_) => {} - Err(_) => { - self.pc.close().await.ok(); - return Err(Error::new( - ErrorKind::TimedOut, - "WebRTC send wait for connect timeout", - ) - .into()); - } - } - } else { - self.wait_for_connect_result().await; + ErrorKind::TimedOut + } else { + ErrorKind::Other + }; + return Err(Error::new(kind, err.to_string()).into()); } let stream = self.stream.lock().await.clone(); stream.send(&bytes).await?; @@ -470,7 +534,10 @@ impl WebRTCStream { #[inline] pub async fn next(&mut self) -> Option> { - self.wait_for_connect_result().await; + if let Err(err) = self.wait_for_connect_result().await { + self.pc.close().await.ok(); + return Some(Err(Error::new(ErrorKind::Other, err.to_string()))); + } let stream = self.stream.lock().await.clone(); // TODO reuse buffer? @@ -767,4 +834,11 @@ IHR5cCBzcmZseCByYWRkciAwLjAuMC4wIHJwb3J0IDY0MDA4XHJcbmE9ZW5kLW9mLWNhbmRpZGF0ZXNc "connect to an 'answer' webrtc endpoint should error" ); } + + #[tokio::test] + async fn test_webrtc_wait_connected_timeout() { + let mut stream = WebRTCStream::new("", false, 100).await.unwrap(); + let err = stream.wait_connected(10).await.unwrap_err(); + assert!(err.to_string().contains("timeout")); + } }