[fix bug] fix twice enc.

- fix twice enc.
- enable proxy (temp ver.)
- rewite websocket::next.
This commit is contained in:
YinMo19 2025-04-22 02:55:17 +08:00
parent 2d65c24e4b
commit d299e4909f
3 changed files with 71 additions and 69 deletions

View File

@ -120,7 +120,7 @@ impl Stream {
}
impl Stream {
/// 从 TCP 流创建(自动判断是否升级为 WebSocket
/// establish connect from tcp.
pub async fn from_tcp(
stream: tokio::net::TcpStream,
addr: SocketAddr,
@ -135,7 +135,7 @@ impl Stream {
}
}
/// 创建 WebSocket 客户端连接
/// establish connect from websocket
pub async fn connect_websocket(
url: impl AsRef<str>,
local_addr: Option<SocketAddr>,
@ -147,7 +147,7 @@ impl Stream {
Ok(Self::WebSocket(ws_stream))
}
/// 发送消息
/// send message
pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> {
match self {
Self::WebSocket(ws) => ws.send(msg).await,
@ -155,7 +155,7 @@ impl Stream {
}
}
/// 接收消息
/// receive message
pub async fn next(&mut self) -> Option<Result<bytes::BytesMut, std::io::Error>> {
match self {
Self::WebSocket(ws) => ws.next().await,
@ -163,7 +163,6 @@ impl Stream {
}
}
// 其他必要的方法...
pub fn local_addr(&self) -> SocketAddr {
match self {
Self::WebSocket(ws) => ws.local_addr(),

View File

@ -56,7 +56,7 @@ const MAXIMUM_RESPONSE_HEADERS: usize = 16;
const DEFINE_TIME_OUT: u64 = 600;
pub trait IntoUrl {
// Besides parsing as a valid `Url`, the `Url` must be a valid
// `http::Uri`, in that it makes sense to use in a network request.
fn into_url(self) -> Result<Url, ProxyError>;

View File

@ -11,7 +11,6 @@ use std::{
net::SocketAddr,
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream};
use tokio::{net::TcpStream, time::timeout};
use tokio_tungstenite::{
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream,
@ -26,41 +25,71 @@ pub struct WsFramedStream {
addr: SocketAddr,
encrypt: Option<Encrypt>,
send_timeout: u64,
read_buf: BytesMut,
// read_buf: BytesMut,
}
impl WsFramedStream {
pub async fn new<T: AsRef<str>>(
url: T,
local_addr: Option<SocketAddr>,
_proxy_conf: Option<&Socks5Server>,
proxy_conf: Option<&Socks5Server>,
ms_timeout: u64,
) -> ResultType<Self> {
let (stream, _) = connect_async(url.as_ref()).await?;
let url_str = url.as_ref();
// 获取底层TCP流的peer_addr
let addr = match stream.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?,
#[cfg(feature = "rustls")]
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
// 处理其他可能的情况
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
};
if let Some(proxy_conf) = proxy_conf {
// use proxy connect
let url_obj = url::Url::parse(url_str)?;
let host = url_obj
.host_str()
.ok_or_else(|| Error::new(ErrorKind::Other, "Invalid URL: no host"))?;
Ok(Self {
stream,
addr,
encrypt: None,
send_timeout: ms_timeout,
read_buf: BytesMut::new(),
})
let port = url_obj
.port()
.unwrap_or(if url_obj.scheme() == "wss" { 443 } else { 80 });
let socket =
tokio_socks::tcp::Socks5Stream::connect(proxy_conf.proxy.as_str(), (host, port))
.await?;
let tcp_stream = socket.into_inner();
let maybe_tls_stream = MaybeTlsStream::Plain(tcp_stream);
let ws_stream =
WebSocketStream::from_raw_socket(maybe_tls_stream, Role::Client, None).await;
let addr = match ws_stream.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
};
Ok(Self {
stream: ws_stream,
addr,
encrypt: None,
send_timeout: ms_timeout,
})
} else {
let (stream, _) = connect_async(url_str).await?;
let addr = match stream.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(tls) => tls.get_ref().peer_addr()?,
#[cfg(feature = "rustls")]
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
};
Ok(Self {
stream,
addr,
encrypt: None,
send_timeout: ms_timeout,
})
}
}
pub fn set_raw(&mut self) {
// WebSocket不需要特殊处理保持空实现
}
pub fn set_raw(&mut self) {}
pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType<Self> {
let ws_stream =
@ -72,24 +101,21 @@ impl WsFramedStream {
addr,
encrypt: None,
send_timeout: 0,
read_buf: BytesMut::new(),
// read_buf: BytesMut::new(),
})
}
pub async fn from(stream: TcpStream, addr: SocketAddr) -> Self {
let ws_stream = WebSocketStream::from_raw_socket(
MaybeTlsStream::Plain(stream), // 包装为MaybeTlsStream
Role::Client,
None,
)
.await;
let ws_stream =
WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None)
.await;
Self {
stream: ws_stream,
addr,
encrypt: None,
send_timeout: 0,
read_buf: BytesMut::new(),
// read_buf: BytesMut::new(),
}
}
@ -125,14 +151,7 @@ impl WsFramedStream {
#[inline]
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
// 转换为Vec<u8>时需要处理加密
let data = if let Some(key) = self.encrypt.as_mut() {
key.enc(&bytes.to_vec())
} else {
bytes.to_vec()
};
let msg = WsMessage::Binary(Bytes::from(data));
let msg = WsMessage::Binary(Bytes::from(bytes));
if self.send_timeout > 0 {
let send_future = self.stream.send(msg);
timeout(Duration::from_millis(self.send_timeout), send_future)
@ -151,38 +170,22 @@ impl WsFramedStream {
#[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
loop {
if let Some((frame, _)) = self.read_buf.split_first() {
if let Some(decrypted) = self.try_decrypt() {
return Some(Ok(decrypted));
}
}
match self.stream.next().await? {
Ok(WsMessage::Binary(data)) => {
self.read_buf.extend_from_slice(&data);
if let Some(decrypted) = self.try_decrypt() {
return Some(Ok(decrypted));
let mut bytes = BytesMut::from(&data[..]);
if let Some(key) = self.encrypt.as_mut() {
if let Err(e) = key.dec(&mut bytes) {
return Some(Err(e));
}
}
return Some(Ok(bytes));
}
Ok(_) => continue, // 忽略非二进制消息
Ok(_) => continue,
Err(e) => return Some(Err(Error::new(ErrorKind::Other, e))),
}
}
}
fn try_decrypt(&mut self) -> Option<BytesMut> {
if let Some(key) = self.encrypt.as_mut() {
if let Ok(()) = key.dec(&mut self.read_buf) {
let data = self.read_buf.split();
return Some(data);
}
} else {
let data = self.read_buf.split();
return Some(data);
}
None
}
#[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
match timeout(Duration::from_millis(ms), self.next()).await {