rustdesk-server/src/rendezvous_server.rs
2022-05-13 09:56:44 +08:00

1210 lines
44 KiB
Rust

use crate::common::*;
use crate::peer::*;
use hbb_common::{
allow_err,
bytes::{Bytes, BytesMut},
bytes_codec::BytesCodec,
futures::future::join_all,
futures_util::{
sink::SinkExt,
stream::{SplitSink, StreamExt},
},
log,
protobuf::{Message as _, MessageField},
rendezvous_proto::{
register_pk_response::Result::{TOO_FREQUENT, UUID_MISMATCH},
*,
},
tcp::{new_listener, FramedStream},
timeout,
tokio::{
self,
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::{mpsc, Mutex},
time::{interval, Duration},
},
tokio_util::codec::Framed,
udp::FramedSocket,
AddrMangle, ResultType,
};
use sodiumoxide::crypto::sign;
use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
time::Instant,
};
const ADDR_127: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
#[derive(Clone, Debug)]
enum Data {
Msg(RendezvousMessage, SocketAddr),
RelayServers0(String),
RelayServers(RelayServers),
}
const REG_TIMEOUT: i32 = 30_000;
type TcpStreamSink = SplitSink<Framed<TcpStream, BytesCodec>, Bytes>;
type WsSink = SplitSink<tokio_tungstenite::WebSocketStream<TcpStream>, tungstenite::Message>;
enum Sink {
TcpStream(TcpStreamSink),
Ws(WsSink),
}
type Sender = mpsc::UnboundedSender<Data>;
type Receiver = mpsc::UnboundedReceiver<Data>;
static mut ROTATION_RELAY_SERVER: usize = 0;
type RelayServers = Vec<String>;
static CHECK_RELAY_TIMEOUT: u64 = 3_000;
static mut ALWAYS_USE_RELAY: bool = false;
#[derive(Clone)]
pub struct RendezvousServer {
tcp_punch: Arc<Mutex<HashMap<SocketAddr, Sink>>>,
pm: PeerMap,
tx: Sender,
relay_servers: Arc<RelayServers>,
relay_servers0: Arc<RelayServers>,
serial: i32,
rendezvous_servers: Arc<Vec<String>>,
version: String,
software_url: String,
sk: Option<sign::SecretKey>,
}
enum LoopFailure {
UdpSocket,
Listener3,
Listener2,
Listener,
}
impl RendezvousServer {
#[tokio::main(flavor = "multi_thread")]
pub async fn start(
port: i32,
serial: i32,
key: &str,
rmem: usize,
) -> ResultType<()> {
let addr = format!("0.0.0.0:{}", port);
let addr2 = format!("0.0.0.0:{}", port - 1);
let addr3 = format!("0.0.0.0:{}", port + 2);
let pm = PeerMap::new().await?;
log::info!("serial={}", serial);
let rendezvous_servers = get_servers(&get_arg("rendezvous-servers"), "rendezvous-servers");
log::info!("Listening on tcp/udp {}", addr);
log::info!("Listening on tcp {}, extra port for NAT test", addr2);
log::info!("Listening on websocket {}", addr3);
let mut socket = FramedSocket::new_with_buf_size(&addr, rmem).await?;
let (tx, mut rx) = mpsc::unbounded_channel::<Data>();
let software_url = get_arg("software-url");
let version = hbb_common::get_version_from_url(&software_url);
if !version.is_empty() {
log::info!("software_url: {}, version: {}", software_url, version);
}
let mut rs = Self {
tcp_punch: Arc::new(Mutex::new(HashMap::new())),
pm,
tx: tx.clone(),
relay_servers: Default::default(),
relay_servers0: Default::default(),
serial,
rendezvous_servers: Arc::new(rendezvous_servers),
version,
software_url,
sk: None,
};
let key = rs.get_server_sk(key);
std::env::set_var("PORT_FOR_API", port.to_string());
rs.parse_relay_servers(&get_arg("relay-servers"));
let pm = rs.pm.clone();
let mut listener = new_listener(&addr, false).await?;
let mut listener2 = new_listener(&addr2, false).await?;
let mut listener3 = new_listener(&addr3, false).await?;
let test_addr = std::env::var("TEST_HBBS").unwrap_or_default();
if std::env::var("ALWAYS_USE_RELAY")
.unwrap_or_default()
.to_uppercase()
== "Y"
{
unsafe {
ALWAYS_USE_RELAY = true;
}
}
log::info!(
"ALWAYS_USE_RELAY={}",
if unsafe { ALWAYS_USE_RELAY } {
"Y"
} else {
"N"
}
);
if test_addr.to_lowercase() != "no" {
let test_addr = (if test_addr.is_empty() {
addr.replace("0.0.0.0", "127.0.0.1")
} else {
test_addr
})
.parse::<SocketAddr>()?;
tokio::spawn(async move {
allow_err!(test_hbbs(test_addr).await);
});
};
loop {
log::info!("Start");
match rs
.io_loop(
&mut rx,
&mut listener,
&mut listener2,
&mut listener3,
&mut socket,
&key,
)
.await
{
LoopFailure::UdpSocket => {
drop(socket);
socket = FramedSocket::new_with_buf_size(&addr, rmem).await?;
}
LoopFailure::Listener => {
drop(listener);
listener = new_listener(&addr, false).await?;
}
LoopFailure::Listener2 => {
drop(listener2);
listener2 = new_listener(&addr2, false).await?;
}
LoopFailure::Listener3 => {
drop(listener3);
listener3 = new_listener(&addr3, false).await?;
}
}
}
}
async fn io_loop(
&mut self,
rx: &mut Receiver,
listener: &mut TcpListener,
listener2: &mut TcpListener,
listener3: &mut TcpListener,
socket: &mut FramedSocket,
key: &str,
) -> LoopFailure {
let mut timer_check_relay = interval(Duration::from_millis(CHECK_RELAY_TIMEOUT));
loop {
tokio::select! {
_ = timer_check_relay.tick() => {
if self.relay_servers0.len() > 1 {
let rs = self.relay_servers0.clone();
let tx = self.tx.clone();
tokio::spawn(async move {
check_relay_servers(rs, tx).await;
});
}
}
Some(data) = rx.recv() => {
match data {
Data::Msg(msg, addr) => { allow_err!(socket.send(&msg, addr).await); }
Data::RelayServers0(rs) => { self.parse_relay_servers(&rs); }
Data::RelayServers(rs) => { self.relay_servers = Arc::new(rs); }
}
}
res = socket.next() => {
match res {
Some(Ok((bytes, addr))) => {
if let Err(err) = self.handle_udp(&bytes, addr.into(), socket, key).await {
log::error!("udp failure: {}", err);
return LoopFailure::UdpSocket;
}
}
Some(Err(err)) => {
log::error!("udp failure: {}", err);
return LoopFailure::UdpSocket;
}
None => {
// unreachable!() ?
}
}
}
res = listener2.accept() => {
match res {
Ok((stream, addr)) => {
stream.set_nodelay(true).ok();
self.handle_listener2(stream, addr).await;
}
Err(err) => {
log::error!("listener2.accept failed: {}", err);
return LoopFailure::Listener2;
}
}
}
res = listener3.accept() => {
match res {
Ok((stream, addr)) => {
stream.set_nodelay(true).ok();
self.handle_listener(stream, addr, key, true).await;
}
Err(err) => {
log::error!("listener3.accept failed: {}", err);
return LoopFailure::Listener3;
}
}
}
res = listener.accept() => {
match res {
Ok((stream, addr)) => {
stream.set_nodelay(true).ok();
self.handle_listener(stream, addr, key, false).await;
}
Err(err) => {
log::error!("listener.accept failed: {}", err);
return LoopFailure::Listener;
}
}
}
}
}
}
#[inline]
async fn handle_udp(
&mut self,
bytes: &BytesMut,
addr: SocketAddr,
socket: &mut FramedSocket,
key: &str,
) -> ResultType<()> {
if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
match msg_in.union {
Some(rendezvous_message::Union::register_peer(rp)) => {
// B registered
if rp.id.len() > 0 {
log::trace!("New peer registered: {:?} {:?}", &rp.id, &addr);
self.update_addr(rp.id, addr, socket).await?;
if self.serial > rp.serial {
let mut msg_out = RendezvousMessage::new();
msg_out.set_configure_update(ConfigUpdate {
serial: self.serial,
rendezvous_servers: (*self.rendezvous_servers).clone(),
..Default::default()
});
socket.send(&msg_out, addr).await?;
}
}
}
Some(rendezvous_message::Union::register_pk(rk)) => {
if rk.uuid.is_empty() || rk.pk.is_empty() {
return Ok(());
}
let id = rk.id;
let ip = addr.ip().to_string();
if id.len() < 6 {
return send_rk_res(socket, addr, UUID_MISMATCH).await;
} else if !self.check_ip_blocker(&ip, &id).await {
return send_rk_res(socket, addr, TOO_FREQUENT).await;
}
let peer = self.pm.get_or(&id).await;
let (changed, ip_changed) = {
let peer = peer.read().await;
if peer.uuid.is_empty() {
(true, false)
} else {
if peer.uuid == rk.uuid {
if peer.info.ip != ip && peer.pk != rk.pk {
log::warn!(
"Peer {} ip/pk mismatch: {}/{:?} vs {}/{:?}",
id,
ip,
rk.pk,
peer.info.ip,
peer.pk,
);
drop(peer);
return send_rk_res(socket, addr, UUID_MISMATCH).await;
}
} else {
log::warn!(
"Peer {} uuid mismatch: {:?} vs {:?}",
id,
rk.uuid,
peer.uuid
);
drop(peer);
return send_rk_res(socket, addr, UUID_MISMATCH).await;
}
let ip_changed = peer.info.ip != ip;
(
peer.uuid != rk.uuid || peer.pk != rk.pk || ip_changed,
ip_changed,
)
}
};
let mut req_pk = peer.read().await.reg_pk;
if req_pk.1.elapsed().as_secs() > 6 {
req_pk.0 = 0;
} else if req_pk.0 > 2 {
return send_rk_res(socket, addr, TOO_FREQUENT).await;
}
req_pk.0 += 1;
req_pk.1 = Instant::now();
peer.write().await.reg_pk = req_pk;
if ip_changed {
let mut lock = IP_CHANGES.lock().await;
if let Some((tm, ips)) = lock.get_mut(&id) {
if tm.elapsed().as_secs() > IP_CHANGE_DUR {
*tm = Instant::now();
ips.clear();
ips.insert(ip.clone(), 1);
} else {
if let Some(v) = ips.get_mut(&ip) {
*v += 1;
} else {
ips.insert(ip.clone(), 1);
}
}
} else {
lock.insert(
id.clone(),
(Instant::now(), HashMap::from([(ip.clone(), 1)])),
);
}
}
if changed {
self.pm.update_pk(id, peer, addr, rk.uuid, rk.pk, ip).await;
}
let mut msg_out = RendezvousMessage::new();
msg_out.set_register_pk_response(RegisterPkResponse {
result: register_pk_response::Result::OK.into(),
..Default::default()
});
socket.send(&msg_out, addr).await?
}
Some(rendezvous_message::Union::punch_hole_request(ph)) => {
if self.pm.is_in_memory(&ph.id).await {
self.handle_udp_punch_hole_request(addr, ph, key).await?;
} else {
// not in memory, fetch from db with spawn in case blocking me
let mut me = self.clone();
let key = key.to_owned();
tokio::spawn(async move {
allow_err!(me.handle_udp_punch_hole_request(addr, ph, &key).await);
});
}
}
Some(rendezvous_message::Union::punch_hole_sent(phs)) => {
self.handle_hole_sent(phs, addr, Some(socket)).await?;
}
Some(rendezvous_message::Union::local_addr(la)) => {
self.handle_local_addr(la, addr, Some(socket)).await?;
}
Some(rendezvous_message::Union::configure_update(mut cu)) => {
if addr.ip() == ADDR_127 && cu.serial > self.serial {
self.serial = cu.serial;
self.rendezvous_servers = Arc::new(
cu.rendezvous_servers
.drain(..)
.filter(|x| {
!x.is_empty()
&& test_if_valid_server(x, "rendezvous-server").is_ok()
})
.collect(),
);
log::info!(
"configure updated: serial={} rendezvous-servers={:?}",
self.serial,
self.rendezvous_servers
);
}
}
Some(rendezvous_message::Union::software_update(su)) => {
if !self.version.is_empty() && su.url != self.version {
let mut msg_out = RendezvousMessage::new();
msg_out.set_software_update(SoftwareUpdate {
url: self.software_url.clone(),
..Default::default()
});
socket.send(&msg_out, addr).await?;
}
}
_ => {}
}
}
Ok(())
}
#[inline]
async fn handle_tcp(
&mut self,
bytes: &[u8],
sink: &mut Option<Sink>,
addr: SocketAddr,
key: &str,
ws: bool,
) -> bool {
if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
match msg_in.union {
Some(rendezvous_message::Union::punch_hole_request(ph)) => {
// there maybe several attempt, so sink can be none
if let Some(sink) = sink.take() {
self.tcp_punch.lock().await.insert(addr, sink);
}
allow_err!(self.handle_tcp_punch_hole_request(addr, ph, &key, ws).await);
return true;
}
Some(rendezvous_message::Union::request_relay(mut rf)) => {
// there maybe several attempt, so sink can be none
if let Some(sink) = sink.take() {
self.tcp_punch.lock().await.insert(addr, sink);
}
if let Some(peer) = self.pm.get_in_memory(&rf.id).await {
let mut msg_out = RendezvousMessage::new();
rf.socket_addr = AddrMangle::encode(addr);
msg_out.set_request_relay(rf);
let peer_addr = peer.read().await.socket_addr;
self.tx.send(Data::Msg(msg_out, peer_addr)).ok();
}
return true;
}
Some(rendezvous_message::Union::relay_response(mut rr)) => {
let addr_b = AddrMangle::decode(&rr.socket_addr);
rr.socket_addr = Default::default();
let id = rr.get_id();
if !id.is_empty() {
let pk = self.get_pk(&rr.version, id.to_owned()).await;
rr.set_pk(pk);
}
let mut msg_out = RendezvousMessage::new();
msg_out.set_relay_response(rr);
allow_err!(self.send_to_tcp_sync(msg_out, addr_b).await);
}
Some(rendezvous_message::Union::punch_hole_sent(phs)) => {
allow_err!(self.handle_hole_sent(phs, addr, None).await);
}
Some(rendezvous_message::Union::local_addr(la)) => {
allow_err!(self.handle_local_addr(la, addr, None).await);
}
Some(rendezvous_message::Union::test_nat_request(tar)) => {
let mut msg_out = RendezvousMessage::new();
let mut res = TestNatResponse {
port: addr.port() as _,
..Default::default()
};
if self.serial > tar.serial {
let mut cu = ConfigUpdate::new();
cu.serial = self.serial;
cu.rendezvous_servers = (*self.rendezvous_servers).clone();
res.cu = MessageField::from_option(Some(cu));
}
msg_out.set_test_nat_response(res);
Self::send_to_sink(sink, msg_out).await;
}
Some(rendezvous_message::Union::register_pk(_rk)) => {
let res = register_pk_response::Result::NOT_SUPPORT;
let mut msg_out = RendezvousMessage::new();
msg_out.set_register_pk_response(RegisterPkResponse {
result: res.into(),
..Default::default()
});
Self::send_to_sink(sink, msg_out).await;
}
_ => {}
}
}
false
}
#[inline]
async fn update_addr(
&mut self,
id: String,
socket_addr: SocketAddr,
socket: &mut FramedSocket,
) -> ResultType<()> {
let (request_pk, ip_change) = if let Some(old) = self.pm.get_in_memory(&id).await {
let mut old = old.write().await;
let ip = socket_addr.ip();
let ip_change = if old.socket_addr.port() != 0 {
ip != old.socket_addr.ip()
} else {
ip.to_string() != old.info.ip
} && ip != ADDR_127;
let request_pk = old.pk.is_empty() || ip_change;
if !request_pk {
old.socket_addr = socket_addr;
old.last_reg_time = Instant::now();
}
let ip_change = if ip_change && old.reg_pk.0 <= 2 {
Some(if old.socket_addr.port() == 0 {
old.info.ip.clone()
} else {
old.socket_addr.to_string()
})
} else {
None
};
(request_pk, ip_change)
} else {
(true, None)
};
if let Some(old) = ip_change {
log::info!("IP change of {} from {} to {}", id, old, socket_addr);
}
let mut msg_out = RendezvousMessage::new();
msg_out.set_register_peer_response(RegisterPeerResponse {
request_pk,
..Default::default()
});
socket.send(&msg_out, socket_addr).await
}
#[inline]
async fn handle_hole_sent<'a>(
&mut self,
phs: PunchHoleSent,
addr: SocketAddr,
socket: Option<&'a mut FramedSocket>,
) -> ResultType<()> {
// punch hole sent from B, tell A that B is ready to be connected
let addr_a = AddrMangle::decode(&phs.socket_addr);
log::debug!(
"{} punch hole response to {:?} from {:?}",
if socket.is_none() { "TCP" } else { "UDP" },
&addr_a,
&addr
);
let mut msg_out = RendezvousMessage::new();
let mut p = PunchHoleResponse {
socket_addr: AddrMangle::encode(addr),
pk: self.get_pk(&phs.version, phs.id).await,
relay_server: phs.relay_server.clone(),
..Default::default()
};
if let Ok(t) = phs.nat_type.enum_value() {
p.set_nat_type(t);
}
msg_out.set_punch_hole_response(p);
if let Some(socket) = socket {
socket.send(&msg_out, addr_a).await?;
} else {
self.send_to_tcp(msg_out, addr_a).await;
}
Ok(())
}
#[inline]
async fn handle_local_addr<'a>(
&mut self,
la: LocalAddr,
addr: SocketAddr,
socket: Option<&'a mut FramedSocket>,
) -> ResultType<()> {
// relay local addrs of B to A
let addr_a = AddrMangle::decode(&la.socket_addr);
log::debug!(
"{} local addrs response to {:?} from {:?}",
if socket.is_none() { "TCP" } else { "UDP" },
&addr_a,
&addr
);
let mut msg_out = RendezvousMessage::new();
let mut p = PunchHoleResponse {
socket_addr: la.local_addr.clone(),
pk: self.get_pk(&la.version, la.id).await,
relay_server: la.relay_server,
..Default::default()
};
p.set_is_local(true);
msg_out.set_punch_hole_response(p);
if let Some(socket) = socket {
socket.send(&msg_out, addr_a).await?;
} else {
self.send_to_tcp(msg_out, addr_a).await;
}
Ok(())
}
#[inline]
async fn handle_punch_hole_request(
&mut self,
addr: SocketAddr,
ph: PunchHoleRequest,
key: &str,
ws: bool,
) -> ResultType<(RendezvousMessage, Option<SocketAddr>)> {
if !key.is_empty() && ph.licence_key != key {
let mut msg_out = RendezvousMessage::new();
msg_out.set_punch_hole_response(PunchHoleResponse {
failure: punch_hole_response::Failure::LICENSE_MISMATCH.into(),
..Default::default()
});
return Ok((msg_out, None));
}
let id = ph.id;
// punch hole request from A, relay to B,
// check if in same intranet first,
// fetch local addrs if in same intranet.
// because punch hole won't work if in the same intranet,
// all routers will drop such self-connections.
if let Some(peer) = self.pm.get(&id).await {
let (elapsed, peer_addr) = {
let r = peer.read().await;
(r.last_reg_time.elapsed().as_millis() as i32, r.socket_addr)
};
if elapsed >= REG_TIMEOUT {
let mut msg_out = RendezvousMessage::new();
msg_out.set_punch_hole_response(PunchHoleResponse {
failure: punch_hole_response::Failure::OFFLINE.into(),
..Default::default()
});
return Ok((msg_out, None));
}
let mut msg_out = RendezvousMessage::new();
if unsafe { ALWAYS_USE_RELAY } {
let relay_server = self.get_relay_server(addr.ip(), peer_addr.ip());
if !relay_server.is_empty() {
msg_out.set_request_relay(RequestRelay {
relay_server,
..Default::default()
});
return Ok((msg_out, Some(peer_addr)));
}
}
let same_intranet = !ws
&& match peer_addr {
SocketAddr::V4(a) => match addr {
SocketAddr::V4(b) => a.ip() == b.ip(),
_ => false,
},
SocketAddr::V6(a) => match addr {
SocketAddr::V6(b) => a.ip() == b.ip(),
_ => false,
},
};
let socket_addr = AddrMangle::encode(addr);
let relay_server = self.get_relay_server(addr.ip(), peer_addr.ip());
if same_intranet {
log::debug!(
"Fetch local addr {:?} {:?} request from {:?}",
id,
peer_addr,
addr
);
msg_out.set_fetch_local_addr(FetchLocalAddr {
socket_addr,
relay_server,
..Default::default()
});
} else {
log::debug!(
"Punch hole {:?} {:?} request from {:?}",
id,
peer_addr,
addr
);
msg_out.set_punch_hole(PunchHole {
socket_addr,
nat_type: ph.nat_type,
relay_server,
..Default::default()
});
}
return Ok((msg_out, Some(peer_addr)));
} else {
let mut msg_out = RendezvousMessage::new();
msg_out.set_punch_hole_response(PunchHoleResponse {
failure: punch_hole_response::Failure::ID_NOT_EXIST.into(),
..Default::default()
});
return Ok((msg_out, None));
}
}
#[inline]
async fn send_to_tcp(&mut self, msg: RendezvousMessage, addr: SocketAddr) {
let mut tcp = self.tcp_punch.lock().await.remove(&addr);
tokio::spawn(async move {
Self::send_to_sink(&mut tcp, msg).await;
});
}
#[inline]
async fn send_to_sink(sink: &mut Option<Sink>, msg: RendezvousMessage) {
if let Some(sink) = sink.as_mut() {
if let Ok(bytes) = msg.write_to_bytes() {
match sink {
Sink::TcpStream(s) => {
allow_err!(s.send(Bytes::from(bytes)).await);
}
Sink::Ws(ws) => {
allow_err!(ws.send(tungstenite::Message::Binary(bytes)).await);
}
}
}
}
}
#[inline]
async fn send_to_tcp_sync(
&mut self,
msg: RendezvousMessage,
addr: SocketAddr,
) -> ResultType<()> {
let mut sink = self.tcp_punch.lock().await.remove(&addr);
Self::send_to_sink(&mut sink, msg).await;
Ok(())
}
#[inline]
async fn handle_tcp_punch_hole_request(
&mut self,
addr: SocketAddr,
ph: PunchHoleRequest,
key: &str,
ws: bool,
) -> ResultType<()> {
let (msg, to_addr) = self.handle_punch_hole_request(addr, ph, key, ws).await?;
if let Some(addr) = to_addr {
self.tx.send(Data::Msg(msg, addr))?;
} else {
self.send_to_tcp_sync(msg, addr).await?;
}
Ok(())
}
#[inline]
async fn handle_udp_punch_hole_request(
&mut self,
addr: SocketAddr,
ph: PunchHoleRequest,
key: &str,
) -> ResultType<()> {
let (msg, to_addr) = self.handle_punch_hole_request(addr, ph, key, false).await?;
self.tx.send(Data::Msg(
msg,
match to_addr {
Some(addr) => addr,
None => addr,
},
))?;
Ok(())
}
async fn check_ip_blocker(&self, ip: &str, id: &str) -> bool {
let mut lock = IP_BLOCKER.lock().await;
let now = Instant::now();
if let Some(old) = lock.get_mut(ip) {
let counter = &mut old.0;
if counter.1.elapsed().as_secs() > IP_BLOCK_DUR {
counter.0 = 0;
} else if counter.0 > 30 {
return false;
}
counter.0 += 1;
counter.1 = now;
let counter = &mut old.1;
let is_new = counter.0.get(id).is_none();
if counter.1.elapsed().as_secs() > DAY_SECONDS {
counter.0.clear();
} else if counter.0.len() > 300 {
return !is_new;
}
if is_new {
counter.0.insert(id.to_owned());
}
counter.1 = now;
} else {
lock.insert(ip.to_owned(), ((0, now), (Default::default(), now)));
}
true
}
fn parse_relay_servers(&mut self, relay_servers: &str) {
let rs = get_servers(relay_servers, "relay-servers");
self.relay_servers0 = Arc::new(rs);
self.relay_servers = self.relay_servers0.clone();
}
fn get_relay_server(&self, pa: IpAddr, pb: IpAddr) -> String {
if self.relay_servers.is_empty() {
return "".to_owned();
} else if self.relay_servers.len() == 1 {
return self.relay_servers[0].clone();
}
let i = unsafe {
ROTATION_RELAY_SERVER += 1;
ROTATION_RELAY_SERVER % self.relay_servers.len()
};
self.relay_servers[i].clone()
}
async fn check_cmd(&self, cmd: &str) -> String {
let mut res = "".to_owned();
let mut fds = cmd.trim().split(" ");
match fds.next() {
Some("h") => {
res = format!(
"{}\n{}\n{}\n{}\n{}\n{}\n",
"relay-servers(rs) <separated by ,>",
"reload-geo(rg)",
"ip-blocker(ib) [<ip>|<number>] [-]",
"ip-changes(ic) [<id>|<number>] [-]",
"always-use-relay(aur)",
"test-geo(tg) <ip1> <ip2>"
)
}
Some("relay-servers" | "rs") => {
if let Some(rs) = fds.next() {
self.tx.send(Data::RelayServers0(rs.to_owned())).ok();
} else {
for ip in self.relay_servers.iter() {
res += &format!("{}\n", ip);
}
}
}
Some("ip-blocker" | "ib") => {
let mut lock = IP_BLOCKER.lock().await;
lock.retain(|&_, (a, b)| {
a.1.elapsed().as_secs() <= IP_BLOCK_DUR
|| b.1.elapsed().as_secs() <= DAY_SECONDS
});
res = format!("{}\n", lock.len());
let ip = fds.next();
let mut start = ip.map(|x| x.parse::<i32>().unwrap_or(-1)).unwrap_or(-1);
if start < 0 {
if let Some(ip) = ip {
if let Some((a, b)) = lock.get(ip) {
res += &format!(
"{}/{}s {}/{}s\n",
a.0,
a.1.elapsed().as_secs(),
b.0.len(),
b.1.elapsed().as_secs()
);
}
if fds.next() == Some("-") {
lock.remove(ip);
}
} else {
start = 0;
}
}
if start >= 0 {
let mut it = lock.iter();
for i in 0..(start + 10) {
let x = it.next();
if x.is_none() {
break;
}
if i < start {
continue;
}
if let Some((ip, (a, b))) = x {
res += &format!(
"{}: {}/{}s {}/{}s\n",
ip,
a.0,
a.1.elapsed().as_secs(),
b.0.len(),
b.1.elapsed().as_secs()
);
}
}
}
}
Some("ip-changes" | "ic") => {
let mut lock = IP_CHANGES.lock().await;
lock.retain(|&_, v| v.0.elapsed().as_secs() < IP_CHANGE_DUR_X2 && v.1.len() > 1);
res = format!("{}\n", lock.len());
let id = fds.next();
let mut start = id.map(|x| x.parse::<i32>().unwrap_or(-1)).unwrap_or(-1);
if start < 0 || start > 10_000_000 {
if let Some(id) = id {
if let Some((tm, ips)) = lock.get(id) {
res += &format!("{}s {:?}\n", tm.elapsed().as_secs(), ips);
}
if fds.next() == Some("-") {
lock.remove(id);
}
} else {
start = 0;
}
}
if start >= 0 {
let mut it = lock.iter();
for i in 0..(start + 10) {
let x = it.next();
if x.is_none() {
break;
}
if i < start {
continue;
}
if let Some((id, (tm, ips))) = x {
res += &format!("{}: {}s {:?}\n", id, tm.elapsed().as_secs(), ips,);
}
}
}
}
Some("always-use-relay" | "aur") => {
if let Some(rs) = fds.next() {
if rs.to_uppercase() == "Y" {
unsafe { ALWAYS_USE_RELAY = true };
} else {
unsafe { ALWAYS_USE_RELAY = false };
}
self.tx.send(Data::RelayServers0(rs.to_owned())).ok();
} else {
res += &format!("ALWAYS_USE_RELAY: {:?}\n", unsafe { ALWAYS_USE_RELAY });
}
}
Some("test-geo" | "tg") => {
if let Some(rs) = fds.next() {
if let Ok(a) = rs.parse::<IpAddr>() {
if let Some(rs) = fds.next() {
if let Ok(b) = rs.parse::<IpAddr>() {
res = format!("{:?}", self.get_relay_server(a, b));
}
} else {
res = format!("{:?}", self.get_relay_server(a, a));
}
}
}
}
_ => {}
}
res
}
async fn handle_listener2(&self, stream: TcpStream, addr: SocketAddr) {
if addr.ip().to_string() == "127.0.0.1" {
let rs = self.clone();
tokio::spawn(async move {
let mut stream = stream;
let mut buffer = [0; 64];
if let Ok(Ok(n)) = timeout(1000, stream.read(&mut buffer[..])).await {
if let Ok(data) = std::str::from_utf8(&buffer[..n]) {
let res = rs.check_cmd(data).await;
stream.write(res.as_bytes()).await.ok();
}
}
});
return;
}
let stream = FramedStream::from(stream, addr);
tokio::spawn(async move {
let mut stream = stream;
if let Some(Ok(bytes)) = stream.next_timeout(30_000).await {
if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
if let Some(rendezvous_message::Union::test_nat_request(_)) = msg_in.union {
let mut msg_out = RendezvousMessage::new();
msg_out.set_test_nat_response(TestNatResponse {
port: addr.port() as _,
..Default::default()
});
stream.send(&msg_out).await.ok();
}
}
}
});
}
async fn handle_listener(
&self,
stream: TcpStream,
addr: SocketAddr,
key: &str,
ws: bool,
) {
log::debug!("Tcp connection from {:?}, ws: {}", addr, ws);
let mut rs = self.clone();
let key = key.to_owned();
tokio::spawn(async move {
allow_err!(
rs.handle_listener_inner(stream, addr, &key, ws)
.await
);
});
}
#[inline]
async fn handle_listener_inner(
&mut self,
stream: TcpStream,
addr: SocketAddr,
key: &str,
ws: bool,
) -> ResultType<()> {
let mut sink;
if ws {
let ws_stream = tokio_tungstenite::accept_async(stream).await?;
let (a, mut b) = ws_stream.split();
sink = Some(Sink::Ws(a));
while let Ok(Some(Ok(msg))) = timeout(30_000, b.next()).await {
match msg {
tungstenite::Message::Binary(bytes) => {
if !self
.handle_tcp(&bytes, &mut sink, addr, key, ws)
.await
{
break;
}
}
_ => {}
}
}
} else {
let (a, mut b) = Framed::new(stream, BytesCodec::new()).split();
sink = Some(Sink::TcpStream(a));
while let Ok(Some(Ok(bytes))) = timeout(30_000, b.next()).await {
if !self
.handle_tcp(&bytes, &mut sink, addr, key, ws)
.await
{
break;
}
}
}
if sink.is_none() {
self.tcp_punch.lock().await.remove(&addr);
}
log::debug!("Tcp connection from {:?} closed", addr);
Ok(())
}
#[inline]
async fn get_pk(&mut self, version: &str, id: String) -> Vec<u8> {
if version.is_empty() || self.sk.is_none() {
Vec::new()
} else {
match self.pm.get(&id).await {
Some(peer) => {
let pk = peer.read().await.pk.clone();
sign::sign(
&hbb_common::message_proto::IdPk {
id,
pk,
..Default::default()
}
.write_to_bytes()
.unwrap_or_default(),
&self.sk.as_ref().unwrap(),
)
}
_ => Vec::new(),
}
}
}
#[inline]
fn get_server_sk(&mut self, key: &str) -> String {
let mut key = key.to_owned();
if let Ok(sk) = base64::decode(&key) {
if sk.len() == sign::SECRETKEYBYTES {
log::info!("The key is a crypto private key");
key = base64::encode(&sk[(sign::SECRETKEYBYTES / 2)..]);
let mut tmp = [0u8; sign::SECRETKEYBYTES];
tmp[..].copy_from_slice(&sk);
self.sk = Some(sign::SecretKey(tmp));
}
}
if key.is_empty() || key == "-" || key == "_" {
let (pk, sk) = crate::common::gen_sk();
self.sk = sk;
if !key.is_empty() {
key = pk;
} else {
std::env::set_var("KEY_FOR_API", pk);
}
}
if !key.is_empty() {
log::info!("Key: {}", key);
std::env::set_var("KEY_FOR_API", key.clone());
}
key
}
}
async fn check_relay_servers(rs0: Arc<RelayServers>, tx: Sender) {
let mut futs = Vec::new();
let rs = Arc::new(Mutex::new(Vec::new()));
for x in rs0.iter() {
let mut host = x.to_owned();
if !host.contains(":") {
host = format!("{}:{}", host, hbb_common::config::RELAY_PORT);
}
let rs = rs.clone();
let x = x.clone();
futs.push(tokio::spawn(async move {
if FramedStream::new(&host, "0.0.0.0:0", CHECK_RELAY_TIMEOUT)
.await
.is_ok()
{
rs.lock().await.push(x);
}
}));
}
join_all(futs).await;
log::debug!("check_relay_servers");
let rs = std::mem::replace(&mut *rs.lock().await, Default::default());
if !rs.is_empty() {
tx.send(Data::RelayServers(rs)).ok();
}
}
// temp solution to solve udp socket failure
async fn test_hbbs(addr: SocketAddr) -> ResultType<()> {
let mut socket = FramedSocket::new("0.0.0.0:0").await?;
let mut msg_out = RendezvousMessage::new();
msg_out.set_register_peer(RegisterPeer {
id: "(:test_hbbs:)".to_owned(),
..Default::default()
});
let mut last_time_recv = Instant::now();
let mut timer = interval(Duration::from_secs(1));
loop {
tokio::select! {
_ = timer.tick() => {
if last_time_recv.elapsed().as_secs() > 12 {
log::error!("Timeout of test_hbbs");
std::process::exit(1);
}
socket.send(&msg_out, addr).await?;
}
Some(Ok((bytes, _))) = socket.next() => {
if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
log::trace!("Recv {:?} of test_hbbs", msg_in);
last_time_recv = Instant::now();
}
}
}
}
}
#[inline]
fn distance(a: &(i32, i32), b: &(i32, i32)) -> i32 {
let dx = a.0 - b.0;
let dy = a.1 - b.1;
dx * dx + dy * dy
}
#[inline]
async fn send_rk_res(
socket: &mut FramedSocket,
addr: SocketAddr,
res: register_pk_response::Result,
) -> ResultType<()> {
let mut msg_out = RendezvousMessage::new();
msg_out.set_register_pk_response(RegisterPkResponse {
result: res.into(),
..Default::default()
});
socket.send(&msg_out, addr).await
}