mirror of
https://github.com/rustdesk/hbb_common.git
synced 2026-02-16 02:20:43 +00:00
122 lines
3.6 KiB
Rust
122 lines
3.6 KiB
Rust
use std::{collections::HashMap, sync::RwLock};
|
|
|
|
use crate::config::allow_insecure_tls_fallback;
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub enum TlsType {
|
|
Plain,
|
|
NativeTls,
|
|
Rustls,
|
|
}
|
|
|
|
lazy_static::lazy_static! {
|
|
static ref URL_TLS_TYPE: RwLock<HashMap<String, TlsType>> = RwLock::new(HashMap::new());
|
|
static ref URL_TLS_DANGER_ACCEPT_INVALID_CERTS: RwLock<HashMap<String, bool>> = RwLock::new(HashMap::new());
|
|
}
|
|
|
|
#[inline]
|
|
pub fn is_plain(url: &str) -> bool {
|
|
url.starts_with("ws://") || url.starts_with("http://")
|
|
}
|
|
|
|
// Extract domain from URL.
|
|
// e.g., "https://example.com/path" -> "example.com"
|
|
// "https://example.com:8080/path" -> "example.com:8080"
|
|
// See the tests for more examples.
|
|
#[inline]
|
|
fn get_domain_and_port_from_url(url: &str) -> &str {
|
|
// Remove scheme (e.g., http://, https://, ws://, wss://)
|
|
let scheme_end = url.find("://").map(|pos| pos + 3).unwrap_or(0);
|
|
let url2 = &url[scheme_end..];
|
|
// If userinfo is present, domain is after last '@'
|
|
let after_at = match url2.rfind('@') {
|
|
Some(pos) => &url2[pos + 1..],
|
|
None => url2,
|
|
};
|
|
// Find the end of domain (before '/' or '?')
|
|
let domain_end = after_at.find(&['/', '?'][..]).unwrap_or(after_at.len());
|
|
&after_at[..domain_end]
|
|
}
|
|
|
|
#[inline]
|
|
pub fn upsert_tls_cache(url: &str, tls_type: TlsType, danger_accept_invalid_cert: bool) {
|
|
if is_plain(url) {
|
|
return;
|
|
}
|
|
|
|
let domain_port = get_domain_and_port_from_url(url);
|
|
// Use curly braces to ensure the lock is released immediately.
|
|
{
|
|
URL_TLS_TYPE
|
|
.write()
|
|
.unwrap()
|
|
.insert(domain_port.to_string(), tls_type);
|
|
}
|
|
{
|
|
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
|
|
.write()
|
|
.unwrap()
|
|
.insert(domain_port.to_string(), danger_accept_invalid_cert);
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn reset_tls_cache() {
|
|
// Use curly braces to ensure the lock is released immediately.
|
|
{
|
|
URL_TLS_TYPE.write().unwrap().clear();
|
|
}
|
|
{
|
|
URL_TLS_DANGER_ACCEPT_INVALID_CERTS.write().unwrap().clear();
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn get_cached_tls_type(url: &str) -> Option<TlsType> {
|
|
if is_plain(url) {
|
|
return Some(TlsType::Plain);
|
|
}
|
|
let domain_port = get_domain_and_port_from_url(url);
|
|
URL_TLS_TYPE.read().unwrap().get(domain_port).cloned()
|
|
}
|
|
|
|
#[inline]
|
|
pub fn get_cached_tls_accept_invalid_cert(url: &str) -> Option<bool> {
|
|
if !allow_insecure_tls_fallback() {
|
|
return Some(false);
|
|
}
|
|
|
|
if is_plain(url) {
|
|
return Some(false);
|
|
}
|
|
let domain_port = get_domain_and_port_from_url(url);
|
|
URL_TLS_DANGER_ACCEPT_INVALID_CERTS
|
|
.read()
|
|
.unwrap()
|
|
.get(domain_port)
|
|
.cloned()
|
|
}
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_get_domain_and_port_from_url() {
|
|
for (url, expected_domain_port) in vec![
|
|
("http://example.com", "example.com"),
|
|
("https://example.com", "example.com"),
|
|
("ws://example.com/path", "example.com"),
|
|
("wss://example.com:8080/path", "example.com:8080"),
|
|
("https://user:pass@example.com", "example.com"),
|
|
("https://example.com?query=param", "example.com"),
|
|
("https://example.com:8443?query=param", "example.com:8443"),
|
|
("ftp://example.com/resource", "example.com"), // ftp scheme
|
|
("example.com/path", "example.com"), // no scheme
|
|
("example.com:8080/path", "example.com:8080"),
|
|
] {
|
|
let domain_port = get_domain_and_port_from_url(url);
|
|
assert_eq!(domain_port, expected_domain_port);
|
|
}
|
|
}
|
|
}
|