first commit

This commit is contained in:
rustdesk 2021-05-27 19:59:49 +08:00
parent a562ec712c
commit c18fa776ce
16 changed files with 2860 additions and 2 deletions

10
Cargo.toml Normal file
View File

@ -0,0 +1,10 @@
[package]
name = "rustdesk-server"
version = "1.0.0"
authors = ["rustdesk <info@rustdesk.com>"]
edition = "2018"
description = "A remote control software."
[dependencies]
hbb_common = { path = "libs/hbb_common" }

View File

@ -1,2 +1,2 @@
# A demo of RustDesk server implementation # A working demo of RustDesk server implementation
This is a super simple implementation with only one relay connection allowed, without NAT traverse, persistence, encryption and any other advanced features. But it can be your good starting point to write your own RustDesk server program. This is a super simple working demo implementation with only one relay connection allowed, without NAT traverse, persistence, encryption and any other advanced features. But it can be your good starting point to write your own RustDesk server program.

4
libs/hbb_common/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
/target
**/*.rs.bk
Cargo.lock
src/protos/

View File

@ -0,0 +1,46 @@
[package]
name = "hbb_common"
version = "0.1.0"
authors = ["rustdesk<info@rustdesk.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
protobuf = { version = "3.0.0-pre", git = "https://github.com/stepancheg/rust-protobuf" }
tokio = { version = "0.2", features = ["full"] }
tokio-util = { version = "0.3", features = ["full"] }
futures = "0.3"
bytes = "0.5"
log = "0.4"
env_logger = "0.8"
socket2 = { version = "0.3", features = ["reuseport"] }
zstd = "0.5"
quinn = {version = "0.6", optional = true }
anyhow = "1.0"
futures-util = "0.3"
directories-next = "2.0"
rand = "0.7"
serde_derive = "1.0"
serde = "1.0"
lazy_static = "1.4"
confy = { git = "https://github.com/open-trade/confy" }
dirs-next = "2.0"
filetime = "0.2"
sodiumoxide = "0.2"
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
mac_address = "1.1"
[features]
quic = ["quinn"]
[build-dependencies]
protobuf-codegen-pure = { version = "3.0.0-pre", git = "https://github.com/stepancheg/rust-protobuf" }
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["winuser"] }
[dev-dependencies]
toml = "0.5"
serde_json = "1.0"

9
libs/hbb_common/build.rs Normal file
View File

@ -0,0 +1,9 @@
fn main() {
std::fs::create_dir_all("src/protos").unwrap();
protobuf_codegen_pure::Codegen::new()
.out_dir("src/protos")
.inputs(&["protos/rendezvous.proto", "protos/message.proto"])
.include("protos")
.run()
.expect("Codegen failed.");
}

View File

@ -0,0 +1,404 @@
syntax = "proto3";
package hbb;
message VP9 {
bytes data = 1;
bool key = 2;
int64 pts = 3;
}
message VP9s { repeated VP9 frames = 1; }
message RGB { bool compress = 1; }
// planes data send directly in binary for better use arraybuffer on web
message YUV {
bool compress = 1;
int32 stride = 2;
}
message VideoFrame {
oneof union {
VP9s vp9s = 6;
RGB rgb = 7;
YUV yuv = 8;
}
}
message DisplayInfo {
sint32 x = 1;
sint32 y = 2;
int32 width = 3;
int32 height = 4;
string name = 5;
bool online = 6;
}
message PortForward {
string host = 1;
int32 port = 2;
}
message FileTransfer {
string dir = 1;
bool show_hidden = 2;
}
message LoginRequest {
string username = 1;
bytes password = 2;
string my_id = 4;
string my_name = 5;
OptionMessage option = 6;
oneof union {
FileTransfer file_transfer = 7;
PortForward port_forward = 8;
}
}
message ChatMessage { string text = 1; }
message PeerInfo {
string username = 1;
string hostname = 2;
string platform = 3;
repeated DisplayInfo displays = 4;
int32 current_display = 5;
bool sas_enabled = 6;
string version = 7;
}
message LoginResponse {
oneof union {
string error = 1;
PeerInfo peer_info = 2;
}
}
message MouseEvent {
int32 mask = 1;
sint32 x = 2;
sint32 y = 3;
repeated ControlKey modifiers = 4;
}
enum ControlKey {
Alt = 1;
Backspace = 2;
CapsLock = 3;
Control = 4;
Delete = 5;
DownArrow = 6;
End = 7;
Escape = 8;
F1 = 9;
F10 = 10;
F11 = 11;
F12 = 12;
F2 = 13;
F3 = 14;
F4 = 15;
F5 = 16;
F6 = 17;
F7 = 18;
F8 = 19;
F9 = 20;
Home = 21;
LeftArrow = 22;
/// meta key (also known as "windows"; "super"; and "command")
Meta = 23;
/// option key on macOS (alt key on Linux and Windows)
Option = 24;
PageDown = 25;
PageUp = 26;
Return = 27;
RightArrow = 28;
Shift = 29;
Space = 30;
Tab = 31;
UpArrow = 32;
Numpad0 = 33;
Numpad1 = 34;
Numpad2 = 35;
Numpad3 = 36;
Numpad4 = 37;
Numpad5 = 38;
Numpad6 = 39;
Numpad7 = 40;
Numpad8 = 41;
Numpad9 = 42;
Cancel = 43;
Clear = 44;
Menu = 45; // deprecated, use Alt instead
Pause = 46;
Kana = 47;
Hangul = 48;
Junja = 49;
Final = 50;
Hanja = 51;
Kanji = 52;
Convert = 53;
Select = 54;
Print = 55;
Execute = 56;
Snapshot = 57;
Insert = 58;
Help = 59;
Sleep = 60;
Separator = 61;
Scroll = 62;
NumLock = 63;
RWin = 64;
Apps = 65;
Multiply = 66;
Add = 67;
Subtract = 68;
Decimal = 69;
Divide = 70;
Equals = 71;
NumpadEnter = 72;
RShift= 73;
RControl = 74;
RAlt = 75;
CtrlAltDel = 100;
LockScreen = 101;
}
message KeyEvent {
bool down = 1;
bool press = 2;
oneof union {
ControlKey control_key = 3;
uint32 chr = 4;
uint32 unicode = 5;
string seq = 6;
}
repeated ControlKey modifiers = 8;
}
message CursorData {
uint64 id = 1;
sint32 hotx = 2;
sint32 hoty = 3;
int32 width = 4;
int32 height = 5;
bytes colors = 6;
}
message CursorPosition {
sint32 x = 1;
sint32 y = 2;
}
message Hash {
string salt = 1;
string challenge = 2;
};
message Clipboard {
bool compress = 1;
bytes content = 2;
};
enum FileType {
Dir = 1;
DirLink = 2;
DirDrive = 3;
File = 4;
FileLink = 5;
}
message FileEntry {
FileType entry_type = 1;
string name = 2;
bool is_hidden = 3;
uint64 size = 4;
uint64 modified_time = 5;
}
message FileDirectory {
int32 id = 1;
string path = 2;
repeated FileEntry entries = 3;
}
message ReadDir {
string path = 1;
bool include_hidden = 2;
}
message ReadAllFiles {
int32 id = 1;
string path = 2;
bool include_hidden = 3;
}
message FileAction {
oneof union {
ReadDir read_dir = 1;
FileTransferSendRequest send = 2;
FileTransferReceiveRequest receive = 3;
FileDirCreate create = 4;
FileRemoveDir remove_dir = 5;
FileRemoveFile remove_file = 6;
ReadAllFiles all_files = 7;
FileTransferCancel cancel = 8;
}
}
message FileTransferCancel { int32 id = 1; }
message FileResponse {
oneof union {
FileDirectory dir = 1;
FileTransferBlock block = 2;
FileTransferError error = 3;
FileTransferDone done = 4;
}
}
message FileTransferBlock {
int32 id = 1;
sint32 file_num = 2;
bytes data = 3;
bool compressed = 4;
}
message FileTransferError {
int32 id = 1;
string error = 2;
sint32 file_num = 3;
}
message FileTransferSendRequest {
int32 id = 1;
string path = 2;
bool include_hidden = 3;
}
message FileTransferDone {
int32 id = 1;
sint32 file_num = 2;
}
message FileTransferReceiveRequest {
int32 id = 1;
string path = 2; // path written to
repeated FileEntry files = 3;
}
message FileRemoveDir {
int32 id = 1;
string path = 2;
bool recursive = 3;
}
message FileRemoveFile {
int32 id = 1;
string path = 2;
sint32 file_num = 3;
}
message FileDirCreate {
int32 id = 1;
string path = 2;
}
message SwitchDisplay {
int32 display = 1;
sint32 x = 2;
sint32 y = 3;
int32 width = 4;
int32 height = 5;
}
enum Permission {
Keyboard = 1;
Clipboard = 2;
Audio = 3;
}
message PermissionInfo {
Permission permission = 1;
bool enabled = 2;
}
enum ImageQuality {
NotSet = 0;
Low = 2;
Balanced = 3;
Best = 4;
}
enum BoolOption {
NotSet = 0;
No = 1;
Yes = 2;
}
message OptionMessage {
ImageQuality image_quality = 1;
BoolOption lock_after_session_end = 2;
BoolOption show_remote_cursor = 3;
BoolOption privacy_mode = 4;
BoolOption block_input = 5;
int32 custom_image_quality = 6;
BoolOption disable_audio = 7;
BoolOption disable_clipboard = 8;
}
message TestDelay {
int64 time = 1;
bool from_client = 2;
}
message PublicKey {
bytes asymmetric_value = 1;
bytes symmetric_value = 2;
}
message SignedId {
bytes id = 1;
bytes pk = 2;
}
message AudioFormat {
uint32 sample_rate = 1;
uint32 channels = 2;
}
message AudioFrame { bytes data = 1; }
message Misc {
oneof union {
ChatMessage chat_message = 4;
SwitchDisplay switch_display = 5;
PermissionInfo permission_info = 6;
OptionMessage option = 7;
AudioFormat audio_format = 8;
string close_reason = 9;
bool refresh_video = 10;
}
}
message Message {
oneof union {
SignedId signed_id = 3;
PublicKey public_key = 4;
TestDelay test_delay = 5;
VideoFrame video_frame = 6;
LoginRequest login_request = 7;
LoginResponse login_response = 8;
Hash hash = 9;
MouseEvent mouse_event = 10;
AudioFrame audio_frame = 11;
CursorData cursor_data = 12;
CursorPosition cursor_position = 13;
uint64 cursor_id = 14;
KeyEvent key_event = 15;
Clipboard clipboard = 16;
FileAction file_action = 17;
FileResponse file_response = 18;
Misc misc = 19;
}
}

View File

@ -0,0 +1,134 @@
syntax = "proto3";
package hbb;
message RegisterPeer {
string id = 1;
int32 serial = 2;
}
message RegisterPeerResponse { bool request_pk = 2; }
message PunchHoleRequest {
string id = 1;
NatType nat_type = 2;
}
message PunchHole {
bytes socket_addr = 1;
string relay_server = 2;
NatType nat_type = 3;
}
message TestNatRequest {
int32 serial = 1;
}
// per my test, uint/int has no difference in encoding, int not good for negative, use sint for negative
message TestNatResponse {
int32 port = 1;
ConfigUpdate cu = 2; // for mobile
}
enum NatType {
UNKNOWN_NAT = 0;
ASYMMETRIC = 1;
SYMMETRIC = 2;
}
message PunchHoleSent {
bytes socket_addr = 1;
string id = 2;
string relay_server = 3;
NatType nat_type = 4;
}
message RegisterPk {
string id = 1;
bytes uuid = 2;
bytes pk = 3;
}
message RegisterPkResponse {
enum Result {
OK = 1;
UUID_MISMATCH = 2;
}
Result result = 1;
}
message PunchHoleResponse {
bytes socket_addr = 1;
bytes pk = 2;
enum Failure {
ID_NOT_EXIST = 1;
OFFLINE = 2;
}
Failure failure = 3;
string relay_server = 4;
oneof union {
NatType nat_type = 5;
bool is_local = 6;
}
}
message ConfigUpdate {
int32 serial = 1;
repeated string rendezvous_servers = 2;
}
message RequestRelay {
string id = 1;
string uuid = 2;
bytes socket_addr = 3;
string relay_server = 4;
bool secure = 5;
}
message RelayResponse {
bytes socket_addr = 1;
string uuid = 2;
string relay_server = 3;
oneof union {
string id = 4;
bytes pk = 5;
}
}
message SoftwareUpdate { string url = 1; }
// if in same intranet, punch hole won't work both for udp and tcp,
// even some router has below connection error if we connect itself,
// { kind: Other, error: "could not resolve to any address" },
// so we request local address to connect.
message FetchLocalAddr {
bytes socket_addr = 1;
string relay_server = 2;
}
message LocalAddr {
bytes socket_addr = 1;
bytes local_addr = 2;
string relay_server = 3;
string id = 4;
}
message RendezvousMessage {
oneof union {
RegisterPeer register_peer = 6;
RegisterPeerResponse register_peer_response = 7;
PunchHoleRequest punch_hole_request = 8;
PunchHole punch_hole = 9;
PunchHoleSent punch_hole_sent = 10;
PunchHoleResponse punch_hole_response = 11;
FetchLocalAddr fetch_local_addr = 12;
LocalAddr local_addr = 13;
ConfigUpdate configure_update = 14;
RegisterPk register_pk = 15;
RegisterPkResponse register_pk_response = 16;
SoftwareUpdate software_update = 17;
RequestRelay request_relay = 18;
RelayResponse relay_response = 19;
TestNatRequest test_nat_request = 20;
TestNatResponse test_nat_response = 21;
}
}

View File

@ -0,0 +1,274 @@
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::io;
use tokio_util::codec::{Decoder, Encoder};
#[derive(Debug, Clone, Copy)]
pub struct BytesCodec {
state: DecodeState,
raw: bool,
max_packet_length: usize,
}
#[derive(Debug, Clone, Copy)]
enum DecodeState {
Head,
Data(usize),
}
impl BytesCodec {
pub fn new() -> Self {
Self {
state: DecodeState::Head,
raw: false,
max_packet_length: usize::MAX,
}
}
pub fn set_raw(&mut self) {
self.raw = true;
}
pub fn set_max_packet_length(&mut self, n: usize) {
self.max_packet_length = n;
}
fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
if src.is_empty() {
return Ok(None);
}
let head_len = ((src[0] & 0x3) + 1) as usize;
if src.len() < head_len {
return Ok(None);
}
let mut n = src[0] as usize;
if head_len > 1 {
n |= (src[1] as usize) << 8;
}
if head_len > 2 {
n |= (src[2] as usize) << 16;
}
if head_len > 3 {
n |= (src[3] as usize) << 24;
}
n >>= 2;
if n > self.max_packet_length {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Too big packet"));
}
src.advance(head_len);
src.reserve(n);
return Ok(Some(n));
}
fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
if src.len() < n {
return Ok(None);
}
Ok(Some(src.split_to(n)))
}
}
impl Decoder for BytesCodec {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
if self.raw {
if !src.is_empty() {
let len = src.len();
return Ok(Some(src.split_to(len)));
} else {
return Ok(None);
}
}
let n = match self.state {
DecodeState::Head => match self.decode_head(src)? {
Some(n) => {
self.state = DecodeState::Data(n);
n
}
None => return Ok(None),
},
DecodeState::Data(n) => n,
};
match self.decode_data(n, src)? {
Some(data) => {
self.state = DecodeState::Head;
Ok(Some(data))
}
None => Ok(None),
}
}
}
impl Encoder<Bytes> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
if self.raw {
buf.reserve(data.len());
buf.put(data);
return Ok(());
}
if data.len() <= 0x3F {
buf.put_u8((data.len() << 2) as u8);
} else if data.len() <= 0x3FFF {
buf.put_u16_le((data.len() << 2) as u16 | 0x1);
} else if data.len() <= 0x3FFFFF {
let h = (data.len() << 2) as u32 | 0x2;
buf.put_u16_le((h & 0xFFFF) as u16);
buf.put_u8((h >> 16) as u8);
} else if data.len() <= 0x3FFFFFFF {
buf.put_u32_le((data.len() << 2) as u32 | 0x3);
} else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Overflow"));
}
buf.extend(data);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_codec1() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3F, 1);
assert!(!codec.encode(bytes.into(), &mut buf).is_err());
let buf_saved = buf.clone();
assert_eq!(buf.len(), 0x3F + 1);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3F);
assert_eq!(res[0], 1);
} else {
assert!(false);
}
let mut codec2 = BytesCodec::new();
let mut buf2 = BytesMut::new();
if let Ok(None) = codec2.decode(&mut buf2) {
} else {
assert!(false);
}
buf2.extend(&buf_saved[0..1]);
if let Ok(None) = codec2.decode(&mut buf2) {
} else {
assert!(false);
}
buf2.extend(&buf_saved[1..]);
if let Ok(Some(res)) = codec2.decode(&mut buf2) {
assert_eq!(res.len(), 0x3F);
assert_eq!(res[0], 1);
} else {
assert!(false);
}
}
#[test]
fn test_codec2() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
assert!(!codec.encode("".into(), &mut buf).is_err());
assert_eq!(buf.len(), 1);
bytes.resize(0x3F + 1, 2);
assert!(!codec.encode(bytes.into(), &mut buf).is_err());
assert_eq!(buf.len(), 0x3F + 2 + 2);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0);
} else {
assert!(false);
}
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3F + 1);
assert_eq!(res[0], 2);
} else {
assert!(false);
}
}
#[test]
fn test_codec3() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3F - 1, 3);
assert!(!codec.encode(bytes.into(), &mut buf).is_err());
assert_eq!(buf.len(), 0x3F + 1 - 1);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3F - 1);
assert_eq!(res[0], 3);
} else {
assert!(false);
}
}
#[test]
fn test_codec4() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3FFF, 4);
assert!(!codec.encode(bytes.into(), &mut buf).is_err());
assert_eq!(buf.len(), 0x3FFF + 2);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3FFF);
assert_eq!(res[0], 4);
} else {
assert!(false);
}
}
#[test]
fn test_codec5() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3FFFFF, 5);
assert!(!codec.encode(bytes.into(), &mut buf).is_err());
assert_eq!(buf.len(), 0x3FFFFF + 3);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3FFFFF);
assert_eq!(res[0], 5);
} else {
assert!(false);
}
}
#[test]
fn test_codec6() {
let mut codec = BytesCodec::new();
let mut buf = BytesMut::new();
let mut bytes: Vec<u8> = Vec::new();
bytes.resize(0x3FFFFF + 1, 6);
assert!(!codec.encode(bytes.into(), &mut buf).is_err());
let buf_saved = buf.clone();
assert_eq!(buf.len(), 0x3FFFFF + 4 + 1);
if let Ok(Some(res)) = codec.decode(&mut buf) {
assert_eq!(res.len(), 0x3FFFFF + 1);
assert_eq!(res[0], 6);
} else {
assert!(false);
}
let mut codec2 = BytesCodec::new();
let mut buf2 = BytesMut::new();
buf2.extend(&buf_saved[0..1]);
if let Ok(None) = codec2.decode(&mut buf2) {
} else {
assert!(false);
}
buf2.extend(&buf_saved[1..6]);
if let Ok(None) = codec2.decode(&mut buf2) {
} else {
assert!(false);
}
buf2.extend(&buf_saved[6..]);
if let Ok(Some(res)) = codec2.decode(&mut buf2) {
assert_eq!(res.len(), 0x3FFFFF + 1);
assert_eq!(res[0], 6);
} else {
assert!(false);
}
}
}

View File

@ -0,0 +1,50 @@
use std::cell::RefCell;
use zstd::block::{Compressor, Decompressor};
thread_local! {
static COMPRESSOR: RefCell<Compressor> = RefCell::new(Compressor::new());
static DECOMPRESSOR: RefCell<Decompressor> = RefCell::new(Decompressor::new());
}
/// The library supports regular compression levels from 1 up to ZSTD_maxCLevel(),
/// which is currently 22. Levels >= 20
/// Default level is ZSTD_CLEVEL_DEFAULT==3.
/// value 0 means default, which is controlled by ZSTD_CLEVEL_DEFAULT
pub fn compress(data: &[u8], level: i32) -> Vec<u8> {
let mut out = Vec::new();
COMPRESSOR.with(|c| {
if let Ok(mut c) = c.try_borrow_mut() {
match c.compress(data, level) {
Ok(res) => out = res,
Err(err) => {
crate::log::debug!("Failed to compress: {}", err);
}
}
}
});
out
}
pub fn decompress(data: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
DECOMPRESSOR.with(|d| {
if let Ok(mut d) = d.try_borrow_mut() {
const MAX: usize = 1024 * 1024 * 64;
const MIN: usize = 1024 * 1024;
let mut n = 30 * data.len();
if n > MAX {
n = MAX;
}
if n < MIN {
n = MIN;
}
match d.decompress(data, n) {
Ok(res) => out = res,
Err(err) => {
crate::log::debug!("Failed to decompress: {}", err);
}
}
}
});
out
}

View File

@ -0,0 +1,698 @@
use crate::log;
use directories_next::ProjectDirs;
use rand::Rng;
use serde_derive::{Deserialize, Serialize};
use sodiumoxide::crypto::sign;
use std::{
collections::HashMap,
fs,
net::SocketAddr,
path::{Path, PathBuf},
sync::{Arc, Mutex, RwLock},
time::SystemTime,
};
pub const APP_NAME: &str = "RustDesk";
pub const BIND_INTERFACE: &str = "0.0.0.0";
pub const RENDEZVOUS_TIMEOUT: u64 = 12_000;
pub const CONNECT_TIMEOUT: u64 = 18_000;
pub const COMPRESS_LEVEL: i32 = 3;
const SERIAL: i32 = 0;
// 128x128
#[cfg(target_os = "macos")] // 128x128 on 160x160 canvas, then shrink to 128, mac looks better with padding
pub const ICON: &str = "
";
#[cfg(windows)] // windows, 32x32, bigger very ugly after shrink
pub const ICON: &str = "
";
#[cfg(target_os = "linux")] // 128x128 no padding
pub const ICON: &str = "
";
#[cfg(target_os = "macos")]
pub const ORG: &str = "com.carriez";
type Size = (i32, i32, i32, i32);
lazy_static::lazy_static! {
static ref CONFIG: Arc<RwLock<Config>> = Arc::new(RwLock::new(Config::load()));
static ref CONFIG2: Arc<RwLock<Config2>> = Arc::new(RwLock::new(Config2::load()));
pub static ref ONLINE: Arc<Mutex<HashMap<String, i64>>> = Default::default();
}
#[cfg(any(target_os = "android", target_os = "ios"))]
lazy_static::lazy_static! {
pub static ref APP_DIR: Arc<RwLock<String>> = Default::default();
}
const CHARS: &'static [char] = &[
'2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k',
'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
];
pub const RENDEZVOUS_SERVERS: &'static [&'static str] = &[
"rs-sg.rustdesk.com",
"rs-cn.rustdesk.com",
];
pub const RENDEZVOUS_PORT: i32 = 21116;
pub const RELAY_PORT: i32 = 21117;
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct Config {
#[serde(default)]
id: String,
#[serde(default)]
password: String,
#[serde(default)]
salt: String,
#[serde(default)]
key_pair: (Vec<u8>, Vec<u8>), // sk, pk
#[serde(default)]
key_confirmed: bool,
#[serde(default)]
keys_confirmed: HashMap<String, bool>,
}
// more variable configs
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct Config2 {
#[serde(default)]
remote_id: String, // latest used one
#[serde(default)]
size: Size,
#[serde(default)]
rendezvous_server: String,
#[serde(default)]
nat_type: i32,
#[serde(default)]
serial: i32,
// the other scalar value must before this
#[serde(default)]
pub options: HashMap<String, String>,
}
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct PeerConfig {
#[serde(default)]
pub password: Vec<u8>,
#[serde(default)]
pub size: Size,
#[serde(default)]
pub size_ft: Size,
#[serde(default)]
pub size_pf: Size,
#[serde(default)]
pub view_style: String, // original (default), scale
#[serde(default)]
pub image_quality: String,
#[serde(default)]
pub custom_image_quality: Vec<i32>,
#[serde(default)]
pub show_remote_cursor: bool,
#[serde(default)]
pub lock_after_session_end: bool,
#[serde(default)]
pub privacy_mode: bool,
#[serde(default)]
pub port_forwards: Vec<(i32, String, i32)>,
#[serde(default)]
pub direct_failures: i32,
#[serde(default)]
pub disable_audio: bool,
#[serde(default)]
pub disable_clipboard: bool,
// the other scalar value must before this
#[serde(default)]
pub options: HashMap<String, String>,
#[serde(default)]
pub info: PeerInfoSerde,
}
#[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)]
pub struct PeerInfoSerde {
#[serde(default)]
pub username: String,
#[serde(default)]
pub hostname: String,
#[serde(default)]
pub platform: String,
}
fn patch(path: PathBuf) -> PathBuf {
if let Some(_tmp) = path.to_str() {
#[cfg(windows)]
return _tmp
.replace(
"system32\\config\\systemprofile",
"ServiceProfiles\\LocalService",
)
.into();
#[cfg(target_os = "macos")]
return _tmp.replace("Application Support", "Preferences").into();
#[cfg(target_os = "linux")]
{
if _tmp == "/root" {
if let Ok(output) = std::process::Command::new("whoami").output() {
let user = String::from_utf8_lossy(&output.stdout).to_string().trim().to_owned();
if user != "root" {
return format!("/home/{}", user).into();
}
}
}
}
}
path
}
impl Config2 {
fn load() -> Config2 {
Config::load_::<Config2>("2")
}
fn store(&self) {
Config::store_(self, "2");
}
}
impl Config {
fn load_<T: serde::Serialize + serde::de::DeserializeOwned + Default + std::fmt::Debug>(
suffix: &str,
) -> T {
let file = Self::file_(suffix);
log::debug!("Configuration path: {}", file.display());
let cfg = match confy::load_path(&file) {
Ok(config) => config,
Err(err) => {
log::error!("Failed to load config: {}", err);
T::default()
}
};
if suffix.is_empty() {
log::debug!("{:?}", cfg);
}
cfg
}
fn store_<T: serde::Serialize>(config: &T, suffix: &str) {
let file = Self::file_(suffix);
if let Err(err) = confy::store_path(file, config) {
log::error!("Failed to store config: {}", err);
}
}
fn load() -> Config {
Config::load_::<Config>("")
}
fn store(&self) {
Config::store_(self, "");
}
pub fn file() -> PathBuf {
Self::file_("")
}
pub fn import(from: &str) {
log::info!("import {}", from);
// load first to create path
Self::load();
crate::allow_err!(std::fs::copy(from, Self::file()));
crate::allow_err!(std::fs::copy(
from.replace(".toml", "2.toml"),
Self::file_("2")
));
}
pub fn save_tmp() -> String {
let _lock = CONFIG.read().unwrap(); // do not use let _, which will be dropped immediately
let path = Self::file_("2").to_str().unwrap_or("").to_owned();
let path2 = format!("{}_tmp", path);
crate::allow_err!(std::fs::copy(&path, &path2));
let path = Self::file().to_str().unwrap_or("").to_owned();
let path2 = format!("{}_tmp", path);
crate::allow_err!(std::fs::copy(&path, &path2));
path2
}
fn file_(suffix: &str) -> PathBuf {
let name = format!("{}{}", APP_NAME, suffix);
Self::path(name).with_extension("toml")
}
pub fn get_home() -> PathBuf {
#[cfg(any(target_os = "android", target_os = "ios"))]
return Self::path("");
if let Some(path) = dirs_next::home_dir() {
patch(path)
} else if let Ok(path) = std::env::current_dir() {
path
} else {
std::env::temp_dir()
}
}
fn path<P: AsRef<Path>>(p: P) -> PathBuf {
#[cfg(any(target_os = "android", target_os = "ios"))]
{
let mut path: PathBuf = APP_DIR.read().unwrap().clone().into();
path.push(p);
return path;
}
#[cfg(not(target_os = "macos"))]
let org = "";
#[cfg(target_os = "macos")]
let org = ORG;
// /var/root for root
if let Some(project) = ProjectDirs::from("", org, APP_NAME) {
let mut path = patch(project.config_dir().to_path_buf());
path.push(p);
return path;
}
return "".into();
}
pub fn log_path() -> PathBuf {
#[cfg(target_os = "macos")]
{
if let Some(path) = dirs_next::home_dir().as_mut() {
path.push(format!("Library/Logs/{}", APP_NAME));
return path.clone();
}
}
#[cfg(target_os = "linux")]
{
let mut path = Self::get_home();
path.push(format!(".local/share/logs/{}", APP_NAME));
std::fs::create_dir_all(&path).ok();
return path;
}
if let Some(path) = Self::path("").parent() {
let mut path: PathBuf = path.into();
path.push("log");
return path;
}
"".into()
}
pub fn ipc_path(postfix: &str) -> String {
#[cfg(windows)]
{
// \\ServerName\pipe\PipeName
// where ServerName is either the name of a remote computer or a period, to specify the local computer.
// https://docs.microsoft.com/en-us/windows/win32/ipc/pipe-names
format!("\\\\.\\pipe\\{}\\query{}", APP_NAME, postfix)
}
#[cfg(not(windows))]
{
use std::os::unix::fs::PermissionsExt;
let mut path: PathBuf = format!("/tmp/{}", APP_NAME).into();
fs::create_dir(&path).ok();
fs::set_permissions(&path, fs::Permissions::from_mode(0o0777)).ok();
path.push(format!("ipc{}", postfix));
path.to_str().unwrap_or("").to_owned()
}
}
pub fn icon_path() -> PathBuf {
let mut path = Self::path("icons");
if fs::create_dir_all(&path).is_err() {
path = std::env::temp_dir();
}
path
}
#[inline]
pub fn get_any_listen_addr() -> SocketAddr {
format!("{}:0", BIND_INTERFACE).parse().unwrap()
}
pub fn get_rendezvous_server() -> SocketAddr {
let mut rendezvous_server = Self::get_option("custom-rendezvous-server");
if rendezvous_server.is_empty() {
rendezvous_server = CONFIG2.write().unwrap().rendezvous_server.clone();
}
if rendezvous_server.is_empty() {
rendezvous_server = Self::get_rendezvous_servers()
.drain(..)
.next()
.unwrap_or("".to_owned());
}
if !rendezvous_server.contains(":") {
rendezvous_server = format!("{}:{}", rendezvous_server, RENDEZVOUS_PORT);
}
if let Ok(addr) = crate::to_socket_addr(&rendezvous_server) {
addr
} else {
Self::get_any_listen_addr()
}
}
pub fn get_rendezvous_servers() -> Vec<String> {
let s = Self::get_option("custom-rendezvous-server");
if !s.is_empty() {
return vec![s];
}
let serial_obsolute = CONFIG2.read().unwrap().serial > SERIAL;
if serial_obsolute {
let ss: Vec<String> = Self::get_option("rendezvous-servers")
.split(",")
.filter(|x| x.contains("."))
.map(|x| x.to_owned())
.collect();
if !ss.is_empty() {
return ss;
}
}
return RENDEZVOUS_SERVERS.iter().map(|x| x.to_string()).collect();
}
pub fn reset_online() {
*ONLINE.lock().unwrap() = Default::default();
}
pub fn update_latency(host: &str, latency: i64) {
ONLINE.lock().unwrap().insert(host.to_owned(), latency);
let mut host = "".to_owned();
let mut delay = i64::MAX;
for (tmp_host, tmp_delay) in ONLINE.lock().unwrap().iter() {
if tmp_delay > &0 && tmp_delay < &delay {
delay = tmp_delay.clone();
host = tmp_host.to_string();
}
}
if !host.is_empty() {
let mut config = CONFIG2.write().unwrap();
if host != config.rendezvous_server {
log::debug!("Update rendezvous_server in config to {}", host);
log::debug!("{:?}", *ONLINE.lock().unwrap());
config.rendezvous_server = host;
config.store();
}
}
}
pub fn set_id(id: &str) {
let mut config = CONFIG.write().unwrap();
if id == config.id {
return;
}
config.id = id.into();
config.store();
}
pub fn set_nat_type(nat_type: i32) {
let mut config = CONFIG2.write().unwrap();
if nat_type == config.nat_type {
return;
}
config.nat_type = nat_type;
config.store();
}
pub fn get_nat_type() -> i32 {
CONFIG2.read().unwrap().nat_type
}
pub fn set_serial(serial: i32) {
let mut config = CONFIG2.write().unwrap();
if serial == config.serial {
return;
}
config.serial = serial;
config.store();
}
pub fn get_serial() -> i32 {
std::cmp::max(CONFIG2.read().unwrap().serial, SERIAL)
}
fn get_auto_id() -> Option<String> {
#[cfg(any(target_os = "android", target_os = "ios"))]
return None;
let mut id = 0u32;
#[cfg(not(any(target_os = "android", target_os = "ios")))]
if let Ok(Some(ma)) = mac_address::get_mac_address() {
for x in &ma.bytes()[2..] {
id = (id << 8) | (*x as u32);
}
id = id & 0x1FFFFFFF;
Some(id.to_string())
} else {
None
}
}
pub fn get_auto_password() -> String {
let mut rng = rand::thread_rng();
(0..6)
.map(|_| CHARS[rng.gen::<usize>() % CHARS.len()])
.collect()
}
pub fn get_key_confirmed() -> bool {
CONFIG.read().unwrap().key_confirmed
}
pub fn set_key_confirmed(v: bool) {
let mut config = CONFIG.write().unwrap();
if config.key_confirmed == v {
return;
}
config.key_confirmed = v;
if !v {
config.keys_confirmed = Default::default();
}
config.store();
}
pub fn get_host_key_confirmed(host: &str) -> bool {
if let Some(true) = CONFIG.read().unwrap().keys_confirmed.get(host) {
true
} else {
false
}
}
pub fn set_host_key_confirmed(host: &str, v: bool) {
if Self::get_host_key_confirmed(host) == v {
return;
}
let mut config = CONFIG.write().unwrap();
config.keys_confirmed.insert(host.to_owned(), v);
config.store();
}
pub fn set_key_pair(pair: (Vec<u8>, Vec<u8>)) {
let mut config = CONFIG.write().unwrap();
config.key_pair = pair;
config.store();
}
pub fn get_key_pair() -> (Vec<u8>, Vec<u8>) {
// lock here to make sure no gen_keypair more than once
let mut config = CONFIG.write().unwrap();
if config.key_pair.0.is_empty() {
let (pk, sk) = sign::gen_keypair();
config.key_pair = (sk.0.to_vec(), pk.0.into());
config.store();
}
config.key_pair.clone()
}
pub fn get_id() -> String {
let mut id = CONFIG.read().unwrap().id.clone();
if id.is_empty() {
if let Some(tmp) = Config::get_auto_id() {
id = tmp;
Config::set_id(&id);
}
}
id
}
pub fn get_options() -> HashMap<String, String> {
CONFIG2.read().unwrap().options.clone()
}
pub fn set_options(v: HashMap<String, String>) {
let mut config = CONFIG2.write().unwrap();
config.options = v;
config.store();
}
pub fn get_option(k: &str) -> String {
if let Some(v) = CONFIG2.read().unwrap().options.get(k) {
v.clone()
} else {
"".to_owned()
}
}
pub fn set_option(k: String, v: String) {
let mut config = CONFIG2.write().unwrap();
if k == "custom-rendezvous-server" {
config.rendezvous_server = "".to_owned();
}
let v2 = if v.is_empty() { None } else { Some(&v) };
if v2 != config.options.get(&k) {
if v2.is_none() {
config.options.remove(&k);
} else {
config.options.insert(k, v);
}
config.store();
}
}
pub fn update_id() {
// to-do: how about if one ip register a lot of ids?
let id = Self::get_id();
let mut rng = rand::thread_rng();
let new_id = rng.gen_range(1_000_000_000, 2_000_000_000).to_string();
Config::set_id(&new_id);
log::info!("id updated from {} to {}", id, new_id);
}
pub fn set_password(password: &str) {
let mut config = CONFIG.write().unwrap();
if password == config.password {
return;
}
config.password = password.into();
config.store();
}
pub fn get_password() -> String {
let mut password = CONFIG.read().unwrap().password.clone();
if password.is_empty() {
password = Config::get_auto_password();
Config::set_password(&password);
}
password
}
pub fn set_salt(salt: &str) {
let mut config = CONFIG.write().unwrap();
if salt == config.salt {
return;
}
config.salt = salt.into();
config.store();
}
pub fn get_salt() -> String {
let mut salt = CONFIG.read().unwrap().salt.clone();
if salt.is_empty() {
salt = Config::get_auto_password();
Config::set_salt(&salt);
}
salt
}
pub fn get_size() -> Size {
CONFIG2.read().unwrap().size
}
pub fn set_size(x: i32, y: i32, w: i32, h: i32) {
let mut config = CONFIG2.write().unwrap();
let size = (x, y, w, h);
if size == config.size || size.2 < 300 || size.3 < 300 {
return;
}
config.size = size;
config.store();
}
pub fn set_remote_id(remote_id: &str) {
let mut config = CONFIG2.write().unwrap();
if remote_id == config.remote_id {
return;
}
config.remote_id = remote_id.into();
config.store();
}
pub fn get_remote_id() -> String {
CONFIG2.read().unwrap().remote_id.clone()
}
}
const PEERS: &str = "peers";
impl PeerConfig {
pub fn load(id: &str) -> PeerConfig {
let _ = CONFIG.read().unwrap(); // for lock
match confy::load_path(&Self::path(id)) {
Ok(config) => config,
Err(err) => {
log::error!("Failed to load config: {}", err);
Default::default()
}
}
}
pub fn store(&self, id: &str) {
let _ = CONFIG.read().unwrap(); // for lock
if let Err(err) = confy::store_path(Self::path(id), self) {
log::error!("Failed to store config: {}", err);
}
}
pub fn remove(id: &str) {
fs::remove_file(&Self::path(id)).ok();
}
fn path(id: &str) -> PathBuf {
let path: PathBuf = [PEERS, id].iter().collect();
Config::path(path).with_extension("toml")
}
pub fn peers() -> Vec<(String, SystemTime, PeerInfoSerde)> {
if let Ok(peers) = Config::path(PEERS).read_dir() {
if let Ok(peers) = peers
.map(|res| res.map(|e| e.path()))
.collect::<Result<Vec<_>, _>>()
{
let mut peers: Vec<_> = peers
.iter()
.filter(|p| {
p.is_file()
&& p.extension().map(|p| p.to_str().unwrap_or("")) == Some("toml")
})
.map(|p| {
let t = fs::metadata(p)
.map(|m| m.modified().unwrap_or(SystemTime::UNIX_EPOCH))
.unwrap_or(SystemTime::UNIX_EPOCH);
let id = p
.file_stem()
.map(|p| p.to_str().unwrap_or(""))
.unwrap_or("")
.to_owned();
let info = PeerConfig::load(&id).info;
if info.platform.is_empty() {
fs::remove_file(&p).ok();
}
(id, t, info)
})
.filter(|p| !p.2.platform.is_empty())
.collect();
peers.sort_unstable_by(|a, b| b.1.cmp(&a.1));
return peers;
}
}
Default::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize() {
let cfg: Config = Default::default();
let res = toml::to_string_pretty(&cfg);
assert!(res.is_ok());
let cfg: PeerConfig = Default::default();
let res = toml::to_string_pretty(&cfg);
assert!(res.is_ok());
}
}

554
libs/hbb_common/src/fs.rs Normal file
View File

@ -0,0 +1,554 @@
use crate::{bail, message_proto::*, ResultType};
use std::path::{Path, PathBuf};
// https://doc.rust-lang.org/std/os/windows/fs/trait.MetadataExt.html
use crate::{
compress::{compress, decompress},
config::{Config, COMPRESS_LEVEL},
};
#[cfg(windows)]
use std::os::windows::prelude::*;
use tokio::{fs::File, prelude::*};
pub fn read_dir(path: &PathBuf, include_hidden: bool) -> ResultType<FileDirectory> {
let mut dir = FileDirectory {
path: get_string(&path),
..Default::default()
};
#[cfg(windows)]
if "/" == &get_string(&path) {
let drives = unsafe { winapi::um::fileapi::GetLogicalDrives() };
for i in 0..32 {
if drives & (1 << i) != 0 {
let name = format!(
"{}:",
std::char::from_u32('A' as u32 + i as u32).unwrap_or('A')
);
dir.entries.push(FileEntry {
name,
entry_type: FileType::DirDrive.into(),
..Default::default()
});
}
}
return Ok(dir);
}
for entry in path.read_dir()? {
if let Ok(entry) = entry {
let p = entry.path();
let name = p
.file_name()
.map(|p| p.to_str().unwrap_or(""))
.unwrap_or("")
.to_owned();
if name.is_empty() {
continue;
}
let mut is_hidden = false;
let meta;
if let Ok(tmp) = std::fs::symlink_metadata(&p) {
meta = tmp;
} else {
continue;
}
// docs.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants
#[cfg(windows)]
if meta.file_attributes() & 0x2 != 0 {
is_hidden = true;
}
#[cfg(not(windows))]
if name.find('.').unwrap_or(usize::MAX) == 0 {
is_hidden = true;
}
if is_hidden && !include_hidden {
continue;
}
let (entry_type, size) = {
if p.is_dir() {
if meta.file_type().is_symlink() {
(FileType::DirLink.into(), 0)
} else {
(FileType::Dir.into(), 0)
}
} else {
if meta.file_type().is_symlink() {
(FileType::FileLink.into(), 0)
} else {
(FileType::File.into(), meta.len())
}
}
};
let modified_time = meta
.modified()
.map(|x| {
x.duration_since(std::time::SystemTime::UNIX_EPOCH)
.map(|x| x.as_secs())
.unwrap_or(0)
})
.unwrap_or(0) as u64;
dir.entries.push(FileEntry {
name: get_file_name(&p),
entry_type,
is_hidden,
size,
modified_time,
..Default::default()
});
}
}
Ok(dir)
}
#[inline]
pub fn get_file_name(p: &PathBuf) -> String {
p.file_name()
.map(|p| p.to_str().unwrap_or(""))
.unwrap_or("")
.to_owned()
}
#[inline]
pub fn get_string(path: &PathBuf) -> String {
path.to_str().unwrap_or("").to_owned()
}
#[inline]
pub fn get_path(path: &str) -> PathBuf {
Path::new(path).to_path_buf()
}
#[inline]
pub fn get_home_as_string() -> String {
get_string(&Config::get_home())
}
fn read_dir_recursive(
path: &PathBuf,
prefix: &PathBuf,
include_hidden: bool,
) -> ResultType<Vec<FileEntry>> {
let mut files = Vec::new();
if path.is_dir() {
// to-do: symbol link handling, cp the link rather than the content
// to-do: file mode, for unix
let fd = read_dir(&path, include_hidden)?;
for entry in fd.entries.iter() {
match entry.entry_type.enum_value() {
Ok(FileType::File) => {
let mut entry = entry.clone();
entry.name = get_string(&prefix.join(entry.name));
files.push(entry);
}
Ok(FileType::Dir) => {
if let Ok(mut tmp) = read_dir_recursive(
&path.join(&entry.name),
&prefix.join(&entry.name),
include_hidden,
) {
for entry in tmp.drain(0..) {
files.push(entry);
}
}
}
_ => {}
}
}
Ok(files)
} else if path.is_file() {
let (size, modified_time) = if let Ok(meta) = std::fs::metadata(&path) {
(
meta.len(),
meta.modified()
.map(|x| {
x.duration_since(std::time::SystemTime::UNIX_EPOCH)
.map(|x| x.as_secs())
.unwrap_or(0)
})
.unwrap_or(0) as u64,
)
} else {
(0, 0)
};
files.push(FileEntry {
entry_type: FileType::File.into(),
size,
modified_time,
..Default::default()
});
Ok(files)
} else {
bail!("Not exists");
}
}
pub fn get_recursive_files(path: &str, include_hidden: bool) -> ResultType<Vec<FileEntry>> {
read_dir_recursive(&get_path(path), &get_path(""), include_hidden)
}
#[derive(Default)]
pub struct TransferJob {
id: i32,
path: PathBuf,
files: Vec<FileEntry>,
file_num: i32,
file: Option<File>,
total_size: u64,
finished_size: u64,
transfered: u64,
}
#[inline]
fn get_ext(name: &str) -> &str {
if let Some(i) = name.rfind(".") {
return &name[i + 1..];
}
""
}
#[inline]
fn is_compressed_file(name: &str) -> bool {
let ext = get_ext(name);
ext == "xz"
|| ext == "gz"
|| ext == "zip"
|| ext == "7z"
|| ext == "rar"
|| ext == "bz2"
|| ext == "tgz"
|| ext == "png"
|| ext == "jpg"
}
impl TransferJob {
pub fn new_write(id: i32, path: String, files: Vec<FileEntry>) -> Self {
let total_size = files.iter().map(|x| x.size as u64).sum();
Self {
id,
path: get_path(&path),
files,
total_size,
..Default::default()
}
}
pub fn new_read(id: i32, path: String, include_hidden: bool) -> ResultType<Self> {
let files = get_recursive_files(&path, include_hidden)?;
let total_size = files.iter().map(|x| x.size as u64).sum();
Ok(Self {
id,
path: get_path(&path),
files,
total_size,
..Default::default()
})
}
#[inline]
pub fn files(&self) -> &Vec<FileEntry> {
&self.files
}
#[inline]
pub fn set_files(&mut self, files: Vec<FileEntry>) {
self.files = files;
}
#[inline]
pub fn id(&self) -> i32 {
self.id
}
#[inline]
pub fn total_size(&self) -> u64 {
self.total_size
}
#[inline]
pub fn finished_size(&self) -> u64 {
self.finished_size
}
#[inline]
pub fn transfered(&self) -> u64 {
self.transfered
}
#[inline]
pub fn file_num(&self) -> i32 {
self.file_num
}
pub fn modify_time(&self) {
let file_num = self.file_num as usize;
if file_num < self.files.len() {
let entry = &self.files[file_num];
let path = self.join(&entry.name);
let download_path = format!("{}.download", get_string(&path));
std::fs::rename(&download_path, &path).ok();
filetime::set_file_mtime(
&path,
filetime::FileTime::from_unix_time(entry.modified_time as _, 0),
)
.ok();
}
}
pub fn remove_download_file(&self) {
let file_num = self.file_num as usize;
if file_num < self.files.len() {
let entry = &self.files[file_num];
let path = self.join(&entry.name);
let download_path = format!("{}.download", get_string(&path));
std::fs::remove_file(&download_path).ok();
}
}
pub async fn write(&mut self, block: FileTransferBlock) -> ResultType<()> {
if block.id != self.id {
bail!("Wrong id");
}
let file_num = block.file_num as usize;
if file_num >= self.files.len() {
bail!("Wrong file number");
}
if file_num != self.file_num as usize || self.file.is_none() {
self.modify_time();
if let Some(file) = self.file.as_mut() {
file.sync_all().await?;
}
self.file_num = block.file_num;
let entry = &self.files[file_num];
let path = self.join(&entry.name);
if let Some(p) = path.parent() {
std::fs::create_dir_all(p).ok();
}
let path = format!("{}.download", get_string(&path));
self.file = Some(File::create(&path).await?);
}
if block.compressed {
let tmp = decompress(&block.data);
self.file.as_mut().unwrap().write_all(&tmp).await?;
self.finished_size += tmp.len() as u64;
} else {
self.file.as_mut().unwrap().write_all(&block.data).await?;
self.finished_size += block.data.len() as u64;
}
self.transfered += block.data.len() as u64;
Ok(())
}
#[inline]
fn join(&self, name: &str) -> PathBuf {
if name.is_empty() {
self.path.clone()
} else {
self.path.join(name)
}
}
pub async fn read(&mut self) -> ResultType<Option<FileTransferBlock>> {
let file_num = self.file_num as usize;
if file_num >= self.files.len() {
self.file.take();
return Ok(None);
}
let name = &self.files[file_num].name;
if self.file.is_none() {
match File::open(self.join(&name)).await {
Ok(file) => {
self.file = Some(file);
}
Err(err) => {
self.file_num += 1;
return Err(err.into());
}
}
}
const BUF_SIZE: usize = 128 * 1024;
let mut buf: Vec<u8> = Vec::with_capacity(BUF_SIZE);
unsafe {
buf.set_len(BUF_SIZE);
}
let mut compressed = false;
let mut offset: usize = 0;
loop {
match self.file.as_mut().unwrap().read(&mut buf[offset..]).await {
Err(err) => {
self.file_num += 1;
self.file = None;
return Err(err.into());
}
Ok(n) => {
offset += n;
if n == 0 || offset == BUF_SIZE {
break;
}
}
}
}
unsafe { buf.set_len(offset) };
if offset == 0 {
self.file_num += 1;
self.file = None;
} else {
self.finished_size += offset as u64;
if !is_compressed_file(name) {
let tmp = compress(&buf, COMPRESS_LEVEL);
if tmp.len() < buf.len() {
buf = tmp;
compressed = true;
}
}
self.transfered += buf.len() as u64;
}
Ok(Some(FileTransferBlock {
id: self.id,
file_num: file_num as _,
data: buf.into(),
compressed,
..Default::default()
}))
}
}
#[inline]
pub fn new_error<T: std::string::ToString>(id: i32, err: T, file_num: i32) -> Message {
let mut resp = FileResponse::new();
resp.set_error(FileTransferError {
id,
error: err.to_string(),
file_num,
..Default::default()
});
let mut msg_out = Message::new();
msg_out.set_file_response(resp);
msg_out
}
#[inline]
pub fn new_dir(id: i32, files: Vec<FileEntry>) -> Message {
let mut resp = FileResponse::new();
resp.set_dir(FileDirectory {
id,
entries: files.into(),
..Default::default()
});
let mut msg_out = Message::new();
msg_out.set_file_response(resp);
msg_out
}
#[inline]
pub fn new_block(block: FileTransferBlock) -> Message {
let mut resp = FileResponse::new();
resp.set_block(block);
let mut msg_out = Message::new();
msg_out.set_file_response(resp);
msg_out
}
#[inline]
pub fn new_receive(id: i32, path: String, files: Vec<FileEntry>) -> Message {
let mut action = FileAction::new();
action.set_receive(FileTransferReceiveRequest {
id,
path,
files: files.into(),
..Default::default()
});
let mut msg_out = Message::new();
msg_out.set_file_action(action);
msg_out
}
#[inline]
pub fn new_send(id: i32, path: String, include_hidden: bool) -> Message {
let mut action = FileAction::new();
action.set_send(FileTransferSendRequest {
id,
path,
include_hidden,
..Default::default()
});
let mut msg_out = Message::new();
msg_out.set_file_action(action);
msg_out
}
#[inline]
pub fn new_done(id: i32, file_num: i32) -> Message {
let mut resp = FileResponse::new();
resp.set_done(FileTransferDone {
id,
file_num,
..Default::default()
});
let mut msg_out = Message::new();
msg_out.set_file_response(resp);
msg_out
}
#[inline]
pub fn remove_job(id: i32, jobs: &mut Vec<TransferJob>) {
*jobs = jobs.drain(0..).filter(|x| x.id() != id).collect();
}
#[inline]
pub fn get_job(id: i32, jobs: &mut Vec<TransferJob>) -> Option<&mut TransferJob> {
jobs.iter_mut().filter(|x| x.id() == id).next()
}
pub async fn handle_read_jobs(
jobs: &mut Vec<TransferJob>,
stream: &mut crate::Stream,
) -> ResultType<()> {
let mut finished = Vec::new();
for job in jobs.iter_mut() {
match job.read().await {
Err(err) => {
stream
.send(&new_error(job.id(), err, job.file_num()))
.await?;
}
Ok(Some(block)) => {
stream.send(&new_block(block)).await?;
}
Ok(None) => {
finished.push(job.id());
stream.send(&new_done(job.id(), job.file_num())).await?;
}
}
}
for id in finished {
remove_job(id, jobs);
}
Ok(())
}
pub fn remove_all_empty_dir(path: &PathBuf) -> ResultType<()> {
let fd = read_dir(path, true)?;
for entry in fd.entries.iter() {
match entry.entry_type.enum_value() {
Ok(FileType::Dir) => {
remove_all_empty_dir(&path.join(&entry.name)).ok();
}
Ok(FileType::DirLink) | Ok(FileType::FileLink) => {
std::fs::remove_file(&path.join(&entry.name)).ok();
}
_ => {}
}
}
std::fs::remove_dir(path).ok();
Ok(())
}
#[inline]
pub fn remove_file(file: &str) -> ResultType<()> {
std::fs::remove_file(get_path(file))?;
Ok(())
}
#[inline]
pub fn create_dir(dir: &str) -> ResultType<()> {
std::fs::create_dir_all(get_path(dir))?;
Ok(())
}

215
libs/hbb_common/src/lib.rs Normal file
View File

@ -0,0 +1,215 @@
pub mod compress;
#[path = "./protos/message.rs"]
pub mod message_proto;
#[path = "./protos/rendezvous.rs"]
pub mod rendezvous_proto;
pub use bytes;
pub use futures;
pub use protobuf;
use socket2::{Domain, Socket, Type};
use std::{
fs::File,
io::{self, BufRead},
net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs},
path::Path,
time::{self, SystemTime, UNIX_EPOCH},
};
pub use tokio;
pub use tokio_util;
pub mod tcp;
pub mod udp;
pub use env_logger;
pub use log;
pub mod bytes_codec;
#[cfg(feature = "quic")]
pub mod quic;
pub use anyhow::{self, bail};
pub use futures_util;
pub mod config;
pub mod fs;
pub use sodiumoxide;
#[cfg(feature = "quic")]
pub type Stream = quic::Connection;
#[cfg(not(feature = "quic"))]
pub type Stream = tcp::FramedStream;
#[inline]
pub async fn sleep(sec: f32) {
tokio::time::delay_for(time::Duration::from_secs_f32(sec)).await;
}
#[macro_export]
macro_rules! allow_err {
($e:expr) => {
if let Err(err) = $e {
log::debug!(
"{:?}, {}:{}:{}:{}",
err,
module_path!(),
file!(),
line!(),
column!()
);
} else {
}
};
}
#[inline]
pub fn timeout<T: std::future::Future>(ms: u64, future: T) -> tokio::time::Timeout<T> {
tokio::time::timeout(std::time::Duration::from_millis(ms), future)
}
fn new_socket(addr: SocketAddr, tcp: bool, reuse: bool) -> Result<Socket, std::io::Error> {
let stype = {
if tcp {
Type::stream()
} else {
Type::dgram()
}
};
let socket = match addr {
SocketAddr::V4(..) => Socket::new(Domain::ipv4(), stype, None),
SocketAddr::V6(..) => Socket::new(Domain::ipv6(), stype, None),
}?;
if reuse {
// windows has no reuse_port, but it's reuse_address
// almost equals to unix's reuse_port + reuse_address,
// though may introduce nondeterministic bahavior
#[cfg(unix)]
socket.set_reuse_port(true)?;
socket.set_reuse_address(true)?;
}
socket.bind(&addr.into())?;
Ok(socket)
}
pub type ResultType<F, E = anyhow::Error> = anyhow::Result<F, E>;
/// Certain router and firewalls scan the packet and if they
/// find an IP address belonging to their pool that they use to do the NAT mapping/translation, so here we mangle the ip address
pub struct AddrMangle();
impl AddrMangle {
pub fn encode(addr: SocketAddr) -> Vec<u8> {
match addr {
SocketAddr::V4(addr_v4) => {
let tm = (SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u32) as u128;
let ip = u32::from_ne_bytes(addr_v4.ip().octets()) as u128;
let port = addr.port() as u128;
let v = ((ip + tm) << 49) | (tm << 17) | (port + (tm & 0xFFFF));
let bytes = v.to_ne_bytes();
let mut n_padding = 0;
for i in bytes.iter().rev() {
if i == &0u8 {
n_padding += 1;
} else {
break;
}
}
bytes[..(16 - n_padding)].to_vec()
}
_ => {
panic!("Only support ipv4");
}
}
}
pub fn decode(bytes: &[u8]) -> SocketAddr {
let mut padded = [0u8; 16];
padded[..bytes.len()].copy_from_slice(&bytes);
let number = u128::from_ne_bytes(padded);
let tm = (number >> 17) & (u32::max_value() as u128);
let ip = (((number >> 49) - tm) as u32).to_ne_bytes();
let port = (number & 0xFFFFFF) - (tm & 0xFFFF);
SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]),
port as u16,
))
}
}
pub fn get_version_from_url(url: &str) -> String {
let n = url.chars().count();
let a = url
.chars()
.rev()
.enumerate()
.filter(|(_, x)| x == &'-')
.next()
.map(|(i, _)| i);
if let Some(a) = a {
let b = url
.chars()
.rev()
.enumerate()
.filter(|(_, x)| x == &'.')
.next()
.map(|(i, _)| i);
if let Some(b) = b {
if a > b {
if url
.chars()
.skip(n - b)
.collect::<String>()
.parse::<i32>()
.is_ok()
{
return url.chars().skip(n - a).collect();
} else {
return url.chars().skip(n - a).take(a - b - 1).collect();
}
} else {
return url.chars().skip(n - a).collect();
}
}
}
"".to_owned()
}
pub fn to_socket_addr(host: &str) -> ResultType<SocketAddr> {
let addrs: Vec<SocketAddr> = host.to_socket_addrs()?.collect();
if addrs.is_empty() {
bail!("Failed to solve {}", host);
}
Ok(addrs[0])
}
pub fn gen_version() {
let mut file = File::create("./src/version.rs").unwrap();
for line in read_lines("Cargo.toml").unwrap() {
if let Ok(line) = line {
let ab: Vec<&str> = line.split("=").map(|x| x.trim()).collect();
if ab.len() == 2 && ab[0] == "version" {
use std::io::prelude::*;
file.write_all(format!("pub const VERSION: &str = {};", ab[1]).as_bytes())
.ok();
file.sync_all().ok();
break;
}
}
}
}
fn read_lines<P>(filename: P) -> io::Result<io::Lines<io::BufReader<File>>>
where
P: AsRef<Path>,
{
let file = File::open(filename)?;
Ok(io::BufReader::new(file).lines())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mangle() {
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 16, 32), 21116));
assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr)));
}
}

135
libs/hbb_common/src/quic.rs Normal file
View File

@ -0,0 +1,135 @@
use crate::{allow_err, anyhow::anyhow, ResultType};
use protobuf::Message;
use std::{net::SocketAddr, sync::Arc};
use tokio::{self, stream::StreamExt, sync::mpsc};
const QUIC_HBB: &[&[u8]] = &[b"hbb"];
const SERVER_NAME: &str = "hbb";
type Sender = mpsc::UnboundedSender<Value>;
type Receiver = mpsc::UnboundedReceiver<Value>;
pub fn new_server(socket: std::net::UdpSocket) -> ResultType<(Server, SocketAddr)> {
let mut transport_config = quinn::TransportConfig::default();
transport_config.stream_window_uni(0);
let mut server_config = quinn::ServerConfig::default();
server_config.transport = Arc::new(transport_config);
let mut server_config = quinn::ServerConfigBuilder::new(server_config);
server_config.protocols(QUIC_HBB);
// server_config.enable_keylog();
// server_config.use_stateless_retry(true);
let mut endpoint = quinn::Endpoint::builder();
endpoint.listen(server_config.build());
let (end, incoming) = endpoint.with_socket(socket)?;
Ok((Server { incoming }, end.local_addr()?))
}
pub async fn new_client(local_addr: &SocketAddr, peer: &SocketAddr) -> ResultType<Connection> {
let mut endpoint = quinn::Endpoint::builder();
let mut client_config = quinn::ClientConfigBuilder::default();
client_config.protocols(QUIC_HBB);
//client_config.enable_keylog();
endpoint.default_client_config(client_config.build());
let (endpoint, _) = endpoint.bind(local_addr)?;
let new_conn = endpoint.connect(peer, SERVER_NAME)?.await?;
Connection::new_for_client(new_conn.connection).await
}
pub struct Server {
incoming: quinn::Incoming,
}
impl Server {
#[inline]
pub async fn next(&mut self) -> ResultType<Option<Connection>> {
Connection::new_for_server(&mut self.incoming).await
}
}
pub struct Connection {
conn: quinn::Connection,
tx: quinn::SendStream,
rx: Receiver,
}
type Value = ResultType<Vec<u8>>;
impl Connection {
async fn new_for_server(incoming: &mut quinn::Incoming) -> ResultType<Option<Self>> {
if let Some(conn) = incoming.next().await {
let quinn::NewConnection {
connection: conn,
// uni_streams,
mut bi_streams,
..
} = conn.await?;
let (tx, rx) = mpsc::unbounded_channel::<Value>();
tokio::spawn(async move {
loop {
let stream = bi_streams.next().await;
if let Some(stream) = stream {
let stream = match stream {
Err(e) => {
tx.send(Err(e.into())).ok();
break;
}
Ok(s) => s,
};
let cloned = tx.clone();
tokio::spawn(async move {
allow_err!(handle_request(stream.1, cloned).await);
});
} else {
tx.send(Err(anyhow!("Reset by the peer"))).ok();
break;
}
}
log::info!("Exit connection outer loop");
});
let tx = conn.open_uni().await?;
Ok(Some(Self { conn, tx, rx }))
} else {
Ok(None)
}
}
async fn new_for_client(conn: quinn::Connection) -> ResultType<Self> {
let (tx, rx_quic) = conn.open_bi().await?;
let (tx_mpsc, rx) = mpsc::unbounded_channel::<Value>();
tokio::spawn(async move {
allow_err!(handle_request(rx_quic, tx_mpsc).await);
});
Ok(Self { conn, tx, rx })
}
#[inline]
pub async fn next(&mut self) -> Option<Value> {
// None is returned when all Sender halves have dropped,
// indicating that no further values can be sent on the channel.
self.rx.recv().await
}
#[inline]
pub fn remote_address(&self) -> SocketAddr {
self.conn.remote_address()
}
#[inline]
pub async fn send_raw(&mut self, bytes: &[u8]) -> ResultType<()> {
self.tx.write_all(bytes).await?;
Ok(())
}
#[inline]
pub async fn send(&mut self, msg: &dyn Message) -> ResultType<()> {
match msg.write_to_bytes() {
Ok(bytes) => self.send_raw(&bytes).await?,
err => allow_err!(err),
}
Ok(())
}
}
async fn handle_request(rx: quinn::RecvStream, tx: Sender) -> ResultType<()> {
Ok(())
}

146
libs/hbb_common/src/tcp.rs Normal file
View File

@ -0,0 +1,146 @@
use crate::{bail, bytes_codec::BytesCodec, ResultType};
use bytes::{BufMut, Bytes, BytesMut};
use futures::SinkExt;
use protobuf::Message;
use sodiumoxide::crypto::secretbox::{self, Key, Nonce};
use std::{
io::{Error, ErrorKind},
ops::{Deref, DerefMut},
};
use tokio::{
net::{TcpListener, TcpStream, ToSocketAddrs},
stream::StreamExt,
};
use tokio_util::codec::Framed;
pub struct FramedStream(Framed<TcpStream, BytesCodec>, Option<(Key, u64, u64)>);
impl Deref for FramedStream {
type Target = Framed<TcpStream, BytesCodec>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for FramedStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl FramedStream {
pub async fn new<T: ToSocketAddrs, T2: ToSocketAddrs>(
remote_addr: T,
local_addr: T2,
ms_timeout: u64,
) -> ResultType<Self> {
for local_addr in local_addr.to_socket_addrs().await? {
for remote_addr in remote_addr.to_socket_addrs().await? {
if let Ok(stream) = super::timeout(
ms_timeout,
TcpStream::connect_std(
super::new_socket(local_addr, true, true)?.into_tcp_stream(),
&remote_addr,
),
)
.await?
{
return Ok(Self(Framed::new(stream, BytesCodec::new()), None));
}
}
}
bail!("could not resolve to any address");
}
pub fn from(stream: TcpStream) -> Self {
Self(Framed::new(stream, BytesCodec::new()), None)
}
pub fn set_raw(&mut self) {
self.0.codec_mut().set_raw();
self.1 = None;
}
pub fn is_secured(&self) -> bool {
self.1.is_some()
}
#[inline]
pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> {
self.send_raw(msg.write_to_bytes()?).await
}
#[inline]
pub async fn send_raw(&mut self, msg: Vec<u8>) -> ResultType<()> {
let mut msg = msg;
if let Some(key) = self.1.as_mut() {
key.1 += 1;
let nonce = Self::get_nonce(key.1);
msg = secretbox::seal(&msg, &nonce, &key.0);
}
self.0.send(bytes::Bytes::from(msg)).await?;
Ok(())
}
pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
self.0.send(bytes).await?;
Ok(())
}
#[inline]
pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
let mut res = self.0.next().await;
if let Some(key) = self.1.as_mut() {
if let Some(Ok(bytes)) = res.as_mut() {
key.2 += 1;
let nonce = Self::get_nonce(key.2);
match secretbox::open(&bytes, &nonce, &key.0) {
Ok(res) => {
bytes.clear();
bytes.put_slice(&res);
}
Err(()) => {
return Some(Err(Error::new(ErrorKind::Other, "decryption error")));
}
}
}
}
res
}
#[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
if let Ok(res) = super::timeout(ms, self.next()).await {
res
} else {
None
}
}
pub fn set_key(&mut self, key: Key) {
self.1 = Some((key, 0, 0));
}
fn get_nonce(seqnum: u64) -> Nonce {
let mut nonce = Nonce([0u8; secretbox::NONCEBYTES]);
nonce.0[..std::mem::size_of_val(&seqnum)].copy_from_slice(&seqnum.to_ne_bytes());
nonce
}
}
const DEFAULT_BACKLOG: i32 = 128;
#[allow(clippy::never_loop)]
pub async fn new_listener<T: ToSocketAddrs>(addr: T, reuse: bool) -> ResultType<TcpListener> {
if !reuse {
Ok(TcpListener::bind(addr).await?)
} else {
for addr in addr.to_socket_addrs().await? {
let socket = super::new_socket(addr, true, true)?;
socket.listen(DEFAULT_BACKLOG)?;
return Ok(TcpListener::from_std(socket.into_tcp_listener())?);
}
bail!("could not resolve to any address");
}
}

View File

@ -0,0 +1,75 @@
use crate::{bail, ResultType};
use bytes::BytesMut;
use futures::SinkExt;
use protobuf::Message;
use std::{
io::Error,
net::SocketAddr,
ops::{Deref, DerefMut},
};
use tokio::{net::ToSocketAddrs, net::UdpSocket, stream::StreamExt};
use tokio_util::{codec::BytesCodec, udp::UdpFramed};
pub struct FramedSocket(UdpFramed<BytesCodec>);
impl Deref for FramedSocket {
type Target = UdpFramed<BytesCodec>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for FramedSocket {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl FramedSocket {
pub async fn new<T: ToSocketAddrs>(addr: T) -> ResultType<Self> {
let socket = UdpSocket::bind(addr).await?;
Ok(Self(UdpFramed::new(socket, BytesCodec::new())))
}
#[allow(clippy::never_loop)]
pub async fn new_reuse<T: ToSocketAddrs>(addr: T) -> ResultType<Self> {
for addr in addr.to_socket_addrs().await? {
return Ok(Self(UdpFramed::new(
UdpSocket::from_std(super::new_socket(addr, false, true)?.into_udp_socket())?,
BytesCodec::new(),
)));
}
bail!("could not resolve to any address");
}
#[inline]
pub async fn send(&mut self, msg: &impl Message, addr: SocketAddr) -> ResultType<()> {
self.0
.send((bytes::Bytes::from(msg.write_to_bytes().unwrap()), addr))
.await?;
Ok(())
}
#[inline]
pub async fn send_raw(&mut self, msg: &'static [u8], addr: SocketAddr) -> ResultType<()> {
self.0.send((bytes::Bytes::from(msg), addr)).await?;
Ok(())
}
#[inline]
pub async fn next(&mut self) -> Option<Result<(BytesMut, SocketAddr), Error>> {
self.0.next().await
}
#[inline]
pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<(BytesMut, SocketAddr), Error>> {
if let Ok(res) =
tokio::time::timeout(std::time::Duration::from_millis(ms), self.0.next()).await
{
res
} else {
None
}
}
}

104
src/main.rs Normal file
View File

@ -0,0 +1,104 @@
use hbb_common::{
protobuf::Message as _,
rendezvous_proto::*,
tcp::{new_listener, FramedStream},
tokio,
udp::FramedSocket,
};
#[tokio::main(basic_scheduler)]
async fn main() {
let mut socket = FramedSocket::new("0.0.0.0:21116").await.unwrap();
let mut listener = new_listener("0.0.0.0:21116", false).await.unwrap();
let mut rlistener = new_listener("0.0.0.0:21117", false).await.unwrap();
let mut id_map = std::collections::HashMap::new();
let relay_server = std::env::var("IP").unwrap();
let mut saved_stream = None;
loop {
tokio::select! {
Some(Ok((bytes, addr))) = socket.next() => {
if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
match msg_in.union {
Some(rendezvous_message::Union::register_peer(rp)) => {
id_map.insert(rp.id, addr);
}
Some(rendezvous_message::Union::register_pk(_)) => {
let mut msg_out = RendezvousMessage::new();
msg_out.set_register_pk_response(RegisterPkResponse {
result: register_pk_response::Result::OK.into(),
..Default::default()
});
socket.send(&msg_out, addr).await.ok();
}
_ => {}
}
}
}
Ok((stream, _)) = listener.accept() => {
let mut stream = FramedStream::from(stream);
if let Some(Ok(bytes)) = stream.next_timeout(3000).await {
if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
match msg_in.union {
Some(rendezvous_message::Union::punch_hole_request(ph)) => {
if let Some(addr) = id_map.get(&ph.id) {
let mut msg_out = RendezvousMessage::new();
msg_out.set_request_relay(RequestRelay {
relay_server: relay_server.clone(),
..Default::default()
});
socket.send(&msg_out, addr.clone()).await.ok();
saved_stream = Some(stream);
}
}
Some(rendezvous_message::Union::relay_response(_)) => {
let mut msg_out = RendezvousMessage::new();
msg_out.set_relay_response(RelayResponse {
relay_server: relay_server.clone(),
..Default::default()
});
if let Some(mut stream) = saved_stream.take() {
stream.send(&msg_out).await.ok();
if let Ok((stream_a, _)) = rlistener.accept().await {
let mut stream_a = FramedStream::from(stream_a);
stream_a.next_timeout(3_000).await;
if let Ok((stream_b, _)) = rlistener.accept().await {
let mut stream_b = FramedStream::from(stream_b);
stream_b.next_timeout(3_000).await;
relay(stream_a, stream_b).await;
}
}
}
}
_ => {}
}
}
}
}
}
}
}
async fn relay(stream: FramedStream, peer: FramedStream) {
let mut peer = peer;
let mut stream = stream;
peer.set_raw();
stream.set_raw();
loop {
tokio::select! {
res = peer.next() => {
if let Some(Ok(bytes)) = res {
stream.send_bytes(bytes.into()).await.ok();
} else {
break;
}
},
res = stream.next() => {
if let Some(Ok(bytes)) = res {
peer.send_bytes(bytes.into()).await.ok();
} else {
break;
}
},
}
}
}