From c156d2eef7a1585dcd34b41745829418bd31a967 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 17 Apr 2025 11:42:58 +0800 Subject: [PATCH 01/22] [enhance] websocket --- Cargo.toml | 8 +- src/lib.rs | 8 ++ src/socket_client.rs | 30 +++--- src/websocket.rs | 227 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 12 deletions(-) create mode 100644 src/websocket.rs diff --git a/Cargo.toml b/Cargo.toml index 58f54ac..669c0aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,13 @@ httparse = "1.5" base64 = "0.22" url = "2.2" sha2 = "0.10" +tokio-tungstenite = {version = "0.26.2", optional = true} +tungstenite = {version = "0.26.2", optional = true} + +[features] +default = ["websocket"] +websocket = ["dep:tokio-tungstenite", "dep:tungstenite"] + [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] mac_address = "1.1" @@ -64,4 +71,3 @@ winapi = { version = "0.3", features = ["winuser", "synchapi", "pdh", "memoryapi [target.'cfg(target_os = "macos")'.dependencies] osascript = "0.3" - diff --git a/src/lib.rs b/src/lib.rs index 9475b96..f9cfd02 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,8 +57,16 @@ pub use toml; pub use uuid; pub mod fingerprint; pub use flexi_logger; +pub mod websocket; + +#[cfg(feature = "websocket")] +pub type Stream = websocket::WsFramedStream; + + +#[cfg(not(feature = "websocket"))] pub type Stream = tcp::FramedStream; + pub type SessionID = uuid::Uuid; #[inline] diff --git a/src/socket_client.rs b/src/socket_client.rs index 4cb0bf2..1c39c01 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -2,7 +2,7 @@ use crate::{ config::{Config, NetworkType}, tcp::FramedStream, udp::FramedSocket, - ResultType, + websocket, ResultType, }; use anyhow::Context; use std::net::SocketAddr; @@ -102,7 +102,7 @@ pub async fn connect_tcp< >( target: T, ms_timeout: u64, -) -> ResultType { +) -> ResultType { connect_tcp_local(target, None, ms_timeout).await } @@ -113,19 +113,27 @@ pub async fn connect_tcp_local< target: T, local: Option, ms_timeout: u64, -) -> ResultType { - if let Some(conf) = Config::get_socks() { - return FramedStream::connect(target, local, &conf, ms_timeout).await; +) -> ResultType { + #[cfg(feature = "websocket")] + { + let url = format!("ws://{}", target); + websocket::WsFramedStream::new(url, local, None, ms_timeout).await } - if let Some(target) = target.resolve() { - if let Some(local) = local { - if local.is_ipv6() && target.is_ipv4() { - let target = query_nip_io(target).await?; - return FramedStream::new(target, Some(local), ms_timeout).await; + #[cfg(not(feature = "websocket"))] + { + if let Some(conf) = Config::get_socks() { + return tcp::FramedStream::connect(target, local, &conf, ms_timeout).await; + } + if let Some(target) = target.resolve() { + if let Some(local) = local { + if local.is_ipv6() && target.is_ipv4() { + let target = query_nip_io(target).await?; + return tcp::FramedStream::new(target, Some(local), ms_timeout).await; + } } } + tcp::FramedStream::new(target, local, ms_timeout).await } - FramedStream::new(target, local, ms_timeout).await } #[inline] diff --git a/src/websocket.rs b/src/websocket.rs new file mode 100644 index 0000000..d44efb9 --- /dev/null +++ b/src/websocket.rs @@ -0,0 +1,227 @@ +use crate::{ + config::Socks5Server, + protobuf::Message, + sodiumoxide::crypto::secretbox::{self, Key, Nonce}, + ResultType, +}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use std::{ + io::{Error, ErrorKind}, + net::SocketAddr, + time::Duration, +}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream}; +use tokio::{net::TcpStream, time::timeout}; +use tokio_tungstenite::{ + connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, +}; +use tungstenite::protocol::Role; + +#[derive(Clone)] +pub struct Encrypt(Key, u64, u64); + +pub struct WsFramedStream { + stream: WebSocketStream>, + addr: SocketAddr, + encrypt: Option, + send_timeout: u64, + read_buf: BytesMut, +} + +impl WsFramedStream { + pub async fn new>( + url: T, + local_addr: Option, + _proxy_conf: Option<&Socks5Server>, + ms_timeout: u64, + ) -> ResultType { + let (stream, _) = connect_async(url.as_ref()).await?; + + // 获取底层TCP流的peer_addr + let addr = match stream.get_ref() { + MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, + #[cfg(feature = "native-tls")] + MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?, + #[cfg(feature = "rustls")] + MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, + // 处理其他可能的情况 + _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), + }; + + Ok(Self { + stream, + addr, + encrypt: None, + send_timeout: ms_timeout, + read_buf: BytesMut::new(), + }) + } + + pub fn set_raw(&mut self) { + // WebSocket不需要特殊处理,保持空实现 + } + + pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { + let ws_stream = + WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Server, None) + .await; + + Ok(Self { + stream: ws_stream, + addr, + encrypt: None, + send_timeout: 0, + read_buf: BytesMut::new(), + }) + } + + pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self { + let ws_stream = WebSocketStream::from_raw_socket( + MaybeTlsStream::Plain(stream), // 包装为MaybeTlsStream + Role::Client, + None, + ) + .await; + + Self { + stream: ws_stream, + addr, + encrypt: None, + send_timeout: 0, + read_buf: BytesMut::new(), + } + } + + pub fn local_addr(&self) -> SocketAddr { + self.addr + } + + pub fn set_send_timeout(&mut self, ms: u64) { + self.send_timeout = ms; + } + + pub fn set_key(&mut self, key: Key) { + self.encrypt = Some(Encrypt::new(key)); + } + + pub fn is_secured(&self) -> bool { + self.encrypt.is_some() + } + + #[inline] + pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> { + self.send_raw(msg.write_to_bytes()?).await + } + + #[inline] + pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { + let mut msg = msg; + if let Some(key) = self.encrypt.as_mut() { + msg = key.enc(&msg); + } + self.send_bytes(bytes::Bytes::from(msg)).await + } + + #[inline] + pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { + // 转换为Vec时需要处理加密 + let data = if let Some(key) = self.encrypt.as_mut() { + key.enc(&bytes.to_vec()) + } else { + bytes.to_vec() + }; + + let msg = WsMessage::Binary(Bytes::from(data)); + if self.send_timeout > 0 { + let send_future = self.stream.send(msg); + timeout(Duration::from_millis(self.send_timeout), send_future) + .await + .map_err(|_| Error::new(ErrorKind::TimedOut, "Send timeout"))? + .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; + } else { + self.stream + .send(msg) + .await + .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; + } + Ok(()) + } + + #[inline] + pub async fn next(&mut self) -> Option> { + loop { + if let Some((frame, _)) = self.read_buf.split_first() { + if let Some(decrypted) = self.try_decrypt() { + return Some(Ok(decrypted)); + } + } + + match self.stream.next().await? { + Ok(WsMessage::Binary(data)) => { + self.read_buf.extend_from_slice(&data); + if let Some(decrypted) = self.try_decrypt() { + return Some(Ok(decrypted)); + } + } + Ok(_) => continue, // 忽略非二进制消息 + Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), + } + } + } + + fn try_decrypt(&mut self) -> Option { + if let Some(key) = self.encrypt.as_mut() { + if let Ok(()) = key.dec(&mut self.read_buf) { + let data = self.read_buf.split(); + return Some(data); + } + } else { + let data = self.read_buf.split(); + return Some(data); + } + None + } + + #[inline] + pub async fn next_timeout(&mut self, ms: u64) -> Option> { + match timeout(Duration::from_millis(ms), self.next()).await { + Ok(res) => res, + Err(_) => None, + } + } +} + +impl Encrypt { + pub fn new(key: Key) -> Self { + Self(key, 0, 0) + } + + pub fn dec(&mut self, bytes: &mut BytesMut) -> Result<(), Error> { + if bytes.len() <= 1 { + return Ok(()); + } + self.2 += 1; + let nonce = get_nonce(self.2); + match secretbox::open(bytes, &nonce, &self.0) { + Ok(res) => { + bytes.clear(); + bytes.put_slice(&res); + Ok(()) + } + Err(()) => Err(Error::new(ErrorKind::Other, "decryption error")), + } + } + + pub fn enc(&mut self, data: &[u8]) -> Vec { + self.1 += 1; + let nonce = get_nonce(self.1); + secretbox::seal(data, &nonce, &self.0) + } +} + +fn get_nonce(seqnum: u64) -> Nonce { + let mut nonce = Nonce([0u8; secretbox::NONCEBYTES]); + nonce.0[..std::mem::size_of_val(&seqnum)].copy_from_slice(&seqnum.to_le_bytes()); + nonce +} From 58103659e7e6c8a3ad18d80441b973afae6e8224 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Fri, 18 Apr 2025 02:53:39 +0800 Subject: [PATCH 02/22] [enhance] add websocket NOT a config feature. --- src/lib.rs | 117 +++++++++++++++++++++++++++++++++++++++++-- src/socket_client.rs | 6 +-- 2 files changed, 116 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f9cfd02..5348977 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,14 +58,123 @@ pub use uuid; pub mod fingerprint; pub use flexi_logger; pub mod websocket; +use sodiumoxide::crypto::secretbox::Key; +pub enum Stream { + WebSocket(websocket::WsFramedStream), + Tcp(tcp::FramedStream), +} -#[cfg(feature = "websocket")] -pub type Stream = websocket::WsFramedStream; +impl Stream { + pub fn set_send_timeout(&mut self, ms: u64) { + match self { + Stream::WebSocket(s) => s.set_send_timeout(ms), + Stream::Tcp(s) => s.set_send_timeout(ms), + } + } + pub fn set_raw(&mut self) { + match self { + Stream::WebSocket(s) => s.set_raw(), + Stream::Tcp(s) => s.set_raw(), + } + } -#[cfg(not(feature = "websocket"))] -pub type Stream = tcp::FramedStream; + pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> { + match self { + Stream::WebSocket(s) => s.send_bytes(bytes).await, + Stream::Tcp(s) => s.send_bytes(bytes).await, + } + } + + pub async fn send_raw(&mut self, bytes: Vec) -> ResultType<()> { + match self { + Stream::WebSocket(s) => s.send_raw(bytes).await, + Stream::Tcp(s) => s.send_raw(bytes).await, + } + } + + pub fn set_key(&mut self, key: Key) { + match self { + Stream::WebSocket(s) => s.set_key(key), + Stream::Tcp(s) => s.set_key(key), + } + } + + pub fn is_secured(&self) -> bool { + match self { + Stream::WebSocket(s) => s.is_secured(), + Stream::Tcp(s) => s.is_secured(), + } + } + + pub async fn next_timeout( + &mut self, + timeout: u64, + ) -> Option> { + match self { + Stream::WebSocket(s) => s.next_timeout(timeout).await, + Stream::Tcp(s) => s.next_timeout(timeout).await, + } + } +} + +impl Stream { + /// 从 TCP 流创建(自动判断是否升级为 WebSocket) + pub async fn from_tcp( + stream: tokio::net::TcpStream, + addr: SocketAddr, + is_websocket: bool, + ) -> ResultType { + if is_websocket { + Ok(Self::WebSocket( + websocket::WsFramedStream::from(stream, addr).await, + )) + } else { + Ok(Self::Tcp(tcp::FramedStream::from(stream, addr))) + } + } + + /// 创建 WebSocket 客户端连接 + pub async fn connect_websocket( + url: impl AsRef, + local_addr: Option, + proxy_conf: Option<&config::Socks5Server>, + timeout_ms: u64, + ) -> ResultType { + let ws_stream = + websocket::WsFramedStream::new(url, local_addr, proxy_conf, timeout_ms).await?; + Ok(Self::WebSocket(ws_stream)) + } + + /// 发送消息 + pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { + match self { + Self::WebSocket(ws) => ws.send(msg).await, + Self::Tcp(tcp) => tcp.send(msg).await, + } + } + + /// 接收消息 + pub async fn next(&mut self) -> Option> { + match self { + Self::WebSocket(ws) => ws.next().await, + Self::Tcp(tcp) => tcp.next().await, + } + } + + // 其他必要的方法... + pub fn local_addr(&self) -> SocketAddr { + match self { + Self::WebSocket(ws) => ws.local_addr(), + Self::Tcp(tcp) => tcp.local_addr(), + } + } + + pub fn is_websocket(&self) -> bool { + matches!(self, Self::WebSocket(_)) + } +} pub type SessionID = uuid::Uuid; diff --git a/src/socket_client.rs b/src/socket_client.rs index 1c39c01..3b6457a 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -2,7 +2,7 @@ use crate::{ config::{Config, NetworkType}, tcp::FramedStream, udp::FramedSocket, - websocket, ResultType, + websocket, ResultType, Stream, }; use anyhow::Context; use std::net::SocketAddr; @@ -113,11 +113,11 @@ pub async fn connect_tcp_local< target: T, local: Option, ms_timeout: u64, -) -> ResultType { +) -> ResultType { #[cfg(feature = "websocket")] { let url = format!("ws://{}", target); - websocket::WsFramedStream::new(url, local, None, ms_timeout).await + Ok(Stream::WebSocket(websocket::WsFramedStream::new(url, local, None, ms_timeout).await?)) } #[cfg(not(feature = "websocket"))] { From 2d65c24e4b8a9d83716ac8bbea4052c87ac7c2cc Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Fri, 18 Apr 2025 11:31:46 +0800 Subject: [PATCH 03/22] [enhance] remove cfg select. --- src/socket_client.rs | 44 ++++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/socket_client.rs b/src/socket_client.rs index 3b6457a..b26e8f4 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -114,25 +114,41 @@ pub async fn connect_tcp_local< local: Option, ms_timeout: u64, ) -> ResultType { - #[cfg(feature = "websocket")] - { - let url = format!("ws://{}", target); - Ok(Stream::WebSocket(websocket::WsFramedStream::new(url, local, None, ms_timeout).await?)) - } - #[cfg(not(feature = "websocket"))] - { + let target_str = target.to_string(); + + // 根据目标地址协议决定连接方式 + if target_str.starts_with("ws://") || target_str.starts_with("wss://") { + // WebSocket 连接逻辑 + Ok(Stream::WebSocket(websocket::WsFramedStream::new( + target_str, + local, + None, + ms_timeout, + ) + .await?)) + } else { + // TCP 连接逻辑 if let Some(conf) = Config::get_socks() { - return tcp::FramedStream::connect(target, local, &conf, ms_timeout).await; + return Ok(Stream::Tcp( + FramedStream::connect(target, local, &conf, ms_timeout).await?, + )); } - if let Some(target) = target.resolve() { - if let Some(local) = local { - if local.is_ipv6() && target.is_ipv4() { - let target = query_nip_io(target).await?; - return tcp::FramedStream::new(target, Some(local), ms_timeout).await; + + if let Some(target_addr) = target.resolve() { + if let Some(local_addr) = local { + if local_addr.is_ipv6() && target_addr.is_ipv4() { + let resolved_target = query_nip_io(target_addr).await?; + return Ok(Stream::Tcp( + FramedStream::new(resolved_target, Some(local_addr), ms_timeout) + .await?, + )); } } } - tcp::FramedStream::new(target, local, ms_timeout).await + + Ok(Stream::Tcp( + FramedStream::new(target, local, ms_timeout).await?, + )) } } From d299e4909f8b74db8daef47c14c01719396d89f3 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Tue, 22 Apr 2025 02:55:17 +0800 Subject: [PATCH 04/22] [fix bug] fix twice enc. - fix twice enc. - enable proxy (temp ver.) - rewite websocket::next. --- src/lib.rs | 9 ++-- src/proxy.rs | 2 +- src/websocket.rs | 129 ++++++++++++++++++++++++----------------------- 3 files changed, 71 insertions(+), 69 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5348977..0b2b176 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,7 +120,7 @@ impl Stream { } impl Stream { - /// 从 TCP 流创建(自动判断是否升级为 WebSocket) + /// establish connect from tcp. pub async fn from_tcp( stream: tokio::net::TcpStream, addr: SocketAddr, @@ -135,7 +135,7 @@ impl Stream { } } - /// 创建 WebSocket 客户端连接 + /// establish connect from websocket pub async fn connect_websocket( url: impl AsRef, local_addr: Option, @@ -147,7 +147,7 @@ impl Stream { Ok(Self::WebSocket(ws_stream)) } - /// 发送消息 + /// send message pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { match self { Self::WebSocket(ws) => ws.send(msg).await, @@ -155,7 +155,7 @@ impl Stream { } } - /// 接收消息 + /// receive message pub async fn next(&mut self) -> Option> { match self { Self::WebSocket(ws) => ws.next().await, @@ -163,7 +163,6 @@ impl Stream { } } - // 其他必要的方法... pub fn local_addr(&self) -> SocketAddr { match self { Self::WebSocket(ws) => ws.local_addr(), diff --git a/src/proxy.rs b/src/proxy.rs index 34d2c51..e32778f 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -56,7 +56,7 @@ const MAXIMUM_RESPONSE_HEADERS: usize = 16; const DEFINE_TIME_OUT: u64 = 600; pub trait IntoUrl { - + // Besides parsing as a valid `Url`, the `Url` must be a valid // `http::Uri`, in that it makes sense to use in a network request. fn into_url(self) -> Result; diff --git a/src/websocket.rs b/src/websocket.rs index d44efb9..274791f 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -11,7 +11,6 @@ use std::{ net::SocketAddr, time::Duration, }; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream}; use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, @@ -26,41 +25,71 @@ pub struct WsFramedStream { addr: SocketAddr, encrypt: Option, send_timeout: u64, - read_buf: BytesMut, + // read_buf: BytesMut, } impl WsFramedStream { pub async fn new>( url: T, local_addr: Option, - _proxy_conf: Option<&Socks5Server>, + proxy_conf: Option<&Socks5Server>, ms_timeout: u64, ) -> ResultType { - let (stream, _) = connect_async(url.as_ref()).await?; + let url_str = url.as_ref(); - // 获取底层TCP流的peer_addr - let addr = match stream.get_ref() { - MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, - #[cfg(feature = "native-tls")] - MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?, - #[cfg(feature = "rustls")] - MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, - // 处理其他可能的情况 - _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), - }; + if let Some(proxy_conf) = proxy_conf { + // use proxy connect + let url_obj = url::Url::parse(url_str)?; + let host = url_obj + .host_str() + .ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?; - Ok(Self { - stream, - addr, - encrypt: None, - send_timeout: ms_timeout, - read_buf: BytesMut::new(), - }) + let port = url_obj + .port() + .unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 }); + + let socket = + tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port)) + .await?; + + let tcp_stream = socket.into_inner(); + let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream); + let ws_stream = + WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await; + + let addr = match ws_stream.get_ref() { + MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, + _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), + }; + + Ok(Self { + stream: ws_stream, + addr, + encrypt: None, + send_timeout: ms_timeout, + }) + } else { + let (stream, _) = connect_async(url_str).await?; + + let addr = match stream.get_ref() { + MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, + #[cfg(feature = "native-tls")] + MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?, + #[cfg(feature = "rustls")] + MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, + _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), + }; + + Ok(Self { + stream, + addr, + encrypt: None, + send_timeout: ms_timeout, + }) + } } - pub fn set_raw(&mut self) { - // WebSocket不需要特殊处理,保持空实现 - } + pub fn set_raw(&mut self) {} pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { let ws_stream = @@ -72,24 +101,21 @@ impl WsFramedStream { addr, encrypt: None, send_timeout: 0, - read_buf: BytesMut::new(), + // read_buf: BytesMut::new(), }) } pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self { - let ws_stream = WebSocketStream::from_raw_socket( - MaybeTlsStream::Plain(stream), // 包装为MaybeTlsStream - Role::Client, - None, - ) - .await; + let ws_stream = + WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) + .await; Self { stream: ws_stream, addr, encrypt: None, send_timeout: 0, - read_buf: BytesMut::new(), + // read_buf: BytesMut::new(), } } @@ -125,14 +151,7 @@ impl WsFramedStream { #[inline] pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { - // 转换为Vec时需要处理加密 - let data = if let Some(key) = self.encrypt.as_mut() { - key.enc(&bytes.to_vec()) - } else { - bytes.to_vec() - }; - - let msg = WsMessage::Binary(Bytes::from(data)); + let msg = WsMessage::Binary(Bytes::from(bytes)); if self.send_timeout > 0 { let send_future = self.stream.send(msg); timeout(Duration::from_millis(self.send_timeout), send_future) @@ -151,38 +170,22 @@ impl WsFramedStream { #[inline] pub async fn next(&mut self) -> Option> { loop { - if let Some((frame, _)) = self.read_buf.split_first() { - if let Some(decrypted) = self.try_decrypt() { - return Some(Ok(decrypted)); - } - } - match self.stream.next().await? { Ok(WsMessage::Binary(data)) => { - self.read_buf.extend_from_slice(&data); - if let Some(decrypted) = self.try_decrypt() { - return Some(Ok(decrypted)); + let mut bytes = BytesMut::from(&data[..]); + if let Some(key) = self.encrypt.as_mut() { + if let Err(e) = key.dec(&mut bytes) { + return Some(Err(e)); + } } + return Some(Ok(bytes)); } - Ok(_) => continue, // 忽略非二进制消息 + Ok(_) => continue, Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), } } } - fn try_decrypt(&mut self) -> Option { - if let Some(key) = self.encrypt.as_mut() { - if let Ok(()) = key.dec(&mut self.read_buf) { - let data = self.read_buf.split(); - return Some(data); - } - } else { - let data = self.read_buf.split(); - return Some(data); - } - None - } - #[inline] pub async fn next_timeout(&mut self, ms: u64) -> Option> { match timeout(Duration::from_millis(ms), self.next()).await { From 6798bf3780a8c69b23489bb54ca076f7179c46ef Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Wed, 23 Apr 2025 08:33:01 +0800 Subject: [PATCH 05/22] [test] tcp success. Websocket test. (temp ver.) --- src/socket_client.rs | 22 +++++++++------------- src/websocket.rs | 13 ++++++++++++- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/socket_client.rs b/src/socket_client.rs index b26e8f4..926e70e 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -116,18 +116,15 @@ pub async fn connect_tcp_local< ) -> ResultType { let target_str = target.to_string(); - // 根据目标地址协议决定连接方式 - if target_str.starts_with("ws://") || target_str.starts_with("wss://") { - // WebSocket 连接逻辑 - Ok(Stream::WebSocket(websocket::WsFramedStream::new( - target_str, - local, - None, - ms_timeout, - ) - .await?)) + // if target_str.starts_with("ws://") || target_str.starts_with("wss://") { + // Ok(Stream::WebSocket( + // websocket::WsFramedStream::new(target_str, local, None, ms_timeout).await?, + // )) + if true { + Ok(Stream::WebSocket( + websocket::WsFramedStream::new(target_str, local, None, ms_timeout).await?, + )) } else { - // TCP 连接逻辑 if let Some(conf) = Config::get_socks() { return Ok(Stream::Tcp( FramedStream::connect(target, local, &conf, ms_timeout).await?, @@ -139,8 +136,7 @@ pub async fn connect_tcp_local< if local_addr.is_ipv6() && target_addr.is_ipv4() { let resolved_target = query_nip_io(target_addr).await?; return Ok(Stream::Tcp( - FramedStream::new(resolved_target, Some(local_addr), ms_timeout) - .await?, + FramedStream::new(resolved_target, Some(local_addr), ms_timeout).await?, )); } } diff --git a/src/websocket.rs b/src/websocket.rs index 274791f..e560cb6 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -93,7 +93,7 @@ impl WsFramedStream { pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { let ws_stream = - WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Server, None) + WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; Ok(Self { @@ -169,6 +169,7 @@ impl WsFramedStream { #[inline] pub async fn next(&mut self) -> Option> { + log::info!("test"); loop { match self.stream.next().await? { Ok(WsMessage::Binary(data)) => { @@ -180,6 +181,16 @@ impl WsFramedStream { } return Some(Ok(bytes)); } + Ok(WsMessage::Ping(ping)) => { + if let Err(e) = self.stream.send(WsMessage::Pong(ping)).await { + return Some(Err(Error::new( + ErrorKind::Other, + format!("Failed to send pong: {}", e), + ))); + } + continue; + } + Ok(WsMessage::Close(_)) => return None, Ok(_) => continue, Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), } From b1dd3bb9c8f66b27040bcd2ee80e7ae8a3109808 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Wed, 23 Apr 2025 08:44:02 +0800 Subject: [PATCH 06/22] [add log] add log and add ws scheme --- src/websocket.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/websocket.rs b/src/websocket.rs index e560cb6..094b770 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -69,7 +69,9 @@ impl WsFramedStream { send_timeout: ms_timeout, }) } else { - let (stream, _) = connect_async(url_str).await?; + log::info!("{:?}", url_str); + let ws_url = format!("ws://{}", url_str); + let (stream, _) = connect_async(ws_url).await?; let addr = match stream.get_ref() { MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, From 4ff800a8be1b0294a6bc803b9cc24ba479d54cc0 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Wed, 23 Apr 2025 09:00:29 +0800 Subject: [PATCH 07/22] [test] websocket should be determined in runtime. --- src/socket_client.rs | 6 +----- src/websocket.rs | 9 +++------ 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/socket_client.rs b/src/socket_client.rs index 926e70e..60695af 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -116,11 +116,7 @@ pub async fn connect_tcp_local< ) -> ResultType { let target_str = target.to_string(); - // if target_str.starts_with("ws://") || target_str.starts_with("wss://") { - // Ok(Stream::WebSocket( - // websocket::WsFramedStream::new(target_str, local, None, ms_timeout).await?, - // )) - if true { + if target_str.starts_with("ws://") || target_str.starts_with("wss://") { Ok(Stream::WebSocket( websocket::WsFramedStream::new(target_str, local, None, ms_timeout).await?, )) diff --git a/src/websocket.rs b/src/websocket.rs index 094b770..2790899 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -15,6 +15,7 @@ use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, }; +use tungstenite::client::IntoClientRequest; use tungstenite::protocol::Role; #[derive(Clone)] @@ -70,15 +71,11 @@ impl WsFramedStream { }) } else { log::info!("{:?}", url_str); - let ws_url = format!("ws://{}", url_str); - let (stream, _) = connect_async(ws_url).await?; + + let (stream, _) = connect_async(url_str.into_client_request().unwrap()).await?; let addr = match stream.get_ref() { MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, - #[cfg(feature = "native-tls")] - MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?, - #[cfg(feature = "rustls")] - MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; From 0d6948c97b15d9bcaa0467f74fa073718598d870 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Wed, 23 Apr 2025 09:58:20 +0800 Subject: [PATCH 08/22] [test] fix bug: twice enc. add ping pong log. --- src/websocket.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/websocket.rs b/src/websocket.rs index 2790899..73fe102 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -72,7 +72,18 @@ impl WsFramedStream { } else { log::info!("{:?}", url_str); - let (stream, _) = connect_async(url_str.into_client_request().unwrap()).await?; + let mut request = url_str + .into_client_request() + .map_err(|e| Error::new(ErrorKind::Other, e))?; + + // 添加必要协议头 + // request.headers_mut().insert( + // "Sec-WebSocket-Protocol", + // tungstenite::http::HeaderValue::from_static("rustdesk"), + // ); + + let (stream, _) = + timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??; let addr = match stream.get_ref() { MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, @@ -141,10 +152,6 @@ impl WsFramedStream { #[inline] pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { - let mut msg = msg; - if let Some(key) = self.encrypt.as_mut() { - msg = key.enc(&msg); - } self.send_bytes(bytes::Bytes::from(msg)).await } @@ -189,6 +196,10 @@ impl WsFramedStream { } continue; } + Ok(WsMessage::Pong(_)) => { + log::debug!("Received pong"); + continue; + } Ok(WsMessage::Close(_)) => return None, Ok(_) => continue, Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), From f31d1ec1b8d837d59108f9cd04f0aae6ec925c49 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Wed, 23 Apr 2025 10:42:34 +0800 Subject: [PATCH 09/22] [enhance] add ping pong send. --- src/websocket.rs | 161 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 116 insertions(+), 45 deletions(-) diff --git a/src/websocket.rs b/src/websocket.rs index 73fe102..c5ee839 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -5,12 +5,15 @@ use crate::{ ResultType, }; use bytes::{BufMut, Bytes, BytesMut}; +use futures::stream::SplitSink; use futures::{SinkExt, StreamExt}; +use std::sync::Arc; use std::{ io::{Error, ErrorKind}, net::SocketAddr, time::Duration, }; +use tokio::sync::Mutex; use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, @@ -22,7 +25,9 @@ use tungstenite::protocol::Role; pub struct Encrypt(Key, u64, u64); pub struct WsFramedStream { - stream: WebSocketStream>, + // stream: WebSocketStream>, + writer: Arc>, WsMessage>>>, + reader: futures::stream::SplitStream>>, addr: SocketAddr, encrypt: Option, send_timeout: u64, @@ -30,6 +35,9 @@ pub struct WsFramedStream { } impl WsFramedStream { + const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); + const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15); + pub async fn new>( url: T, local_addr: Option, @@ -63,16 +71,22 @@ impl WsFramedStream { _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; - Ok(Self { - stream: ws_stream, + let (writer, reader) = ws_stream.split(); + + let mut ws = Self { + writer: Arc::new(Mutex::new(writer)), + reader, addr, encrypt: None, send_timeout: ms_timeout, - }) + }; + + ws.start_heartbeat(); + Ok(ws) } else { log::info!("{:?}", url_str); - let mut request = url_str + let request = url_str .into_client_request() .map_err(|e| Error::new(ErrorKind::Other, e))?; @@ -90,15 +104,36 @@ impl WsFramedStream { _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; - Ok(Self { - stream, + let (writer, reader) = stream.split(); + let mut ws = Self { + writer: Arc::new(Mutex::new(writer)), + reader, addr, encrypt: None, send_timeout: ms_timeout, - }) + }; + + ws.start_heartbeat(); + Ok(ws) } } + fn start_heartbeat(&self) { + let writer = Arc::clone(&self.writer); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Self::HEARTBEAT_INTERVAL); + loop { + interval.tick().await; + let mut lock = writer.lock().await; + if let Err(e) = lock.send(WsMessage::Ping(Bytes::new())).await { + log::error!("Failed to send ping: {}", e); + break; + } + drop(lock); // 及时释放锁 + } + }); + } + pub fn set_raw(&mut self) {} pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { @@ -106,12 +141,14 @@ impl WsFramedStream { WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; + let (writer, reader) = ws_stream.split(); + Ok(Self { - stream: ws_stream, + writer: Arc::new(Mutex::new(writer)), + reader, addr, encrypt: None, send_timeout: 0, - // read_buf: BytesMut::new(), }) } @@ -120,12 +157,14 @@ impl WsFramedStream { WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; + let (writer, reader) = ws_stream.split(); + Self { - stream: ws_stream, + writer: Arc::new(Mutex::new(writer)), + reader, addr, encrypt: None, send_timeout: 0, - // read_buf: BytesMut::new(), } } @@ -157,52 +196,84 @@ impl WsFramedStream { #[inline] pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { - let msg = WsMessage::Binary(Bytes::from(bytes)); + let msg = WsMessage::Binary(bytes); + let mut writer = self.writer.lock().await; if self.send_timeout > 0 { - let send_future = self.stream.send(msg); - timeout(Duration::from_millis(self.send_timeout), send_future) - .await - .map_err(|_| Error::new(ErrorKind::TimedOut, "Send timeout"))? - .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; + timeout(Duration::from_millis(self.send_timeout), writer.send(msg)).await?? } else { - self.stream - .send(msg) - .await - .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; - } + writer.send(msg).await? + }; Ok(()) } #[inline] pub async fn next(&mut self) -> Option> { - log::info!("test"); + log::debug!("Waiting for next message"); + let start = std::time::Instant::now(); + loop { - match self.stream.next().await? { - Ok(WsMessage::Binary(data)) => { - let mut bytes = BytesMut::from(&data[..]); - if let Some(key) = self.encrypt.as_mut() { - if let Err(e) = key.dec(&mut bytes) { - return Some(Err(e)); + match self.reader.next().await { + Some(Ok(msg)) => { + log::debug!("Received message: {:?}", &msg); + match msg { + WsMessage::Binary(data) => { + log::info!("Received binary data ({} bytes)", data.len()); + let mut bytes = BytesMut::from(&data[..]); + if let Some(key) = self.encrypt.as_mut() { + log::debug!("Decrypting data with seq: {}", key.2); + match key.dec(&mut bytes) { + Ok(_) => { + log::debug!("Decryption successful"); + return Some(Ok(bytes)); + } + Err(e) => { + log::error!("Decryption failed: {}", e); + return Some(Err(e)); + } + } + } + return Some(Ok(bytes)); + } + WsMessage::Ping(ping) => { + log::info!("Received ping ({} bytes)", ping.len()); + let mut writer = self.writer.lock().await; + if let Err(e) = writer.send(WsMessage::Pong(ping)).await { + log::error!("Failed to send pong: {}", e); + return Some(Err(Error::new( + ErrorKind::Other, + format!("Failed to send pong: {}", e), + ))); + } + log::debug!("Pong sent"); + } + WsMessage::Pong(_) => { + log::debug!("Received pong"); + } + WsMessage::Close(frame) => { + log::info!("Connection closed: {:?}", frame); + return None; + } + _ => { + log::warn!("Unhandled message :{}", &msg); } } - return Some(Ok(bytes)); } - Ok(WsMessage::Ping(ping)) => { - if let Err(e) = self.stream.send(WsMessage::Pong(ping)).await { - return Some(Err(Error::new( - ErrorKind::Other, - format!("Failed to send pong: {}", e), - ))); - } - continue; + Some(Err(e)) => { + log::error!("WebSocket error: {}", e); + return Some(Err(Error::new( + ErrorKind::Other, + format!("Failed to send pong: {}", e), + ))); } - Ok(WsMessage::Pong(_)) => { - log::debug!("Received pong"); - continue; + None => { + log::info!("Connection closed gracefully"); + return None; } - Ok(WsMessage::Close(_)) => return None, - Ok(_) => continue, - Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), + } + + if start.elapsed() > Self::HEARTBEAT_TIMEOUT { + log::warn!("No message received within heartbeat timeout"); + return Some(Err(Error::new(ErrorKind::TimedOut, "Heartbeat timeout"))); } } } From 5c6b12c438c9a8340513fa7b74c25766e26fa198 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Wed, 23 Apr 2025 10:58:28 +0800 Subject: [PATCH 10/22] [enhance] add ping pong send (ver.2) --- src/websocket.rs | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/websocket.rs b/src/websocket.rs index c5ee839..a40b719 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -14,6 +14,7 @@ use std::{ time::Duration, }; use tokio::sync::Mutex; +use tokio::time::Instant; use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, @@ -34,10 +35,9 @@ pub struct WsFramedStream { // read_buf: BytesMut, } +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(3); +const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(10); impl WsFramedStream { - const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); - const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(15); - pub async fn new>( url: T, local_addr: Option, @@ -121,15 +121,26 @@ impl WsFramedStream { fn start_heartbeat(&self) { let writer = Arc::clone(&self.writer); tokio::spawn(async move { - let mut interval = tokio::time::interval(Self::HEARTBEAT_INTERVAL); + let mut last_pong = Instant::now(); + let mut interval = tokio::time::interval(HEARTBEAT_INTERVAL); + loop { - interval.tick().await; - let mut lock = writer.lock().await; - if let Err(e) = lock.send(WsMessage::Ping(Bytes::new())).await { - log::error!("Failed to send ping: {}", e); - break; + tokio::select! { + _ = interval.tick() => { + let mut lock = writer.lock().await; + if let Err(e) = lock.send(WsMessage::Ping(Bytes::new())).await { + log::error!("Heartbeat failed: {}", e); + break; + } + log::debug!("Sent ping"); + } + _ = tokio::time::sleep(HEARTBEAT_TIMEOUT) => { + if last_pong.elapsed() > HEARTBEAT_TIMEOUT { + log::error!("Heartbeat timeout"); + break; + } + } } - drop(lock); // 及时释放锁 } }); } @@ -271,7 +282,7 @@ impl WsFramedStream { } } - if start.elapsed() > Self::HEARTBEAT_TIMEOUT { + if start.elapsed() > HEARTBEAT_TIMEOUT { log::warn!("No message received within heartbeat timeout"); return Some(Err(Error::new(ErrorKind::TimedOut, "Heartbeat timeout"))); } From 836dbbc1449bf613c6d3be533e71c1543d9aa883 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Wed, 23 Apr 2025 11:42:53 +0800 Subject: [PATCH 11/22] [fallback] remove pingpong --- src/websocket.rs | 104 +++++++++++++++-------------------------------- 1 file changed, 32 insertions(+), 72 deletions(-) diff --git a/src/websocket.rs b/src/websocket.rs index a40b719..8ab25c7 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -5,16 +5,12 @@ use crate::{ ResultType, }; use bytes::{BufMut, Bytes, BytesMut}; -use futures::stream::SplitSink; use futures::{SinkExt, StreamExt}; -use std::sync::Arc; use std::{ io::{Error, ErrorKind}, net::SocketAddr, time::Duration, }; -use tokio::sync::Mutex; -use tokio::time::Instant; use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, @@ -26,9 +22,7 @@ use tungstenite::protocol::Role; pub struct Encrypt(Key, u64, u64); pub struct WsFramedStream { - // stream: WebSocketStream>, - writer: Arc>, WsMessage>>>, - reader: futures::stream::SplitStream>>, + stream: WebSocketStream>, addr: SocketAddr, encrypt: Option, send_timeout: u64, @@ -71,17 +65,14 @@ impl WsFramedStream { _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; - let (writer, reader) = ws_stream.split(); - let mut ws = Self { - writer: Arc::new(Mutex::new(writer)), - reader, + let ws = Self { + stream: ws_stream, addr, encrypt: None, send_timeout: ms_timeout, }; - ws.start_heartbeat(); Ok(ws) } else { log::info!("{:?}", url_str); @@ -104,47 +95,17 @@ impl WsFramedStream { _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; - let (writer, reader) = stream.split(); let mut ws = Self { - writer: Arc::new(Mutex::new(writer)), - reader, + stream, addr, encrypt: None, send_timeout: ms_timeout, }; - ws.start_heartbeat(); Ok(ws) } } - fn start_heartbeat(&self) { - let writer = Arc::clone(&self.writer); - tokio::spawn(async move { - let mut last_pong = Instant::now(); - let mut interval = tokio::time::interval(HEARTBEAT_INTERVAL); - - loop { - tokio::select! { - _ = interval.tick() => { - let mut lock = writer.lock().await; - if let Err(e) = lock.send(WsMessage::Ping(Bytes::new())).await { - log::error!("Heartbeat failed: {}", e); - break; - } - log::debug!("Sent ping"); - } - _ = tokio::time::sleep(HEARTBEAT_TIMEOUT) => { - if last_pong.elapsed() > HEARTBEAT_TIMEOUT { - log::error!("Heartbeat timeout"); - break; - } - } - } - } - }); - } - pub fn set_raw(&mut self) {} pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { @@ -152,11 +113,9 @@ impl WsFramedStream { WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; - let (writer, reader) = ws_stream.split(); Ok(Self { - writer: Arc::new(Mutex::new(writer)), - reader, + stream: ws_stream, addr, encrypt: None, send_timeout: 0, @@ -168,11 +127,9 @@ impl WsFramedStream { WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; - let (writer, reader) = ws_stream.split(); Self { - writer: Arc::new(Mutex::new(writer)), - reader, + stream: ws_stream, addr, encrypt: None, send_timeout: 0, @@ -208,11 +165,14 @@ impl WsFramedStream { #[inline] pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { let msg = WsMessage::Binary(bytes); - let mut writer = self.writer.lock().await; if self.send_timeout > 0 { - timeout(Duration::from_millis(self.send_timeout), writer.send(msg)).await?? + timeout( + Duration::from_millis(self.send_timeout), + self.stream.send(msg), + ) + .await?? } else { - writer.send(msg).await? + self.stream.send(msg).await? }; Ok(()) } @@ -223,7 +183,7 @@ impl WsFramedStream { let start = std::time::Instant::now(); loop { - match self.reader.next().await { + match self.stream.next().await { Some(Ok(msg)) => { log::debug!("Received message: {:?}", &msg); match msg { @@ -245,25 +205,25 @@ impl WsFramedStream { } return Some(Ok(bytes)); } - WsMessage::Ping(ping) => { - log::info!("Received ping ({} bytes)", ping.len()); - let mut writer = self.writer.lock().await; - if let Err(e) = writer.send(WsMessage::Pong(ping)).await { - log::error!("Failed to send pong: {}", e); - return Some(Err(Error::new( - ErrorKind::Other, - format!("Failed to send pong: {}", e), - ))); - } - log::debug!("Pong sent"); - } - WsMessage::Pong(_) => { - log::debug!("Received pong"); - } - WsMessage::Close(frame) => { - log::info!("Connection closed: {:?}", frame); - return None; - } + // WsMessage::Ping(ping) => { + // log::info!("Received ping ({} bytes)", ping.len()); + // let mut writer = self.writer.lock().await; + // if let Err(e) = writer.send(WsMessage::Pong(ping)).await { + // log::error!("Failed to send pong: {}", e); + // return Some(Err(Error::new( + // ErrorKind::Other, + // format!("Failed to send pong: {}", e), + // ))); + // } + // log::debug!("Pong sent"); + // } + // WsMessage::Pong(_) => { + // log::debug!("Received pong"); + // } + // WsMessage::Close(frame) => { + // log::info!("Connection closed: {:?}", frame); + // return None; + // } _ => { log::warn!("Unhandled message :{}", &msg); } From 13ffda490d8f9a63d9d8786a52a6cdc7565dca1e Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Wed, 23 Apr 2025 21:36:17 +0800 Subject: [PATCH 12/22] [enhance] remove time ticking. --- src/websocket.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/websocket.rs b/src/websocket.rs index 8ab25c7..9835a19 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -29,8 +29,6 @@ pub struct WsFramedStream { // read_buf: BytesMut, } -const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(3); -const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(10); impl WsFramedStream { pub async fn new>( url: T, @@ -180,7 +178,6 @@ impl WsFramedStream { #[inline] pub async fn next(&mut self) -> Option> { log::debug!("Waiting for next message"); - let start = std::time::Instant::now(); loop { match self.stream.next().await { @@ -242,10 +239,6 @@ impl WsFramedStream { } } - if start.elapsed() > HEARTBEAT_TIMEOUT { - log::warn!("No message received within heartbeat timeout"); - return Some(Err(Error::new(ErrorKind::TimedOut, "Heartbeat timeout"))); - } } } From bac2ffd31ed3040af8cec8fc9729c9291e61c1f8 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 24 Apr 2025 00:16:14 +0800 Subject: [PATCH 13/22] [enhance] rewrite websocket next. --- src/websocket.rs | 91 +++++++++++++++++++----------------------------- 1 file changed, 35 insertions(+), 56 deletions(-) diff --git a/src/websocket.rs b/src/websocket.rs index 9835a19..98c26f6 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -63,7 +63,6 @@ impl WsFramedStream { _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), }; - let ws = Self { stream: ws_stream, addr, @@ -111,7 +110,6 @@ impl WsFramedStream { WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; - Ok(Self { stream: ws_stream, addr, @@ -125,7 +123,6 @@ impl WsFramedStream { WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) .await; - Self { stream: ws_stream, addr, @@ -179,67 +176,49 @@ impl WsFramedStream { pub async fn next(&mut self) -> Option> { log::debug!("Waiting for next message"); - loop { - match self.stream.next().await { - Some(Ok(msg)) => { - log::debug!("Received message: {:?}", &msg); - match msg { - WsMessage::Binary(data) => { - log::info!("Received binary data ({} bytes)", data.len()); - let mut bytes = BytesMut::from(&data[..]); - if let Some(key) = self.encrypt.as_mut() { - log::debug!("Decrypting data with seq: {}", key.2); - match key.dec(&mut bytes) { - Ok(_) => { - log::debug!("Decryption successful"); - return Some(Ok(bytes)); - } - Err(e) => { - log::error!("Decryption failed: {}", e); - return Some(Err(e)); - } - } - } - return Some(Ok(bytes)); - } - // WsMessage::Ping(ping) => { - // log::info!("Received ping ({} bytes)", ping.len()); - // let mut writer = self.writer.lock().await; - // if let Err(e) = writer.send(WsMessage::Pong(ping)).await { - // log::error!("Failed to send pong: {}", e); - // return Some(Err(Error::new( - // ErrorKind::Other, - // format!("Failed to send pong: {}", e), - // ))); - // } - // log::debug!("Pong sent"); - // } - // WsMessage::Pong(_) => { - // log::debug!("Received pong"); - // } - // WsMessage::Close(frame) => { - // log::info!("Connection closed: {:?}", frame); - // return None; - // } - _ => { - log::warn!("Unhandled message :{}", &msg); - } - } - } - Some(Err(e)) => { - log::error!("WebSocket error: {}", e); + while let Some(msg) = self.stream.next().await { + let msg = match msg { + Ok(msg) => msg, + Err(e) => { return Some(Err(Error::new( ErrorKind::Other, - format!("Failed to send pong: {}", e), + format!("WebSocket protocol error: {}", e), ))); } - None => { - log::info!("Connection closed gracefully"); + }; + + log::debug!("Received message type: {}", msg.to_string()); + match msg { + WsMessage::Binary(data) => { + log::info!("Received binary data ({} bytes)", data.len()); + let mut bytes = BytesMut::from(&data[..]); + if let Some(key) = self.encrypt.as_mut() { + log::debug!("Decrypting data with seq: {}", key.2); + match key.dec(&mut bytes) { + Ok(_) => { + log::debug!("Decryption successful"); + return Some(Ok(bytes)); + } + Err(e) => { + log::error!("Decryption failed: {}", e); + return Some(Err(e)); + } + } + } + return Some(Ok(bytes)); + } + WsMessage::Close(_) => { + log::info!("Received close frame"); return None; } + _ => { + log::debug!("Unhandled message type: {}", msg.to_string()); + continue; + } } - } + + None } #[inline] From 29a322e6e30a934eb88021cb42cbd5af8499b7c7 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 24 Apr 2025 00:27:54 +0800 Subject: [PATCH 14/22] [test] temp stop proxy. --- src/websocket.rs | 99 +++++++++++++++++++++++------------------------- 1 file changed, 47 insertions(+), 52 deletions(-) diff --git a/src/websocket.rs b/src/websocket.rs index 98c26f6..dcf1a61 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -38,69 +38,63 @@ impl WsFramedStream { ) -> ResultType { let url_str = url.as_ref(); - if let Some(proxy_conf) = proxy_conf { - // use proxy connect - let url_obj = url::Url::parse(url_str)?; - let host = url_obj - .host_str() - .ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?; + // if let Some(proxy_conf) = proxy_conf { + // // use proxy connect + // let url_obj = url::Url::parse(url_str)?; + // let host = url_obj + // .host_str() + // .ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?; - let port = url_obj - .port() - .unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 }); + // let port = url_obj + // .port() + // .unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 }); - let socket = - tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port)) - .await?; + // let socket = + // tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port)) + // .await?; - let tcp_stream = socket.into_inner(); - let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream); - let ws_stream = - WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await; + // let tcp_stream = socket.into_inner(); + // let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream); + // let ws_stream = + // WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await; - let addr = match ws_stream.get_ref() { - MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, - _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), - }; + // let addr = match ws_stream.get_ref() { + // MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, + // _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), + // }; - let ws = Self { - stream: ws_stream, - addr, - encrypt: None, - send_timeout: ms_timeout, - }; + // let ws = Self { + // stream: ws_stream, + // addr, + // encrypt: None, + // send_timeout: ms_timeout, + // }; - Ok(ws) - } else { - log::info!("{:?}", url_str); + // Ok(ws) + // } else { + log::info!("{:?}", url_str); - let request = url_str - .into_client_request() - .map_err(|e| Error::new(ErrorKind::Other, e))?; + let request = url_str + .into_client_request() + .map_err(|e| Error::new(ErrorKind::Other, e))?; - // 添加必要协议头 - // request.headers_mut().insert( - // "Sec-WebSocket-Protocol", - // tungstenite::http::HeaderValue::from_static("rustdesk"), - // ); + let (stream, _) = + timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??; - let (stream, _) = - timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??; + let addr = match stream.get_ref() { + MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, + _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), + }; - let addr = match stream.get_ref() { - MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, - _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), - }; + let ws = Self { + stream, + addr, + encrypt: None, + send_timeout: ms_timeout, + }; - let mut ws = Self { - stream, - addr, - encrypt: None, - send_timeout: ms_timeout, - }; - - Ok(ws) - } + Ok(ws) + // } } pub fn set_raw(&mut self) {} @@ -180,6 +174,7 @@ impl WsFramedStream { let msg = match msg { Ok(msg) => msg, Err(e) => { + log::debug!("{}", e); return Some(Err(Error::new( ErrorKind::Other, format!("WebSocket protocol error: {}", e), From 7d5cc2ed470d70a0a26dab6affef91eefacc9a8d Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 24 Apr 2025 03:11:54 +0800 Subject: [PATCH 15/22] [enhance] test ver. Add text for test. --- src/lib.rs | 22 +--------------------- src/websocket.rs | 18 +++++------------- 2 files changed, 6 insertions(+), 34 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0b2b176..8316086 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,23 +117,6 @@ impl Stream { Stream::Tcp(s) => s.next_timeout(timeout).await, } } -} - -impl Stream { - /// establish connect from tcp. - pub async fn from_tcp( - stream: tokio::net::TcpStream, - addr: SocketAddr, - is_websocket: bool, - ) -> ResultType { - if is_websocket { - Ok(Self::WebSocket( - websocket::WsFramedStream::from(stream, addr).await, - )) - } else { - Ok(Self::Tcp(tcp::FramedStream::from(stream, addr))) - } - } /// establish connect from websocket pub async fn connect_websocket( @@ -144,6 +127,7 @@ impl Stream { ) -> ResultType { let ws_stream = websocket::WsFramedStream::new(url, local_addr, proxy_conf, timeout_ms).await?; + log::debug!("WebSocket connection established"); Ok(Self::WebSocket(ws_stream)) } @@ -169,10 +153,6 @@ impl Stream { Self::Tcp(tcp) => tcp.local_addr(), } } - - pub fn is_websocket(&self) -> bool { - matches!(self, Self::WebSocket(_)) - } } pub type SessionID = uuid::Uuid; diff --git a/src/websocket.rs b/src/websocket.rs index dcf1a61..0f60b15 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -112,19 +112,6 @@ impl WsFramedStream { }) } - pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self { - let ws_stream = - WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) - .await; - - Self { - stream: ws_stream, - addr, - encrypt: None, - send_timeout: 0, - } - } - pub fn local_addr(&self) -> SocketAddr { self.addr } @@ -202,6 +189,11 @@ impl WsFramedStream { } return Some(Ok(bytes)); } + WsMessage::Text(text) => { + log::debug!("Received text message, converting to binary"); + let bytes = BytesMut::from(text.as_bytes()); + return Some(Ok(bytes)); + } WsMessage::Close(_) => { log::info!("Received close frame"); return None; From e9813ffdd6839de8673ba4367899b34ef5b0d93c Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 24 Apr 2025 19:36:07 +0800 Subject: [PATCH 16/22] [test] add test debug info. --- src/websocket.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/websocket.rs b/src/websocket.rs index 0f60b15..94ee634 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -135,7 +135,7 @@ impl WsFramedStream { #[inline] pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { - self.send_bytes(bytes::Bytes::from(msg)).await + self.send_bytes(Bytes::from(msg)).await } #[inline] @@ -158,6 +158,7 @@ impl WsFramedStream { log::debug!("Waiting for next message"); while let Some(msg) = self.stream.next().await { + log::debug!("receive msg: {:?}", msg); let msg = match msg { Ok(msg) => msg, Err(e) => { From 7110634c095933660692b46370a0d230e87b3f86 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 24 Apr 2025 20:53:58 +0800 Subject: [PATCH 17/22] [debug] add debug. --- src/websocket.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/websocket.rs b/src/websocket.rs index 94ee634..22e7646 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -188,6 +188,7 @@ impl WsFramedStream { } } } + log::error!("not encrypt set."); return Some(Ok(bytes)); } WsMessage::Text(text) => { From 608eb5983fe6141cf71aa206e0fedd004d67dd79 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 24 Apr 2025 21:25:00 +0800 Subject: [PATCH 18/22] [bug fix] add logic --- src/tcp.rs | 2 +- src/websocket.rs | 51 +++--------------------------------------------- 2 files changed, 4 insertions(+), 49 deletions(-) diff --git a/src/tcp.rs b/src/tcp.rs index c85338d..8852c8d 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -25,7 +25,7 @@ pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {} pub struct DynTcpStream(pub(crate) Box); #[derive(Clone)] -pub struct Encrypt(Key, u64, u64); +pub struct Encrypt(pub Key, pub u64, pub u64); pub struct FramedStream( pub(crate) Framed, diff --git a/src/websocket.rs b/src/websocket.rs index 22e7646..434849b 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -1,3 +1,4 @@ +use crate::tcp::Encrypt; use crate::{ config::Socks5Server, protobuf::Message, @@ -18,9 +19,6 @@ use tokio_tungstenite::{ use tungstenite::client::IntoClientRequest; use tungstenite::protocol::Role; -#[derive(Clone)] -pub struct Encrypt(Key, u64, u64); - pub struct WsFramedStream { stream: WebSocketStream>, addr: SocketAddr, @@ -176,19 +174,10 @@ impl WsFramedStream { log::info!("Received binary data ({} bytes)", data.len()); let mut bytes = BytesMut::from(&data[..]); if let Some(key) = self.encrypt.as_mut() { - log::debug!("Decrypting data with seq: {}", key.2); - match key.dec(&mut bytes) { - Ok(_) => { - log::debug!("Decryption successful"); - return Some(Ok(bytes)); - } - Err(e) => { - log::error!("Decryption failed: {}", e); - return Some(Err(e)); - } + if let Err(err) = key.dec(&mut bytes) { + return Some(Err(err)); } } - log::error!("not encrypt set."); return Some(Ok(bytes)); } WsMessage::Text(text) => { @@ -218,37 +207,3 @@ impl WsFramedStream { } } } - -impl Encrypt { - pub fn new(key: Key) -> Self { - Self(key, 0, 0) - } - - pub fn dec(&mut self, bytes: &mut BytesMut) -> Result<(), Error> { - if bytes.len() <= 1 { - return Ok(()); - } - self.2 += 1; - let nonce = get_nonce(self.2); - match secretbox::open(bytes, &nonce, &self.0) { - Ok(res) => { - bytes.clear(); - bytes.put_slice(&res); - Ok(()) - } - Err(()) => Err(Error::new(ErrorKind::Other, "decryption error")), - } - } - - pub fn enc(&mut self, data: &[u8]) -> Vec { - self.1 += 1; - let nonce = get_nonce(self.1); - secretbox::seal(data, &nonce, &self.0) - } -} - -fn get_nonce(seqnum: u64) -> Nonce { - let mut nonce = Nonce([0u8; secretbox::NONCEBYTES]); - nonce.0[..std::mem::size_of_val(&seqnum)].copy_from_slice(&seqnum.to_le_bytes()); - nonce -} From 880365cab09dfe1378fc15eed550b383b71139e0 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 24 Apr 2025 21:49:39 +0800 Subject: [PATCH 19/22] [bug fix] add enc logic. --- src/websocket.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/websocket.rs b/src/websocket.rs index 434849b..dd794e0 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -95,7 +95,9 @@ impl WsFramedStream { // } } - pub fn set_raw(&mut self) {} + pub fn set_raw(&mut self) { + self.encrypt = None; + } pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { let ws_stream = @@ -133,6 +135,10 @@ impl WsFramedStream { #[inline] pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { + let mut msg = msg; + if let Some(key) = self.encrypt.as_mut() { + msg = key.enc(&msg); + } self.send_bytes(Bytes::from(msg)).await } From 3ef70f0e4d27485a9ca8be445e7497545429ed49 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Thu, 24 Apr 2025 22:27:16 +0800 Subject: [PATCH 20/22] [clean] test passed. - clean unused. --- src/lib.rs | 1 + src/websocket.rs | 44 +++----------------------------------------- 2 files changed, 4 insertions(+), 41 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8316086..a5d10de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,6 +60,7 @@ pub use flexi_logger; pub mod websocket; use sodiumoxide::crypto::secretbox::Key; +// support Websocket and tcp. pub enum Stream { WebSocket(websocket::WsFramedStream), Tcp(tcp::FramedStream), diff --git a/src/websocket.rs b/src/websocket.rs index dd794e0..4b2aa37 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -1,11 +1,8 @@ use crate::tcp::Encrypt; use crate::{ - config::Socks5Server, - protobuf::Message, - sodiumoxide::crypto::secretbox::{self, Key, Nonce}, - ResultType, + config::Socks5Server, protobuf::Message, sodiumoxide::crypto::secretbox::Key, ResultType, }; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use std::{ io::{Error, ErrorKind}, @@ -24,7 +21,6 @@ pub struct WsFramedStream { addr: SocketAddr, encrypt: Option, send_timeout: u64, - // read_buf: BytesMut, } impl WsFramedStream { @@ -36,40 +32,7 @@ impl WsFramedStream { ) -> ResultType { let url_str = url.as_ref(); - // if let Some(proxy_conf) = proxy_conf { - // // use proxy connect - // let url_obj = url::Url::parse(url_str)?; - // let host = url_obj - // .host_str() - // .ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?; - - // let port = url_obj - // .port() - // .unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 }); - - // let socket = - // tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port)) - // .await?; - - // let tcp_stream = socket.into_inner(); - // let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream); - // let ws_stream = - // WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await; - - // let addr = match ws_stream.get_ref() { - // MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, - // _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), - // }; - - // let ws = Self { - // stream: ws_stream, - // addr, - // encrypt: None, - // send_timeout: ms_timeout, - // }; - - // Ok(ws) - // } else { + // to-do: websocket proxy. log::info!("{:?}", url_str); let request = url_str @@ -92,7 +55,6 @@ impl WsFramedStream { }; Ok(ws) - // } } pub fn set_raw(&mut self) { From 6be5600b773b24c5b0de51780987a012ecaf6615 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Fri, 25 Apr 2025 11:03:16 +0800 Subject: [PATCH 21/22] [enhance] split into `stream.rs`. --- src/lib.rs | 99 +----------------------------------------------- src/stream.rs | 102 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 97 deletions(-) create mode 100644 src/stream.rs diff --git a/src/lib.rs b/src/lib.rs index a5d10de..b1770f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,103 +58,8 @@ pub use uuid; pub mod fingerprint; pub use flexi_logger; pub mod websocket; -use sodiumoxide::crypto::secretbox::Key; - -// support Websocket and tcp. -pub enum Stream { - WebSocket(websocket::WsFramedStream), - Tcp(tcp::FramedStream), -} - -impl Stream { - pub fn set_send_timeout(&mut self, ms: u64) { - match self { - Stream::WebSocket(s) => s.set_send_timeout(ms), - Stream::Tcp(s) => s.set_send_timeout(ms), - } - } - - pub fn set_raw(&mut self) { - match self { - Stream::WebSocket(s) => s.set_raw(), - Stream::Tcp(s) => s.set_raw(), - } - } - - pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> { - match self { - Stream::WebSocket(s) => s.send_bytes(bytes).await, - Stream::Tcp(s) => s.send_bytes(bytes).await, - } - } - - pub async fn send_raw(&mut self, bytes: Vec) -> ResultType<()> { - match self { - Stream::WebSocket(s) => s.send_raw(bytes).await, - Stream::Tcp(s) => s.send_raw(bytes).await, - } - } - - pub fn set_key(&mut self, key: Key) { - match self { - Stream::WebSocket(s) => s.set_key(key), - Stream::Tcp(s) => s.set_key(key), - } - } - - pub fn is_secured(&self) -> bool { - match self { - Stream::WebSocket(s) => s.is_secured(), - Stream::Tcp(s) => s.is_secured(), - } - } - - pub async fn next_timeout( - &mut self, - timeout: u64, - ) -> Option> { - match self { - Stream::WebSocket(s) => s.next_timeout(timeout).await, - Stream::Tcp(s) => s.next_timeout(timeout).await, - } - } - - /// establish connect from websocket - pub async fn connect_websocket( - url: impl AsRef, - local_addr: Option, - proxy_conf: Option<&config::Socks5Server>, - timeout_ms: u64, - ) -> ResultType { - let ws_stream = - websocket::WsFramedStream::new(url, local_addr, proxy_conf, timeout_ms).await?; - log::debug!("WebSocket connection established"); - Ok(Self::WebSocket(ws_stream)) - } - - /// send message - pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { - match self { - Self::WebSocket(ws) => ws.send(msg).await, - Self::Tcp(tcp) => tcp.send(msg).await, - } - } - - /// receive message - pub async fn next(&mut self) -> Option> { - match self { - Self::WebSocket(ws) => ws.next().await, - Self::Tcp(tcp) => tcp.next().await, - } - } - - pub fn local_addr(&self) -> SocketAddr { - match self { - Self::WebSocket(ws) => ws.local_addr(), - Self::Tcp(tcp) => tcp.local_addr(), - } - } -} +pub mod stream; +pub use stream::Stream; pub type SessionID = uuid::Uuid; diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..947fe5a --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,102 @@ +use crate::tcp; +use crate::websocket; +use sodiumoxide::crypto::secretbox::Key; +use crate::config; +use crate::ResultType; +use std::net::SocketAddr; + +// support Websocket and tcp. +pub enum Stream { + WebSocket(websocket::WsFramedStream), + Tcp(tcp::FramedStream), +} + +impl Stream { + pub fn set_send_timeout(&mut self, ms: u64) { + match self { + Stream::WebSocket(s) => s.set_send_timeout(ms), + Stream::Tcp(s) => s.set_send_timeout(ms), + } + } + + pub fn set_raw(&mut self) { + match self { + Stream::WebSocket(s) => s.set_raw(), + Stream::Tcp(s) => s.set_raw(), + } + } + + pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> { + match self { + Stream::WebSocket(s) => s.send_bytes(bytes).await, + Stream::Tcp(s) => s.send_bytes(bytes).await, + } + } + + pub async fn send_raw(&mut self, bytes: Vec) -> ResultType<()> { + match self { + Stream::WebSocket(s) => s.send_raw(bytes).await, + Stream::Tcp(s) => s.send_raw(bytes).await, + } + } + + pub fn set_key(&mut self, key: Key) { + match self { + Stream::WebSocket(s) => s.set_key(key), + Stream::Tcp(s) => s.set_key(key), + } + } + + pub fn is_secured(&self) -> bool { + match self { + Stream::WebSocket(s) => s.is_secured(), + Stream::Tcp(s) => s.is_secured(), + } + } + + pub async fn next_timeout( + &mut self, + timeout: u64, + ) -> Option> { + match self { + Stream::WebSocket(s) => s.next_timeout(timeout).await, + Stream::Tcp(s) => s.next_timeout(timeout).await, + } + } + + /// establish connect from websocket + pub async fn connect_websocket( + url: impl AsRef, + local_addr: Option, + proxy_conf: Option<&config::Socks5Server>, + timeout_ms: u64, + ) -> ResultType { + let ws_stream = + websocket::WsFramedStream::new(url, local_addr, proxy_conf, timeout_ms).await?; + log::debug!("WebSocket connection established"); + Ok(Self::WebSocket(ws_stream)) + } + + /// send message + pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { + match self { + Self::WebSocket(ws) => ws.send(msg).await, + Self::Tcp(tcp) => tcp.send(msg).await, + } + } + + /// receive message + pub async fn next(&mut self) -> Option> { + match self { + Self::WebSocket(ws) => ws.next().await, + Self::Tcp(tcp) => tcp.next().await, + } + } + + pub fn local_addr(&self) -> SocketAddr { + match self { + Self::WebSocket(ws) => ws.local_addr(), + Self::Tcp(tcp) => tcp.local_addr(), + } + } +} From d8f907a0d9fdcee4c0e21115fb9fe4cfdc95eaad Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Fri, 25 Apr 2025 11:11:53 +0800 Subject: [PATCH 22/22] [enhance] implement inline func and neat import. --- src/stream.rs | 16 ++++++++++++---- src/websocket.rs | 11 ++++++++--- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/stream.rs b/src/stream.rs index 947fe5a..fcac31d 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,8 +1,5 @@ -use crate::tcp; -use crate::websocket; +use crate::{config, tcp, websocket, ResultType}; use sodiumoxide::crypto::secretbox::Key; -use crate::config; -use crate::ResultType; use std::net::SocketAddr; // support Websocket and tcp. @@ -12,6 +9,7 @@ pub enum Stream { } impl Stream { + #[inline] pub fn set_send_timeout(&mut self, ms: u64) { match self { Stream::WebSocket(s) => s.set_send_timeout(ms), @@ -19,6 +17,7 @@ impl Stream { } } + #[inline] pub fn set_raw(&mut self) { match self { Stream::WebSocket(s) => s.set_raw(), @@ -26,6 +25,7 @@ impl Stream { } } + #[inline] pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> { match self { Stream::WebSocket(s) => s.send_bytes(bytes).await, @@ -33,6 +33,7 @@ impl Stream { } } + #[inline] pub async fn send_raw(&mut self, bytes: Vec) -> ResultType<()> { match self { Stream::WebSocket(s) => s.send_raw(bytes).await, @@ -40,6 +41,7 @@ impl Stream { } } + #[inline] pub fn set_key(&mut self, key: Key) { match self { Stream::WebSocket(s) => s.set_key(key), @@ -47,6 +49,7 @@ impl Stream { } } + #[inline] pub fn is_secured(&self) -> bool { match self { Stream::WebSocket(s) => s.is_secured(), @@ -54,6 +57,7 @@ impl Stream { } } + #[inline] pub async fn next_timeout( &mut self, timeout: u64, @@ -65,6 +69,7 @@ impl Stream { } /// establish connect from websocket + #[inline] pub async fn connect_websocket( url: impl AsRef, local_addr: Option, @@ -78,6 +83,7 @@ impl Stream { } /// send message + #[inline] pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { match self { Self::WebSocket(ws) => ws.send(msg).await, @@ -86,6 +92,7 @@ impl Stream { } /// receive message + #[inline] pub async fn next(&mut self) -> Option> { match self { Self::WebSocket(ws) => ws.next().await, @@ -93,6 +100,7 @@ impl Stream { } } + #[inline] pub fn local_addr(&self) -> SocketAddr { match self { Self::WebSocket(ws) => ws.local_addr(), diff --git a/src/websocket.rs b/src/websocket.rs index 4b2aa37..25f2f29 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -1,6 +1,6 @@ -use crate::tcp::Encrypt; use crate::{ - config::Socks5Server, protobuf::Message, sodiumoxide::crypto::secretbox::Key, ResultType, + config::Socks5Server, protobuf::Message, sodiumoxide::crypto::secretbox::Key, tcp::Encrypt, + ResultType, }; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; @@ -57,10 +57,12 @@ impl WsFramedStream { Ok(ws) } + #[inline] pub fn set_raw(&mut self) { self.encrypt = None; } + #[inline] pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { let ws_stream = WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) @@ -74,18 +76,22 @@ impl WsFramedStream { }) } + #[inline] pub fn local_addr(&self) -> SocketAddr { self.addr } + #[inline] pub fn set_send_timeout(&mut self, ms: u64) { self.send_timeout = ms; } + #[inline] pub fn set_key(&mut self, key: Key) { self.encrypt = Some(Encrypt::new(key)); } + #[inline] pub fn is_secured(&self) -> bool { self.encrypt.is_some() } @@ -104,7 +110,6 @@ impl WsFramedStream { self.send_bytes(Bytes::from(msg)).await } - #[inline] pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { let msg = WsMessage::Binary(bytes); if self.send_timeout > 0 {