refact: tls native-tls fallback rustls-tls

Signed-off-by: fufesou <linlong1266@gmail.com>
This commit is contained in:
fufesou
2025-10-30 15:34:43 +08:00
parent d6dd7ae052
commit 54c4d869ed
7 changed files with 545 additions and 116 deletions

View File

@@ -3,16 +3,15 @@ use std::{
net::{SocketAddr, ToSocketAddrs},
};
use anyhow::bail;
use async_recursion::async_recursion;
use base64::{engine::general_purpose, Engine};
use httparse::{Error as HttpParseError, Response, EMPTY_HEADER};
use log::info;
use thiserror::Error as ThisError;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream};
#[cfg(any(target_os = "windows", target_os = "macos"))]
use tokio_native_tls::{native_tls, TlsConnector, TlsStream};
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
use tokio_rustls::{client::TlsStream, TlsConnector};
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr};
use tokio_rustls::{client::TlsStream as RustlsTlsStream, TlsConnector as RustlsTlsConnector};
use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, TargetAddr};
use tokio_util::codec::Framed;
use url::Url;
@@ -20,6 +19,7 @@ use crate::{
bytes_codec::BytesCodec,
config::Socks5Server,
tcp::{DynTcpStream, FramedStream},
tls::{get_cached_tls_accept_invalid_cert, get_cached_tls_type, upsert_tls_cache, TlsType},
ResultType,
};
@@ -45,7 +45,6 @@ pub enum ProxyError {
HttpCode200(u16),
#[error("The proxy address resolution failed: {0}")]
AddressResolutionFailed(String),
#[cfg(any(target_os = "windows", target_os = "macos"))]
#[error("The native tls error: {0}")]
NativeTlsError(#[from] tokio_native_tls::native_tls::Error),
}
@@ -226,7 +225,7 @@ impl ProxyScheme {
Ok(scheme)
}
pub async fn socket_addrs(&self) -> Result<SocketAddr, ProxyError> {
info!("Resolving socket address");
log::trace!("Resolving socket address");
match self {
ProxyScheme::Http { host, .. } => self.resolve_host(host, 80).await,
ProxyScheme::Https { host, .. } => self.resolve_host(host, 443).await,
@@ -356,37 +355,50 @@ impl Proxy {
self
}
async fn new_stream(
&self,
local: SocketAddr,
proxy: SocketAddr,
) -> ResultType<tokio::net::TcpStream> {
let stream = super::timeout(
self.ms_timeout,
crate::tcp::new_socket(local, true)?.connect(proxy),
)
.await??;
stream.set_nodelay(true).ok();
Ok(stream)
}
pub async fn connect<'t, T>(
self,
&self,
target: T,
local_addr: Option<SocketAddr>,
) -> ResultType<FramedStream>
where
T: IntoTargetAddr<'t>,
{
info!("Connect to proxy server");
log::trace!("Connect to proxy server");
let proxy = self.proxy_addrs().await?;
let target_addr = target
.into_target_addr()
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
let local = if let Some(addr) = local_addr {
addr
} else {
crate::config::Config::get_any_listen_addr(proxy.is_ipv4())
};
let stream = super::timeout(
self.ms_timeout,
crate::tcp::new_socket(local, true)?.connect(proxy),
)
.await??;
stream.set_nodelay(true).ok();
let stream = self.new_stream(local, proxy).await?;
let addr = stream.local_addr()?;
return match self.intercept {
ProxyScheme::Http { .. } => {
info!("Connect to remote http proxy server: {}", proxy);
log::trace!("Connect to remote http proxy server: {}", proxy);
let stream =
super::timeout(self.ms_timeout, self.http_connect(stream, target)).await??;
super::timeout(self.ms_timeout, self.http_connect(stream, &target_addr))
.await??;
Ok(FramedStream(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
addr,
@@ -395,24 +407,54 @@ impl Proxy {
))
}
ProxyScheme::Https { .. } => {
info!("Connect to remote https proxy server: {}", proxy);
let stream =
super::timeout(self.ms_timeout, self.https_connect(stream, target)).await??;
log::trace!("Connect to remote https proxy server: {}", proxy);
let url = format!("https://{}", self.intercept.get_host_and_port()?);
let tls_type = get_cached_tls_type(&url);
let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(&url);
let stream = match tls_type.unwrap_or(TlsType::NativeTls) {
TlsType::NativeTls => {
self.https_connect_nativetls_wrap_danger(
&url,
local,
proxy,
Some(stream),
&target_addr,
tls_type.is_some(),
danger_accept_invalid_cert,
danger_accept_invalid_cert,
)
.await?
}
TlsType::Rustls => {
self.https_connect_rustls_wrap_danger(
&url,
local,
proxy,
&target_addr,
danger_accept_invalid_cert,
)
.await?
}
_ => {
// Unreachable
crate::bail!("Unreachable, TlsType::Plain in HTTPS proxy");
}
};
Ok(FramedStream(
Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
Framed::new(stream, BytesCodec::new()),
addr,
None,
0,
))
}
ProxyScheme::Socks5 { .. } => {
info!("Connect to remote socket5 proxy server: {}", proxy);
log::trace!("Connect to remote socket5 proxy server: {}", proxy);
let stream = if let Some(auth) = self.intercept.maybe_auth() {
super::timeout(
self.ms_timeout,
Socks5Stream::connect_with_password_and_socket(
stream,
target,
target_addr,
&auth.user_name,
&auth.password,
),
@@ -421,7 +463,7 @@ impl Proxy {
} else {
super::timeout(
self.ms_timeout,
Socks5Stream::connect_with_socket(stream, target),
Socks5Stream::connect_with_socket(stream, target_addr),
)
.await??
};
@@ -435,32 +477,133 @@ impl Proxy {
};
}
#[cfg(any(target_os = "windows", target_os = "macos"))]
pub async fn https_connect<'a, Input, T>(
self,
#[async_recursion]
async fn https_connect_nativetls_wrap_danger<'a>(
&self,
url: &str,
local: SocketAddr,
proxy: SocketAddr,
stream: Option<tokio::net::TcpStream>,
target_addr: &TargetAddr<'a>,
is_tls_type_cached: bool,
danger_accept_invalid_cert: Option<bool>,
origin_danger_accept_invalid_cert: Option<bool>,
) -> ResultType<DynTcpStream> {
let stream = stream.unwrap_or(self.new_stream(local, proxy).await?);
match super::timeout(
self.ms_timeout,
self.https_connect_nativetls(
stream,
target_addr,
danger_accept_invalid_cert.unwrap_or(false),
),
)
.await?
{
Ok(s) => {
upsert_tls_cache(
&url,
TlsType::NativeTls,
danger_accept_invalid_cert.unwrap_or(false),
);
Ok(DynTcpStream(Box::new(s)))
}
Err(ProxyError::NativeTlsError(e)) => {
let s = if danger_accept_invalid_cert.is_none() {
log::warn!(
"Falling back to native-tls (accept invalid cert) for HTTPS proxy server."
);
self.https_connect_nativetls_wrap_danger(
&url,
local,
proxy,
None,
target_addr,
is_tls_type_cached,
Some(true),
origin_danger_accept_invalid_cert,
)
.await?
} else if !is_tls_type_cached {
log::warn!("Falling back to rustls for HTTPS proxy server.");
self.https_connect_rustls_wrap_danger(
&url,
local,
proxy,
&target_addr,
origin_danger_accept_invalid_cert,
)
.await?
} else {
log::error!(
"Failed to connect to HTTPS proxy server with native-tls: {:?}.",
e
);
bail!(e)
};
Ok(s)
}
Err(e) => {
log::error!("Failed to connect to HTTPS proxy server: {:?}.", e);
bail!(e)
}
}
}
pub async fn https_connect_nativetls<'a, Input>(
&self,
io: Input,
target: T,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: bool,
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
let mut tls_connector_builder = native_tls::TlsConnector::builder();
if danger_accept_invalid_cert {
tls_connector_builder.danger_accept_invalid_certs(true);
}
let tls_connector = TlsConnector::from(tls_connector_builder.build()?);
let stream = tls_connector
.connect(&self.intercept.get_domain()?, io)
.await?;
self.http_connect(stream, target).await
self.http_connect(stream, target_addr).await
}
#[cfg(not(any(target_os = "windows", target_os = "macos")))]
pub async fn https_connect<'a, Input, T>(
self,
async fn https_connect_rustls_wrap_danger<'a>(
&self,
url: &str,
local: SocketAddr,
proxy: SocketAddr,
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: Option<bool>,
) -> ResultType<DynTcpStream> {
let stream = self.new_stream(local, proxy).await?;
let s = super::timeout(
self.ms_timeout,
self.https_connect_rustls(
stream,
&target_addr,
danger_accept_invalid_cert.unwrap_or(false),
),
)
.await??;
upsert_tls_cache(
url,
TlsType::Rustls,
danger_accept_invalid_cert.unwrap_or(false),
);
Ok(DynTcpStream(Box::new(s)))
}
pub async fn https_connect_rustls<'a, Input>(
&self,
io: Input,
target: T,
) -> Result<BufStream<TlsStream<Input>>, ProxyError>
target_addr: &TargetAddr<'a>,
danger_accept_invalid_cert: bool,
) -> Result<BufStream<RustlsTlsStream<Input>>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
use std::convert::TryFrom;
@@ -468,24 +611,23 @@ impl Proxy {
let domain = rustls_pki_types::ServerName::try_from(url_domain.as_str())
.map_err(|e| ProxyError::AddressResolutionFailed(e.to_string()))?
.to_owned();
let client_config = crate::verifier::client_config()
let client_config = crate::verifier::client_config(danger_accept_invalid_cert)
.map_err(|e| ProxyError::IoError(std::io::Error::other(e)))?;
let tls_connector = TlsConnector::from(std::sync::Arc::new(client_config));
let tls_connector = RustlsTlsConnector::from(std::sync::Arc::new(client_config));
let stream = tls_connector.connect(domain, io).await?;
self.http_connect(stream, target).await
self.http_connect(stream, target_addr).await
}
pub async fn http_connect<'a, Input, T>(
self,
pub async fn http_connect<'a, Input>(
&self,
io: Input,
target: T,
target_addr: &TargetAddr<'a>,
) -> Result<BufStream<Input>, ProxyError>
where
Input: AsyncRead + AsyncWrite + Unpin,
T: IntoTargetAddr<'a>,
{
let mut stream = BufStream::new(io);
let (domain, port) = get_domain_and_port(target)?;
let (domain, port) = get_domain_and_port(target_addr)?;
let request = self.make_request(&domain, port);
stream.write_all(request.as_bytes()).await?;
@@ -510,13 +652,10 @@ impl Proxy {
}
}
fn get_domain_and_port<'a, T: IntoTargetAddr<'a>>(target: T) -> Result<(String, u16), ProxyError> {
let target_addr = target
.into_target_addr()
.map_err(|e| ProxyError::TargetParseError(e.to_string()))?;
fn get_domain_and_port<'a>(target_addr: &TargetAddr<'a>) -> Result<(String, u16), ProxyError> {
match target_addr {
tokio_socks::TargetAddr::Ip(addr) => Ok((addr.ip().to_string(), addr.port())),
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), port)),
tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), *port)),
}
}