mirror of
https://github.com/rustdesk/hbb_common.git
synced 2026-02-16 02:20:43 +00:00
refact: tls native-tls fallback rustls-tls
Signed-off-by: fufesou <linlong1266@gmail.com>
This commit is contained in:
24
Cargo.toml
24
Cargo.toml
@@ -48,30 +48,24 @@ url = "2.5"
|
||||
sha2 = "0.10"
|
||||
whoami = "1.5"
|
||||
|
||||
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
|
||||
mac_address = "1.1"
|
||||
default_net = { git = "https://github.com/rustdesk-org/default_net" }
|
||||
machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" }
|
||||
|
||||
[target.'cfg(not(any(target_os = "macos", target_os = "windows")))'.dependencies]
|
||||
tokio-rustls = { version = "0.26", features = [
|
||||
"logging",
|
||||
"tls12",
|
||||
"ring",
|
||||
], default-features = false }
|
||||
tokio-native-tls = "0.3"
|
||||
tokio-tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
|
||||
tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
|
||||
rustls-platform-verifier = "0.6"
|
||||
rustls-pki-types = "1.11"
|
||||
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
|
||||
tungstenite = { version = "0.26", features = ["rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
|
||||
rustls-native-certs = "0.8"
|
||||
webpki-roots = "1.0"
|
||||
async-recursion = "1.1"
|
||||
|
||||
[target.'cfg(any(target_os = "android", target_os = "ios"))'.dependencies]
|
||||
rustls-platform-verifier = "0.6"
|
||||
|
||||
[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies]
|
||||
tokio-native-tls = "0.3"
|
||||
tokio-tungstenite = { version = "0.26", features = ["native-tls"] }
|
||||
tungstenite = { version = "0.26", features = ["native-tls"] }
|
||||
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
|
||||
mac_address = "1.1"
|
||||
default_net = { git = "https://github.com/rustdesk-org/default_net" }
|
||||
machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" }
|
||||
|
||||
[build-dependencies]
|
||||
protobuf-codegen = { version = "3.7" }
|
||||
|
||||
@@ -2417,6 +2417,11 @@ pub fn use_ws() -> bool {
|
||||
option2bool(option, &Config::get_option(option))
|
||||
}
|
||||
|
||||
pub fn allow_insecure_tls_fallback() -> bool {
|
||||
let option = keys::OPTION_ALLOW_INSECURE_TLS_FALLBACK;
|
||||
option2bool(option, &Config::get_option(option))
|
||||
}
|
||||
|
||||
pub mod keys {
|
||||
pub const OPTION_VIEW_ONLY: &str = "view_only";
|
||||
pub const OPTION_SHOW_MONITORS_TOOLBAR: &str = "show_monitors_toolbar";
|
||||
@@ -2513,14 +2518,16 @@ pub mod keys {
|
||||
pub const OPTION_TRACKPAD_SPEED: &str = "trackpad-speed";
|
||||
pub const OPTION_REGISTER_DEVICE: &str = "register-device";
|
||||
pub const OPTION_RELAY_SERVER: &str = "relay-server";
|
||||
pub const OPTION_DISABLE_UDP: &str = "disable-udp";
|
||||
pub const OPTION_ALLOW_INSECURE_TLS_FALLBACK: &str = "allow-insecure-tls-fallback";
|
||||
pub const OPTION_SHOW_VIRTUAL_MOUSE: &str = "show-virtual-mouse";
|
||||
// joystick is the virtual mouse.
|
||||
// So `OPTION_SHOW_VIRTUAL_MOUSE` should also be set if `OPTION_SHOW_VIRTUAL_JOYSTICK` is set.
|
||||
pub const OPTION_SHOW_VIRTUAL_JOYSTICK: &str = "show-virtual-joystick";
|
||||
pub const OPTION_ENABLE_FLUTTER_HTTP_ON_RUST: &str = "enable-flutter-http-on-rust";
|
||||
|
||||
// built-in options
|
||||
pub const OPTION_DISPLAY_NAME: &str = "display-name";
|
||||
pub const OPTION_DISABLE_UDP: &str = "disable-udp";
|
||||
pub const OPTION_PRESET_DEVICE_GROUP_NAME: &str = "preset-device-group-name";
|
||||
pub const OPTION_PRESET_USERNAME: &str = "preset-user-name";
|
||||
pub const OPTION_PRESET_STRATEGY_NAME: &str = "preset-strategy-name";
|
||||
@@ -2651,6 +2658,7 @@ pub mod keys {
|
||||
OPTION_TOUCH_MODE,
|
||||
OPTION_SHOW_VIRTUAL_MOUSE,
|
||||
OPTION_SHOW_VIRTUAL_JOYSTICK,
|
||||
OPTION_ENABLE_FLUTTER_HTTP_ON_RUST,
|
||||
];
|
||||
// DEFAULT_SETTINGS, OVERWRITE_SETTINGS
|
||||
pub const KEYS_SETTINGS: &[&str] = &[
|
||||
@@ -2703,12 +2711,12 @@ pub mod keys {
|
||||
OPTION_ENABLE_ANDROID_SOFTWARE_ENCODING_HALF_SCALE,
|
||||
OPTION_ENABLE_TRUSTED_DEVICES,
|
||||
OPTION_RELAY_SERVER,
|
||||
OPTION_DISABLE_UDP,
|
||||
];
|
||||
|
||||
// BUILDIN_SETTINGS
|
||||
pub const KEYS_BUILDIN_SETTINGS: &[&str] = &[
|
||||
OPTION_DISPLAY_NAME,
|
||||
OPTION_DISABLE_UDP,
|
||||
OPTION_PRESET_DEVICE_GROUP_NAME,
|
||||
OPTION_PRESET_USERNAME,
|
||||
OPTION_PRESET_STRATEGY_NAME,
|
||||
|
||||
@@ -63,8 +63,9 @@ pub mod websocket;
|
||||
pub use rustls_platform_verifier;
|
||||
pub use stream::Stream;
|
||||
pub use whoami;
|
||||
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
|
||||
pub mod tls;
|
||||
pub mod verifier;
|
||||
pub use async_recursion;
|
||||
|
||||
pub type SessionID = uuid::Uuid;
|
||||
|
||||
|
||||
241
src/proxy.rs
241
src/proxy.rs
@@ -3,16 +3,15 @@ use std::{
|
||||
net::{SocketAddr, ToSocketAddrs},
|
||||
};
|
||||
|
||||
use anyhow::bail;
|
||||
use async_recursion::async_recursion;
|
||||
use base64::{engine::general_purpose, Engine};
|
||||
use httparse::{Error as HttpParseError, Response, EMPTY_HEADER};
|
||||
use log::info;
|
||||
use thiserror::Error as ThisError;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream};
|
||||
#[cfg(any(target_os = "windows", target_os = "macos"))]
|
||||
use tokio_native_tls::{native_tls, TlsConnector, TlsStream};
|
||||
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr};
|
||||
use tokio_rustls::{client::TlsStream as RustlsTlsStream, TlsConnector as RustlsTlsConnector};
|
||||
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, TargetAddr};
|
||||
use tokio_util::codec::Framed;
|
||||
use url::Url;
|
||||
|
||||
@@ -20,6 +19,7 @@ use crate::{
|
||||
bytes_codec::BytesCodec,
|
||||
config::Socks5Server,
|
||||
tcp::{DynTcpStream, FramedStream},
|
||||
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
|
||||
ResultType,
|
||||
};
|
||||
|
||||
@@ -45,7 +45,6 @@ pub enum ProxyError {
|
||||
HttpCode200(u16),
|
||||
#[error("The proxy address resolution failed: {0}")]
|
||||
AddressResolutionFailed(String),
|
||||
#[cfg(any(target_os = "windows", target_os = "macos"))]
|
||||
#[error("The native tls error: {0}")]
|
||||
NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
|
||||
}
|
||||
@@ -226,7 +225,7 @@ impl ProxyScheme {
|
||||
Ok(scheme)
|
||||
}
|
||||
pub async fn socket_addrs(&self) -> Result<SocketAddr, ProxyError> {
|
||||
info!("Resolving socket address");
|
||||
log::trace!("Resolving socket address");
|
||||
match self {
|
||||
ProxyScheme::Http { host, .. } => self.resolve_host(host, 80).await,
|
||||
ProxyScheme::Https { host, .. } => self.resolve_host(host, 443).await,
|
||||
@@ -356,37 +355,50 @@ impl Proxy {
|
||||
self
|
||||
}
|
||||
|
||||
async fn new_stream(
|
||||
&self,
|
||||
local: SocketAddr,
|
||||
proxy: SocketAddr,
|
||||
) -> ResultType<tokio::net::TcpStream> {
|
||||
let stream = super::timeout(
|
||||
self.ms_timeout,
|
||||
crate::tcp::new_socket(local, true)?.connect(proxy),
|
||||
)
|
||||
.await??;
|
||||
stream.set_nodelay(true).ok();
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub async fn connect<'t, T>(
|
||||
self,
|
||||
&self,
|
||||
target: T,
|
||||
local_addr: Option<SocketAddr>,
|
||||
) -> ResultType<FramedStream>
|
||||
where
|
||||
T: IntoTargetAddr<'t>,
|
||||
{
|
||||
info!("Connect to proxy server");
|
||||
log::trace!("Connect to proxy server");
|
||||
let proxy = self.proxy_addrs().await?;
|
||||
|
||||
let target_addr = target
|
||||
.into_target_addr()
|
||||
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
|
||||
|
||||
let local = if let Some(addr) = local_addr {
|
||||
addr
|
||||
} else {
|
||||
crate::config::Config::get_any_listen_addr(proxy.is_ipv4())
|
||||
};
|
||||
|
||||
let stream = super::timeout(
|
||||
self.ms_timeout,
|
||||
crate::tcp::new_socket(local, true)?.connect(proxy),
|
||||
)
|
||||
.await??;
|
||||
stream.set_nodelay(true).ok();
|
||||
|
||||
let stream = self.new_stream(local, proxy).await?;
|
||||
let addr = stream.local_addr()?;
|
||||
|
||||
return match self.intercept {
|
||||
ProxyScheme::Http { .. } => {
|
||||
info!("Connect to remote http proxy server: {}", proxy);
|
||||
log::trace!("Connect to remote http proxy server: {}", proxy);
|
||||
let stream =
|
||||
super::timeout(self.ms_timeout, self.http_connect(stream, target)).await??;
|
||||
super::timeout(self.ms_timeout, self.http_connect(stream, &target_addr))
|
||||
.await??;
|
||||
Ok(FramedStream(
|
||||
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
|
||||
addr,
|
||||
@@ -395,24 +407,54 @@ impl Proxy {
|
||||
))
|
||||
}
|
||||
ProxyScheme::Https { .. } => {
|
||||
info!("Connect to remote https proxy server: {}", proxy);
|
||||
let stream =
|
||||
super::timeout(self.ms_timeout, self.https_connect(stream, target)).await??;
|
||||
log::trace!("Connect to remote https proxy server: {}", proxy);
|
||||
let url = format!("https://{}", self.intercept.get_host_and_port()?);
|
||||
let tls_type = get_cached_tls_type(&url);
|
||||
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
|
||||
let stream = match tls_type.unwrap_or(TlsType::NativeTls) {
|
||||
TlsType::NativeTls => {
|
||||
self.https_connect_nativetls_wrap_danger(
|
||||
&url,
|
||||
local,
|
||||
proxy,
|
||||
Some(stream),
|
||||
&target_addr,
|
||||
tls_type.is_some(),
|
||||
danger_accept_invalid_cert,
|
||||
danger_accept_invalid_cert,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
TlsType::Rustls => {
|
||||
self.https_connect_rustls_wrap_danger(
|
||||
&url,
|
||||
local,
|
||||
proxy,
|
||||
&target_addr,
|
||||
danger_accept_invalid_cert,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
_ => {
|
||||
// Unreachable
|
||||
crate::bail!("Unreachable, TlsType::Plain in HTTPS proxy");
|
||||
}
|
||||
};
|
||||
Ok(FramedStream(
|
||||
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
|
||||
Framed::new(stream, BytesCodec::new()),
|
||||
addr,
|
||||
None,
|
||||
0,
|
||||
))
|
||||
}
|
||||
ProxyScheme::Socks5 { .. } => {
|
||||
info!("Connect to remote socket5 proxy server: {}", proxy);
|
||||
log::trace!("Connect to remote socket5 proxy server: {}", proxy);
|
||||
let stream = if let Some(auth) = self.intercept.maybe_auth() {
|
||||
super::timeout(
|
||||
self.ms_timeout,
|
||||
Socks5Stream::connect_with_password_and_socket(
|
||||
stream,
|
||||
target,
|
||||
target_addr,
|
||||
&auth.user_name,
|
||||
&auth.password,
|
||||
),
|
||||
@@ -421,7 +463,7 @@ impl Proxy {
|
||||
} else {
|
||||
super::timeout(
|
||||
self.ms_timeout,
|
||||
Socks5Stream::connect_with_socket(stream, target),
|
||||
Socks5Stream::connect_with_socket(stream, target_addr),
|
||||
)
|
||||
.await??
|
||||
};
|
||||
@@ -435,32 +477,133 @@ impl Proxy {
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "windows", target_os = "macos"))]
|
||||
pub async fn https_connect<'a, Input, T>(
|
||||
self,
|
||||
#[async_recursion]
|
||||
async fn https_connect_nativetls_wrap_danger<'a>(
|
||||
&self,
|
||||
url: &str,
|
||||
local: SocketAddr,
|
||||
proxy: SocketAddr,
|
||||
stream: Option<tokio::net::TcpStream>,
|
||||
target_addr: &TargetAddr<'a>,
|
||||
is_tls_type_cached: bool,
|
||||
danger_accept_invalid_cert: Option<bool>,
|
||||
origin_danger_accept_invalid_cert: Option<bool>,
|
||||
) -> ResultType<DynTcpStream> {
|
||||
let stream = stream.unwrap_or(self.new_stream(local, proxy).await?);
|
||||
match super::timeout(
|
||||
self.ms_timeout,
|
||||
self.https_connect_nativetls(
|
||||
stream,
|
||||
target_addr,
|
||||
danger_accept_invalid_cert.unwrap_or(false),
|
||||
),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Ok(s) => {
|
||||
upsert_tls_cache(
|
||||
&url,
|
||||
TlsType::NativeTls,
|
||||
danger_accept_invalid_cert.unwrap_or(false),
|
||||
);
|
||||
Ok(DynTcpStream(Box::new(s)))
|
||||
}
|
||||
Err(ProxyError::NativeTlsError(e)) => {
|
||||
let s = if danger_accept_invalid_cert.is_none() {
|
||||
log::warn!(
|
||||
"Falling back to native-tls (accept invalid cert) for HTTPS proxy server."
|
||||
);
|
||||
self.https_connect_nativetls_wrap_danger(
|
||||
&url,
|
||||
local,
|
||||
proxy,
|
||||
None,
|
||||
target_addr,
|
||||
is_tls_type_cached,
|
||||
Some(true),
|
||||
origin_danger_accept_invalid_cert,
|
||||
)
|
||||
.await?
|
||||
} else if !is_tls_type_cached {
|
||||
log::warn!("Falling back to rustls for HTTPS proxy server.");
|
||||
self.https_connect_rustls_wrap_danger(
|
||||
&url,
|
||||
local,
|
||||
proxy,
|
||||
&target_addr,
|
||||
origin_danger_accept_invalid_cert,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
log::error!(
|
||||
"Failed to connect to HTTPS proxy server with native-tls: {:?}.",
|
||||
e
|
||||
);
|
||||
bail!(e)
|
||||
};
|
||||
Ok(s)
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to connect to HTTPS proxy server: {:?}.", e);
|
||||
bail!(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn https_connect_nativetls<'a, Input>(
|
||||
&self,
|
||||
io: Input,
|
||||
target: T,
|
||||
target_addr: &TargetAddr<'a>,
|
||||
danger_accept_invalid_cert: bool,
|
||||
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
|
||||
where
|
||||
Input: AsyncRead + AsyncWrite + Unpin,
|
||||
T: IntoTargetAddr<'a>,
|
||||
{
|
||||
let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
|
||||
let mut tls_connector_builder = native_tls::TlsConnector::builder();
|
||||
if danger_accept_invalid_cert {
|
||||
tls_connector_builder.danger_accept_invalid_certs(true);
|
||||
}
|
||||
let tls_connector = TlsConnector::from(tls_connector_builder.build()?);
|
||||
let stream = tls_connector
|
||||
.connect(&self.intercept.get_domain()?, io)
|
||||
.await?;
|
||||
self.http_connect(stream, target).await
|
||||
self.http_connect(stream, target_addr).await
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
|
||||
pub async fn https_connect<'a, Input, T>(
|
||||
self,
|
||||
async fn https_connect_rustls_wrap_danger<'a>(
|
||||
&self,
|
||||
url: &str,
|
||||
local: SocketAddr,
|
||||
proxy: SocketAddr,
|
||||
target_addr: &TargetAddr<'a>,
|
||||
danger_accept_invalid_cert: Option<bool>,
|
||||
) -> ResultType<DynTcpStream> {
|
||||
let stream = self.new_stream(local, proxy).await?;
|
||||
let s = super::timeout(
|
||||
self.ms_timeout,
|
||||
self.https_connect_rustls(
|
||||
stream,
|
||||
&target_addr,
|
||||
danger_accept_invalid_cert.unwrap_or(false),
|
||||
),
|
||||
)
|
||||
.await??;
|
||||
upsert_tls_cache(
|
||||
url,
|
||||
TlsType::Rustls,
|
||||
danger_accept_invalid_cert.unwrap_or(false),
|
||||
);
|
||||
Ok(DynTcpStream(Box::new(s)))
|
||||
}
|
||||
|
||||
pub async fn https_connect_rustls<'a, Input>(
|
||||
&self,
|
||||
io: Input,
|
||||
target: T,
|
||||
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
|
||||
target_addr: &TargetAddr<'a>,
|
||||
danger_accept_invalid_cert: bool,
|
||||
) -> Result<BufStream<RustlsTlsStream<Input>>, ProxyError>
|
||||
where
|
||||
Input: AsyncRead + AsyncWrite + Unpin,
|
||||
T: IntoTargetAddr<'a>,
|
||||
{
|
||||
use std::convert::TryFrom;
|
||||
|
||||
@@ -468,24 +611,23 @@ impl Proxy {
|
||||
let domain = rustls_pki_types::ServerName::try_from(url_domain.as_str())
|
||||
.map_err(|e| ProxyError::AddressResolutionFailed(e.to_string()))?
|
||||
.to_owned();
|
||||
let client_config = crate::verifier::client_config()
|
||||
let client_config = crate::verifier::client_config(danger_accept_invalid_cert)
|
||||
.map_err(|e| ProxyError::IoError(std::io::Error::other(e)))?;
|
||||
let tls_connector = TlsConnector::from(std::sync::Arc::new(client_config));
|
||||
let tls_connector = RustlsTlsConnector::from(std::sync::Arc::new(client_config));
|
||||
let stream = tls_connector.connect(domain, io).await?;
|
||||
self.http_connect(stream, target).await
|
||||
self.http_connect(stream, target_addr).await
|
||||
}
|
||||
|
||||
pub async fn http_connect<'a, Input, T>(
|
||||
self,
|
||||
pub async fn http_connect<'a, Input>(
|
||||
&self,
|
||||
io: Input,
|
||||
target: T,
|
||||
target_addr: &TargetAddr<'a>,
|
||||
) -> Result<BufStream<Input>, ProxyError>
|
||||
where
|
||||
Input: AsyncRead + AsyncWrite + Unpin,
|
||||
T: IntoTargetAddr<'a>,
|
||||
{
|
||||
let mut stream = BufStream::new(io);
|
||||
let (domain, port) = get_domain_and_port(target)?;
|
||||
let (domain, port) = get_domain_and_port(target_addr)?;
|
||||
|
||||
let request = self.make_request(&domain, port);
|
||||
stream.write_all(request.as_bytes()).await?;
|
||||
@@ -510,13 +652,10 @@ impl Proxy {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_domain_and_port<'a, T: IntoTargetAddr<'a>>(target: T) -> Result<(String, u16), ProxyError> {
|
||||
let target_addr = target
|
||||
.into_target_addr()
|
||||
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
|
||||
fn get_domain_and_port<'a>(target_addr: &TargetAddr<'a>) -> Result<(String, u16), ProxyError> {
|
||||
match target_addr {
|
||||
tokio_socks::TargetAddr::Ip(addr) => Ok((addr.ip().to_string(), addr.port())),
|
||||
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), port)),
|
||||
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), *port)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
121
src/tls.rs
Normal file
121
src/tls.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use std::{collections::HashMap, sync::RwLock};
|
||||
|
||||
use crate::config::allow_insecure_tls_fallback;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsType {
|
||||
Plain,
|
||||
NativeTls,
|
||||
Rustls,
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref URL_TLS_TYPE: RwLock<HashMap<String, TlsType>> = RwLock::new(HashMap::new());
|
||||
static ref URL_TLS_DANGER_ACCEPT_INVALID_CERTS: RwLock<HashMap<String, bool>> = RwLock::new(HashMap::new());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_plain(url: &str) -> bool {
|
||||
url.starts_with("ws://") || url.starts_with("http://")
|
||||
}
|
||||
|
||||
// Extract domain from URL.
|
||||
// e.g., "https://example.com/path" -> "example.com"
|
||||
// "https://example.com:8080/path" -> "example.com:8080"
|
||||
// See the tests for more examples.
|
||||
#[inline]
|
||||
fn get_domain_and_port_from_url(url: &str) -> &str {
|
||||
// Remove scheme (e.g., http://, https://, ws://, wss://)
|
||||
let scheme_end = url.find("://").map(|pos| pos + 3).unwrap_or(0);
|
||||
let url2 = &url[scheme_end..];
|
||||
// If userinfo is present, domain is after last '@'
|
||||
let after_at = match url2.rfind('@') {
|
||||
Some(pos) => &url2[pos + 1..],
|
||||
None => url2,
|
||||
};
|
||||
// Find the end of domain (before '/' or '?')
|
||||
let domain_end = after_at.find(&['/', '?'][..]).unwrap_or(after_at.len());
|
||||
&after_at[..domain_end]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn upsert_tls_cache(url: &str, tls_type: TlsType, danger_accept_invalid_cert: bool) {
|
||||
if is_plain(url) {
|
||||
return;
|
||||
}
|
||||
|
||||
let domain_port = get_domain_and_port_from_url(url);
|
||||
// Use curly braces to ensure the lock is released immediately.
|
||||
{
|
||||
URL_TLS_TYPE
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(domain_port.to_string(), tls_type);
|
||||
}
|
||||
{
|
||||
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(domain_port.to_string(), danger_accept_invalid_cert);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn reset_tls_cache() {
|
||||
// Use curly braces to ensure the lock is released immediately.
|
||||
{
|
||||
URL_TLS_TYPE.write().unwrap().clear();
|
||||
}
|
||||
{
|
||||
URL_TLS_DANGER_ACCEPT_INVALID_CERTS.write().unwrap().clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_cached_tls_type(url: &str) -> Option<TlsType> {
|
||||
if is_plain(url) {
|
||||
return Some(TlsType::Plain);
|
||||
}
|
||||
let domain_port = get_domain_and_port_from_url(url);
|
||||
URL_TLS_TYPE.read().unwrap().get(domain_port).cloned()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_cached_tls_accept_invalid_cert(url: &str) -> Option<bool> {
|
||||
if !allow_insecure_tls_fallback() {
|
||||
return Some(false);
|
||||
}
|
||||
|
||||
if is_plain(url) {
|
||||
return Some(false);
|
||||
}
|
||||
let domain_port = get_domain_and_port_from_url(url);
|
||||
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(domain_port)
|
||||
.cloned()
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_domain_and_port_from_url() {
|
||||
for (url, expected_domain_port) in vec![
|
||||
("http://example.com", "example.com"),
|
||||
("https://example.com", "example.com"),
|
||||
("ws://example.com/path", "example.com"),
|
||||
("wss://example.com:8080/path", "example.com:8080"),
|
||||
("https://user:pass@example.com", "example.com"),
|
||||
("https://example.com?query=param", "example.com"),
|
||||
("https://example.com:8443?query=param", "example.com:8443"),
|
||||
("ftp://example.com/resource", "example.com"), // ftp scheme
|
||||
("example.com/path", "example.com"), // no scheme
|
||||
("example.com:8080/path", "example.com:8080"),
|
||||
] {
|
||||
let domain_port = get_domain_and_port_from_url(url);
|
||||
assert_eq!(domain_port, expected_domain_port);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,65 @@
|
||||
use crate::ResultType;
|
||||
#[cfg(any(target_os = "android", target_os = "ios"))]
|
||||
use rustls_pki_types::{ServerName, UnixTime};
|
||||
use std::sync::Arc;
|
||||
use tokio_rustls::rustls::{self, client::WebPkiServerVerifier, ClientConfig};
|
||||
#[cfg(any(target_os = "android", target_os = "ios"))]
|
||||
use tokio_rustls::rustls::{
|
||||
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
|
||||
DigitallySignedStruct, Error as TLSError, SignatureScheme,
|
||||
};
|
||||
|
||||
// https://github.com/seanmonstar/reqwest/blob/fd61bc93e6f936454ce0b978c6f282f06eee9287/src/tls.rs#L608
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct NoVerifier;
|
||||
|
||||
impl ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls_pki_types::CertificateDer,
|
||||
_intermediates: &[rustls_pki_types::CertificateDer],
|
||||
_server_name: &ServerName,
|
||||
_ocsp_response: &[u8],
|
||||
_now: UnixTime,
|
||||
) -> Result<ServerCertVerified, TLSError> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, TLSError> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls_pki_types::CertificateDer,
|
||||
_dss: &DigitallySignedStruct,
|
||||
) -> Result<HandshakeSignatureValid, TLSError> {
|
||||
Ok(HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
|
||||
vec![
|
||||
SignatureScheme::RSA_PKCS1_SHA1,
|
||||
SignatureScheme::ECDSA_SHA1_Legacy,
|
||||
SignatureScheme::RSA_PKCS1_SHA256,
|
||||
SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
SignatureScheme::RSA_PKCS1_SHA384,
|
||||
SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
SignatureScheme::RSA_PKCS1_SHA512,
|
||||
SignatureScheme::ECDSA_NISTP521_SHA512,
|
||||
SignatureScheme::RSA_PSS_SHA256,
|
||||
SignatureScheme::RSA_PSS_SHA384,
|
||||
SignatureScheme::RSA_PSS_SHA512,
|
||||
SignatureScheme::ED25519,
|
||||
SignatureScheme::ED448,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// A certificate verifier that tries a primary verifier first,
|
||||
/// and falls back to a platform verifier if the primary fails.
|
||||
#[cfg(any(target_os = "android", target_os = "ios"))]
|
||||
@@ -149,7 +200,15 @@ fn webpki_server_verifier(
|
||||
Ok(verifier)
|
||||
}
|
||||
|
||||
pub fn client_config() -> ResultType<ClientConfig> {
|
||||
pub fn client_config(danger_accept_invalid_cert: bool) -> ResultType<ClientConfig> {
|
||||
if danger_accept_invalid_cert {
|
||||
client_config_danger()
|
||||
} else {
|
||||
client_config_safe()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn client_config_safe() -> ResultType<ClientConfig> {
|
||||
// Use the default builder which uses the default protocol versions and crypto provider.
|
||||
// The with_protocol_versions API has been removed in rustls master branch:
|
||||
// https://github.com/rustls/rustls/pull/2599
|
||||
@@ -188,3 +247,11 @@ pub fn client_config() -> ResultType<ClientConfig> {
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn client_config_danger() -> ResultType<ClientConfig> {
|
||||
let config = ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||
.with_no_client_auth();
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
187
src/websocket.rs
187
src/websocket.rs
@@ -1,22 +1,29 @@
|
||||
use crate::{
|
||||
config::keys::OPTION_RELAY_SERVER,
|
||||
config::{use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT},
|
||||
config::{
|
||||
keys::OPTION_RELAY_SERVER, use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT,
|
||||
},
|
||||
protobuf::Message,
|
||||
socket_client::split_host_port,
|
||||
sodiumoxide::crypto::secretbox::Key,
|
||||
tcp::Encrypt,
|
||||
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
|
||||
ResultType,
|
||||
};
|
||||
use anyhow::bail;
|
||||
use async_recursion::async_recursion;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::{
|
||||
io::{Error, ErrorKind},
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{net::TcpStream, time::timeout};
|
||||
use tokio_native_tls::native_tls::TlsConnector;
|
||||
use tokio_tungstenite::{
|
||||
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream,
|
||||
connect_async_tls_with_config, tungstenite::protocol::Message as WsMessage, Connector,
|
||||
MaybeTlsStream, WebSocketStream,
|
||||
};
|
||||
use tungstenite::client::IntoClientRequest;
|
||||
use tungstenite::protocol::Role;
|
||||
@@ -29,29 +36,21 @@ pub struct WsFramedStream {
|
||||
}
|
||||
|
||||
impl WsFramedStream {
|
||||
pub async fn new<T: AsRef<str>>(
|
||||
url: T,
|
||||
_local_addr: Option<SocketAddr>,
|
||||
_proxy_conf: Option<&Socks5Server>,
|
||||
ms_timeout: u64,
|
||||
) -> ResultType<Self> {
|
||||
let url_str = url.as_ref();
|
||||
|
||||
// to-do: websocket proxy.
|
||||
|
||||
let request = url_str
|
||||
.into_client_request()
|
||||
.map_err(|e| Error::new(ErrorKind::Other, e))?;
|
||||
|
||||
let stream;
|
||||
#[cfg(any(target_os = "android", target_os = "ios"))]
|
||||
{
|
||||
let is_wss = url_str.starts_with("wss://");
|
||||
if is_wss {
|
||||
use std::sync::Arc;
|
||||
use tokio_tungstenite::{connect_async_tls_with_config, Connector};
|
||||
|
||||
let connector = match crate::verifier::client_config() {
|
||||
#[inline]
|
||||
fn get_connector(
|
||||
tls_type: &TlsType,
|
||||
danger_accept_invalid_certs: bool,
|
||||
) -> ResultType<Option<Connector>> {
|
||||
match tls_type {
|
||||
TlsType::Plain => Ok(Some(Connector::Plain)),
|
||||
TlsType::NativeTls => {
|
||||
let connector = TlsConnector::builder()
|
||||
.danger_accept_invalid_certs(danger_accept_invalid_certs)
|
||||
.build()?;
|
||||
Ok(Some(Connector::NativeTls(connector)))
|
||||
}
|
||||
TlsType::Rustls => {
|
||||
let connector = match crate::verifier::client_config(danger_accept_invalid_certs) {
|
||||
Ok(client_config) => Some(Connector::Rustls(Arc::new(client_config))),
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
@@ -61,30 +60,130 @@ impl WsFramedStream {
|
||||
None
|
||||
}
|
||||
};
|
||||
let (s, _) = timeout(
|
||||
Duration::from_millis(ms_timeout),
|
||||
connect_async_tls_with_config(request, None, false, connector),
|
||||
)
|
||||
.await??;
|
||||
stream = s;
|
||||
} else {
|
||||
let (s, _) =
|
||||
timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??;
|
||||
stream = s;
|
||||
Ok(connector)
|
||||
}
|
||||
}
|
||||
#[cfg(not(any(target_os = "android", target_os = "ios")))]
|
||||
{
|
||||
let (s, _) =
|
||||
timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??;
|
||||
stream = s;
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
url: &str,
|
||||
ms_timeout: u64,
|
||||
) -> ResultType<WebSocketStream<MaybeTlsStream<TcpStream>>> {
|
||||
// to-do: websocket proxy.
|
||||
|
||||
let tls_type = get_cached_tls_type(url);
|
||||
let is_tls_type_cached = tls_type.is_some();
|
||||
let tls_type = tls_type.unwrap_or(TlsType::NativeTls);
|
||||
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
|
||||
Self::try_connect(
|
||||
url,
|
||||
ms_timeout,
|
||||
tls_type,
|
||||
is_tls_type_cached,
|
||||
danger_accept_invalid_cert,
|
||||
danger_accept_invalid_cert,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[async_recursion]
|
||||
async fn try_connect(
|
||||
url: &str,
|
||||
ms_timeout: u64,
|
||||
tls_type: TlsType,
|
||||
is_tls_type_cached: bool,
|
||||
danger_accept_invalid_cert: Option<bool>,
|
||||
original_danger_accept_invalid_certs: Option<bool>,
|
||||
) -> ResultType<WebSocketStream<MaybeTlsStream<TcpStream>>> {
|
||||
let ws_config = None;
|
||||
let disable_nagle = false;
|
||||
let request = url
|
||||
.into_client_request()
|
||||
.map_err(|e| Error::new(ErrorKind::Other, e))?;
|
||||
let connector =
|
||||
Self::get_connector(&tls_type, danger_accept_invalid_cert.unwrap_or(false))?;
|
||||
match timeout(
|
||||
Duration::from_millis(ms_timeout),
|
||||
connect_async_tls_with_config(request, ws_config, disable_nagle, connector),
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Ok((ws_stream, _)) => {
|
||||
upsert_tls_cache(url, tls_type, danger_accept_invalid_cert.unwrap_or(false));
|
||||
Ok(ws_stream)
|
||||
}
|
||||
Err(e) => match (tls_type, is_tls_type_cached, danger_accept_invalid_cert) {
|
||||
(TlsType::NativeTls, _, None) => {
|
||||
log::warn!(
|
||||
"WebSocket connection with native-tls failed, try accept invalid certs: {}, {:?}",
|
||||
url,
|
||||
e
|
||||
);
|
||||
Self::try_connect(
|
||||
url,
|
||||
ms_timeout,
|
||||
tls_type,
|
||||
is_tls_type_cached,
|
||||
Some(true),
|
||||
original_danger_accept_invalid_certs,
|
||||
)
|
||||
.await
|
||||
}
|
||||
(TlsType::NativeTls, false, Some(_)) => {
|
||||
log::warn!(
|
||||
"WebSocket connection with native-tls failed, try rustls: {}, {:?}",
|
||||
url,
|
||||
e
|
||||
);
|
||||
Self::try_connect(
|
||||
url,
|
||||
ms_timeout,
|
||||
TlsType::Rustls,
|
||||
is_tls_type_cached,
|
||||
original_danger_accept_invalid_certs,
|
||||
original_danger_accept_invalid_certs,
|
||||
)
|
||||
.await
|
||||
}
|
||||
(TlsType::Rustls, _, None) => {
|
||||
log::warn!(
|
||||
"WebSocket connection with rustls failed, try accept invalid certs: {}, {:?}",
|
||||
url,
|
||||
e
|
||||
);
|
||||
Self::try_connect(
|
||||
url,
|
||||
ms_timeout,
|
||||
tls_type,
|
||||
is_tls_type_cached,
|
||||
Some(true),
|
||||
original_danger_accept_invalid_certs,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
log::error!(
|
||||
"WebSocket connection failed with tls_type {:?}: {}, {:?}",
|
||||
tls_type,
|
||||
url,
|
||||
e
|
||||
);
|
||||
bail!(e)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new<T: AsRef<str>>(
|
||||
url: T,
|
||||
_local_addr: Option<SocketAddr>,
|
||||
_proxy_conf: Option<&Socks5Server>,
|
||||
ms_timeout: u64,
|
||||
) -> ResultType<Self> {
|
||||
let stream = Self::connect(url.as_ref(), ms_timeout).await?;
|
||||
let addr = match stream.get_ref() {
|
||||
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
|
||||
#[cfg(any(target_os = "macos", target_os = "windows"))]
|
||||
MaybeTlsStream::NativeTls(tls) => tls.get_ref().get_ref().get_ref().peer_addr()?,
|
||||
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
|
||||
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
|
||||
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user