diff --git a/src/lib.rs b/src/lib.rs index f9cfd02..5348977 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,14 +58,123 @@ pub use uuid; pub mod fingerprint; pub use flexi_logger; pub mod websocket; +use sodiumoxide::crypto::secretbox::Key; +pub enum Stream { + WebSocket(websocket::WsFramedStream), + Tcp(tcp::FramedStream), +} -#[cfg(feature = "websocket")] -pub type Stream = websocket::WsFramedStream; +impl Stream { + pub fn set_send_timeout(&mut self, ms: u64) { + match self { + Stream::WebSocket(s) => s.set_send_timeout(ms), + Stream::Tcp(s) => s.set_send_timeout(ms), + } + } + pub fn set_raw(&mut self) { + match self { + Stream::WebSocket(s) => s.set_raw(), + Stream::Tcp(s) => s.set_raw(), + } + } -#[cfg(not(feature = "websocket"))] -pub type Stream = tcp::FramedStream; + pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> { + match self { + Stream::WebSocket(s) => s.send_bytes(bytes).await, + Stream::Tcp(s) => s.send_bytes(bytes).await, + } + } + + pub async fn send_raw(&mut self, bytes: Vec) -> ResultType<()> { + match self { + Stream::WebSocket(s) => s.send_raw(bytes).await, + Stream::Tcp(s) => s.send_raw(bytes).await, + } + } + + pub fn set_key(&mut self, key: Key) { + match self { + Stream::WebSocket(s) => s.set_key(key), + Stream::Tcp(s) => s.set_key(key), + } + } + + pub fn is_secured(&self) -> bool { + match self { + Stream::WebSocket(s) => s.is_secured(), + Stream::Tcp(s) => s.is_secured(), + } + } + + pub async fn next_timeout( + &mut self, + timeout: u64, + ) -> Option> { + match self { + Stream::WebSocket(s) => s.next_timeout(timeout).await, + Stream::Tcp(s) => s.next_timeout(timeout).await, + } + } +} + +impl Stream { + /// 从 TCP 流创建(自动判断是否升级为 WebSocket) + pub async fn from_tcp( + stream: tokio::net::TcpStream, + addr: SocketAddr, + is_websocket: bool, + ) -> ResultType { + if is_websocket { + Ok(Self::WebSocket( + websocket::WsFramedStream::from(stream, addr).await, + )) + } else { + Ok(Self::Tcp(tcp::FramedStream::from(stream, addr))) + } + } + + /// 创建 WebSocket 客户端连接 + pub async fn connect_websocket( + url: impl AsRef, + local_addr: Option, + proxy_conf: Option<&config::Socks5Server>, + timeout_ms: u64, + ) -> ResultType { + let ws_stream = + websocket::WsFramedStream::new(url, local_addr, proxy_conf, timeout_ms).await?; + Ok(Self::WebSocket(ws_stream)) + } + + /// 发送消息 + pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { + match self { + Self::WebSocket(ws) => ws.send(msg).await, + Self::Tcp(tcp) => tcp.send(msg).await, + } + } + + /// 接收消息 + pub async fn next(&mut self) -> Option> { + match self { + Self::WebSocket(ws) => ws.next().await, + Self::Tcp(tcp) => tcp.next().await, + } + } + + // 其他必要的方法... + pub fn local_addr(&self) -> SocketAddr { + match self { + Self::WebSocket(ws) => ws.local_addr(), + Self::Tcp(tcp) => tcp.local_addr(), + } + } + + pub fn is_websocket(&self) -> bool { + matches!(self, Self::WebSocket(_)) + } +} pub type SessionID = uuid::Uuid; diff --git a/src/socket_client.rs b/src/socket_client.rs index 1c39c01..3b6457a 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -2,7 +2,7 @@ use crate::{ config::{Config, NetworkType}, tcp::FramedStream, udp::FramedSocket, - websocket, ResultType, + websocket, ResultType, Stream, }; use anyhow::Context; use std::net::SocketAddr; @@ -113,11 +113,11 @@ pub async fn connect_tcp_local< target: T, local: Option, ms_timeout: u64, -) -> ResultType { +) -> ResultType { #[cfg(feature = "websocket")] { let url = format!("ws://{}", target); - websocket::WsFramedStream::new(url, local, None, ms_timeout).await + Ok(Stream::WebSocket(websocket::WsFramedStream::new(url, local, None, ms_timeout).await?)) } #[cfg(not(feature = "websocket"))] {