feat: support WebRTC rendezvous signaling

This commit is contained in:
RustDesk
2026-05-18 18:59:36 +08:00
parent 9043c15acc
commit 5a78ec4230
2 changed files with 124 additions and 36 deletions
+14
View File
@@ -28,6 +28,7 @@ message PunchHoleRequest {
bool force_relay = 8; bool force_relay = 8;
int32 upnp_port = 9; int32 upnp_port = 9;
bytes socket_addr_v6 = 10; bytes socket_addr_v6 = 10;
string webrtc_sdp_offer = 11;
} }
message ControlPermissions { message ControlPermissions {
@@ -58,6 +59,8 @@ message PunchHole {
int32 upnp_port = 6; int32 upnp_port = 6;
bytes socket_addr_v6 = 7; bytes socket_addr_v6 = 7;
ControlPermissions control_permissions = 8; ControlPermissions control_permissions = 8;
string webrtc_sdp_offer = 9;
reserved 10;
} }
message TestNatRequest { message TestNatRequest {
@@ -84,6 +87,7 @@ message PunchHoleSent {
string version = 5; string version = 5;
int32 upnp_port = 6; int32 upnp_port = 6;
bytes socket_addr_v6 = 7; bytes socket_addr_v6 = 7;
string webrtc_sdp_answer = 8;
} }
message RegisterPk { message RegisterPk {
@@ -129,6 +133,7 @@ message PunchHoleResponse {
bool is_udp = 9; bool is_udp = 9;
int32 upnp_port = 10; int32 upnp_port = 10;
bytes socket_addr_v6 = 11; bytes socket_addr_v6 = 11;
string webrtc_sdp_answer = 12;
} }
message ConfigUpdate { message ConfigUpdate {
@@ -161,6 +166,7 @@ message RelayResponse {
int32 feedback = 9; int32 feedback = 9;
bytes socket_addr_v6 = 10; bytes socket_addr_v6 = 10;
int32 upnp_port = 11; int32 upnp_port = 11;
string webrtc_sdp_answer = 12;
} }
message SoftwareUpdate { string url = 1; } message SoftwareUpdate { string url = 1; }
@@ -231,6 +237,13 @@ message HttpProxyResponse {
string error = 4; string error = 4;
} }
message IceCandidate {
string id = 1;
bytes socket_addr = 2;
string session_key = 3;
string candidate = 4;
}
message RendezvousMessage { message RendezvousMessage {
oneof union { oneof union {
RegisterPeer register_peer = 6; RegisterPeer register_peer = 6;
@@ -256,5 +269,6 @@ message RendezvousMessage {
HealthCheck hc = 26; HealthCheck hc = 26;
HttpProxyRequest http_proxy_request = 27; HttpProxyRequest http_proxy_request = 27;
HttpProxyResponse http_proxy_response = 28; HttpProxyResponse http_proxy_response = 28;
IceCandidate ice_candidate = 29;
} }
} }
+110 -36
View File
@@ -1,13 +1,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind};
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc; use std::sync::{Arc, Mutex as StdMutex};
use std::time::Duration; use std::time::Duration;
use webrtc::api::setting_engine::SettingEngine; use webrtc::api::setting_engine::SettingEngine;
use webrtc::api::APIBuilder; use webrtc::api::APIBuilder;
use webrtc::data_channel::RTCDataChannel; use webrtc::data_channel::RTCDataChannel;
use webrtc::ice::mdns::MulticastDnsMode; use webrtc::ice::mdns::MulticastDnsMode;
use webrtc::ice_transport::ice_candidate::RTCIceCandidateInit;
use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
@@ -18,8 +19,7 @@ use webrtc::peer_connection::RTCPeerConnection;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine; use base64::Engine;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use tokio::sync::watch; use tokio::sync::{mpsc, watch, Mutex};
use tokio::sync::Mutex;
use tokio::time::timeout; use tokio::time::timeout;
use url::Url; use url::Url;
@@ -28,10 +28,19 @@ use crate::protobuf::Message;
use crate::sodiumoxide::crypto::secretbox::Key; use crate::sodiumoxide::crypto::secretbox::Key;
use crate::ResultType; use crate::ResultType;
#[derive(Clone, Debug, PartialEq, Eq)]
enum WebRTCConnectionState {
Pending,
Open,
Closed(String),
}
pub struct WebRTCStream { pub struct WebRTCStream {
pc: Arc<RTCPeerConnection>, pc: Arc<RTCPeerConnection>,
stream: Arc<Mutex<Arc<RTCDataChannel>>>, stream: Arc<Mutex<Arc<RTCDataChannel>>>,
state_notify: watch::Receiver<bool>, state_notify: watch::Receiver<WebRTCConnectionState>,
local_ice_rx: Arc<StdMutex<Option<mpsc::UnboundedReceiver<String>>>>,
session_key: String,
send_timeout: u64, send_timeout: u64,
} }
@@ -59,6 +68,8 @@ impl Clone for WebRTCStream {
pc: self.pc.clone(), pc: self.pc.clone(),
stream: self.stream.clone(), stream: self.stream.clone(),
state_notify: self.state_notify.clone(), state_notify: self.state_notify.clone(),
local_ice_rx: self.local_ice_rx.clone(),
session_key: self.session_key.clone(),
send_timeout: self.send_timeout, send_timeout: self.send_timeout,
} }
} }
@@ -122,7 +133,7 @@ impl WebRTCStream {
if sdp_json.is_empty() { if sdp_json.is_empty() {
return Ok("".to_string()); return Ok("".to_string());
} }
let sdp = serde_json::from_str::<RTCSessionDescription>(&sdp_json)?; let sdp = serde_json::from_str::<RTCSessionDescription>(sdp_json)?;
Self::get_key_for_sdp(&sdp) Self::get_key_for_sdp(&sdp)
} }
@@ -243,16 +254,40 @@ impl WebRTCStream {
..Default::default() ..Default::default()
}; };
let (notify_tx, notify_rx) = watch::channel(false); let (notify_tx, notify_rx) = watch::channel(WebRTCConnectionState::Pending);
let (ice_tx, ice_rx) = mpsc::unbounded_channel::<String>();
// Create a new RTCPeerConnection // Create a new RTCPeerConnection
let pc = Arc::new(api.new_peer_connection(config).await?); let pc = Arc::new(api.new_peer_connection(config).await?);
let local_ice_tx = ice_tx.clone();
pc.on_ice_candidate(Box::new(move |candidate| {
let local_ice_tx = local_ice_tx.clone();
Box::pin(async move {
let Some(candidate) = candidate else {
return;
};
match candidate.to_json() {
Ok(candidate) => match serde_json::to_string(&candidate) {
Ok(candidate_json) => {
let _ = local_ice_tx.send(candidate_json);
}
Err(err) => {
log::warn!("failed to serialize local ICE candidate: {}", err);
}
},
Err(err) => {
log::warn!("failed to convert local ICE candidate to JSON: {}", err);
}
}
})
}));
let bootstrap_dc = if start_local_offer { let bootstrap_dc = if start_local_offer {
let dc_open_notify = notify_tx.clone(); let dc_open_notify = notify_tx.clone();
// Create a data channel with label "bootstrap" // Create a data channel with label "bootstrap"
let dc = pc.create_data_channel("bootstrap", None).await?; let dc = pc.create_data_channel("bootstrap", None).await?;
dc.on_open(Box::new(move || { dc.on_open(Box::new(move || {
log::debug!("Local data channel bootstrap open."); log::debug!("Local data channel bootstrap open.");
let _ = dc_open_notify.send(true); let _ = dc_open_notify.send(WebRTCConnectionState::Open);
Box::pin(async {}) Box::pin(async {})
})); }));
dc dc
@@ -277,7 +312,7 @@ impl WebRTCStream {
*stream_lock = dc.clone(); *stream_lock = dc.clone();
drop(stream_lock); drop(stream_lock);
dc.on_open(Box::new(move || { dc.on_open(Box::new(move || {
let _ = dc_open_notify2.send(true); let _ = dc_open_notify2.send(WebRTCConnectionState::Open);
Box::pin(async {}) Box::pin(async {})
})); }));
}) })
@@ -297,7 +332,9 @@ impl WebRTCStream {
RTCPeerConnectionState::Disconnected RTCPeerConnectionState::Disconnected
| RTCPeerConnectionState::Failed | RTCPeerConnectionState::Failed
| RTCPeerConnectionState::Closed => { | RTCPeerConnectionState::Closed => {
let _ = on_connection_notify.send(true); let _ = on_connection_notify.send(WebRTCConnectionState::Closed(
s.to_string(),
));
log::debug!("WebRTC session closing due to disconnected"); log::debug!("WebRTC session closing due to disconnected");
let _ = stream_for_close2.lock().await.close().await; let _ = stream_for_close2.lock().await.close().await;
log::debug!("WebRTC session stream closed"); log::debug!("WebRTC session stream closed");
@@ -339,9 +376,7 @@ impl WebRTCStream {
// process offer/answer // process offer/answer
if start_local_offer { if start_local_offer {
let sdp = pc.create_offer(None).await?; let sdp = pc.create_offer(None).await?;
let mut gather_complete = pc.gathering_complete_promise().await;
pc.set_local_description(sdp.clone()).await?; pc.set_local_description(sdp.clone()).await?;
let _ = gather_complete.recv().await;
log::debug!("local offer:\n{}", sdp.sdp); log::debug!("local offer:\n{}", sdp.sdp);
// get local sdp key // get local sdp key
@@ -351,9 +386,7 @@ impl WebRTCStream {
let sdp = serde_json::from_str::<RTCSessionDescription>(&remote_offer)?; let sdp = serde_json::from_str::<RTCSessionDescription>(&remote_offer)?;
pc.set_remote_description(sdp.clone()).await?; pc.set_remote_description(sdp.clone()).await?;
let answer = pc.create_answer(None).await?; let answer = pc.create_answer(None).await?;
let mut gather_complete = pc.gathering_complete_promise().await;
pc.set_local_description(answer).await?; pc.set_local_description(answer).await?;
let _ = gather_complete.recv().await;
log::debug!("remote offer:\n{}", sdp.sdp); log::debug!("remote offer:\n{}", sdp.sdp);
// get remote sdp key // get remote sdp key
@@ -371,6 +404,8 @@ impl WebRTCStream {
pc, pc,
stream, stream,
state_notify: notify_rx, state_notify: notify_rx,
local_ice_rx: Arc::new(StdMutex::new(Some(ice_rx))),
session_key: key.clone(),
send_timeout: ms_timeout, send_timeout: ms_timeout,
}; };
final_lock.insert(key, webrtc_stream.clone()); final_lock.insert(key, webrtc_stream.clone());
@@ -397,6 +432,38 @@ impl WebRTCStream {
Ok(()) Ok(())
} }
#[inline]
pub fn take_local_ice_rx(&self) -> Option<mpsc::UnboundedReceiver<String>> {
self.local_ice_rx.lock().ok().and_then(|mut rx| rx.take())
}
#[inline]
pub async fn add_remote_ice_candidate(&self, candidate_json: &str) -> ResultType<()> {
if candidate_json.is_empty() {
return Ok(());
}
let candidate = serde_json::from_str::<RTCIceCandidateInit>(candidate_json)?;
self.pc.add_ice_candidate(candidate).await?;
Ok(())
}
#[inline]
pub fn session_key(&self) -> &str {
&self.session_key
}
pub async fn wait_connected(&mut self, ms: u64) -> ResultType<()> {
if ms > 0 {
match timeout(Duration::from_millis(ms), self.wait_for_connect_result()).await {
Ok(result) => result?,
Err(_) => return Err(anyhow::anyhow!("WebRTC wait_connected timeout")),
}
} else {
self.wait_for_connect_result().await?;
}
Ok(())
}
#[inline] #[inline]
pub fn set_raw(&mut self) { pub fn set_raw(&mut self) {
// not-supported // not-supported
@@ -435,33 +502,30 @@ impl WebRTCStream {
} }
#[inline] #[inline]
async fn wait_for_connect_result(&mut self) { async fn wait_for_connect_result(&mut self) -> ResultType<()> {
if *self.state_notify.borrow() { loop {
return; match self.state_notify.borrow().clone() {
WebRTCConnectionState::Open => return Ok(()),
WebRTCConnectionState::Closed(reason) => {
return Err(anyhow::anyhow!("WebRTC connection closed: {}", reason));
}
WebRTCConnectionState::Pending => {}
}
self.state_notify.changed().await?;
} }
let _ = self.state_notify.changed().await;
} }
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
if self.send_timeout > 0 { if let Err(err) = self.wait_connected(self.send_timeout).await {
match timeout( self.pc.close().await.ok();
Duration::from_millis(self.send_timeout), let kind = if err.to_string().contains("deadline")
self.wait_for_connect_result(), || err.to_string().contains("timeout")
)
.await
{ {
Ok(_) => {} ErrorKind::TimedOut
Err(_) => { } else {
self.pc.close().await.ok(); ErrorKind::Other
return Err(Error::new( };
ErrorKind::TimedOut, return Err(Error::new(kind, err.to_string()).into());
"WebRTC send wait for connect timeout",
)
.into());
}
}
} else {
self.wait_for_connect_result().await;
} }
let stream = self.stream.lock().await.clone(); let stream = self.stream.lock().await.clone();
stream.send(&bytes).await?; stream.send(&bytes).await?;
@@ -470,7 +534,10 @@ impl WebRTCStream {
#[inline] #[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> { pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
self.wait_for_connect_result().await; if let Err(err) = self.wait_for_connect_result().await {
self.pc.close().await.ok();
return Some(Err(Error::new(ErrorKind::Other, err.to_string())));
}
let stream = self.stream.lock().await.clone(); let stream = self.stream.lock().await.clone();
// TODO reuse buffer? // TODO reuse buffer?
@@ -767,4 +834,11 @@ IHR5cCBzcmZseCByYWRkciAwLjAuMC4wIHJwb3J0IDY0MDA4XHJcbmE9ZW5kLW9mLWNhbmRpZGF0ZXNc
"connect to an 'answer' webrtc endpoint should error" "connect to an 'answer' webrtc endpoint should error"
); );
} }
#[tokio::test]
async fn test_webrtc_wait_connected_timeout() {
let mut stream = WebRTCStream::new("", false, 100).await.unwrap();
let err = stream.wait_connected(10).await.unwrap_err();
assert!(err.to_string().contains("timeout"));
}
} }