[fix bug] fix twice enc.

- fix twice enc.
- enable proxy (temp ver.)
- rewite websocket::next.
This commit is contained in:
YinMo19 2025-04-22 02:55:17 +08:00
parent 2d65c24e4b
commit d299e4909f
3 changed files with 71 additions and 69 deletions

View File

@ -120,7 +120,7 @@ impl Stream {
} }
impl Stream { impl Stream {
/// 从 TCP 流创建(自动判断是否升级为 WebSocket /// establish connect from tcp.
pub async fn from_tcp( pub async fn from_tcp(
stream: tokio::net::TcpStream, stream: tokio::net::TcpStream,
addr: SocketAddr, addr: SocketAddr,
@ -135,7 +135,7 @@ impl Stream {
} }
} }
/// 创建 WebSocket 客户端连接 /// establish connect from websocket
pub async fn connect_websocket( pub async fn connect_websocket(
url: impl AsRef<str>, url: impl AsRef<str>,
local_addr: Option<SocketAddr>, local_addr: Option<SocketAddr>,
@ -147,7 +147,7 @@ impl Stream {
Ok(Self::WebSocket(ws_stream)) Ok(Self::WebSocket(ws_stream))
} }
/// 发送消息 /// send message
pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> {
match self { match self {
Self::WebSocket(ws) => ws.send(msg).await, Self::WebSocket(ws) => ws.send(msg).await,
@ -155,7 +155,7 @@ impl Stream {
} }
} }
/// 接收消息 /// receive message
pub async fn next(&mut self) -> Option<Result<bytes::BytesMut, std::io::Error>> { pub async fn next(&mut self) -> Option<Result<bytes::BytesMut, std::io::Error>> {
match self { match self {
Self::WebSocket(ws) => ws.next().await, Self::WebSocket(ws) => ws.next().await,
@ -163,7 +163,6 @@ impl Stream {
} }
} }
// 其他必要的方法...
pub fn local_addr(&self) -> SocketAddr { pub fn local_addr(&self) -> SocketAddr {
match self { match self {
Self::WebSocket(ws) => ws.local_addr(), Self::WebSocket(ws) => ws.local_addr(),

View File

@ -11,7 +11,6 @@ use std::{
net::SocketAddr, net::SocketAddr,
time::Duration, time::Duration,
}; };
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream};
use tokio::{net::TcpStream, time::timeout}; use tokio::{net::TcpStream, time::timeout};
use tokio_tungstenite::{ use tokio_tungstenite::{
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream,
@ -26,26 +25,58 @@ pub struct WsFramedStream {
addr: SocketAddr, addr: SocketAddr,
encrypt: Option<Encrypt>, encrypt: Option<Encrypt>,
send_timeout: u64, send_timeout: u64,
read_buf: BytesMut, // read_buf: BytesMut,
} }
impl WsFramedStream { impl WsFramedStream {
pub async fn new<T: AsRef<str>>( pub async fn new<T: AsRef<str>>(
url: T, url: T,
local_addr: Option<SocketAddr>, local_addr: Option<SocketAddr>,
_proxy_conf: Option<&Socks5Server>, proxy_conf: Option<&Socks5Server>,
ms_timeout: u64, ms_timeout: u64,
) -> ResultType<Self> { ) -> ResultType<Self> {
let (stream, _) = connect_async(url.as_ref()).await?; let url_str = url.as_ref();
if let Some(proxy_conf) = proxy_conf {
// use proxy connect
let url_obj = url::Url::parse(url_str)?;
let host = url_obj
.host_str()
.ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?;
let port = url_obj
.port()
.unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 });
let socket =
tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port))
.await?;
let tcp_stream = socket.into_inner();
let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream);
let ws_stream =
WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await;
let addr = match ws_stream.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
};
Ok(Self {
stream: ws_stream,
addr,
encrypt: None,
send_timeout: ms_timeout,
})
} else {
let (stream, _) = connect_async(url_str).await?;
// 获取底层TCP流的peer_addr
let addr = match stream.get_ref() { let addr = match stream.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
#[cfg(feature = "native-tls")] #[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?, MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?,
#[cfg(feature = "rustls")] #[cfg(feature = "rustls")]
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?, MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
// 处理其他可能的情况
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
}; };
@ -54,14 +85,12 @@ impl WsFramedStream {
addr, addr,
encrypt: None, encrypt: None,
send_timeout: ms_timeout, send_timeout: ms_timeout,
read_buf: BytesMut::new(),
}) })
} }
pub fn set_raw(&mut self) {
// WebSocket不需要特殊处理保持空实现
} }
pub fn set_raw(&mut self) {}
pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> { pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> {
let ws_stream = let ws_stream =
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Server, None) WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Server, None)
@ -72,16 +101,13 @@ impl WsFramedStream {
addr, addr,
encrypt: None, encrypt: None,
send_timeout: 0, send_timeout: 0,
read_buf: BytesMut::new(), // read_buf: BytesMut::new(),
}) })
} }
pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self { pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self {
let ws_stream = WebSocketStream::from_raw_socket( let ws_stream =
MaybeTlsStream::Plain(stream), // 包装为MaybeTlsStream WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None)
Role::Client,
None,
)
.await; .await;
Self { Self {
@ -89,7 +115,7 @@ impl WsFramedStream {
addr, addr,
encrypt: None, encrypt: None,
send_timeout: 0, send_timeout: 0,
read_buf: BytesMut::new(), // read_buf: BytesMut::new(),
} }
} }
@ -125,14 +151,7 @@ impl WsFramedStream {
#[inline] #[inline]
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
// 转换为Vec<u8>时需要处理加密 let msg = WsMessage::Binary(Bytes::from(bytes));
let data = if let Some(key) = self.encrypt.as_mut() {
key.enc(&bytes.to_vec())
} else {
bytes.to_vec()
};
let msg = WsMessage::Binary(Bytes::from(data));
if self.send_timeout > 0 { if self.send_timeout > 0 {
let send_future = self.stream.send(msg); let send_future = self.stream.send(msg);
timeout(Duration::from_millis(self.send_timeout), send_future) timeout(Duration::from_millis(self.send_timeout), send_future)
@ -151,38 +170,22 @@ impl WsFramedStream {
#[inline] #[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> { pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
loop { loop {
if let Some((frame, _)) = self.read_buf.split_first() {
if let Some(decrypted) = self.try_decrypt() {
return Some(Ok(decrypted));
}
}
match self.stream.next().await? { match self.stream.next().await? {
Ok(WsMessage::Binary(data)) => { Ok(WsMessage::Binary(data)) => {
self.read_buf.extend_from_slice(&data); let mut bytes = BytesMut::from(&data[..]);
if let Some(decrypted) = self.try_decrypt() { if let Some(key) = self.encrypt.as_mut() {
return Some(Ok(decrypted)); if let Err(e) = key.dec(&mut bytes) {
return Some(Err(e));
} }
} }
Ok(_) => continue, // 忽略非二进制消息 return Some(Ok(bytes));
}
Ok(_) => continue,
Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))), Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))),
} }
} }
} }
fn try_decrypt(&mut self) -> Option<BytesMut> {
if let Some(key) = self.encrypt.as_mut() {
if let Ok(()) = key.dec(&mut self.read_buf) {
let data = self.read_buf.split();
return Some(data);
}
} else {
let data = self.read_buf.split();
return Some(data);
}
None
}
#[inline] #[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> { pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
match timeout(Duration::from_millis(ms), self.next()).await { match timeout(Duration::from_millis(ms), self.next()).await {