Fix #324 to remove unsafe

This commit is contained in:
rustdesk 2023-12-01 11:32:07 +08:00
parent 5133af1863
commit 1142cf105b
2 changed files with 75 additions and 85 deletions

View File

@ -25,6 +25,7 @@ use std::{
io::prelude::*, io::prelude::*,
io::Error, io::Error,
net::SocketAddr, net::SocketAddr,
sync::atomic::{AtomicUsize, Ordering},
}; };
type Usage = (usize, usize, usize, usize); type Usage = (usize, usize, usize, usize);
@ -36,11 +37,11 @@ lazy_static::lazy_static! {
static ref BLOCKLIST: RwLock<HashSet<String>> = Default::default(); static ref BLOCKLIST: RwLock<HashSet<String>> = Default::default();
} }
static mut DOWNGRADE_THRESHOLD: f64 = 0.66; static DOWNGRADE_THRESHOLD_100: AtomicUsize = AtomicUsize::new(66); // 0.66
static mut DOWNGRADE_START_CHECK: usize = 1_800_000; // in ms static DOWNGRADE_START_CHECK: AtomicUsize = AtomicUsize::new(1_800_000); // in ms
static mut LIMIT_SPEED: usize = 4 * 1024 * 1024; // in bit/s static LIMIT_SPEED: AtomicUsize = AtomicUsize::new(4 * 1024 * 1024); // in bit/s
static mut TOTAL_BANDWIDTH: usize = 1024 * 1024 * 1024; // in bit/s static TOTAL_BANDWIDTH: AtomicUsize = AtomicUsize::new(1024 * 1024 * 1024); // in bit/s
static mut SINGLE_BANDWIDTH: usize = 16 * 1024 * 1024; // in bit/s static SINGLE_BANDWIDTH: AtomicUsize = AtomicUsize::new(16 * 1024 * 1024); // in bit/s
const BLACKLIST_FILE: &str = "blacklist.txt"; const BLACKLIST_FILE: &str = "blacklist.txt";
const BLOCKLIST_FILE: &str = "blocklist.txt"; const BLOCKLIST_FILE: &str = "blocklist.txt";
@ -99,57 +100,53 @@ fn check_params() {
.map(|x| x.parse::<f64>().unwrap_or(0.)) .map(|x| x.parse::<f64>().unwrap_or(0.))
.unwrap_or(0.); .unwrap_or(0.);
if tmp > 0. { if tmp > 0. {
unsafe { DOWNGRADE_THRESHOLD_100.store((tmp * 100.) as _, Ordering::SeqCst);
DOWNGRADE_THRESHOLD = tmp;
}
} }
unsafe { log::info!("DOWNGRADE_THRESHOLD: {}", DOWNGRADE_THRESHOLD) }; log::info!(
"DOWNGRADE_THRESHOLD: {}",
DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
);
let tmp = std::env::var("DOWNGRADE_START_CHECK") let tmp = std::env::var("DOWNGRADE_START_CHECK")
.map(|x| x.parse::<usize>().unwrap_or(0)) .map(|x| x.parse::<usize>().unwrap_or(0))
.unwrap_or(0); .unwrap_or(0);
if tmp > 0 { if tmp > 0 {
unsafe { DOWNGRADE_START_CHECK.store(tmp * 1000, Ordering::SeqCst);
DOWNGRADE_START_CHECK = tmp * 1000;
}
} }
unsafe { log::info!("DOWNGRADE_START_CHECK: {}s", DOWNGRADE_START_CHECK / 1000) }; log::info!(
"DOWNGRADE_START_CHECK: {}s",
DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000
);
let tmp = std::env::var("LIMIT_SPEED") let tmp = std::env::var("LIMIT_SPEED")
.map(|x| x.parse::<f64>().unwrap_or(0.)) .map(|x| x.parse::<f64>().unwrap_or(0.))
.unwrap_or(0.); .unwrap_or(0.);
if tmp > 0. { if tmp > 0. {
unsafe { LIMIT_SPEED.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
LIMIT_SPEED = (tmp * 1024. * 1024.) as usize;
}
} }
unsafe { log::info!("LIMIT_SPEED: {}Mb/s", LIMIT_SPEED as f64 / 1024. / 1024.) }; log::info!(
"LIMIT_SPEED: {}Mb/s",
LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
);
let tmp = std::env::var("TOTAL_BANDWIDTH") let tmp = std::env::var("TOTAL_BANDWIDTH")
.map(|x| x.parse::<f64>().unwrap_or(0.)) .map(|x| x.parse::<f64>().unwrap_or(0.))
.unwrap_or(0.); .unwrap_or(0.);
if tmp > 0. { if tmp > 0. {
unsafe { TOTAL_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
TOTAL_BANDWIDTH = (tmp * 1024. * 1024.) as usize;
}
} }
unsafe {
log::info!( log::info!(
"TOTAL_BANDWIDTH: {}Mb/s", "TOTAL_BANDWIDTH: {}Mb/s",
TOTAL_BANDWIDTH as f64 / 1024. / 1024. TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
) );
};
let tmp = std::env::var("SINGLE_BANDWIDTH") let tmp = std::env::var("SINGLE_BANDWIDTH")
.map(|x| x.parse::<f64>().unwrap_or(0.)) .map(|x| x.parse::<f64>().unwrap_or(0.))
.unwrap_or(0.); .unwrap_or(0.);
if tmp > 0. { if tmp > 0. {
unsafe { SINGLE_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
SINGLE_BANDWIDTH = (tmp * 1024. * 1024.) as usize;
}
} }
unsafe { log::info!(
log::info!( "SINGLE_BANDWIDTH: {}Mb/s",
"SINGLE_BANDWIDTH: {}Mb/s", SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
SINGLE_BANDWIDTH as f64 / 1024. / 1024. )
)
};
} }
async fn check_cmd(cmd: &str, limiter: Limiter) -> String { async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
@ -233,76 +230,68 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
if let Some(v) = fds.next() { if let Some(v) = fds.next() {
if let Ok(v) = v.parse::<f64>() { if let Ok(v) = v.parse::<f64>() {
if v > 0. { if v > 0. {
unsafe { DOWNGRADE_THRESHOLD_100.store((v * 100.) as _, Ordering::SeqCst);
DOWNGRADE_THRESHOLD = v;
}
} }
} }
} else { } else {
unsafe { res = format!(
res = format!("{DOWNGRADE_THRESHOLD}\n"); "{}\n",
} DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
);
} }
} }
Some("downgrade-start-check" | "t") => { Some("downgrade-start-check" | "t") => {
if let Some(v) = fds.next() { if let Some(v) = fds.next() {
if let Ok(v) = v.parse::<usize>() { if let Ok(v) = v.parse::<usize>() {
if v > 0 { if v > 0 {
unsafe { DOWNGRADE_START_CHECK.store(v * 1000, Ordering::SeqCst);
DOWNGRADE_START_CHECK = v * 1000;
}
} }
} }
} else { } else {
unsafe { res = format!("{}s\n", DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000);
res = format!("{}s\n", DOWNGRADE_START_CHECK / 1000);
}
} }
} }
Some("limit-speed" | "ls") => { Some("limit-speed" | "ls") => {
if let Some(v) = fds.next() { if let Some(v) = fds.next() {
if let Ok(v) = v.parse::<f64>() { if let Ok(v) = v.parse::<f64>() {
if v > 0. { if v > 0. {
unsafe { LIMIT_SPEED.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
LIMIT_SPEED = (v * 1024. * 1024.) as _;
}
} }
} }
} else { } else {
unsafe { res = format!(
res = format!("{}Mb/s\n", LIMIT_SPEED as f64 / 1024. / 1024.); "{}Mb/s\n",
} LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
);
} }
} }
Some("total-bandwidth" | "tb") => { Some("total-bandwidth" | "tb") => {
if let Some(v) = fds.next() { if let Some(v) = fds.next() {
if let Ok(v) = v.parse::<f64>() { if let Ok(v) = v.parse::<f64>() {
if v > 0. { if v > 0. {
unsafe { TOTAL_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
TOTAL_BANDWIDTH = (v * 1024. * 1024.) as _; limiter.set_speed_limit(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
limiter.set_speed_limit(TOTAL_BANDWIDTH as _);
}
} }
} }
} else { } else {
unsafe { res = format!(
res = format!("{}Mb/s\n", TOTAL_BANDWIDTH as f64 / 1024. / 1024.); "{}Mb/s\n",
} TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
);
} }
} }
Some("single-bandwidth" | "sb") => { Some("single-bandwidth" | "sb") => {
if let Some(v) = fds.next() { if let Some(v) = fds.next() {
if let Ok(v) = v.parse::<f64>() { if let Ok(v) = v.parse::<f64>() {
if v > 0. { if v > 0. {
unsafe { SINGLE_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
SINGLE_BANDWIDTH = (v * 1024. * 1024.) as _;
}
} }
} }
} else { } else {
unsafe { res = format!(
res = format!("{}Mb/s\n", SINGLE_BANDWIDTH as f64 / 1024. / 1024.); "{}Mb/s\n",
} SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
);
} }
} }
Some("usage" | "u") => { Some("usage" | "u") => {
@ -336,7 +325,7 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) { async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) {
check_params(); check_params();
let limiter = <Limiter>::new(unsafe { TOTAL_BANDWIDTH as _ }); let limiter = <Limiter>::new(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
loop { loop {
tokio::select! { tokio::select! {
res = listener.accept() => { res = listener.accept() => {
@ -475,10 +464,11 @@ async fn relay(
let mut highest_s = 0; let mut highest_s = 0;
let mut downgrade: bool = false; let mut downgrade: bool = false;
let mut blacked: bool = false; let mut blacked: bool = false;
let limiter = <Limiter>::new(unsafe { SINGLE_BANDWIDTH as _ }); let sb = SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64;
let blacklist_limiter = <Limiter>::new(unsafe { LIMIT_SPEED as _ }); let limiter = <Limiter>::new(sb);
let blacklist_limiter = <Limiter>::new(LIMIT_SPEED.load(Ordering::SeqCst) as _);
let downgrade_threshold = let downgrade_threshold =
(unsafe { SINGLE_BANDWIDTH as f64 * DOWNGRADE_THRESHOLD } / 1000.) as usize; // in bit/ms (sb * DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. / 1000.) as usize; // in bit/ms
let mut timer = interval(Duration::from_secs(3)); let mut timer = interval(Duration::from_secs(3));
let mut last_recv_time = std::time::Instant::now(); let mut last_recv_time = std::time::Instant::now();
loop { loop {
@ -546,7 +536,7 @@ async fn relay(
(elapsed as _, total as _, highest_s as _, speed as _), (elapsed as _, total as _, highest_s as _, speed as _),
); );
total_s = 0; total_s = 0;
if elapsed > unsafe { DOWNGRADE_START_CHECK } if elapsed > DOWNGRADE_START_CHECK.load(Ordering::SeqCst)
&& !downgrade && !downgrade
&& total > elapsed * downgrade_threshold && total > elapsed * downgrade_threshold
{ {

View File

@ -35,6 +35,7 @@ use sodiumoxide::crypto::sign;
use std::{ use std::{
collections::HashMap, collections::HashMap,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
sync::Arc, sync::Arc,
time::Instant, time::Instant,
}; };
@ -55,10 +56,10 @@ enum Sink {
} }
type Sender = mpsc::UnboundedSender<Data>; type Sender = mpsc::UnboundedSender<Data>;
type Receiver = mpsc::UnboundedReceiver<Data>; type Receiver = mpsc::UnboundedReceiver<Data>;
static mut ROTATION_RELAY_SERVER: usize = 0; static ROTATION_RELAY_SERVER: AtomicUsize = AtomicUsize::new(0);
type RelayServers = Vec<String>; type RelayServers = Vec<String>;
static CHECK_RELAY_TIMEOUT: u64 = 3_000; static CHECK_RELAY_TIMEOUT: u64 = 3_000;
static mut ALWAYS_USE_RELAY: bool = false; static ALWAYS_USE_RELAY: AtomicBool = AtomicBool::new(false);
#[derive(Clone)] #[derive(Clone)]
struct Inner { struct Inner {
@ -147,13 +148,11 @@ impl RendezvousServer {
.to_uppercase() .to_uppercase()
== "Y" == "Y"
{ {
unsafe { ALWAYS_USE_RELAY.store(true, Ordering::SeqCst);
ALWAYS_USE_RELAY = true;
}
} }
log::info!( log::info!(
"ALWAYS_USE_RELAY={}", "ALWAYS_USE_RELAY={}",
if unsafe { ALWAYS_USE_RELAY } { if ALWAYS_USE_RELAY.load(Ordering::SeqCst) {
"Y" "Y"
} else { } else {
"N" "N"
@ -711,7 +710,7 @@ impl RendezvousServer {
let peer_is_lan = self.is_lan(peer_addr); let peer_is_lan = self.is_lan(peer_addr);
let is_lan = self.is_lan(addr); let is_lan = self.is_lan(addr);
let mut relay_server = self.get_relay_server(addr.ip(), peer_addr.ip()); let mut relay_server = self.get_relay_server(addr.ip(), peer_addr.ip());
if unsafe { ALWAYS_USE_RELAY } || (peer_is_lan ^ is_lan) { if ALWAYS_USE_RELAY.load(Ordering::SeqCst) || (peer_is_lan ^ is_lan) {
if peer_is_lan { if peer_is_lan {
// https://github.com/rustdesk/rustdesk-server/issues/24 // https://github.com/rustdesk/rustdesk-server/issues/24
relay_server = self.inner.local_ip.clone() relay_server = self.inner.local_ip.clone()
@ -905,10 +904,7 @@ impl RendezvousServer {
} else if self.relay_servers.len() == 1 { } else if self.relay_servers.len() == 1 {
return self.relay_servers[0].clone(); return self.relay_servers[0].clone();
} }
let i = unsafe { let i = ROTATION_RELAY_SERVER.fetch_add(1, Ordering::SeqCst) % self.relay_servers.len();
ROTATION_RELAY_SERVER += 1;
ROTATION_RELAY_SERVER % self.relay_servers.len()
};
self.relay_servers[i].clone() self.relay_servers[i].clone()
} }
@ -1027,13 +1023,17 @@ impl RendezvousServer {
Some("always-use-relay" | "aur") => { Some("always-use-relay" | "aur") => {
if let Some(rs) = fds.next() { if let Some(rs) = fds.next() {
if rs.to_uppercase() == "Y" { if rs.to_uppercase() == "Y" {
unsafe { ALWAYS_USE_RELAY = true }; ALWAYS_USE_RELAY.store(true, Ordering::SeqCst);
} else { } else {
unsafe { ALWAYS_USE_RELAY = false }; ALWAYS_USE_RELAY.store(false, Ordering::SeqCst);
} }
self.tx.send(Data::RelayServers0(rs.to_owned())).ok(); self.tx.send(Data::RelayServers0(rs.to_owned())).ok();
} else { } else {
let _ = writeln!(res, "ALWAYS_USE_RELAY: {:?}", unsafe { ALWAYS_USE_RELAY }); let _ = writeln!(
res,
"ALWAYS_USE_RELAY: {:?}",
ALWAYS_USE_RELAY.load(Ordering::SeqCst)
);
} }
} }
Some("test-geo" | "tg") => { Some("test-geo" | "tg") => {