diff --git a/Cargo.toml b/Cargo.toml index 3c6b72e..191f3c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,25 +46,36 @@ httparse = "1.10" base64 = "0.22" url = "2.5" sha2 = "0.10" -tokio-tungstenite = "0.26" +tokio-tungstenite = { version = "0.26.2" } +tungstenite = { version = "0.26.2" } + [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] mac_address = "1.1" default_net = { git = "https://github.com/rustdesk-org/default_net" } machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" } [target.'cfg(not(any(target_os = "macos", target_os = "windows")))'.dependencies] -tokio-rustls = { version = "0.26", features = ["logging", "tls12", "ring"], default-features = false } +tokio-rustls = { version = "0.26", features = [ + "logging", + "tls12", + "ring", +], default-features = false } rustls-platform-verifier = "0.5" rustls-pki-types = "1.11" [target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies] -tokio-native-tls ="0.3" +tokio-native-tls = "0.3" [build-dependencies] protobuf-codegen = { version = "3.7" } [target.'cfg(target_os = "windows")'.dependencies] -winapi = { version = "0.3", features = ["winuser", "synchapi", "pdh", "memoryapi", "sysinfoapi"] } +winapi = { version = "0.3", features = [ + "winuser", + "synchapi", + "pdh", + "memoryapi", + "sysinfoapi", +] } [target.'cfg(target_os = "macos")'.dependencies] osascript = "0.3" - diff --git a/src/lib.rs b/src/lib.rs index 9475b96..b1770f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,8 +57,10 @@ pub use toml; pub use uuid; pub mod fingerprint; pub use flexi_logger; +pub mod websocket; +pub mod stream; +pub use stream::Stream; -pub type Stream = tcp::FramedStream; pub type SessionID = uuid::Uuid; #[inline] diff --git a/src/proxy.rs b/src/proxy.rs index 34d2c51..e32778f 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -56,7 +56,7 @@ const MAXIMUM_RESPONSE_HEADERS: usize = 16; const DEFINE_TIME_OUT: u64 = 600; pub trait IntoUrl { - + // Besides parsing as a valid `Url`, the `Url` must be a valid // `http::Uri`, in that it makes sense to use in a network request. fn into_url(self) -> Result; diff --git a/src/socket_client.rs b/src/socket_client.rs index 4cb0bf2..60695af 100644 --- a/src/socket_client.rs +++ b/src/socket_client.rs @@ -2,7 +2,7 @@ use crate::{ config::{Config, NetworkType}, tcp::FramedStream, udp::FramedSocket, - ResultType, + websocket, ResultType, Stream, }; use anyhow::Context; use std::net::SocketAddr; @@ -102,7 +102,7 @@ pub async fn connect_tcp< >( target: T, ms_timeout: u64, -) -> ResultType { +) -> ResultType { connect_tcp_local(target, None, ms_timeout).await } @@ -113,19 +113,35 @@ pub async fn connect_tcp_local< target: T, local: Option, ms_timeout: u64, -) -> ResultType { - if let Some(conf) = Config::get_socks() { - return FramedStream::connect(target, local, &conf, ms_timeout).await; - } - if let Some(target) = target.resolve() { - if let Some(local) = local { - if local.is_ipv6() && target.is_ipv4() { - let target = query_nip_io(target).await?; - return FramedStream::new(target, Some(local), ms_timeout).await; +) -> ResultType { + let target_str = target.to_string(); + + if target_str.starts_with("ws://") || target_str.starts_with("wss://") { + Ok(Stream::WebSocket( + websocket::WsFramedStream::new(target_str, local, None, ms_timeout).await?, + )) + } else { + if let Some(conf) = Config::get_socks() { + return Ok(Stream::Tcp( + FramedStream::connect(target, local, &conf, ms_timeout).await?, + )); + } + + if let Some(target_addr) = target.resolve() { + if let Some(local_addr) = local { + if local_addr.is_ipv6() && target_addr.is_ipv4() { + let resolved_target = query_nip_io(target_addr).await?; + return Ok(Stream::Tcp( + FramedStream::new(resolved_target, Some(local_addr), ms_timeout).await?, + )); + } } } + + Ok(Stream::Tcp( + FramedStream::new(target, local, ms_timeout).await?, + )) } - FramedStream::new(target, local, ms_timeout).await } #[inline] diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..fcac31d --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,110 @@ +use crate::{config, tcp, websocket, ResultType}; +use sodiumoxide::crypto::secretbox::Key; +use std::net::SocketAddr; + +// support Websocket and tcp. +pub enum Stream { + WebSocket(websocket::WsFramedStream), + Tcp(tcp::FramedStream), +} + +impl Stream { + #[inline] + pub fn set_send_timeout(&mut self, ms: u64) { + match self { + Stream::WebSocket(s) => s.set_send_timeout(ms), + Stream::Tcp(s) => s.set_send_timeout(ms), + } + } + + #[inline] + pub fn set_raw(&mut self) { + match self { + Stream::WebSocket(s) => s.set_raw(), + Stream::Tcp(s) => s.set_raw(), + } + } + + #[inline] + pub async fn send_bytes(&mut self, bytes: bytes::Bytes) -> ResultType<()> { + match self { + Stream::WebSocket(s) => s.send_bytes(bytes).await, + Stream::Tcp(s) => s.send_bytes(bytes).await, + } + } + + #[inline] + pub async fn send_raw(&mut self, bytes: Vec) -> ResultType<()> { + match self { + Stream::WebSocket(s) => s.send_raw(bytes).await, + Stream::Tcp(s) => s.send_raw(bytes).await, + } + } + + #[inline] + pub fn set_key(&mut self, key: Key) { + match self { + Stream::WebSocket(s) => s.set_key(key), + Stream::Tcp(s) => s.set_key(key), + } + } + + #[inline] + pub fn is_secured(&self) -> bool { + match self { + Stream::WebSocket(s) => s.is_secured(), + Stream::Tcp(s) => s.is_secured(), + } + } + + #[inline] + pub async fn next_timeout( + &mut self, + timeout: u64, + ) -> Option> { + match self { + Stream::WebSocket(s) => s.next_timeout(timeout).await, + Stream::Tcp(s) => s.next_timeout(timeout).await, + } + } + + /// establish connect from websocket + #[inline] + pub async fn connect_websocket( + url: impl AsRef, + local_addr: Option, + proxy_conf: Option<&config::Socks5Server>, + timeout_ms: u64, + ) -> ResultType { + let ws_stream = + websocket::WsFramedStream::new(url, local_addr, proxy_conf, timeout_ms).await?; + log::debug!("WebSocket connection established"); + Ok(Self::WebSocket(ws_stream)) + } + + /// send message + #[inline] + pub async fn send(&mut self, msg: &impl protobuf::Message) -> ResultType<()> { + match self { + Self::WebSocket(ws) => ws.send(msg).await, + Self::Tcp(tcp) => tcp.send(msg).await, + } + } + + /// receive message + #[inline] + pub async fn next(&mut self) -> Option> { + match self { + Self::WebSocket(ws) => ws.next().await, + Self::Tcp(tcp) => tcp.next().await, + } + } + + #[inline] + pub fn local_addr(&self) -> SocketAddr { + match self { + Self::WebSocket(ws) => ws.local_addr(), + Self::Tcp(tcp) => tcp.local_addr(), + } + } +} diff --git a/src/tcp.rs b/src/tcp.rs index c85338d..8852c8d 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -25,7 +25,7 @@ pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {} pub struct DynTcpStream(pub(crate) Box); #[derive(Clone)] -pub struct Encrypt(Key, u64, u64); +pub struct Encrypt(pub Key, pub u64, pub u64); pub struct FramedStream( pub(crate) Framed, diff --git a/src/websocket.rs b/src/websocket.rs new file mode 100644 index 0000000..25f2f29 --- /dev/null +++ b/src/websocket.rs @@ -0,0 +1,182 @@ +use crate::{ + config::Socks5Server, protobuf::Message, sodiumoxide::crypto::secretbox::Key, tcp::Encrypt, + ResultType, +}; +use bytes::{Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use std::{ + io::{Error, ErrorKind}, + net::SocketAddr, + time::Duration, +}; +use tokio::{net::TcpStream, time::timeout}; +use tokio_tungstenite::{ + connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, +}; +use tungstenite::client::IntoClientRequest; +use tungstenite::protocol::Role; + +pub struct WsFramedStream { + stream: WebSocketStream>, + addr: SocketAddr, + encrypt: Option, + send_timeout: u64, +} + +impl WsFramedStream { + pub async fn new>( + url: T, + local_addr: Option, + proxy_conf: Option<&Socks5Server>, + ms_timeout: u64, + ) -> ResultType { + let url_str = url.as_ref(); + + // to-do: websocket proxy. + log::info!("{:?}", url_str); + + let request = url_str + .into_client_request() + .map_err(|e| Error::new(ErrorKind::Other, e))?; + + let (stream, _) = + timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??; + + let addr = match stream.get_ref() { + MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?, + _ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()), + }; + + let ws = Self { + stream, + addr, + encrypt: None, + send_timeout: ms_timeout, + }; + + Ok(ws) + } + + #[inline] + pub fn set_raw(&mut self) { + self.encrypt = None; + } + + #[inline] + pub async fn from_tcp_stream(stream: TcpStream, addr: SocketAddr) -> ResultType { + let ws_stream = + WebSocketStream::from_raw_socket(MaybeTlsStream::Plain(stream), Role::Client, None) + .await; + + Ok(Self { + stream: ws_stream, + addr, + encrypt: None, + send_timeout: 0, + }) + } + + #[inline] + pub fn local_addr(&self) -> SocketAddr { + self.addr + } + + #[inline] + pub fn set_send_timeout(&mut self, ms: u64) { + self.send_timeout = ms; + } + + #[inline] + pub fn set_key(&mut self, key: Key) { + self.encrypt = Some(Encrypt::new(key)); + } + + #[inline] + pub fn is_secured(&self) -> bool { + self.encrypt.is_some() + } + + #[inline] + pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> { + self.send_raw(msg.write_to_bytes()?).await + } + + #[inline] + pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { + let mut msg = msg; + if let Some(key) = self.encrypt.as_mut() { + msg = key.enc(&msg); + } + self.send_bytes(Bytes::from(msg)).await + } + + pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { + let msg = WsMessage::Binary(bytes); + if self.send_timeout > 0 { + timeout( + Duration::from_millis(self.send_timeout), + self.stream.send(msg), + ) + .await?? + } else { + self.stream.send(msg).await? + }; + Ok(()) + } + + #[inline] + pub async fn next(&mut self) -> Option> { + log::debug!("Waiting for next message"); + + while let Some(msg) = self.stream.next().await { + log::debug!("receive msg: {:?}", msg); + let msg = match msg { + Ok(msg) => msg, + Err(e) => { + log::debug!("{}", e); + return Some(Err(Error::new( + ErrorKind::Other, + format!("WebSocket protocol error: {}", e), + ))); + } + }; + + log::debug!("Received message type: {}", msg.to_string()); + match msg { + WsMessage::Binary(data) => { + log::info!("Received binary data ({} bytes)", data.len()); + let mut bytes = BytesMut::from(&data[..]); + if let Some(key) = self.encrypt.as_mut() { + if let Err(err) = key.dec(&mut bytes) { + return Some(Err(err)); + } + } + return Some(Ok(bytes)); + } + WsMessage::Text(text) => { + log::debug!("Received text message, converting to binary"); + let bytes = BytesMut::from(text.as_bytes()); + return Some(Ok(bytes)); + } + WsMessage::Close(_) => { + log::info!("Received close frame"); + return None; + } + _ => { + log::debug!("Unhandled message type: {}", msg.to_string()); + continue; + } + } + } + + None + } + + #[inline] + pub async fn next_timeout(&mut self, ms: u64) -> Option> { + match timeout(Duration::from_millis(ms), self.next()).await { + Ok(res) => res, + Err(_) => None, + } + } +}