support send_timeout

This commit is contained in:
lc
2025-11-14 16:02:51 +08:00
parent 47dc73de1e
commit 5dcfea1ee4
2 changed files with 81 additions and 48 deletions

View File

@@ -8,10 +8,10 @@ mod webrtc_dummy;
use crate::webrtc_dummy::WebRTCStream; use crate::webrtc_dummy::WebRTCStream;
use std::io::Write; use std::io::Write;
use bytes::Bytes;
use clap::{Arg, Command};
use anyhow::Result; use anyhow::Result;
use bytes::Bytes;
use clap::{Arg, Command};
use tokio::time::Duration; use tokio::time::Duration;
use webrtc::peer_connection::math_rand_alpha; use webrtc::peer_connection::math_rand_alpha;
@@ -75,7 +75,10 @@ async fn main() -> Result<()> {
// Wait for the answer to be pasted // Wait for the answer to be pasted
println!( println!(
"Start new terminal run: \n{} \ncopy remote endpoint and paste here", "Start new terminal run: \n{} \ncopy remote endpoint and paste here",
format!("cargo r --features webrtc --example webrtc -- --offer {}", local_endpoint) format!(
"cargo r --features webrtc --example webrtc -- --offer {}",
local_endpoint
)
); );
// readline blocking // readline blocking
let line = std::io::stdin() let line = std::io::stdin()
@@ -84,7 +87,10 @@ async fn main() -> Result<()> {
.ok_or_else(|| anyhow::anyhow!("No input received"))??; .ok_or_else(|| anyhow::anyhow!("No input received"))??;
webrtc_stream.set_remote_endpoint(&line).await?; webrtc_stream.set_remote_endpoint(&line).await?;
} else { } else {
println!("Copy local endpoint and paste to the other peer: \n{}", local_endpoint); println!(
"Copy local endpoint and paste to the other peer: \n{}",
local_endpoint
);
} }
let s1 = webrtc_stream.clone(); let s1 = webrtc_stream.clone();

View File

@@ -1,31 +1,29 @@
use std::sync::Arc;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::io::{Error, ErrorKind};
use std::time::Duration;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::{Error, ErrorKind};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use webrtc::api::APIBuilder;
use webrtc::api::setting_engine::SettingEngine; use webrtc::api::setting_engine::SettingEngine;
use webrtc::api::APIBuilder;
use webrtc::data_channel::RTCDataChannel; use webrtc::data_channel::RTCDataChannel;
use webrtc::ice::mdns::MulticastDnsMode;
use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::ice_transport::ice_server::RTCIceServer;
use webrtc::peer_connection::RTCPeerConnection;
use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::ice::mdns::MulticastDnsMode; use webrtc::peer_connection::RTCPeerConnection;
use crate::{
protobuf::Message,
sodiumoxide::crypto::secretbox::Key,
ResultType,
};
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use tokio::time::timeout;
use tokio::sync::watch; use tokio::sync::watch;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::time::timeout;
use crate::protobuf::Message;
use crate::sodiumoxide::crypto::secretbox::Key;
use crate::ResultType;
pub struct WebRTCStream { pub struct WebRTCStream {
pc: Arc<RTCPeerConnection>, pc: Arc<RTCPeerConnection>,
@@ -53,29 +51,36 @@ impl Clone for WebRTCStream {
} }
impl WebRTCStream { impl WebRTCStream {
#[inline]
pub fn get_remote_offer(endpoint: &str) -> ResultType<String> { fn get_remote_offer(endpoint: &str) -> ResultType<String> {
// Ensure the endpoint starts with the "webrtc://" prefix // Ensure the endpoint starts with the "webrtc://" prefix
if !endpoint.starts_with("webrtc://") { if !endpoint.starts_with("webrtc://") {
return Err(Error::new(ErrorKind::InvalidInput, "Invalid WebRTC endpoint format").into()); return Err(
Error::new(ErrorKind::InvalidInput, "Invalid WebRTC endpoint format").into(),
);
} }
// Extract the Base64-encoded SDP part // Extract the Base64-encoded SDP part
let encoded_sdp = &endpoint["webrtc://".len()..]; let encoded_sdp = &endpoint["webrtc://".len()..];
// Decode the Base64 string // Decode the Base64 string
let decoded_bytes = BASE64_STANDARD.decode(encoded_sdp).map_err(|_| let decoded_bytes = BASE64_STANDARD
Error::new(ErrorKind::InvalidInput, "Failed to decode Base64 SDP") .decode(encoded_sdp)
)?; .map_err(|_| Error::new(ErrorKind::InvalidInput, "Failed to decode Base64 SDP"))?;
Ok(String::from_utf8(decoded_bytes).map_err(|_| { Ok(String::from_utf8(decoded_bytes).map_err(|_| {
Error::new(ErrorKind::InvalidInput, "Failed to convert decoded bytes to UTF-8") Error::new(
ErrorKind::InvalidInput,
"Failed to convert decoded bytes to UTF-8",
)
})?) })?)
} }
pub fn sdp_to_endpoint(sdp: &str) -> String { #[inline]
fn sdp_to_endpoint(sdp: &str) -> String {
let encoded_sdp = BASE64_STANDARD.encode(sdp); let encoded_sdp = BASE64_STANDARD.encode(sdp);
format!("webrtc://{}", encoded_sdp) format!("webrtc://{}", encoded_sdp)
} }
#[inline]
async fn get_key_for_peer(pc: &Arc<RTCPeerConnection>) -> String { async fn get_key_for_peer(pc: &Arc<RTCPeerConnection>) -> String {
if let Some(local_desc) = pc.local_description().await { if let Some(local_desc) = pc.local_description().await {
if local_desc.sdp_type != webrtc::peer_connection::sdp::sdp_type::RTCSdpType::Offer { if local_desc.sdp_type != webrtc::peer_connection::sdp::sdp_type::RTCSdpType::Offer {
@@ -89,10 +94,7 @@ impl WebRTCStream {
"".into() "".into()
} }
pub async fn new( pub async fn new(remote_endpoint: &str, ms_timeout: u64) -> ResultType<Self> {
remote_endpoint: &str,
ms_timeout: u64,
) -> ResultType<Self> {
log::debug!("New webrtc stream to endpoint: {}", remote_endpoint); log::debug!("New webrtc stream to endpoint: {}", remote_endpoint);
let remote_offer = if remote_endpoint.is_empty() { let remote_offer = if remote_endpoint.is_empty() {
"".into() "".into()
@@ -115,9 +117,7 @@ impl WebRTCStream {
s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled); s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled);
// Create the API object // Create the API object
let api = APIBuilder::new() let api = APIBuilder::new().with_setting_engine(s).build();
.with_setting_engine(s)
.build();
// Prepare the configuration // Prepare the configuration
let config = RTCConfiguration { let config = RTCConfiguration {
@@ -158,16 +158,25 @@ impl WebRTCStream {
let pc_for_close2 = pc_for_close.clone(); let pc_for_close2 = pc_for_close.clone();
Box::pin(async move { Box::pin(async move {
log::debug!("Peer connection state : {}", s); log::debug!("Peer connection state : {}", s);
if s == RTCPeerConnectionState::Disconnected { match s {
let _ = on_connection_notify2.send(true); RTCPeerConnectionState::Disconnected
log::debug!("WebRTC session closing due to disconnected"); | RTCPeerConnectionState::Failed
let _ = stream_for_close2.lock().await.close().await; | RTCPeerConnectionState::Closed => {
log::debug!("WebRTC session stream closed"); let _ = on_connection_notify2.send(true);
} else if s == RTCPeerConnectionState::Failed || s == RTCPeerConnectionState::Closed { log::debug!("WebRTC session closing due to disconnected");
let mut lock = SESSIONS.lock().await; let _ = stream_for_close2.lock().await.close().await;
let key = WebRTCStream::get_key_for_peer(&pc_for_close2).await; log::debug!("WebRTC session stream closed");
log::debug!("WebRTC session removing key from cache: {}", key);
lock.remove(&key); let mut lock = SESSIONS.lock().await;
let key = WebRTCStream::get_key_for_peer(&pc_for_close2).await;
lock.remove(&key);
log::debug!(
"WebRTC session removed key from cache: {} current len: {}",
key,
lock.len()
);
}
_ => {}
} }
}) })
})); }));
@@ -287,7 +296,26 @@ impl WebRTCStream {
} }
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
self.wait_for_connect_result().await; if self.send_timeout > 0 {
match timeout(
Duration::from_millis(self.send_timeout),
self.wait_for_connect_result(),
)
.await
{
Ok(_) => {}
Err(_) => {
self.pc.close().await.ok();
return Err(Error::new(
ErrorKind::TimedOut,
"WebRTC send wait for connect timeout",
)
.into());
}
}
} else {
self.wait_for_connect_result().await;
}
let stream = self.stream.lock().await.clone(); let stream = self.stream.lock().await.clone();
stream.send(&bytes).await?; stream.send(&bytes).await?;
Ok(()) Ok(())
@@ -339,6 +367,5 @@ pub fn is_webrtc_endpoint(endpoint: &str) -> bool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[test] #[test]
fn test_dc() { fn test_dc() {}
}
} }