refact: tls native-tls fallback rustls-tls

Signed-off-by: fufesou <linlong1266@gmail.com>
This commit is contained in:
fufesou
2025-10-30 15:34:43 +08:00
parent d6dd7ae052
commit 54c4d869ed
7 changed files with 545 additions and 116 deletions

View File

@@ -48,30 +48,24 @@ url = "2.5"
sha2 = "0.10"
whoami = "1.5"
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
mac_address = "1.1"
default_net = { git = "https://github.com/rustdesk-org/default_net" }
machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" }
[target.'cfg(not(any(target_os = "macos", target_os = "windows")))'.dependencies]
tokio-rustls = { version = "0.26", features = [
"logging",
"tls12",
"ring",
], default-features = false }
tokio-native-tls = "0.3"
tokio-tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
tungstenite = { version = "0.26", features = ["native-tls", "rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
rustls-platform-verifier = "0.6"
rustls-pki-types = "1.11"
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
tungstenite = { version = "0.26", features = ["rustls-tls-native-roots", "rustls-tls-webpki-roots"] }
rustls-native-certs = "0.8"
webpki-roots = "1.0"
async-recursion = "1.1"
[target.'cfg(any(target_os = "android", target_os = "ios"))'.dependencies]
rustls-platform-verifier = "0.6"
[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies]
tokio-native-tls = "0.3"
tokio-tungstenite = { version = "0.26", features = ["native-tls"] }
tungstenite = { version = "0.26", features = ["native-tls"] }
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
mac_address = "1.1"
default_net = { git = "https://github.com/rustdesk-org/default_net" }
machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" }
[build-dependencies]
protobuf-codegen = { version = "3.7" }

View File

@@ -2417,6 +2417,11 @@ pub fn use_ws() -> bool {
option2bool(option, &Config::get_option(option))
}
pub fn allow_insecure_tls_fallback() -> bool {
let option = keys::OPTION_ALLOW_INSECURE_TLS_FALLBACK;
option2bool(option, &Config::get_option(option))
}
pub mod keys {
pub const OPTION_VIEW_ONLY: &str = "view_only";
pub const OPTION_SHOW_MONITORS_TOOLBAR: &str = "show_monitors_toolbar";
@@ -2513,14 +2518,16 @@ pub mod keys {
pub const OPTION_TRACKPAD_SPEED: &str = "trackpad-speed";
pub const OPTION_REGISTER_DEVICE: &str = "register-device";
pub const OPTION_RELAY_SERVER: &str = "relay-server";
pub const OPTION_DISABLE_UDP: &str = "disable-udp";
pub const OPTION_ALLOW_INSECURE_TLS_FALLBACK: &str = "allow-insecure-tls-fallback";
pub const OPTION_SHOW_VIRTUAL_MOUSE: &str = "show-virtual-mouse";
// joystick is the virtual mouse.
// So `OPTION_SHOW_VIRTUAL_MOUSE` should also be set if `OPTION_SHOW_VIRTUAL_JOYSTICK` is set.
pub const OPTION_SHOW_VIRTUAL_JOYSTICK: &str = "show-virtual-joystick";
pub const OPTION_ENABLE_FLUTTER_HTTP_ON_RUST: &str = "enable-flutter-http-on-rust";
// built-in options
pub const OPTION_DISPLAY_NAME: &str = "display-name";
pub const OPTION_DISABLE_UDP: &str = "disable-udp";
pub const OPTION_PRESET_DEVICE_GROUP_NAME: &str = "preset-device-group-name";
pub const OPTION_PRESET_USERNAME: &str = "preset-user-name";
pub const OPTION_PRESET_STRATEGY_NAME: &str = "preset-strategy-name";
@@ -2651,6 +2658,7 @@ pub mod keys {
OPTION_TOUCH_MODE,
OPTION_SHOW_VIRTUAL_MOUSE,
OPTION_SHOW_VIRTUAL_JOYSTICK,
OPTION_ENABLE_FLUTTER_HTTP_ON_RUST,
];
// DEFAULT_SETTINGS, OVERWRITE_SETTINGS
pub const KEYS_SETTINGS: &[&str] = &[
@@ -2703,12 +2711,12 @@ pub mod keys {
OPTION_ENABLE_ANDROID_SOFTWARE_ENCODING_HALF_SCALE,
OPTION_ENABLE_TRUSTED_DEVICES,
OPTION_RELAY_SERVER,
OPTION_DISABLE_UDP,
];
// BUILDIN_SETTINGS
pub const KEYS_BUILDIN_SETTINGS: &[&str] = &[
OPTION_DISPLAY_NAME,
OPTION_DISABLE_UDP,
OPTION_PRESET_DEVICE_GROUP_NAME,
OPTION_PRESET_USERNAME,
OPTION_PRESET_STRATEGY_NAME,

View File

@@ -63,8 +63,9 @@ pub mod websocket;
pub use rustls_platform_verifier;
pub use stream::Stream;
pub use whoami;
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
pub mod tls;
pub mod verifier;
pub use async_recursion;
pub type SessionID = uuid::Uuid;

View File

@@ -3,16 +3,15 @@ use std::{
net::{SocketAddr, ToSocketAddrs},
};
use anyhow::bail;
use async_recursion::async_recursion;
use base64::{engine::general_purpose, Engine};
use httparse::{Error as HttpParseError, Response, EMPTY_HEADER};
use log::info;
use thiserror::Error as ThisError;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream};
#[cfg(any(target_os = "windows", target_os = "macos"))]
use tokio_native_tls::{native_tls, TlsConnector, TlsStream};
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
use tokio_rustls::{client::TlsStream, TlsConnector};
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr};
use tokio_rustls::{client::TlsStream as RustlsTlsStream, TlsConnector as RustlsTlsConnector};
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, TargetAddr};
use tokio_util::codec::Framed;
use url::Url;
@@ -20,6 +19,7 @@ use crate::{
bytes_codec::BytesCodec,
config::Socks5Server,
tcp::{DynTcpStream, FramedStream},
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
ResultType,
};
@@ -45,7 +45,6 @@ pub enum ProxyError {
HttpCode200(u16),
#[error("The proxy address resolution failed: {0}")]
AddressResolutionFailed(String),
#[cfg(any(target_os = "windows", target_os = "macos"))]
#[error("The native tls error: {0}")]
NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
}
@@ -226,7 +225,7 @@ impl ProxyScheme {
Ok(scheme)
}
pub async fn socket_addrs(&self) -> Result<SocketAddr, ProxyError> {
info!("Resolving socket address");
log::trace!("Resolving socket address");
match self {
ProxyScheme::Http { host, .. } => self.resolve_host(host, 80).await,
ProxyScheme::Https { host, .. } => self.resolve_host(host, 443).await,
@@ -356,37 +355,50 @@ impl Proxy {
self
}
async fn new_stream(
&self,
local: SocketAddr,
proxy: SocketAddr,
) -> ResultType<tokio::net::TcpStream> {
let stream = super::timeout(
self.ms_timeout,
crate::tcp::new_socket(local, true)?.connect(proxy),
)
.await??;
stream.set_nodelay(true).ok();
Ok(stream)
}
pub async fn connect<'t, T>(
self,
&self,
target: T,
local_addr: Option<SocketAddr>,
) -> ResultType<FramedStream>
where
T: IntoTargetAddr<'t>,
{
info!("Connect to proxy server");
log::trace!("Connect to proxy server");
let proxy = self.proxy_addrs().await?;
let target_addr = target
.into_target_addr()
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
let local = if let Some(addr) = local_addr {
addr
} else {
crate::config::Config::get_any_listen_addr(proxy.is_ipv4())
};
let stream = super::timeout(
self.ms_timeout,
crate::tcp::new_socket(local, true)?.connect(proxy),
)
.await??;
stream.set_nodelay(true).ok();
let stream = self.new_stream(local, proxy).await?;
let addr = stream.local_addr()?;
return match self.intercept {
ProxyScheme::Http { .. } => {
info!("Connect to remote http proxy server: {}", proxy);
log::trace!("Connect to remote http proxy server: {}", proxy);
let stream =
super::timeout(self.ms_timeout, self.http_connect(stream, target)).await??;
super::timeout(self.ms_timeout, self.http_connect(stream, &target_addr))
.await??;
Ok(FramedStream(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
addr,
@@ -395,24 +407,54 @@ impl Proxy {
))
}
ProxyScheme::Https { .. } => {
info!("Connect to remote https proxy server: {}", proxy);
let stream =
super::timeout(self.ms_timeout, self.https_connect(stream, target)).await??;
log::trace!("Connect to remote https proxy server: {}", proxy);
let url = format!("https://{}", self.intercept.get_host_and_port()?);
let tls_type = get_cached_tls_type(&url);
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
let stream = match tls_type.unwrap_or(TlsType::NativeTls) {
TlsType::NativeTls => {
self.https_connect_nativetls_wrap_danger(
&url,
local,
proxy,
Some(stream),
&target_addr,
tls_type.is_some(),
danger_accept_invalid_cert,
danger_accept_invalid_cert,
)
.await?
}
TlsType::Rustls => {
self.https_connect_rustls_wrap_danger(
&url,
local,
proxy,
&target_addr,
danger_accept_invalid_cert,
)
.await?
}
_ => {
// Unreachable
crate::bail!("Unreachable, TlsType::Plain in HTTPS proxy");
}
};
Ok(FramedStream(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
Framed::new(stream, BytesCodec::new()),
addr,
None,
0,
))
}
ProxyScheme::Socks5 { .. } => {
info!("Connect to remote socket5 proxy server: {}", proxy);
log::trace!("Connect to remote socket5 proxy server: {}", proxy);
let stream = if let Some(auth) = self.intercept.maybe_auth() {
super::timeout(
self.ms_timeout,
Socks5Stream::connect_with_password_and_socket(
stream,
target,
target_addr,
&auth.user_name,
&auth.password,
),
@@ -421,7 +463,7 @@ impl Proxy {
} else {
super::timeout(
self.ms_timeout,
Socks5Stream::connect_with_socket(stream, target),
Socks5Stream::connect_with_socket(stream, target_addr),
)
.await??
};
@@ -435,32 +477,133 @@ impl Proxy {
};
}
#[cfg(any(target_os = "windows", target_os = "macos"))]
pub async fn https_connect<'a, Input, T>(
self,
#[async_recursion]
async fn https_connect_nativetls_wrap_danger<'a>(
&self,
url: &str,
local: SocketAddr,
proxy: SocketAddr,
stream: Option<tokio::net::TcpStream>,
target_addr: &TargetAddr<'a>,
is_tls_type_cached: bool,
danger_accept_invalid_cert: Option<bool>,
origin_danger_accept_invalid_cert: Option<bool>,
) -> ResultType<DynTcpStream> {
let stream = stream.unwrap_or(self.new_stream(local, proxy).await?);
match super::timeout(
self.ms_timeout,
self.https_connect_nativetls(
stream,
target_addr,
danger_accept_invalid_cert.unwrap_or(false),
),
)
.await?
{
Ok(s) => {
upsert_tls_cache(
&url,
TlsType::NativeTls,
danger_accept_invalid_cert.unwrap_or(false),
);
Ok(DynTcpStream(Box::new(s)))
}
Err(ProxyError::NativeTlsError(e)) => {
let s = if danger_accept_invalid_cert.is_none() {
log::warn!(
"Falling back to native-tls (accept invalid cert) for HTTPS proxy server."
);
self.https_connect_nativetls_wrap_danger(
&url,
local,
proxy,
None,
target_addr,
is_tls_type_cached,
Some(true),
origin_danger_accept_invalid_cert,
)
.await?
} else if !is_tls_type_cached {
log::warn!("Falling back to rustls for HTTPS proxy server.");
self.https_connect_rustls_wrap_danger(
&url,
local,
proxy,
&target_addr,
origin_danger_accept_invalid_cert,
)
.await?
} else {
log::error!(
"Failed to connect to HTTPS proxy server with native-tls: {:?}.",
e
);
bail!(e)
};
Ok(s)
}
Err(e) => {
log::error!("Failed to connect to HTTPS proxy server: {:?}.", e);
bail!(e)
}
}
}
pub async fn https_connect_nativetls<'a, Input>(
&self,
io: Input,
target: T,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: bool,
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
let mut tls_connector_builder = native_tls::TlsConnector::builder();
if danger_accept_invalid_cert {
tls_connector_builder.danger_accept_invalid_certs(true);
}
let tls_connector = TlsConnector::from(tls_connector_builder.build()?);
let stream = tls_connector
.connect(&self.intercept.get_domain()?, io)
.await?;
self.http_connect(stream, target).await
self.http_connect(stream, target_addr).await
}
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
pub async fn https_connect<'a, Input, T>(
self,
async fn https_connect_rustls_wrap_danger<'a>(
&self,
url: &str,
local: SocketAddr,
proxy: SocketAddr,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: Option<bool>,
) -> ResultType<DynTcpStream> {
let stream = self.new_stream(local, proxy).await?;
let s = super::timeout(
self.ms_timeout,
self.https_connect_rustls(
stream,
&target_addr,
danger_accept_invalid_cert.unwrap_or(false),
),
)
.await??;
upsert_tls_cache(
url,
TlsType::Rustls,
danger_accept_invalid_cert.unwrap_or(false),
);
Ok(DynTcpStream(Box::new(s)))
}
pub async fn https_connect_rustls<'a, Input>(
&self,
io: Input,
target: T,
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: bool,
) -> Result<BufStream<RustlsTlsStream<Input>>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
use std::convert::TryFrom;
@@ -468,24 +611,23 @@ impl Proxy {
let domain = rustls_pki_types::ServerName::try_from(url_domain.as_str())
.map_err(|e| ProxyError::AddressResolutionFailed(e.to_string()))?
.to_owned();
let client_config = crate::verifier::client_config()
let client_config = crate::verifier::client_config(danger_accept_invalid_cert)
.map_err(|e| ProxyError::IoError(std::io::Error::other(e)))?;
let tls_connector = TlsConnector::from(std::sync::Arc::new(client_config));
let tls_connector = RustlsTlsConnector::from(std::sync::Arc::new(client_config));
let stream = tls_connector.connect(domain, io).await?;
self.http_connect(stream, target).await
self.http_connect(stream, target_addr).await
}
pub async fn http_connect<'a, Input, T>(
self,
pub async fn http_connect<'a, Input>(
&self,
io: Input,
target: T,
target_addr: &TargetAddr<'a>,
) -> Result<BufStream<Input>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
let mut stream = BufStream::new(io);
let (domain, port) = get_domain_and_port(target)?;
let (domain, port) = get_domain_and_port(target_addr)?;
let request = self.make_request(&domain, port);
stream.write_all(request.as_bytes()).await?;
@@ -510,13 +652,10 @@ impl Proxy {
}
}
fn get_domain_and_port<'a, T: IntoTargetAddr<'a>>(target: T) -> Result<(String, u16), ProxyError> {
let target_addr = target
.into_target_addr()
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
fn get_domain_and_port<'a>(target_addr: &TargetAddr<'a>) -> Result<(String, u16), ProxyError> {
match target_addr {
tokio_socks::TargetAddr::Ip(addr) => Ok((addr.ip().to_string(), addr.port())),
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), port)),
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), *port)),
}
}

121
src/tls.rs Normal file
View File

@@ -0,0 +1,121 @@
use std::{collections::HashMap, sync::RwLock};
use crate::config::allow_insecure_tls_fallback;
#[derive(Debug, Clone, Copy)]
pub enum TlsType {
Plain,
NativeTls,
Rustls,
}
lazy_static::lazy_static! {
static ref URL_TLS_TYPE: RwLock<HashMap<String, TlsType>> = RwLock::new(HashMap::new());
static ref URL_TLS_DANGER_ACCEPT_INVALID_CERTS: RwLock<HashMap<String, bool>> = RwLock::new(HashMap::new());
}
#[inline]
pub fn is_plain(url: &str) -> bool {
url.starts_with("ws://") || url.starts_with("http://")
}
// Extract domain from URL.
// e.g., "https://example.com/path" -> "example.com"
// "https://example.com:8080/path" -> "example.com:8080"
// See the tests for more examples.
#[inline]
fn get_domain_and_port_from_url(url: &str) -> &str {
// Remove scheme (e.g., http://, https://, ws://, wss://)
let scheme_end = url.find("://").map(|pos| pos + 3).unwrap_or(0);
let url2 = &url[scheme_end..];
// If userinfo is present, domain is after last '@'
let after_at = match url2.rfind('@') {
Some(pos) => &url2[pos + 1..],
None => url2,
};
// Find the end of domain (before '/' or '?')
let domain_end = after_at.find(&['/', '?'][..]).unwrap_or(after_at.len());
&after_at[..domain_end]
}
#[inline]
pub fn upsert_tls_cache(url: &str, tls_type: TlsType, danger_accept_invalid_cert: bool) {
if is_plain(url) {
return;
}
let domain_port = get_domain_and_port_from_url(url);
// Use curly braces to ensure the lock is released immediately.
{
URL_TLS_TYPE
.write()
.unwrap()
.insert(domain_port.to_string(), tls_type);
}
{
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
.write()
.unwrap()
.insert(domain_port.to_string(), danger_accept_invalid_cert);
}
}
#[inline]
pub fn reset_tls_cache() {
// Use curly braces to ensure the lock is released immediately.
{
URL_TLS_TYPE.write().unwrap().clear();
}
{
URL_TLS_DANGER_ACCEPT_INVALID_CERTS.write().unwrap().clear();
}
}
#[inline]
pub fn get_cached_tls_type(url: &str) -> Option<TlsType> {
if is_plain(url) {
return Some(TlsType::Plain);
}
let domain_port = get_domain_and_port_from_url(url);
URL_TLS_TYPE.read().unwrap().get(domain_port).cloned()
}
#[inline]
pub fn get_cached_tls_accept_invalid_cert(url: &str) -> Option<bool> {
if !allow_insecure_tls_fallback() {
return Some(false);
}
if is_plain(url) {
return Some(false);
}
let domain_port = get_domain_and_port_from_url(url);
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
.read()
.unwrap()
.get(domain_port)
.cloned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_domain_and_port_from_url() {
for (url, expected_domain_port) in vec![
("http://example.com", "example.com"),
("https://example.com", "example.com"),
("ws://example.com/path", "example.com"),
("wss://example.com:8080/path", "example.com:8080"),
("https://user:pass@example.com", "example.com"),
("https://example.com?query=param", "example.com"),
("https://example.com:8443?query=param", "example.com:8443"),
("ftp://example.com/resource", "example.com"), // ftp scheme
("example.com/path", "example.com"), // no scheme
("example.com:8080/path", "example.com:8080"),
] {
let domain_port = get_domain_and_port_from_url(url);
assert_eq!(domain_port, expected_domain_port);
}
}
}

View File

@@ -1,14 +1,65 @@
use crate::ResultType;
#[cfg(any(target_os = "android", target_os = "ios"))]
use rustls_pki_types::{ServerName, UnixTime};
use std::sync::Arc;
use tokio_rustls::rustls::{self, client::WebPkiServerVerifier, ClientConfig};
#[cfg(any(target_os = "android", target_os = "ios"))]
use tokio_rustls::rustls::{
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
DigitallySignedStruct, Error as TLSError, SignatureScheme,
};
// https://github.com/seanmonstar/reqwest/blob/fd61bc93e6f936454ce0b978c6f282f06eee9287/src/tls.rs#L608
#[derive(Debug)]
pub(crate) struct NoVerifier;
impl ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer,
_intermediates: &[rustls_pki_types::CertificateDer],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, TLSError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TLSError> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
]
}
}
/// A certificate verifier that tries a primary verifier first,
/// and falls back to a platform verifier if the primary fails.
#[cfg(any(target_os = "android", target_os = "ios"))]
@@ -149,7 +200,15 @@ fn webpki_server_verifier(
Ok(verifier)
}
pub fn client_config() -> ResultType<ClientConfig> {
pub fn client_config(danger_accept_invalid_cert: bool) -> ResultType<ClientConfig> {
if danger_accept_invalid_cert {
client_config_danger()
} else {
client_config_safe()
}
}
pub fn client_config_safe() -> ResultType<ClientConfig> {
// Use the default builder which uses the default protocol versions and crypto provider.
// The with_protocol_versions API has been removed in rustls master branch:
// https://github.com/rustls/rustls/pull/2599
@@ -188,3 +247,11 @@ pub fn client_config() -> ResultType<ClientConfig> {
Ok(config)
}
}
pub fn client_config_danger() -> ResultType<ClientConfig> {
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerifier))
.with_no_client_auth();
Ok(config)
}

View File

@@ -1,22 +1,29 @@
use crate::{
config::keys::OPTION_RELAY_SERVER,
config::{use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT},
config::{
keys::OPTION_RELAY_SERVER, use_ws, Config, Socks5Server, RELAY_PORT, RENDEZVOUS_PORT,
},
protobuf::Message,
socket_client::split_host_port,
sodiumoxide::crypto::secretbox::Key,
tcp::Encrypt,
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
ResultType,
};
use anyhow::bail;
use async_recursion::async_recursion;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use std::{
io::{Error, ErrorKind},
net::SocketAddr,
sync::Arc,
time::Duration,
};
use tokio::{net::TcpStream, time::timeout};
use tokio_native_tls::native_tls::TlsConnector;
use tokio_tungstenite::{
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream,
connect_async_tls_with_config, tungstenite::protocol::Message as WsMessage, Connector,
MaybeTlsStream, WebSocketStream,
};
use tungstenite::client::IntoClientRequest;
use tungstenite::protocol::Role;
@@ -29,29 +36,21 @@ pub struct WsFramedStream {
}
impl WsFramedStream {
pub async fn new<T: AsRef<str>>(
url: T,
_local_addr: Option<SocketAddr>,
_proxy_conf: Option<&Socks5Server>,
ms_timeout: u64,
) -> ResultType<Self> {
let url_str = url.as_ref();
// to-do: websocket proxy.
let request = url_str
.into_client_request()
.map_err(|e| Error::new(ErrorKind::Other, e))?;
let stream;
#[cfg(any(target_os = "android", target_os = "ios"))]
{
let is_wss = url_str.starts_with("wss://");
if is_wss {
use std::sync::Arc;
use tokio_tungstenite::{connect_async_tls_with_config, Connector};
let connector = match crate::verifier::client_config() {
#[inline]
fn get_connector(
tls_type: &TlsType,
danger_accept_invalid_certs: bool,
) -> ResultType<Option<Connector>> {
match tls_type {
TlsType::Plain => Ok(Some(Connector::Plain)),
TlsType::NativeTls => {
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(danger_accept_invalid_certs)
.build()?;
Ok(Some(Connector::NativeTls(connector)))
}
TlsType::Rustls => {
let connector = match crate::verifier::client_config(danger_accept_invalid_certs) {
Ok(client_config) => Some(Connector::Rustls(Arc::new(client_config))),
Err(e) => {
log::warn!(
@@ -61,30 +60,130 @@ impl WsFramedStream {
None
}
};
let (s, _) = timeout(
Duration::from_millis(ms_timeout),
connect_async_tls_with_config(request, None, false, connector),
)
.await??;
stream = s;
} else {
let (s, _) =
timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??;
stream = s;
Ok(connector)
}
}
#[cfg(not(any(target_os = "android", target_os = "ios")))]
{
let (s, _) =
timeout(Duration::from_millis(ms_timeout), connect_async(request)).await??;
stream = s;
}
}
async fn connect(
url: &str,
ms_timeout: u64,
) -> ResultType<WebSocketStream<MaybeTlsStream<TcpStream>>> {
// to-do: websocket proxy.
let tls_type = get_cached_tls_type(url);
let is_tls_type_cached = tls_type.is_some();
let tls_type = tls_type.unwrap_or(TlsType::NativeTls);
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
danger_accept_invalid_cert,
danger_accept_invalid_cert,
)
.await
}
#[async_recursion]
async fn try_connect(
url: &str,
ms_timeout: u64,
tls_type: TlsType,
is_tls_type_cached: bool,
danger_accept_invalid_cert: Option<bool>,
original_danger_accept_invalid_certs: Option<bool>,
) -> ResultType<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let ws_config = None;
let disable_nagle = false;
let request = url
.into_client_request()
.map_err(|e| Error::new(ErrorKind::Other, e))?;
let connector =
Self::get_connector(&tls_type, danger_accept_invalid_cert.unwrap_or(false))?;
match timeout(
Duration::from_millis(ms_timeout),
connect_async_tls_with_config(request, ws_config, disable_nagle, connector),
)
.await?
{
Ok((ws_stream, _)) => {
upsert_tls_cache(url, tls_type, danger_accept_invalid_cert.unwrap_or(false));
Ok(ws_stream)
}
Err(e) => match (tls_type, is_tls_type_cached, danger_accept_invalid_cert) {
(TlsType::NativeTls, _, None) => {
log::warn!(
"WebSocket connection with native-tls failed, try accept invalid certs: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
Some(true),
original_danger_accept_invalid_certs,
)
.await
}
(TlsType::NativeTls, false, Some(_)) => {
log::warn!(
"WebSocket connection with native-tls failed, try rustls: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
TlsType::Rustls,
is_tls_type_cached,
original_danger_accept_invalid_certs,
original_danger_accept_invalid_certs,
)
.await
}
(TlsType::Rustls, _, None) => {
log::warn!(
"WebSocket connection with rustls failed, try accept invalid certs: {}, {:?}",
url,
e
);
Self::try_connect(
url,
ms_timeout,
tls_type,
is_tls_type_cached,
Some(true),
original_danger_accept_invalid_certs,
)
.await
}
_ => {
log::error!(
"WebSocket connection failed with tls_type {:?}: {}, {:?}",
tls_type,
url,
e
);
bail!(e)
}
},
}
}
pub async fn new<T: AsRef<str>>(
url: T,
_local_addr: Option<SocketAddr>,
_proxy_conf: Option<&Socks5Server>,
ms_timeout: u64,
) -> ResultType<Self> {
let stream = Self::connect(url.as_ref(), ms_timeout).await?;
let addr = match stream.get_ref() {
MaybeTlsStream::Plain(tcp) => tcp.peer_addr()?,
#[cfg(any(target_os = "macos", target_os = "windows"))]
MaybeTlsStream::NativeTls(tls) => tls.get_ref().get_ref().get_ref().peer_addr()?,
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
MaybeTlsStream::Rustls(tls) => tls.get_ref().0.peer_addr()?,
_ => return Err(Error::new(ErrorKind::Other, "Unsupported stream type").into()),
};