From d299e4909f8b74db8daef47c14c01719396d89f3 Mon Sep 17 00:00:00 2001 From: YinMo19 Date: Tue, 22 Apr 2025 02:55:17 +0800 Subject: [PATCH] [fix bug] fix twice enc. - fix twice enc. - enable proxy (temp ver.) - rewite websocket::next. --- src/lib.rs | 9 ++-- src/proxy.rs | 2 +- src/websocket.rs | 129 ++++++++++++++++++++++++----------------------- 3 files changed, 71 insertions(+), 69 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5348977..0b2b176 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,7 +120,7 @@ impl Stream { } impl Stream { - /// 从 TCP 流创建(自动判断是否升级为 WebSocket) + /// establish connect from tcp. pub async fn from_tcp( stream: tokio::net::TcpStream, addr: SocketAddr, @@ -135,7 +135,7 @@ impl Stream { } } - /// 创建 WebSocket 客户端连接 + /// establish connect from websocket pub async fn connect_websocket( url: impl AsRef, local_addr: Option, @@ -147,7 +147,7 @@ impl Stream { Ok(Self::WebSocket(ws_stream)) } - /// 发送消息 + /// send message pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { match self { Self::WebSocket(ws) => ws.send(msg).await, @@ -155,7 +155,7 @@ impl Stream { } } - /// 接收消息 + /// receive message pub async fn next(&mut self) -> Option> { match self { Self::WebSocket(ws) => ws.next().await, @@ -163,7 +163,6 @@ impl Stream { } } - // 其他必要的方法... pub fn local_addr(&self) -> SocketAddr { match self { Self::WebSocket(ws) => ws.local_addr(), diff --git a/src/proxy.rs b/src/proxy.rs index 34d2c51..e32778f 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -56,7 +56,7 @@ const MAXIMUM_RESPONSE_HEADERS: usize = 16; const DEFINE_TIME_OUT: u64 = 600; pub trait IntoUrl { - + // Besides parsing as a valid `Url`, the `Url` must be a valid // `http::Uri`, in that it makes sense to use in a network request. fn into_url(self) -> Result; diff --git a/src/websocket.rs b/src/websocket.rs index d44efb9..274791f 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -11,7 +11,6 @@ use std::{ net::SocketAddr, time::Duration, }; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream}; use tokio::{net::TcpStream, time::timeout}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, @@ -26,41 +25,71 @@ pub struct WsFramedStream { addr: SocketAddr, encrypt: Option, send_timeout: u64, - read_buf: BytesMut, + // read_buf: BytesMut, } impl WsFramedStream { pub async fn new>( url: T, local_addr: Option, - _proxy_conf: Option<&Socks5Server>, + proxy_conf: Option<&Socks5Server>, ms_timeout: u64, ) -> ResultType { - let (stream, _) = connect_async(url.as_ref()).await?; + let url_str = url.as_ref(); - // 获取底层TCP流的peer_addr - let addr = match stream.get_ref() { - MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, - #[cfg(feature = "native-tls")] - MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?, - #[cfg(feature = "rustls")] - MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, - // 处理其他可能的情况 - _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), - }; + if let Some(proxy_conf) = proxy_conf { + // use proxy connect + let url_obj = url::Url::parse(url_str)?; + let host = url_obj + .host_str() + .ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?; - Ok(Self { - stream, - addr, - encrypt: None, - send_timeout: ms_timeout, - read_buf: BytesMut::new(), - }) + let port = url_obj + .port() + .unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 }); + + let socket = + tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port)) + .await?; + + let tcp_stream = socket.into_inner(); + let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream); + let ws_stream = + WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await; + + let addr = match ws_stream.get_ref() { + MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, + _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), + }; + + Ok(Self { + stream: ws_stream, + addr, + encrypt: None, + send_timeout: ms_timeout, + }) + } else { + let (stream, _) = connect_async(url_str).await?; + + let addr = match stream.get_ref() { + MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, + #[cfg(feature = "native-tls")] + MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?, + #[cfg(feature = "rustls")] + MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, + _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), + }; + + Ok(Self { + stream, + addr, + encrypt: None, + send_timeout: ms_timeout, + }) + } } - pub fn set_raw(&mut self) { - // WebSocket不需要特殊处理,保持空实现 - } + pub fn set_raw(&mut self) {} pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { let ws_stream = @@ -72,24 +101,21 @@ impl WsFramedStream { addr, encrypt: None, send_timeout: 0, - read_buf: BytesMut::new(), + // read_buf: BytesMut::new(), }) } pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self { - let ws_stream = WebSocketStream::from_raw_socket( - MaybeTlsStream::Plain(stream), // 包装为MaybeTlsStream - Role::Client, - None, - ) - .await; + let ws_stream = + WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) + .await; Self { stream: ws_stream, addr, encrypt: None, send_timeout: 0, - read_buf: BytesMut::new(), + // read_buf: BytesMut::new(), } } @@ -125,14 +151,7 @@ impl WsFramedStream { #[inline] pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { - // 转换为Vec时需要处理加密 - let data = if let Some(key) = self.encrypt.as_mut() { - key.enc(&bytes.to_vec()) - } else { - bytes.to_vec() - }; - - let msg = WsMessage::Binary(Bytes::from(data)); + let msg = WsMessage::Binary(Bytes::from(bytes)); if self.send_timeout > 0 { let send_future = self.stream.send(msg); timeout(Duration::from_millis(self.send_timeout), send_future) @@ -151,38 +170,22 @@ impl WsFramedStream { #[inline] pub async fn next(&mut self) -> Option> { loop { - if let Some((frame, _)) = self.read_buf.split_first() { - if let Some(decrypted) = self.try_decrypt() { - return Some(Ok(decrypted)); - } - } - match self.stream.next().await? { Ok(WsMessage::Binary(data)) => { - self.read_buf.extend_from_slice(&data); - if let Some(decrypted) = self.try_decrypt() { - return Some(Ok(decrypted)); + let mut bytes = BytesMut::from(&data[..]); + if let Some(key) = self.encrypt.as_mut() { + if let Err(e) = key.dec(&mut bytes) { + return Some(Err(e)); + } } + return Some(Ok(bytes)); } - Ok(_) => continue, // 忽略非二进制消息 + Ok(_) => continue, Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), } } } - fn try_decrypt(&mut self) -> Option { - if let Some(key) = self.encrypt.as_mut() { - if let Ok(()) = key.dec(&mut self.read_buf) { - let data = self.read_buf.split(); - return Some(data); - } - } else { - let data = self.read_buf.split(); - return Some(data); - } - None - } - #[inline] pub async fn next_timeout(&mut self, ms: u64) -> Option> { match timeout(Duration::from_millis(ms), self.next()).await {