commit 49c6b24a7a8c39d4448e07b743007ef1a3febd43 Author: 21pages Date: Mon Jan 13 11:15:51 2025 +0800 init Signed-off-by: 21pages diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6936990 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +**/*.rs.bk +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..58f54ac --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,67 @@ +[package] +name = "hbb_common" +version = "0.1.0" +authors = ["open-trade "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +flexi_logger = { version = "0.27", features = ["async"] } +protobuf = { version = "3.4", features = ["with-bytes"] } +tokio = { version = "1.38", features = ["full"] } +tokio-util = { version = "0.7", features = ["full"] } +futures = "0.3" +bytes = { version = "1.6", features = ["serde"] } +log = "0.4" +env_logger = "0.10" +socket2 = { version = "0.3", features = ["reuseport"] } +zstd = "0.13" +anyhow = "1.0" +futures-util = "0.3" +directories-next = "2.0" +rand = "0.8" +serde_derive = "1.0" +serde = "1.0" +serde_json = "1.0" +lazy_static = "1.4" +confy = { git = "https://github.com/rustdesk-org/confy" } +dirs-next = "2.0" +filetime = "0.2" +sodiumoxide = "0.2" +regex = "1.8" +tokio-socks = { git = "https://github.com/rustdesk-org/tokio-socks" } +chrono = "0.4" +backtrace = "0.3" +libc = "0.2" +dlopen = "0.1" +toml = "0.7" +uuid = { version = "1.3", features = ["v4"] } +# new sysinfo issue: https://github.com/rustdesk/rustdesk/pull/6330#issuecomment-2270871442 +sysinfo = { git = "https://github.com/rustdesk-org/sysinfo", branch = "rlim_max" } +thiserror = "1.0" +httparse = "1.5" +base64 = "0.22" +url = "2.2" +sha2 = "0.10" + +[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] +mac_address = "1.1" +default_net = { git = "https://github.com/rustdesk-org/default_net" } +machine-uid = { git = "https://github.com/rustdesk-org/machine-uid" } +[target.'cfg(not(any(target_os = "macos", target_os = "windows")))'.dependencies] +tokio-rustls = { version = "0.26", features = ["logging", "tls12", "ring"], default-features = false } +rustls-platform-verifier = "0.3.1" +rustls-pki-types = "1.4" +[target.'cfg(any(target_os = "macos", target_os = "windows"))'.dependencies] +tokio-native-tls ="0.3" + +[build-dependencies] +protobuf-codegen = { version = "3.4" } + +[target.'cfg(target_os = "windows")'.dependencies] +winapi = { version = "0.3", features = ["winuser", "synchapi", "pdh", "memoryapi", "sysinfoapi"] } + +[target.'cfg(target_os = "macos")'.dependencies] +osascript = "0.3" + diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..5ebc3a2 --- /dev/null +++ b/build.rs @@ -0,0 +1,14 @@ +fn main() { + let out_dir = format!("{}/protos", std::env::var("OUT_DIR").unwrap()); + + std::fs::create_dir_all(&out_dir).unwrap(); + + protobuf_codegen::Codegen::new() + .pure() + .out_dir(out_dir) + .inputs(["protos/rendezvous.proto", "protos/message.proto"]) + .include("protos") + .customize(protobuf_codegen::Customize::default().tokio_bytes(true)) + .run() + .expect("Codegen failed."); +} diff --git a/examples/config.rs b/examples/config.rs new file mode 100644 index 0000000..95169df --- /dev/null +++ b/examples/config.rs @@ -0,0 +1,5 @@ +extern crate hbb_common; + +fn main() { + println!("{:?}", hbb_common::config::PeerConfig::load("455058072")); +} diff --git a/examples/system_message.rs b/examples/system_message.rs new file mode 100644 index 0000000..0be7884 --- /dev/null +++ b/examples/system_message.rs @@ -0,0 +1,20 @@ +extern crate hbb_common; +#[cfg(target_os = "linux")] +use hbb_common::platform::linux; +#[cfg(target_os = "macos")] +use hbb_common::platform::macos; + +fn main() { + #[cfg(target_os = "linux")] + let res = linux::system_message("test title", "test message", true); + #[cfg(target_os = "macos")] + let res = macos::alert( + "System Preferences".to_owned(), + "warning".to_owned(), + "test title".to_owned(), + "test message".to_owned(), + ["Ok".to_owned()].to_vec(), + ); + #[cfg(any(target_os = "linux", target_os = "macos"))] + println!("result {:?}", &res); +} diff --git a/protos/message.proto b/protos/message.proto new file mode 100644 index 0000000..d4601c0 --- /dev/null +++ b/protos/message.proto @@ -0,0 +1,861 @@ +syntax = "proto3"; +package hbb; + +message EncodedVideoFrame { + bytes data = 1; + bool key = 2; + int64 pts = 3; +} + +message EncodedVideoFrames { repeated EncodedVideoFrame frames = 1; } + +message RGB { bool compress = 1; } + +// planes data send directly in binary for better use arraybuffer on web +message YUV { + bool compress = 1; + int32 stride = 2; +} + +enum Chroma { + I420 = 0; + I444 = 1; +} + +message VideoFrame { + oneof union { + EncodedVideoFrames vp9s = 6; + RGB rgb = 7; + YUV yuv = 8; + EncodedVideoFrames h264s = 10; + EncodedVideoFrames h265s = 11; + EncodedVideoFrames vp8s = 12; + EncodedVideoFrames av1s = 13; + } + int32 display = 14; +} + +message IdPk { + string id = 1; + bytes pk = 2; +} + +message DisplayInfo { + sint32 x = 1; + sint32 y = 2; + int32 width = 3; + int32 height = 4; + string name = 5; + bool online = 6; + bool cursor_embedded = 7; + Resolution original_resolution = 8; + double scale = 9; +} + +message PortForward { + string host = 1; + int32 port = 2; +} + +message FileTransfer { + string dir = 1; + bool show_hidden = 2; +} + +message OSLogin { + string username = 1; + string password = 2; +} + +message LoginRequest { + string username = 1; + bytes password = 2; + string my_id = 4; + string my_name = 5; + OptionMessage option = 6; + oneof union { + FileTransfer file_transfer = 7; + PortForward port_forward = 8; + } + bool video_ack_required = 9; + uint64 session_id = 10; + string version = 11; + OSLogin os_login = 12; + string my_platform = 13; + bytes hwid = 14; +} + +message Auth2FA { + string code = 1; + bytes hwid = 2; +} + +message ChatMessage { string text = 1; } + +message Features { + bool privacy_mode = 1; +} + +message CodecAbility { + bool vp8 = 1; + bool vp9 = 2; + bool av1 = 3; + bool h264 = 4; + bool h265 = 5; +} + +message SupportedEncoding { + bool h264 = 1; + bool h265 = 2; + bool vp8 = 3; + bool av1 = 4; + CodecAbility i444 = 5; +} + +message PeerInfo { + string username = 1; + string hostname = 2; + string platform = 3; + repeated DisplayInfo displays = 4; + int32 current_display = 5; + bool sas_enabled = 6; + string version = 7; + Features features = 9; + SupportedEncoding encoding = 10; + SupportedResolutions resolutions = 11; + // Use JSON's key-value format which is friendly for peer to handle. + // NOTE: Only support one-level dictionaries (for peer to update), and the key is of type string. + string platform_additions = 12; + WindowsSessions windows_sessions = 13; +} + +message WindowsSession { + uint32 sid = 1; + string name = 2; +} + +message LoginResponse { + oneof union { + string error = 1; + PeerInfo peer_info = 2; + } + bool enable_trusted_devices = 3; +} + +message TouchScaleUpdate { + // The delta scale factor relative to the previous scale. + // delta * 1000 + // 0 means scale end + int32 scale = 1; +} + +message TouchPanStart { + int32 x = 1; + int32 y = 2; +} + +message TouchPanUpdate { + // The delta x position relative to the previous position. + int32 x = 1; + // The delta y position relative to the previous position. + int32 y = 2; +} + +message TouchPanEnd { + int32 x = 1; + int32 y = 2; +} + +message TouchEvent { + oneof union { + TouchScaleUpdate scale_update = 1; + TouchPanStart pan_start = 2; + TouchPanUpdate pan_update = 3; + TouchPanEnd pan_end = 4; + } +} + +message PointerDeviceEvent { + oneof union { + TouchEvent touch_event = 1; + } + repeated ControlKey modifiers = 2; +} + +message MouseEvent { + int32 mask = 1; + sint32 x = 2; + sint32 y = 3; + repeated ControlKey modifiers = 4; +} + +enum KeyboardMode{ + Legacy = 0; + Map = 1; + Translate = 2; + Auto = 3; +} + +enum ControlKey { + Unknown = 0; + Alt = 1; + Backspace = 2; + CapsLock = 3; + Control = 4; + Delete = 5; + DownArrow = 6; + End = 7; + Escape = 8; + F1 = 9; + F10 = 10; + F11 = 11; + F12 = 12; + F2 = 13; + F3 = 14; + F4 = 15; + F5 = 16; + F6 = 17; + F7 = 18; + F8 = 19; + F9 = 20; + Home = 21; + LeftArrow = 22; + /// meta key (also known as "windows"; "super"; and "command") + Meta = 23; + /// option key on macOS (alt key on Linux and Windows) + Option = 24; // deprecated, use Alt instead + PageDown = 25; + PageUp = 26; + Return = 27; + RightArrow = 28; + Shift = 29; + Space = 30; + Tab = 31; + UpArrow = 32; + Numpad0 = 33; + Numpad1 = 34; + Numpad2 = 35; + Numpad3 = 36; + Numpad4 = 37; + Numpad5 = 38; + Numpad6 = 39; + Numpad7 = 40; + Numpad8 = 41; + Numpad9 = 42; + Cancel = 43; + Clear = 44; + Menu = 45; // deprecated, use Alt instead + Pause = 46; + Kana = 47; + Hangul = 48; + Junja = 49; + Final = 50; + Hanja = 51; + Kanji = 52; + Convert = 53; + Select = 54; + Print = 55; + Execute = 56; + Snapshot = 57; + Insert = 58; + Help = 59; + Sleep = 60; + Separator = 61; + Scroll = 62; + NumLock = 63; + RWin = 64; + Apps = 65; + Multiply = 66; + Add = 67; + Subtract = 68; + Decimal = 69; + Divide = 70; + Equals = 71; + NumpadEnter = 72; + RShift = 73; + RControl = 74; + RAlt = 75; + VolumeMute = 76; // mainly used on mobile devices as controlled side + VolumeUp = 77; + VolumeDown = 78; + Power = 79; // mainly used on mobile devices as controlled side + CtrlAltDel = 100; + LockScreen = 101; +} + +message KeyEvent { + // `down` indicates the key's state(down or up). + bool down = 1; + // `press` indicates a click event(down and up). + bool press = 2; + oneof union { + ControlKey control_key = 3; + // position key code. win: scancode, linux: key code, macos: key code + uint32 chr = 4; + uint32 unicode = 5; + string seq = 6; + // high word. virtual keycode + // low word. unicode + uint32 win2win_hotkey = 7; + } + repeated ControlKey modifiers = 8; + KeyboardMode mode = 9; +} + +message CursorData { + uint64 id = 1; + sint32 hotx = 2; + sint32 hoty = 3; + int32 width = 4; + int32 height = 5; + bytes colors = 6; +} + +message CursorPosition { + sint32 x = 1; + sint32 y = 2; +} + +message Hash { + string salt = 1; + string challenge = 2; +} + +enum ClipboardFormat { + Text = 0; + Rtf = 1; + Html = 2; + ImageRgba = 21; + ImagePng = 22; + ImageSvg = 23; + Special = 31; +} + +message Clipboard { + bool compress = 1; + bytes content = 2; + int32 width = 3; + int32 height = 4; + ClipboardFormat format = 5; + // Special format name, only used when format is Special. + string special_name = 6; +} + +message MultiClipboards { repeated Clipboard clipboards = 1; } + +enum FileType { + Dir = 0; + DirLink = 2; + DirDrive = 3; + File = 4; + FileLink = 5; +} + +message FileEntry { + FileType entry_type = 1; + string name = 2; + bool is_hidden = 3; + uint64 size = 4; + uint64 modified_time = 5; +} + +message FileDirectory { + int32 id = 1; + string path = 2; + repeated FileEntry entries = 3; +} + +message ReadDir { + string path = 1; + bool include_hidden = 2; +} + +message ReadEmptyDirs { + string path = 1; + bool include_hidden = 2; +} + +message ReadEmptyDirsResponse { + string path = 1; + repeated FileDirectory empty_dirs = 2; +} + +message ReadAllFiles { + int32 id = 1; + string path = 2; + bool include_hidden = 3; +} + +message FileRename { + int32 id = 1; + string path = 2; + string new_name = 3; +} + +message FileAction { + oneof union { + ReadDir read_dir = 1; + FileTransferSendRequest send = 2; + FileTransferReceiveRequest receive = 3; + FileDirCreate create = 4; + FileRemoveDir remove_dir = 5; + FileRemoveFile remove_file = 6; + ReadAllFiles all_files = 7; + FileTransferCancel cancel = 8; + FileTransferSendConfirmRequest send_confirm = 9; + FileRename rename = 10; + ReadEmptyDirs read_empty_dirs = 11; + } +} + +message FileTransferCancel { int32 id = 1; } + +message FileResponse { + oneof union { + FileDirectory dir = 1; + FileTransferBlock block = 2; + FileTransferError error = 3; + FileTransferDone done = 4; + FileTransferDigest digest = 5; + ReadEmptyDirsResponse empty_dirs = 6; + } +} + +message FileTransferDigest { + int32 id = 1; + sint32 file_num = 2; + uint64 last_modified = 3; + uint64 file_size = 4; + bool is_upload = 5; + bool is_identical = 6; +} + +message FileTransferBlock { + int32 id = 1; + sint32 file_num = 2; + bytes data = 3; + bool compressed = 4; + uint32 blk_id = 5; +} + +message FileTransferError { + int32 id = 1; + string error = 2; + sint32 file_num = 3; +} + +message FileTransferSendRequest { + int32 id = 1; + string path = 2; + bool include_hidden = 3; + int32 file_num = 4; +} + +message FileTransferSendConfirmRequest { + int32 id = 1; + sint32 file_num = 2; + oneof union { + bool skip = 3; + uint32 offset_blk = 4; + } +} + +message FileTransferDone { + int32 id = 1; + sint32 file_num = 2; +} + +message FileTransferReceiveRequest { + int32 id = 1; + string path = 2; // path written to + repeated FileEntry files = 3; + int32 file_num = 4; + uint64 total_size = 5; +} + +message FileRemoveDir { + int32 id = 1; + string path = 2; + bool recursive = 3; +} + +message FileRemoveFile { + int32 id = 1; + string path = 2; + sint32 file_num = 3; +} + +message FileDirCreate { + int32 id = 1; + string path = 2; +} + +// main logic from freeRDP +message CliprdrMonitorReady { +} + +message CliprdrFormat { + int32 id = 2; + string format = 3; +} + +message CliprdrServerFormatList { + repeated CliprdrFormat formats = 2; +} + +message CliprdrServerFormatListResponse { + int32 msg_flags = 2; +} + +message CliprdrServerFormatDataRequest { + int32 requested_format_id = 2; +} + +message CliprdrServerFormatDataResponse { + int32 msg_flags = 2; + bytes format_data = 3; +} + +message CliprdrFileContentsRequest { + int32 stream_id = 2; + int32 list_index = 3; + int32 dw_flags = 4; + int32 n_position_low = 5; + int32 n_position_high = 6; + int32 cb_requested = 7; + bool have_clip_data_id = 8; + int32 clip_data_id = 9; +} + +message CliprdrFileContentsResponse { + int32 msg_flags = 3; + int32 stream_id = 4; + bytes requested_data = 5; +} + +message Cliprdr { + oneof union { + CliprdrMonitorReady ready = 1; + CliprdrServerFormatList format_list = 2; + CliprdrServerFormatListResponse format_list_response = 3; + CliprdrServerFormatDataRequest format_data_request = 4; + CliprdrServerFormatDataResponse format_data_response = 5; + CliprdrFileContentsRequest file_contents_request = 6; + CliprdrFileContentsResponse file_contents_response = 7; + } +} + +message Resolution { + int32 width = 1; + int32 height = 2; +} + +message DisplayResolution { + int32 display = 1; + Resolution resolution = 2; +} + +message SupportedResolutions { repeated Resolution resolutions = 1; } + +message SwitchDisplay { + int32 display = 1; + sint32 x = 2; + sint32 y = 3; + int32 width = 4; + int32 height = 5; + bool cursor_embedded = 6; + SupportedResolutions resolutions = 7; + // Do not care about the origin point for now. + Resolution original_resolution = 8; +} + +message CaptureDisplays { + repeated int32 add = 1; + repeated int32 sub = 2; + repeated int32 set = 3; +} + +message ToggleVirtualDisplay { + int32 display = 1; + bool on = 2; +} + +message TogglePrivacyMode { + string impl_key = 1; + bool on = 2; +} + +message PermissionInfo { + enum Permission { + Keyboard = 0; + Clipboard = 2; + Audio = 3; + File = 4; + Restart = 5; + Recording = 6; + BlockInput = 7; + } + + Permission permission = 1; + bool enabled = 2; +} + +enum ImageQuality { + NotSet = 0; + Low = 2; + Balanced = 3; + Best = 4; +} + +message SupportedDecoding { + enum PreferCodec { + Auto = 0; + VP9 = 1; + H264 = 2; + H265 = 3; + VP8 = 4; + AV1 = 5; + } + + int32 ability_vp9 = 1; + int32 ability_h264 = 2; + int32 ability_h265 = 3; + PreferCodec prefer = 4; + int32 ability_vp8 = 5; + int32 ability_av1 = 6; + CodecAbility i444 = 7; + Chroma prefer_chroma = 8; +} + +message OptionMessage { + enum BoolOption { + NotSet = 0; + No = 1; + Yes = 2; + } + ImageQuality image_quality = 1; + BoolOption lock_after_session_end = 2; + BoolOption show_remote_cursor = 3; + BoolOption privacy_mode = 4; + BoolOption block_input = 5; + int32 custom_image_quality = 6; + BoolOption disable_audio = 7; + BoolOption disable_clipboard = 8; + BoolOption enable_file_transfer = 9; + SupportedDecoding supported_decoding = 10; + int32 custom_fps = 11; + BoolOption disable_keyboard = 12; +// Position 13 is used for Resolution. Remove later. +// Resolution custom_resolution = 13; +// BoolOption support_windows_specific_session = 14; + // starting from 15 please, do not use removed fields + BoolOption follow_remote_cursor = 15; + BoolOption follow_remote_window = 16; +} + +message TestDelay { + int64 time = 1; + bool from_client = 2; + uint32 last_delay = 3; + uint32 target_bitrate = 4; +} + +message PublicKey { + bytes asymmetric_value = 1; + bytes symmetric_value = 2; +} + +message SignedId { bytes id = 1; } + +message AudioFormat { + uint32 sample_rate = 1; + uint32 channels = 2; +} + +message AudioFrame { + bytes data = 1; +} + +// Notify peer to show message box. +message MessageBox { + // Message type. Refer to flutter/lib/common.dart/msgBox(). + string msgtype = 1; + string title = 2; + // English + string text = 3; + // If not empty, msgbox provides a button to following the link. + // The link here can't be directly http url. + // It must be the key of http url configed in peer side or "rustdesk://*" (jump in app). + string link = 4; +} + +message BackNotification { + // no need to consider block input by someone else + enum BlockInputState { + BlkStateUnknown = 0; + BlkOnSucceeded = 2; + BlkOnFailed = 3; + BlkOffSucceeded = 4; + BlkOffFailed = 5; + } + enum PrivacyModeState { + PrvStateUnknown = 0; + // Privacy mode on by someone else + PrvOnByOther = 2; + // Privacy mode is not supported on the remote side + PrvNotSupported = 3; + // Privacy mode on by self + PrvOnSucceeded = 4; + // Privacy mode on by self, but denied + PrvOnFailedDenied = 5; + // Some plugins are not found + PrvOnFailedPlugin = 6; + // Privacy mode on by self, but failed + PrvOnFailed = 7; + // Privacy mode off by self + PrvOffSucceeded = 8; + // Ctrl + P + PrvOffByPeer = 9; + // Privacy mode off by self, but failed + PrvOffFailed = 10; + PrvOffUnknown = 11; + } + + oneof union { + PrivacyModeState privacy_mode_state = 1; + BlockInputState block_input_state = 2; + } + // Supplementary message, for "PrvOnFailed" and "PrvOffFailed" + string details = 3; + // The key of the implementation + string impl_key = 4; +} + +message ElevationRequestWithLogon { + string username = 1; + string password = 2; +} + +message ElevationRequest { + oneof union { + bool direct = 1; + ElevationRequestWithLogon logon = 2; + } +} + +message SwitchSidesRequest { + bytes uuid = 1; +} + +message SwitchSidesResponse { + bytes uuid = 1; + LoginRequest lr = 2; +} + +message SwitchBack {} + +message PluginRequest { + string id = 1; + bytes content = 2; +} + +message PluginFailure { + string id = 1; + string name = 2; + string msg = 3; +} + +message WindowsSessions { + repeated WindowsSession sessions = 1; + uint32 current_sid = 2; +} + +// Query messages from peer. +message MessageQuery { + // The SwitchDisplay message of the target display. + // If the target display is not found, the message will be ignored. + int32 switch_display = 1; +} + +message Misc { + oneof union { + ChatMessage chat_message = 4; + SwitchDisplay switch_display = 5; + PermissionInfo permission_info = 6; + OptionMessage option = 7; + AudioFormat audio_format = 8; + string close_reason = 9; + bool refresh_video = 10; + bool video_received = 12; + BackNotification back_notification = 13; + bool restart_remote_device = 14; + bool uac = 15; + bool foreground_window_elevated = 16; + bool stop_service = 17; + ElevationRequest elevation_request = 18; + string elevation_response = 19; + bool portable_service_running = 20; + SwitchSidesRequest switch_sides_request = 21; + SwitchBack switch_back = 22; + // Deprecated since 1.2.4, use `change_display_resolution` (36) instead. + // But we must keep it for compatibility when peer version < 1.2.4. + Resolution change_resolution = 24; + PluginRequest plugin_request = 25; + PluginFailure plugin_failure = 26; + uint32 full_speed_fps = 27; // deprecated + uint32 auto_adjust_fps = 28; + bool client_record_status = 29; + CaptureDisplays capture_displays = 30; + int32 refresh_video_display = 31; + ToggleVirtualDisplay toggle_virtual_display = 32; + TogglePrivacyMode toggle_privacy_mode = 33; + SupportedEncoding supported_encoding = 34; + uint32 selected_sid = 35; + DisplayResolution change_display_resolution = 36; + MessageQuery message_query = 37; + int32 follow_current_display = 38; + } +} + +message VoiceCallRequest { + int64 req_timestamp = 1; + // Indicates whether the request is a connect action or a disconnect action. + bool is_connect = 2; +} + +message VoiceCallResponse { + bool accepted = 1; + int64 req_timestamp = 2; // Should copy from [VoiceCallRequest::req_timestamp]. + int64 ack_timestamp = 3; +} + +message Message { + oneof union { + SignedId signed_id = 3; + PublicKey public_key = 4; + TestDelay test_delay = 5; + VideoFrame video_frame = 6; + LoginRequest login_request = 7; + LoginResponse login_response = 8; + Hash hash = 9; + MouseEvent mouse_event = 10; + AudioFrame audio_frame = 11; + CursorData cursor_data = 12; + CursorPosition cursor_position = 13; + uint64 cursor_id = 14; + KeyEvent key_event = 15; + Clipboard clipboard = 16; + FileAction file_action = 17; + FileResponse file_response = 18; + Misc misc = 19; + Cliprdr cliprdr = 20; + MessageBox message_box = 21; + SwitchSidesResponse switch_sides_response = 22; + VoiceCallRequest voice_call_request = 23; + VoiceCallResponse voice_call_response = 24; + PeerInfo peer_info = 25; + PointerDeviceEvent pointer_device_event = 26; + Auth2FA auth_2fa = 27; + MultiClipboards multi_clipboards = 28; + } +} diff --git a/protos/rendezvous.proto b/protos/rendezvous.proto new file mode 100644 index 0000000..2fc0d90 --- /dev/null +++ b/protos/rendezvous.proto @@ -0,0 +1,196 @@ +syntax = "proto3"; +package hbb; + +message RegisterPeer { + string id = 1; + int32 serial = 2; +} + +enum ConnType { + DEFAULT_CONN = 0; + FILE_TRANSFER = 1; + PORT_FORWARD = 2; + RDP = 3; +} + +message RegisterPeerResponse { bool request_pk = 2; } + +message PunchHoleRequest { + string id = 1; + NatType nat_type = 2; + string licence_key = 3; + ConnType conn_type = 4; + string token = 5; + string version = 6; +} + +message PunchHole { + bytes socket_addr = 1; + string relay_server = 2; + NatType nat_type = 3; +} + +message TestNatRequest { + int32 serial = 1; +} + +// per my test, uint/int has no difference in encoding, int not good for negative, use sint for negative +message TestNatResponse { + int32 port = 1; + ConfigUpdate cu = 2; // for mobile +} + +enum NatType { + UNKNOWN_NAT = 0; + ASYMMETRIC = 1; + SYMMETRIC = 2; +} + +message PunchHoleSent { + bytes socket_addr = 1; + string id = 2; + string relay_server = 3; + NatType nat_type = 4; + string version = 5; +} + +message RegisterPk { + string id = 1; + bytes uuid = 2; + bytes pk = 3; + string old_id = 4; +} + +message RegisterPkResponse { + enum Result { + OK = 0; + UUID_MISMATCH = 2; + ID_EXISTS = 3; + TOO_FREQUENT = 4; + INVALID_ID_FORMAT = 5; + NOT_SUPPORT = 6; + SERVER_ERROR = 7; + } + Result result = 1; + int32 keep_alive = 2; +} + +message PunchHoleResponse { + bytes socket_addr = 1; + bytes pk = 2; + enum Failure { + ID_NOT_EXIST = 0; + OFFLINE = 2; + LICENSE_MISMATCH = 3; + LICENSE_OVERUSE = 4; + } + Failure failure = 3; + string relay_server = 4; + oneof union { + NatType nat_type = 5; + bool is_local = 6; + } + string other_failure = 7; + int32 feedback = 8; +} + +message ConfigUpdate { + int32 serial = 1; + repeated string rendezvous_servers = 2; +} + +message RequestRelay { + string id = 1; + string uuid = 2; + bytes socket_addr = 3; + string relay_server = 4; + bool secure = 5; + string licence_key = 6; + ConnType conn_type = 7; + string token = 8; +} + +message RelayResponse { + bytes socket_addr = 1; + string uuid = 2; + string relay_server = 3; + oneof union { + string id = 4; + bytes pk = 5; + } + string refuse_reason = 6; + string version = 7; + int32 feedback = 9; +} + +message SoftwareUpdate { string url = 1; } + +// if in same intranet, punch hole won't work both for udp and tcp, +// even some router has below connection error if we connect itself, +// { kind: Other, error: "could not resolve to any address" }, +// so we request local address to connect. +message FetchLocalAddr { + bytes socket_addr = 1; + string relay_server = 2; +} + +message LocalAddr { + bytes socket_addr = 1; + bytes local_addr = 2; + string relay_server = 3; + string id = 4; + string version = 5; +} + +message PeerDiscovery { + string cmd = 1; + string mac = 2; + string id = 3; + string username = 4; + string hostname = 5; + string platform = 6; + string misc = 7; +} + +message OnlineRequest { + string id = 1; + repeated string peers = 2; +} + +message OnlineResponse { + bytes states = 1; +} + +message KeyExchange { + repeated bytes keys = 1; +} + +message HealthCheck { + string token = 1; +} + +message RendezvousMessage { + oneof union { + RegisterPeer register_peer = 6; + RegisterPeerResponse register_peer_response = 7; + PunchHoleRequest punch_hole_request = 8; + PunchHole punch_hole = 9; + PunchHoleSent punch_hole_sent = 10; + PunchHoleResponse punch_hole_response = 11; + FetchLocalAddr fetch_local_addr = 12; + LocalAddr local_addr = 13; + ConfigUpdate configure_update = 14; + RegisterPk register_pk = 15; + RegisterPkResponse register_pk_response = 16; + SoftwareUpdate software_update = 17; + RequestRelay request_relay = 18; + RelayResponse relay_response = 19; + TestNatRequest test_nat_request = 20; + TestNatResponse test_nat_response = 21; + PeerDiscovery peer_discovery = 22; + OnlineRequest online_request = 23; + OnlineResponse online_response = 24; + KeyExchange key_exchange = 25; + HealthCheck hc = 26; + } +} diff --git a/src/bytes_codec.rs b/src/bytes_codec.rs new file mode 100644 index 0000000..bfc7987 --- /dev/null +++ b/src/bytes_codec.rs @@ -0,0 +1,280 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::io; +use tokio_util::codec::{Decoder, Encoder}; + +#[derive(Debug, Clone, Copy)] +pub struct BytesCodec { + state: DecodeState, + raw: bool, + max_packet_length: usize, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + Head, + Data(usize), +} + +impl Default for BytesCodec { + fn default() -> Self { + Self::new() + } +} + +impl BytesCodec { + pub fn new() -> Self { + Self { + state: DecodeState::Head, + raw: false, + max_packet_length: usize::MAX, + } + } + + pub fn set_raw(&mut self) { + self.raw = true; + } + + pub fn set_max_packet_length(&mut self, n: usize) { + self.max_packet_length = n; + } + + fn decode_head(&mut self, src: &mut BytesMut) -> io::Result> { + if src.is_empty() { + return Ok(None); + } + let head_len = ((src[0] & 0x3) + 1) as usize; + if src.len() < head_len { + return Ok(None); + } + let mut n = src[0] as usize; + if head_len > 1 { + n |= (src[1] as usize) << 8; + } + if head_len > 2 { + n |= (src[2] as usize) << 16; + } + if head_len > 3 { + n |= (src[3] as usize) << 24; + } + n >>= 2; + if n > self.max_packet_length { + return Err(io::Error::new(io::ErrorKind::InvalidData, "Too big packet")); + } + src.advance(head_len); + src.reserve(n); + Ok(Some(n)) + } + + fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result> { + if src.len() < n { + return Ok(None); + } + Ok(Some(src.split_to(n))) + } +} + +impl Decoder for BytesCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { + if self.raw { + if !src.is_empty() { + let len = src.len(); + return Ok(Some(src.split_to(len))); + } else { + return Ok(None); + } + } + let n = match self.state { + DecodeState::Head => match self.decode_head(src)? { + Some(n) => { + self.state = DecodeState::Data(n); + n + } + None => return Ok(None), + }, + DecodeState::Data(n) => n, + }; + + match self.decode_data(n, src)? { + Some(data) => { + self.state = DecodeState::Head; + Ok(Some(data)) + } + None => Ok(None), + } + } +} + +impl Encoder for BytesCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { + if self.raw { + buf.reserve(data.len()); + buf.put(data); + return Ok(()); + } + if data.len() <= 0x3F { + buf.put_u8((data.len() << 2) as u8); + } else if data.len() <= 0x3FFF { + buf.put_u16_le((data.len() << 2) as u16 | 0x1); + } else if data.len() <= 0x3FFFFF { + let h = (data.len() << 2) as u32 | 0x2; + buf.put_u16_le((h & 0xFFFF) as u16); + buf.put_u8((h >> 16) as u8); + } else if data.len() <= 0x3FFFFFFF { + buf.put_u32_le((data.len() << 2) as u32 | 0x3); + } else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "Overflow")); + } + buf.extend(data); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_codec1() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec = Vec::new(); + bytes.resize(0x3F, 1); + assert!(codec.encode(bytes.into(), &mut buf).is_ok()); + let buf_saved = buf.clone(); + assert_eq!(buf.len(), 0x3F + 1); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3F); + assert_eq!(res[0], 1); + } else { + panic!(); + } + let mut codec2 = BytesCodec::new(); + let mut buf2 = BytesMut::new(); + if let Ok(None) = codec2.decode(&mut buf2) { + } else { + panic!(); + } + buf2.extend(&buf_saved[0..1]); + if let Ok(None) = codec2.decode(&mut buf2) { + } else { + panic!(); + } + buf2.extend(&buf_saved[1..]); + if let Ok(Some(res)) = codec2.decode(&mut buf2) { + assert_eq!(res.len(), 0x3F); + assert_eq!(res[0], 1); + } else { + panic!(); + } + } + + #[test] + fn test_codec2() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec = Vec::new(); + assert!(codec.encode("".into(), &mut buf).is_ok()); + assert_eq!(buf.len(), 1); + bytes.resize(0x3F + 1, 2); + assert!(codec.encode(bytes.into(), &mut buf).is_ok()); + assert_eq!(buf.len(), 0x3F + 2 + 2); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0); + } else { + panic!(); + } + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3F + 1); + assert_eq!(res[0], 2); + } else { + panic!(); + } + } + + #[test] + fn test_codec3() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec = Vec::new(); + bytes.resize(0x3F - 1, 3); + assert!(codec.encode(bytes.into(), &mut buf).is_ok()); + assert_eq!(buf.len(), 0x3F + 1 - 1); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3F - 1); + assert_eq!(res[0], 3); + } else { + panic!(); + } + } + #[test] + fn test_codec4() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec = Vec::new(); + bytes.resize(0x3FFF, 4); + assert!(codec.encode(bytes.into(), &mut buf).is_ok()); + assert_eq!(buf.len(), 0x3FFF + 2); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3FFF); + assert_eq!(res[0], 4); + } else { + panic!(); + } + } + + #[test] + fn test_codec5() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec = Vec::new(); + bytes.resize(0x3FFFFF, 5); + assert!(codec.encode(bytes.into(), &mut buf).is_ok()); + assert_eq!(buf.len(), 0x3FFFFF + 3); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3FFFFF); + assert_eq!(res[0], 5); + } else { + panic!(); + } + } + + #[test] + fn test_codec6() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec = Vec::new(); + bytes.resize(0x3FFFFF + 1, 6); + assert!(codec.encode(bytes.into(), &mut buf).is_ok()); + let buf_saved = buf.clone(); + assert_eq!(buf.len(), 0x3FFFFF + 4 + 1); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3FFFFF + 1); + assert_eq!(res[0], 6); + } else { + panic!(); + } + let mut codec2 = BytesCodec::new(); + let mut buf2 = BytesMut::new(); + buf2.extend(&buf_saved[0..1]); + if let Ok(None) = codec2.decode(&mut buf2) { + } else { + panic!(); + } + buf2.extend(&buf_saved[1..6]); + if let Ok(None) = codec2.decode(&mut buf2) { + } else { + panic!(); + } + buf2.extend(&buf_saved[6..]); + if let Ok(Some(res)) = codec2.decode(&mut buf2) { + assert_eq!(res.len(), 0x3FFFFF + 1); + assert_eq!(res[0], 6); + } else { + panic!(); + } + } +} diff --git a/src/compress.rs b/src/compress.rs new file mode 100644 index 0000000..761d916 --- /dev/null +++ b/src/compress.rs @@ -0,0 +1,34 @@ +use std::{cell::RefCell, io}; +use zstd::bulk::Compressor; + +// The library supports regular compression levels from 1 up to ZSTD_maxCLevel(), +// which is currently 22. Levels >= 20 +// Default level is ZSTD_CLEVEL_DEFAULT==3. +// value 0 means default, which is controlled by ZSTD_CLEVEL_DEFAULT +thread_local! { + static COMPRESSOR: RefCell>> = RefCell::new(Compressor::new(crate::config::COMPRESS_LEVEL)); +} + +pub fn compress(data: &[u8]) -> Vec { + let mut out = Vec::new(); + COMPRESSOR.with(|c| { + if let Ok(mut c) = c.try_borrow_mut() { + match &mut *c { + Ok(c) => match c.compress(data) { + Ok(res) => out = res, + Err(err) => { + crate::log::debug!("Failed to compress: {}", err); + } + }, + Err(err) => { + crate::log::debug!("Failed to get compressor: {}", err); + } + } + } + }); + out +} + +pub fn decompress(data: &[u8]) -> Vec { + zstd::decode_all(data).unwrap_or_default() +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..dd4abaf --- /dev/null +++ b/src/config.rs @@ -0,0 +1,2692 @@ +use std::{ + collections::{HashMap, HashSet}, + fs, + io::{Read, Write}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + ops::{Deref, DerefMut}, + path::{Path, PathBuf}, + sync::{Mutex, RwLock}, + time::{Duration, Instant, SystemTime}, +}; + +use anyhow::Result; +use bytes::Bytes; +use rand::Rng; +use regex::Regex; +use serde as de; +use serde_derive::{Deserialize, Serialize}; +use serde_json; +use sodiumoxide::base64; +use sodiumoxide::crypto::sign; + +use crate::{ + compress::{compress, decompress}, + log, + password_security::{ + decrypt_str_or_original, decrypt_vec_or_original, encrypt_str_or_original, + encrypt_vec_or_original, symmetric_crypt, + }, +}; + +pub const RENDEZVOUS_TIMEOUT: u64 = 12_000; +pub const CONNECT_TIMEOUT: u64 = 18_000; +pub const READ_TIMEOUT: u64 = 18_000; +// https://github.com/quic-go/quic-go/issues/525#issuecomment-294531351 +// https://datatracker.ietf.org/doc/html/draft-hamilton-early-deployment-quic-00#section-6.10 +// 15 seconds is recommended by quic, though oneSIP recommend 25 seconds, +// https://www.onsip.com/voip-resources/voip-fundamentals/what-is-nat-keepalive +pub const REG_INTERVAL: i64 = 15_000; +pub const COMPRESS_LEVEL: i32 = 3; +const SERIAL: i32 = 3; +const PASSWORD_ENC_VERSION: &str = "00"; +pub const ENCRYPT_MAX_LEN: usize = 128; // used for password, pin, etc, not for all + +#[cfg(target_os = "macos")] +lazy_static::lazy_static! { + pub static ref ORG: RwLock = RwLock::new("com.carriez".to_owned()); +} + +type Size = (i32, i32, i32, i32); +type KeyPair = (Vec, Vec); + +lazy_static::lazy_static! { + static ref CONFIG: RwLock = RwLock::new(Config::load()); + static ref CONFIG2: RwLock = RwLock::new(Config2::load()); + static ref LOCAL_CONFIG: RwLock = RwLock::new(LocalConfig::load()); + static ref TRUSTED_DEVICES: RwLock<(Vec, bool)> = Default::default(); + static ref ONLINE: Mutex> = Default::default(); + pub static ref PROD_RENDEZVOUS_SERVER: RwLock = RwLock::new(match option_env!("RENDEZVOUS_SERVER") { + Some(key) if !key.is_empty() => key, + _ => "", + }.to_owned()); + pub static ref EXE_RENDEZVOUS_SERVER: RwLock = Default::default(); + pub static ref APP_NAME: RwLock = RwLock::new("RustDesk".to_owned()); + static ref KEY_PAIR: Mutex> = Default::default(); + static ref USER_DEFAULT_CONFIG: RwLock<(UserDefaultConfig, Instant)> = RwLock::new((UserDefaultConfig::load(), Instant::now())); + pub static ref NEW_STORED_PEER_CONFIG: Mutex> = Default::default(); + pub static ref DEFAULT_SETTINGS: RwLock> = Default::default(); + pub static ref OVERWRITE_SETTINGS: RwLock> = Default::default(); + pub static ref DEFAULT_DISPLAY_SETTINGS: RwLock> = Default::default(); + pub static ref OVERWRITE_DISPLAY_SETTINGS: RwLock> = Default::default(); + pub static ref DEFAULT_LOCAL_SETTINGS: RwLock> = Default::default(); + pub static ref OVERWRITE_LOCAL_SETTINGS: RwLock> = Default::default(); + pub static ref HARD_SETTINGS: RwLock> = Default::default(); + pub static ref BUILTIN_SETTINGS: RwLock> = Default::default(); +} + +lazy_static::lazy_static! { + pub static ref APP_DIR: RwLock = Default::default(); +} + +#[cfg(any(target_os = "android", target_os = "ios"))] +lazy_static::lazy_static! { + pub static ref APP_HOME_DIR: RwLock = Default::default(); +} + +pub const LINK_DOCS_HOME: &str = "https://rustdesk.com/docs/en/"; +pub const LINK_DOCS_X11_REQUIRED: &str = "https://rustdesk.com/docs/en/manual/linux/#x11-required"; +pub const LINK_HEADLESS_LINUX_SUPPORT: &str = + "https://github.com/rustdesk/rustdesk/wiki/Headless-Linux-Support"; +lazy_static::lazy_static! { + pub static ref HELPER_URL: HashMap<&'static str, &'static str> = HashMap::from([ + ("rustdesk docs home", LINK_DOCS_HOME), + ("rustdesk docs x11-required", LINK_DOCS_X11_REQUIRED), + ("rustdesk x11 headless", LINK_HEADLESS_LINUX_SUPPORT), + ]); +} + +const CHARS: &[char] = &[ + '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', + 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', +]; + +pub const RENDEZVOUS_SERVERS: &[&str] = &["rs-ny.rustdesk.com"]; +pub const PUBLIC_RS_PUB_KEY: &str = "OeVuKk5nlHiXp+APNn0Y3pC1Iwpwn44JGqrQCsWqmBw="; + +pub const RS_PUB_KEY: &str = match option_env!("RS_PUB_KEY") { + Some(key) if !key.is_empty() => key, + _ => PUBLIC_RS_PUB_KEY, +}; + +pub const RENDEZVOUS_PORT: i32 = 21116; +pub const RELAY_PORT: i32 = 21117; + +macro_rules! serde_field_string { + ($default_func:ident, $de_func:ident, $default_expr:expr) => { + fn $default_func() -> String { + $default_expr + } + + fn $de_func<'de, D>(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + let s: String = + de::Deserialize::deserialize(deserializer).unwrap_or(Self::$default_func()); + if s.is_empty() { + return Ok(Self::$default_func()); + } + Ok(s) + } + }; +} + +macro_rules! serde_field_bool { + ($struct_name: ident, $field_name: literal, $func: ident, $default: literal) => { + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] + pub struct $struct_name { + #[serde(default = $default, rename = $field_name, deserialize_with = "deserialize_bool")] + pub v: bool, + } + impl Default for $struct_name { + fn default() -> Self { + Self { v: Self::$func() } + } + } + impl $struct_name { + pub fn $func() -> bool { + UserDefaultConfig::read($field_name) == "Y" + } + } + impl Deref for $struct_name { + type Target = bool; + + fn deref(&self) -> &Self::Target { + &self.v + } + } + impl DerefMut for $struct_name { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.v + } + } + }; +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum NetworkType { + Direct, + ProxySocks, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +pub struct Config { + #[serde( + default, + skip_serializing_if = "String::is_empty", + deserialize_with = "deserialize_string" + )] + pub id: String, // use + #[serde(default, deserialize_with = "deserialize_string")] + enc_id: String, // store + #[serde(default, deserialize_with = "deserialize_string")] + password: String, + #[serde(default, deserialize_with = "deserialize_string")] + salt: String, + #[serde(default, deserialize_with = "deserialize_keypair")] + key_pair: KeyPair, // sk, pk + #[serde(default, deserialize_with = "deserialize_bool")] + key_confirmed: bool, + #[serde(default, deserialize_with = "deserialize_hashmap_string_bool")] + keys_confirmed: HashMap, +} + +#[derive(Debug, Default, PartialEq, Serialize, Deserialize, Clone)] +pub struct Socks5Server { + #[serde(default, deserialize_with = "deserialize_string")] + pub proxy: String, + #[serde(default, deserialize_with = "deserialize_string")] + pub username: String, + #[serde(default, deserialize_with = "deserialize_string")] + pub password: String, +} + +// more variable configs +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +pub struct Config2 { + #[serde(default, deserialize_with = "deserialize_string")] + rendezvous_server: String, + #[serde(default, deserialize_with = "deserialize_i32")] + nat_type: i32, + #[serde(default, deserialize_with = "deserialize_i32")] + serial: i32, + #[serde(default, deserialize_with = "deserialize_string")] + unlock_pin: String, + #[serde(default, deserialize_with = "deserialize_string")] + trusted_devices: String, + + #[serde(default)] + socks: Option, + + // the other scalar value must before this + #[serde(default, deserialize_with = "deserialize_hashmap_string_string")] + pub options: HashMap, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +pub struct Resolution { + pub w: i32, + pub h: i32, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct PeerConfig { + #[serde(default, deserialize_with = "deserialize_vec_u8")] + pub password: Vec, + #[serde(default, deserialize_with = "deserialize_size")] + pub size: Size, + #[serde(default, deserialize_with = "deserialize_size")] + pub size_ft: Size, + #[serde(default, deserialize_with = "deserialize_size")] + pub size_pf: Size, + #[serde( + default = "PeerConfig::default_view_style", + deserialize_with = "PeerConfig::deserialize_view_style", + skip_serializing_if = "String::is_empty" + )] + pub view_style: String, + // Image scroll style, scrollbar or scroll auto + #[serde( + default = "PeerConfig::default_scroll_style", + deserialize_with = "PeerConfig::deserialize_scroll_style", + skip_serializing_if = "String::is_empty" + )] + pub scroll_style: String, + #[serde( + default = "PeerConfig::default_image_quality", + deserialize_with = "PeerConfig::deserialize_image_quality", + skip_serializing_if = "String::is_empty" + )] + pub image_quality: String, + #[serde( + default = "PeerConfig::default_custom_image_quality", + deserialize_with = "PeerConfig::deserialize_custom_image_quality", + skip_serializing_if = "Vec::is_empty" + )] + pub custom_image_quality: Vec, + #[serde(flatten)] + pub show_remote_cursor: ShowRemoteCursor, + #[serde(flatten)] + pub lock_after_session_end: LockAfterSessionEnd, + #[serde(flatten)] + pub privacy_mode: PrivacyMode, + #[serde(flatten)] + pub allow_swap_key: AllowSwapKey, + #[serde(default, deserialize_with = "deserialize_vec_i32_string_i32")] + pub port_forwards: Vec<(i32, String, i32)>, + #[serde(default, deserialize_with = "deserialize_i32")] + pub direct_failures: i32, + #[serde(flatten)] + pub disable_audio: DisableAudio, + #[serde(flatten)] + pub disable_clipboard: DisableClipboard, + #[serde(flatten)] + pub enable_file_copy_paste: EnableFileCopyPaste, + #[serde(flatten)] + pub show_quality_monitor: ShowQualityMonitor, + #[serde(flatten)] + pub follow_remote_cursor: FollowRemoteCursor, + #[serde(flatten)] + pub follow_remote_window: FollowRemoteWindow, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub keyboard_mode: String, + #[serde(flatten)] + pub view_only: ViewOnly, + #[serde(flatten)] + pub sync_init_clipboard: SyncInitClipboard, + // Mouse wheel or touchpad scroll mode + #[serde( + default = "PeerConfig::default_reverse_mouse_wheel", + deserialize_with = "PeerConfig::deserialize_reverse_mouse_wheel", + skip_serializing_if = "String::is_empty" + )] + pub reverse_mouse_wheel: String, + #[serde( + default = "PeerConfig::default_displays_as_individual_windows", + deserialize_with = "PeerConfig::deserialize_displays_as_individual_windows", + skip_serializing_if = "String::is_empty" + )] + pub displays_as_individual_windows: String, + #[serde( + default = "PeerConfig::default_use_all_my_displays_for_the_remote_session", + deserialize_with = "PeerConfig::deserialize_use_all_my_displays_for_the_remote_session", + skip_serializing_if = "String::is_empty" + )] + pub use_all_my_displays_for_the_remote_session: String, + + #[serde( + default, + deserialize_with = "deserialize_hashmap_resolutions", + skip_serializing_if = "HashMap::is_empty" + )] + pub custom_resolutions: HashMap, + + // The other scalar value must before this + #[serde( + default, + deserialize_with = "deserialize_hashmap_string_string", + skip_serializing_if = "HashMap::is_empty" + )] + pub options: HashMap, // not use delete to represent default values + // Various data for flutter ui + #[serde(default, deserialize_with = "deserialize_hashmap_string_string")] + pub ui_flutter: HashMap, + #[serde(default)] + pub info: PeerInfoSerde, + #[serde(default)] + pub transfer: TransferSerde, +} + +impl Default for PeerConfig { + fn default() -> Self { + Self { + password: Default::default(), + size: Default::default(), + size_ft: Default::default(), + size_pf: Default::default(), + view_style: Self::default_view_style(), + scroll_style: Self::default_scroll_style(), + image_quality: Self::default_image_quality(), + custom_image_quality: Self::default_custom_image_quality(), + show_remote_cursor: Default::default(), + lock_after_session_end: Default::default(), + privacy_mode: Default::default(), + allow_swap_key: Default::default(), + port_forwards: Default::default(), + direct_failures: Default::default(), + disable_audio: Default::default(), + disable_clipboard: Default::default(), + enable_file_copy_paste: Default::default(), + show_quality_monitor: Default::default(), + follow_remote_cursor: Default::default(), + follow_remote_window: Default::default(), + keyboard_mode: Default::default(), + view_only: Default::default(), + reverse_mouse_wheel: Self::default_reverse_mouse_wheel(), + displays_as_individual_windows: Self::default_displays_as_individual_windows(), + use_all_my_displays_for_the_remote_session: + Self::default_use_all_my_displays_for_the_remote_session(), + custom_resolutions: Default::default(), + options: Self::default_options(), + ui_flutter: Default::default(), + info: Default::default(), + transfer: Default::default(), + sync_init_clipboard: Default::default(), + } + } +} + +#[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)] +pub struct PeerInfoSerde { + #[serde(default, deserialize_with = "deserialize_string")] + pub username: String, + #[serde(default, deserialize_with = "deserialize_string")] + pub hostname: String, + #[serde(default, deserialize_with = "deserialize_string")] + pub platform: String, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +pub struct TransferSerde { + #[serde(default, deserialize_with = "deserialize_vec_string")] + pub write_jobs: Vec, + #[serde(default, deserialize_with = "deserialize_vec_string")] + pub read_jobs: Vec, +} + +#[inline] +pub fn get_online_state() -> i64 { + *ONLINE.lock().unwrap().values().max().unwrap_or(&0) +} + +#[cfg(not(any(target_os = "android", target_os = "ios")))] +fn patch(path: PathBuf) -> PathBuf { + if let Some(_tmp) = path.to_str() { + #[cfg(windows)] + return _tmp + .replace( + "system32\\config\\systemprofile", + "ServiceProfiles\\LocalService", + ) + .into(); + #[cfg(target_os = "macos")] + return _tmp.replace("Application Support", "Preferences").into(); + #[cfg(target_os = "linux")] + { + if _tmp == "/root" { + if let Ok(user) = crate::platform::linux::run_cmds_trim_newline("whoami") { + if user != "root" { + let cmd = format!("getent passwd '{}' | awk -F':' '{{print $6}}'", user); + if let Ok(output) = crate::platform::linux::run_cmds_trim_newline(&cmd) { + return output.into(); + } + return format!("/home/{user}").into(); + } + } + } + } + } + path +} + +impl Config2 { + fn load() -> Config2 { + let mut config = Config::load_::("2"); + let mut store = false; + if let Some(mut socks) = config.socks { + let (password, _, store2) = + decrypt_str_or_original(&socks.password, PASSWORD_ENC_VERSION); + socks.password = password; + config.socks = Some(socks); + store |= store2; + } + let (unlock_pin, _, store2) = + decrypt_str_or_original(&config.unlock_pin, PASSWORD_ENC_VERSION); + config.unlock_pin = unlock_pin; + store |= store2; + if store { + config.store(); + } + config + } + + pub fn file() -> PathBuf { + Config::file_("2") + } + + fn store(&self) { + let mut config = self.clone(); + if let Some(mut socks) = config.socks { + socks.password = + encrypt_str_or_original(&socks.password, PASSWORD_ENC_VERSION, ENCRYPT_MAX_LEN); + config.socks = Some(socks); + } + config.unlock_pin = + encrypt_str_or_original(&config.unlock_pin, PASSWORD_ENC_VERSION, ENCRYPT_MAX_LEN); + Config::store_(&config, "2"); + } + + pub fn get() -> Config2 { + return CONFIG2.read().unwrap().clone(); + } + + pub fn set(cfg: Config2) -> bool { + let mut lock = CONFIG2.write().unwrap(); + if *lock == cfg { + return false; + } + *lock = cfg; + lock.store(); + true + } +} + +pub fn load_path( + file: PathBuf, +) -> T { + let cfg = match confy::load_path(&file) { + Ok(config) => config, + Err(err) => { + if let confy::ConfyError::GeneralLoadError(err) = &err { + if err.kind() == std::io::ErrorKind::NotFound { + return T::default(); + } + } + log::error!("Failed to load config '{}': {}", file.display(), err); + T::default() + } + }; + cfg +} + +#[inline] +pub fn store_path(path: PathBuf, cfg: T) -> crate::ResultType<()> { + #[cfg(not(windows))] + { + use std::os::unix::fs::PermissionsExt; + Ok(confy::store_path_perms( + path, + cfg, + fs::Permissions::from_mode(0o600), + )?) + } + #[cfg(windows)] + { + Ok(confy::store_path(path, cfg)?) + } +} + +impl Config { + fn load_( + suffix: &str, + ) -> T { + let file = Self::file_(suffix); + let cfg = load_path(file); + if suffix.is_empty() { + log::trace!("{:?}", cfg); + } + cfg + } + + fn store_(config: &T, suffix: &str) { + let file = Self::file_(suffix); + if let Err(err) = store_path(file, config) { + log::error!("Failed to store {suffix} config: {err}"); + } + } + + fn load() -> Config { + let mut config = Config::load_::(""); + let mut store = false; + let (password, _, store1) = decrypt_str_or_original(&config.password, PASSWORD_ENC_VERSION); + config.password = password; + store |= store1; + let mut id_valid = false; + let (id, encrypted, store2) = decrypt_str_or_original(&config.enc_id, PASSWORD_ENC_VERSION); + if encrypted { + config.id = id; + id_valid = true; + store |= store2; + } else if + // Comment out for forward compatible + // crate::get_modified_time(&Self::file_("")) + // .checked_sub(std::time::Duration::from_secs(30)) // allow modification during installation + // .unwrap_or_else(crate::get_exe_time) + // < crate::get_exe_time() + // && + !config.id.is_empty() + && config.enc_id.is_empty() + && !decrypt_str_or_original(&config.id, PASSWORD_ENC_VERSION).1 + { + id_valid = true; + store = true; + } + if !id_valid { + for _ in 0..3 { + if let Some(id) = Config::get_auto_id() { + config.id = id; + store = true; + break; + } else { + log::error!("Failed to generate new id"); + } + } + } + if store { + config.store(); + } + config + } + + fn store(&self) { + let mut config = self.clone(); + config.password = + encrypt_str_or_original(&config.password, PASSWORD_ENC_VERSION, ENCRYPT_MAX_LEN); + config.enc_id = encrypt_str_or_original(&config.id, PASSWORD_ENC_VERSION, ENCRYPT_MAX_LEN); + config.id = "".to_owned(); + Config::store_(&config, ""); + } + + pub fn file() -> PathBuf { + Self::file_("") + } + + fn file_(suffix: &str) -> PathBuf { + let name = format!("{}{}", *APP_NAME.read().unwrap(), suffix); + Config::with_extension(Self::path(name)) + } + + pub fn is_empty(&self) -> bool { + (self.id.is_empty() && self.enc_id.is_empty()) || self.key_pair.0.is_empty() + } + + pub fn get_home() -> PathBuf { + #[cfg(any(target_os = "android", target_os = "ios"))] + return PathBuf::from(APP_HOME_DIR.read().unwrap().as_str()); + #[cfg(not(any(target_os = "android", target_os = "ios")))] + { + if let Some(path) = dirs_next::home_dir() { + patch(path) + } else if let Ok(path) = std::env::current_dir() { + path + } else { + std::env::temp_dir() + } + } + } + + pub fn path>(p: P) -> PathBuf { + #[cfg(any(target_os = "android", target_os = "ios"))] + { + let mut path: PathBuf = APP_DIR.read().unwrap().clone().into(); + path.push(p); + return path; + } + #[cfg(not(any(target_os = "android", target_os = "ios")))] + { + #[cfg(not(target_os = "macos"))] + let org = "".to_owned(); + #[cfg(target_os = "macos")] + let org = ORG.read().unwrap().clone(); + // /var/root for root + if let Some(project) = + directories_next::ProjectDirs::from("", &org, &APP_NAME.read().unwrap()) + { + let mut path = patch(project.config_dir().to_path_buf()); + path.push(p); + return path; + } + "".into() + } + } + + #[allow(unreachable_code)] + pub fn log_path() -> PathBuf { + #[cfg(target_os = "macos")] + { + if let Some(path) = dirs_next::home_dir().as_mut() { + path.push(format!("Library/Logs/{}", *APP_NAME.read().unwrap())); + return path.clone(); + } + } + #[cfg(target_os = "linux")] + { + let mut path = Self::get_home(); + path.push(format!(".local/share/logs/{}", *APP_NAME.read().unwrap())); + std::fs::create_dir_all(&path).ok(); + return path; + } + #[cfg(target_os = "android")] + { + let mut path = Self::get_home(); + path.push(format!("{}/Logs", *APP_NAME.read().unwrap())); + std::fs::create_dir_all(&path).ok(); + return path; + } + if let Some(path) = Self::path("").parent() { + let mut path: PathBuf = path.into(); + path.push("log"); + return path; + } + "".into() + } + + pub fn ipc_path(postfix: &str) -> String { + #[cfg(windows)] + { + // \\ServerName\pipe\PipeName + // where ServerName is either the name of a remote computer or a period, to specify the local computer. + // https://docs.microsoft.com/en-us/windows/win32/ipc/pipe-names + format!( + "\\\\.\\pipe\\{}\\query{}", + *APP_NAME.read().unwrap(), + postfix + ) + } + #[cfg(not(windows))] + { + use std::os::unix::fs::PermissionsExt; + #[cfg(target_os = "android")] + let mut path: PathBuf = + format!("{}/{}", *APP_DIR.read().unwrap(), *APP_NAME.read().unwrap()).into(); + #[cfg(not(target_os = "android"))] + let mut path: PathBuf = format!("/tmp/{}", *APP_NAME.read().unwrap()).into(); + fs::create_dir(&path).ok(); + fs::set_permissions(&path, fs::Permissions::from_mode(0o0777)).ok(); + path.push(format!("ipc{postfix}")); + path.to_str().unwrap_or("").to_owned() + } + } + + pub fn icon_path() -> PathBuf { + let mut path = Self::path("icons"); + if fs::create_dir_all(&path).is_err() { + path = std::env::temp_dir(); + } + path + } + + #[inline] + pub fn get_any_listen_addr(is_ipv4: bool) -> SocketAddr { + if is_ipv4 { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } else { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } + } + + pub fn get_rendezvous_server() -> String { + let mut rendezvous_server = EXE_RENDEZVOUS_SERVER.read().unwrap().clone(); + if rendezvous_server.is_empty() { + rendezvous_server = Self::get_option("custom-rendezvous-server"); + } + if rendezvous_server.is_empty() { + rendezvous_server = PROD_RENDEZVOUS_SERVER.read().unwrap().clone(); + } + if rendezvous_server.is_empty() { + rendezvous_server = CONFIG2.read().unwrap().rendezvous_server.clone(); + } + if rendezvous_server.is_empty() { + rendezvous_server = Self::get_rendezvous_servers() + .drain(..) + .next() + .unwrap_or_default(); + } + if !rendezvous_server.contains(':') { + rendezvous_server = format!("{rendezvous_server}:{RENDEZVOUS_PORT}"); + } + rendezvous_server + } + + pub fn get_rendezvous_servers() -> Vec { + let s = EXE_RENDEZVOUS_SERVER.read().unwrap().clone(); + if !s.is_empty() { + return vec![s]; + } + let s = Self::get_option("custom-rendezvous-server"); + if !s.is_empty() { + return vec![s]; + } + let s = PROD_RENDEZVOUS_SERVER.read().unwrap().clone(); + if !s.is_empty() { + return vec![s]; + } + let serial_obsolute = CONFIG2.read().unwrap().serial > SERIAL; + if serial_obsolute { + let ss: Vec = Self::get_option("rendezvous-servers") + .split(',') + .filter(|x| x.contains('.')) + .map(|x| x.to_owned()) + .collect(); + if !ss.is_empty() { + return ss; + } + } + return RENDEZVOUS_SERVERS.iter().map(|x| x.to_string()).collect(); + } + + pub fn reset_online() { + *ONLINE.lock().unwrap() = Default::default(); + } + + pub fn update_latency(host: &str, latency: i64) { + ONLINE.lock().unwrap().insert(host.to_owned(), latency); + let mut host = "".to_owned(); + let mut delay = i64::MAX; + for (tmp_host, tmp_delay) in ONLINE.lock().unwrap().iter() { + if tmp_delay > &0 && tmp_delay < &delay { + delay = *tmp_delay; + host = tmp_host.to_string(); + } + } + if !host.is_empty() { + let mut config = CONFIG2.write().unwrap(); + if host != config.rendezvous_server { + log::debug!("Update rendezvous_server in config to {}", host); + log::debug!("{:?}", *ONLINE.lock().unwrap()); + config.rendezvous_server = host; + config.store(); + } + } + } + + pub fn set_id(id: &str) { + let mut config = CONFIG.write().unwrap(); + if id == config.id { + return; + } + config.id = id.into(); + config.store(); + } + + pub fn set_nat_type(nat_type: i32) { + let mut config = CONFIG2.write().unwrap(); + if nat_type == config.nat_type { + return; + } + config.nat_type = nat_type; + config.store(); + } + + pub fn get_nat_type() -> i32 { + CONFIG2.read().unwrap().nat_type + } + + pub fn set_serial(serial: i32) { + let mut config = CONFIG2.write().unwrap(); + if serial == config.serial { + return; + } + config.serial = serial; + config.store(); + } + + pub fn get_serial() -> i32 { + std::cmp::max(CONFIG2.read().unwrap().serial, SERIAL) + } + + fn get_auto_id() -> Option { + #[cfg(any(target_os = "android", target_os = "ios"))] + { + return Some( + rand::thread_rng() + .gen_range(1_000_000_000..2_000_000_000) + .to_string(), + ); + } + + #[cfg(not(any(target_os = "android", target_os = "ios")))] + { + let mut id = 0u32; + if let Ok(Some(ma)) = mac_address::get_mac_address() { + for x in &ma.bytes()[2..] { + id = (id << 8) | (*x as u32); + } + id &= 0x1FFFFFFF; + Some(id.to_string()) + } else { + None + } + } + } + + pub fn get_auto_password(length: usize) -> String { + let mut rng = rand::thread_rng(); + (0..length) + .map(|_| CHARS[rng.gen::() % CHARS.len()]) + .collect() + } + + pub fn get_key_confirmed() -> bool { + CONFIG.read().unwrap().key_confirmed + } + + pub fn set_key_confirmed(v: bool) { + let mut config = CONFIG.write().unwrap(); + if config.key_confirmed == v { + return; + } + config.key_confirmed = v; + if !v { + config.keys_confirmed = Default::default(); + } + config.store(); + } + + pub fn get_host_key_confirmed(host: &str) -> bool { + matches!(CONFIG.read().unwrap().keys_confirmed.get(host), Some(true)) + } + + pub fn set_host_key_confirmed(host: &str, v: bool) { + if Self::get_host_key_confirmed(host) == v { + return; + } + let mut config = CONFIG.write().unwrap(); + config.keys_confirmed.insert(host.to_owned(), v); + config.store(); + } + + pub fn get_key_pair() -> KeyPair { + // lock here to make sure no gen_keypair more than once + // no use of CONFIG directly here to ensure no recursive calling in Config::load because of password dec which calling this function + let mut lock = KEY_PAIR.lock().unwrap(); + if let Some(p) = lock.as_ref() { + return p.clone(); + } + let mut config = Config::load_::(""); + if config.key_pair.0.is_empty() { + log::info!("Generated new keypair for id: {}", config.id); + let (pk, sk) = sign::gen_keypair(); + let key_pair = (sk.0.to_vec(), pk.0.into()); + config.key_pair = key_pair.clone(); + std::thread::spawn(|| { + let mut config = CONFIG.write().unwrap(); + config.key_pair = key_pair; + config.store(); + }); + } + *lock = Some(config.key_pair.clone()); + config.key_pair + } + + pub fn get_id() -> String { + let mut id = CONFIG.read().unwrap().id.clone(); + if id.is_empty() { + if let Some(tmp) = Config::get_auto_id() { + id = tmp; + Config::set_id(&id); + } + } + id + } + + pub fn get_id_or(b: String) -> String { + let a = CONFIG.read().unwrap().id.clone(); + if a.is_empty() { + b + } else { + a + } + } + + pub fn get_options() -> HashMap { + let mut res = DEFAULT_SETTINGS.read().unwrap().clone(); + res.extend(CONFIG2.read().unwrap().options.clone()); + res.extend(OVERWRITE_SETTINGS.read().unwrap().clone()); + res + } + + #[inline] + fn purify_options(v: &mut HashMap) { + v.retain(|k, v| is_option_can_save(&OVERWRITE_SETTINGS, k, &DEFAULT_SETTINGS, v)); + } + + pub fn set_options(mut v: HashMap) { + Self::purify_options(&mut v); + let mut config = CONFIG2.write().unwrap(); + if config.options == v { + return; + } + config.options = v; + config.store(); + } + + pub fn get_option(k: &str) -> String { + get_or( + &OVERWRITE_SETTINGS, + &CONFIG2.read().unwrap().options, + &DEFAULT_SETTINGS, + k, + ) + .unwrap_or_default() + } + + pub fn get_bool_option(k: &str) -> bool { + option2bool(k, &Self::get_option(k)) + } + + pub fn set_option(k: String, v: String) { + if !is_option_can_save(&OVERWRITE_SETTINGS, &k, &DEFAULT_SETTINGS, &v) { + return; + } + let mut config = CONFIG2.write().unwrap(); + let v2 = if v.is_empty() { None } else { Some(&v) }; + if v2 != config.options.get(&k) { + if v2.is_none() { + config.options.remove(&k); + } else { + config.options.insert(k, v); + } + config.store(); + } + } + + pub fn update_id() { + // to-do: how about if one ip register a lot of ids? + let id = Self::get_id(); + let mut rng = rand::thread_rng(); + let new_id = rng.gen_range(1_000_000_000..2_000_000_000).to_string(); + Config::set_id(&new_id); + log::info!("id updated from {} to {}", id, new_id); + } + + pub fn set_permanent_password(password: &str) { + if HARD_SETTINGS + .read() + .unwrap() + .get("password") + .map_or(false, |v| v == password) + { + return; + } + let mut config = CONFIG.write().unwrap(); + if password == config.password { + return; + } + config.password = password.into(); + config.store(); + Self::clear_trusted_devices(); + } + + pub fn get_permanent_password() -> String { + let mut password = CONFIG.read().unwrap().password.clone(); + if password.is_empty() { + if let Some(v) = HARD_SETTINGS.read().unwrap().get("password") { + password = v.to_owned(); + } + } + password + } + + pub fn set_salt(salt: &str) { + let mut config = CONFIG.write().unwrap(); + if salt == config.salt { + return; + } + config.salt = salt.into(); + config.store(); + } + + pub fn get_salt() -> String { + let mut salt = CONFIG.read().unwrap().salt.clone(); + if salt.is_empty() { + salt = Config::get_auto_password(6); + Config::set_salt(&salt); + } + salt + } + + pub fn set_socks(socks: Option) { + let mut config = CONFIG2.write().unwrap(); + if config.socks == socks { + return; + } + config.socks = socks; + config.store(); + } + + #[inline] + fn get_socks_from_custom_client_advanced_settings( + settings: &HashMap, + ) -> Option { + let url = settings.get(keys::OPTION_PROXY_URL)?; + Some(Socks5Server { + proxy: url.to_owned(), + username: settings + .get(keys::OPTION_PROXY_USERNAME) + .map(|x| x.to_string()) + .unwrap_or_default(), + password: settings + .get(keys::OPTION_PROXY_PASSWORD) + .map(|x| x.to_string()) + .unwrap_or_default(), + }) + } + + pub fn get_socks() -> Option { + Self::get_socks_from_custom_client_advanced_settings(&OVERWRITE_SETTINGS.read().unwrap()) + .or(CONFIG2.read().unwrap().socks.clone()) + .or(Self::get_socks_from_custom_client_advanced_settings( + &DEFAULT_SETTINGS.read().unwrap(), + )) + } + + #[inline] + pub fn is_proxy() -> bool { + Self::get_network_type() != NetworkType::Direct + } + + pub fn get_network_type() -> NetworkType { + if OVERWRITE_SETTINGS + .read() + .unwrap() + .get(keys::OPTION_PROXY_URL) + .is_some() + { + return NetworkType::ProxySocks; + } + if CONFIG2.read().unwrap().socks.is_some() { + return NetworkType::ProxySocks; + } + if DEFAULT_SETTINGS + .read() + .unwrap() + .get(keys::OPTION_PROXY_URL) + .is_some() + { + return NetworkType::ProxySocks; + } + NetworkType::Direct + } + + pub fn get_unlock_pin() -> String { + CONFIG2.read().unwrap().unlock_pin.clone() + } + + pub fn set_unlock_pin(pin: &str) { + let mut config = CONFIG2.write().unwrap(); + if pin == config.unlock_pin { + return; + } + config.unlock_pin = pin.to_string(); + config.store(); + } + + pub fn get_trusted_devices_json() -> String { + serde_json::to_string(&Self::get_trusted_devices()).unwrap_or_default() + } + + pub fn get_trusted_devices() -> Vec { + let (devices, synced) = TRUSTED_DEVICES.read().unwrap().clone(); + if synced { + return devices; + } + let devices = CONFIG2.read().unwrap().trusted_devices.clone(); + let (devices, succ, store) = decrypt_str_or_original(&devices, PASSWORD_ENC_VERSION); + if succ { + let mut devices: Vec = + serde_json::from_str(&devices).unwrap_or_default(); + let len = devices.len(); + devices.retain(|d| !d.outdate()); + if store || devices.len() != len { + Self::set_trusted_devices(devices.clone()); + } + *TRUSTED_DEVICES.write().unwrap() = (devices.clone(), true); + devices + } else { + Default::default() + } + } + + fn set_trusted_devices(mut trusted_devices: Vec) { + trusted_devices.retain(|d| !d.outdate()); + let devices = serde_json::to_string(&trusted_devices).unwrap_or_default(); + let max_len = 1024 * 1024; + if devices.bytes().len() > max_len { + log::error!("Trusted devices too large: {}", devices.bytes().len()); + return; + } + let devices = encrypt_str_or_original(&devices, PASSWORD_ENC_VERSION, max_len); + let mut config = CONFIG2.write().unwrap(); + config.trusted_devices = devices; + config.store(); + *TRUSTED_DEVICES.write().unwrap() = (trusted_devices, true); + } + + pub fn add_trusted_device(device: TrustedDevice) { + let mut devices = Self::get_trusted_devices(); + devices.retain(|d| d.hwid != device.hwid); + devices.push(device); + Self::set_trusted_devices(devices); + } + + pub fn remove_trusted_devices(hwids: &Vec) { + let mut devices = Self::get_trusted_devices(); + devices.retain(|d| !hwids.contains(&d.hwid)); + Self::set_trusted_devices(devices); + } + + pub fn clear_trusted_devices() { + Self::set_trusted_devices(Default::default()); + } + + pub fn get() -> Config { + return CONFIG.read().unwrap().clone(); + } + + pub fn set(cfg: Config) -> bool { + let mut lock = CONFIG.write().unwrap(); + if *lock == cfg { + return false; + } + *lock = cfg; + lock.store(); + true + } + + fn with_extension(path: PathBuf) -> PathBuf { + let ext = path.extension(); + if let Some(ext) = ext { + let ext = format!("{}.toml", ext.to_string_lossy()); + path.with_extension(ext) + } else { + path.with_extension("toml") + } + } +} + +const PEERS: &str = "peers"; + +impl PeerConfig { + pub fn load(id: &str) -> PeerConfig { + let _lock = CONFIG.read().unwrap(); + match confy::load_path(Self::path(id)) { + Ok(config) => { + let mut config: PeerConfig = config; + let mut store = false; + let (password, _, store2) = + decrypt_vec_or_original(&config.password, PASSWORD_ENC_VERSION); + config.password = password; + store = store || store2; + for opt in ["rdp_password", "os-username", "os-password"] { + if let Some(v) = config.options.get_mut(opt) { + let (encrypted, _, store2) = + decrypt_str_or_original(v, PASSWORD_ENC_VERSION); + *v = encrypted; + store = store || store2; + } + } + if store { + config.store(id); + } + config + } + Err(err) => { + if let confy::ConfyError::GeneralLoadError(err) = &err { + if err.kind() == std::io::ErrorKind::NotFound { + return Default::default(); + } + } + log::error!("Failed to load peer config '{}': {}", id, err); + Default::default() + } + } + } + + pub fn store(&self, id: &str) { + let _lock = CONFIG.read().unwrap(); + let mut config = self.clone(); + config.password = + encrypt_vec_or_original(&config.password, PASSWORD_ENC_VERSION, ENCRYPT_MAX_LEN); + for opt in ["rdp_password", "os-username", "os-password"] { + if let Some(v) = config.options.get_mut(opt) { + *v = encrypt_str_or_original(v, PASSWORD_ENC_VERSION, ENCRYPT_MAX_LEN) + } + } + if let Err(err) = store_path(Self::path(id), config) { + log::error!("Failed to store config: {}", err); + } + NEW_STORED_PEER_CONFIG.lock().unwrap().insert(id.to_owned()); + } + + pub fn remove(id: &str) { + fs::remove_file(Self::path(id)).ok(); + } + + fn path(id: &str) -> PathBuf { + //If the id contains invalid chars, encode it + let forbidden_paths = Regex::new(r".*[<>:/\\|\?\*].*"); + let path: PathBuf; + if let Ok(forbidden_paths) = forbidden_paths { + let id_encoded = if forbidden_paths.is_match(id) { + "base64_".to_string() + base64::encode(id, base64::Variant::Original).as_str() + } else { + id.to_string() + }; + path = [PEERS, id_encoded.as_str()].iter().collect(); + } else { + log::warn!("Regex create failed: {:?}", forbidden_paths.err()); + // fallback for failing to create this regex. + path = [PEERS, id.replace(":", "_").as_str()].iter().collect(); + } + Config::with_extension(Config::path(path)) + } + + pub fn peers(id_filters: Option>) -> Vec<(String, SystemTime, PeerConfig)> { + if let Ok(peers) = Config::path(PEERS).read_dir() { + if let Ok(peers) = peers + .map(|res| res.map(|e| e.path())) + .collect::, _>>() + { + let mut peers: Vec<_> = peers + .iter() + .filter(|p| { + p.is_file() + && p.extension().map(|p| p.to_str().unwrap_or("")) == Some("toml") + }) + .map(|p| { + let id = p + .file_stem() + .map(|p| p.to_str().unwrap_or("")) + .unwrap_or("") + .to_owned(); + + let id_decoded_string = if id.starts_with("base64_") && id.len() != 7 { + let id_decoded = base64::decode(&id[7..], base64::Variant::Original) + .unwrap_or_default(); + String::from_utf8_lossy(&id_decoded).as_ref().to_owned() + } else { + id + }; + (id_decoded_string, p) + }) + .filter(|(id, _)| { + let Some(filters) = &id_filters else { + return true; + }; + filters.contains(id) + }) + .map(|(id, p)| { + let t = crate::get_modified_time(p); + let c = PeerConfig::load(&id); + if c.info.platform.is_empty() { + fs::remove_file(p).ok(); + } + (id, t, c) + }) + .filter(|p| !p.2.info.platform.is_empty()) + .collect(); + peers.sort_unstable_by(|a, b| b.1.cmp(&a.1)); + return peers; + } + } + Default::default() + } + + pub fn exists(id: &str) -> bool { + Self::path(id).exists() + } + + serde_field_string!( + default_view_style, + deserialize_view_style, + UserDefaultConfig::read(keys::OPTION_VIEW_STYLE) + ); + serde_field_string!( + default_scroll_style, + deserialize_scroll_style, + UserDefaultConfig::read(keys::OPTION_SCROLL_STYLE) + ); + serde_field_string!( + default_image_quality, + deserialize_image_quality, + UserDefaultConfig::read(keys::OPTION_IMAGE_QUALITY) + ); + serde_field_string!( + default_reverse_mouse_wheel, + deserialize_reverse_mouse_wheel, + UserDefaultConfig::read(keys::OPTION_REVERSE_MOUSE_WHEEL) + ); + serde_field_string!( + default_displays_as_individual_windows, + deserialize_displays_as_individual_windows, + UserDefaultConfig::read(keys::OPTION_DISPLAYS_AS_INDIVIDUAL_WINDOWS) + ); + serde_field_string!( + default_use_all_my_displays_for_the_remote_session, + deserialize_use_all_my_displays_for_the_remote_session, + UserDefaultConfig::read(keys::OPTION_USE_ALL_MY_DISPLAYS_FOR_THE_REMOTE_SESSION) + ); + + fn default_custom_image_quality() -> Vec { + let f: f64 = UserDefaultConfig::read(keys::OPTION_CUSTOM_IMAGE_QUALITY) + .parse() + .unwrap_or(50.0); + vec![f as _] + } + + fn deserialize_custom_image_quality<'de, D>(deserializer: D) -> Result, D::Error> + where + D: de::Deserializer<'de>, + { + let v: Vec = de::Deserialize::deserialize(deserializer)?; + if v.len() == 1 && v[0] >= 10 && v[0] <= 0xFFF { + Ok(v) + } else { + Ok(Self::default_custom_image_quality()) + } + } + + fn default_options() -> HashMap { + let mut mp: HashMap = Default::default(); + [ + keys::OPTION_CODEC_PREFERENCE, + keys::OPTION_CUSTOM_FPS, + keys::OPTION_ZOOM_CURSOR, + keys::OPTION_TOUCH_MODE, + keys::OPTION_I444, + keys::OPTION_SWAP_LEFT_RIGHT_MOUSE, + keys::OPTION_COLLAPSE_TOOLBAR, + ] + .map(|key| { + mp.insert(key.to_owned(), UserDefaultConfig::read(key)); + }); + mp + } +} + +serde_field_bool!( + ShowRemoteCursor, + "show_remote_cursor", + default_show_remote_cursor, + "ShowRemoteCursor::default_show_remote_cursor" +); +serde_field_bool!( + FollowRemoteCursor, + "follow_remote_cursor", + default_follow_remote_cursor, + "FollowRemoteCursor::default_follow_remote_cursor" +); + +serde_field_bool!( + FollowRemoteWindow, + "follow_remote_window", + default_follow_remote_window, + "FollowRemoteWindow::default_follow_remote_window" +); +serde_field_bool!( + ShowQualityMonitor, + "show_quality_monitor", + default_show_quality_monitor, + "ShowQualityMonitor::default_show_quality_monitor" +); +serde_field_bool!( + DisableAudio, + "disable_audio", + default_disable_audio, + "DisableAudio::default_disable_audio" +); +serde_field_bool!( + EnableFileCopyPaste, + "enable-file-copy-paste", + default_enable_file_copy_paste, + "EnableFileCopyPaste::default_enable_file_copy_paste" +); +serde_field_bool!( + DisableClipboard, + "disable_clipboard", + default_disable_clipboard, + "DisableClipboard::default_disable_clipboard" +); +serde_field_bool!( + LockAfterSessionEnd, + "lock_after_session_end", + default_lock_after_session_end, + "LockAfterSessionEnd::default_lock_after_session_end" +); +serde_field_bool!( + PrivacyMode, + "privacy_mode", + default_privacy_mode, + "PrivacyMode::default_privacy_mode" +); + +serde_field_bool!( + AllowSwapKey, + "allow_swap_key", + default_allow_swap_key, + "AllowSwapKey::default_allow_swap_key" +); + +serde_field_bool!( + ViewOnly, + "view_only", + default_view_only, + "ViewOnly::default_view_only" +); + +serde_field_bool!( + SyncInitClipboard, + "sync-init-clipboard", + default_sync_init_clipboard, + "SyncInitClipboard::default_sync_init_clipboard" +); + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct LocalConfig { + #[serde(default, deserialize_with = "deserialize_string")] + remote_id: String, // latest used one + #[serde(default, deserialize_with = "deserialize_string")] + kb_layout_type: String, + #[serde(default, deserialize_with = "deserialize_size")] + size: Size, + #[serde(default, deserialize_with = "deserialize_vec_string")] + pub fav: Vec, + #[serde(default, deserialize_with = "deserialize_hashmap_string_string")] + options: HashMap, + // Various data for flutter ui + #[serde(default, deserialize_with = "deserialize_hashmap_string_string")] + ui_flutter: HashMap, +} + +impl LocalConfig { + fn load() -> LocalConfig { + Config::load_::("_local") + } + + fn store(&self) { + Config::store_(self, "_local"); + } + + pub fn get_kb_layout_type() -> String { + LOCAL_CONFIG.read().unwrap().kb_layout_type.clone() + } + + pub fn set_kb_layout_type(kb_layout_type: String) { + let mut config = LOCAL_CONFIG.write().unwrap(); + config.kb_layout_type = kb_layout_type; + config.store(); + } + + pub fn get_size() -> Size { + LOCAL_CONFIG.read().unwrap().size + } + + pub fn set_size(x: i32, y: i32, w: i32, h: i32) { + let mut config = LOCAL_CONFIG.write().unwrap(); + let size = (x, y, w, h); + if size == config.size || size.2 < 300 || size.3 < 300 { + return; + } + config.size = size; + config.store(); + } + + pub fn set_remote_id(remote_id: &str) { + let mut config = LOCAL_CONFIG.write().unwrap(); + if remote_id == config.remote_id { + return; + } + config.remote_id = remote_id.into(); + config.store(); + } + + pub fn get_remote_id() -> String { + LOCAL_CONFIG.read().unwrap().remote_id.clone() + } + + pub fn set_fav(fav: Vec) { + let mut lock = LOCAL_CONFIG.write().unwrap(); + if lock.fav == fav { + return; + } + lock.fav = fav; + lock.store(); + } + + pub fn get_fav() -> Vec { + LOCAL_CONFIG.read().unwrap().fav.clone() + } + + pub fn get_option(k: &str) -> String { + get_or( + &OVERWRITE_LOCAL_SETTINGS, + &LOCAL_CONFIG.read().unwrap().options, + &DEFAULT_LOCAL_SETTINGS, + k, + ) + .unwrap_or_default() + } + + // Usually get_option should be used. + pub fn get_option_from_file(k: &str) -> String { + get_or( + &OVERWRITE_LOCAL_SETTINGS, + &Self::load().options, + &DEFAULT_LOCAL_SETTINGS, + k, + ) + .unwrap_or_default() + } + + pub fn get_bool_option(k: &str) -> bool { + option2bool(k, &Self::get_option(k)) + } + + pub fn set_option(k: String, v: String) { + if !is_option_can_save(&OVERWRITE_LOCAL_SETTINGS, &k, &DEFAULT_LOCAL_SETTINGS, &v) { + return; + } + let mut config = LOCAL_CONFIG.write().unwrap(); + // The custom client will explictly set "default" as the default language. + let is_custom_client_default_lang = k == keys::OPTION_LANGUAGE && v == "default"; + if is_custom_client_default_lang { + config.options.insert(k, "".to_owned()); + config.store(); + return; + } + let v2 = if v.is_empty() { None } else { Some(&v) }; + if v2 != config.options.get(&k) { + if v2.is_none() { + config.options.remove(&k); + } else { + config.options.insert(k, v); + } + config.store(); + } + } + + pub fn get_flutter_option(k: &str) -> String { + get_or( + &OVERWRITE_LOCAL_SETTINGS, + &LOCAL_CONFIG.read().unwrap().ui_flutter, + &DEFAULT_LOCAL_SETTINGS, + k, + ) + .unwrap_or_default() + } + + pub fn set_flutter_option(k: String, v: String) { + let mut config = LOCAL_CONFIG.write().unwrap(); + let v2 = if v.is_empty() { None } else { Some(&v) }; + if v2 != config.ui_flutter.get(&k) { + if v2.is_none() { + config.ui_flutter.remove(&k); + } else { + config.ui_flutter.insert(k, v); + } + config.store(); + } + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct DiscoveryPeer { + #[serde(default, deserialize_with = "deserialize_string")] + pub id: String, + #[serde(default, deserialize_with = "deserialize_string")] + pub username: String, + #[serde(default, deserialize_with = "deserialize_string")] + pub hostname: String, + #[serde(default, deserialize_with = "deserialize_string")] + pub platform: String, + #[serde(default, deserialize_with = "deserialize_bool")] + pub online: bool, + #[serde(default, deserialize_with = "deserialize_hashmap_string_string")] + pub ip_mac: HashMap, +} + +impl DiscoveryPeer { + pub fn is_same_peer(&self, other: &DiscoveryPeer) -> bool { + self.id == other.id && self.username == other.username + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct LanPeers { + #[serde(default, deserialize_with = "deserialize_vec_discoverypeer")] + pub peers: Vec, +} + +impl LanPeers { + pub fn load() -> LanPeers { + let _lock = CONFIG.read().unwrap(); + match confy::load_path(Config::file_("_lan_peers")) { + Ok(peers) => peers, + Err(err) => { + log::error!("Failed to load lan peers: {}", err); + Default::default() + } + } + } + + pub fn store(peers: &[DiscoveryPeer]) { + let f = LanPeers { + peers: peers.to_owned(), + }; + if let Err(err) = store_path(Config::file_("_lan_peers"), f) { + log::error!("Failed to store lan peers: {}", err); + } + } + + pub fn modify_time() -> crate::ResultType { + let p = Config::file_("_lan_peers"); + Ok(fs::metadata(p)? + .modified()? + .duration_since(SystemTime::UNIX_EPOCH)? + .as_millis() as _) + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct UserDefaultConfig { + #[serde(default, deserialize_with = "deserialize_hashmap_string_string")] + options: HashMap, +} + +impl UserDefaultConfig { + fn read(key: &str) -> String { + let mut cfg = USER_DEFAULT_CONFIG.write().unwrap(); + // we do so, because default config may changed in another process, but we don't sync it + // but no need to read every time, give a small interval to avoid too many redundant read waste + if cfg.1.elapsed() > Duration::from_secs(1) { + *cfg = (Self::load(), Instant::now()); + } + cfg.0.get(key) + } + + pub fn load() -> UserDefaultConfig { + Config::load_::("_default") + } + + #[inline] + fn store(&self) { + Config::store_(self, "_default"); + } + + pub fn get(&self, key: &str) -> String { + match key { + #[cfg(any(target_os = "android", target_os = "ios"))] + keys::OPTION_VIEW_STYLE => self.get_string(key, "adaptive", vec!["original"]), + #[cfg(not(any(target_os = "android", target_os = "ios")))] + keys::OPTION_VIEW_STYLE => self.get_string(key, "original", vec!["adaptive"]), + keys::OPTION_SCROLL_STYLE => self.get_string(key, "scrollauto", vec!["scrollbar"]), + keys::OPTION_IMAGE_QUALITY => { + self.get_string(key, "balanced", vec!["best", "low", "custom"]) + } + keys::OPTION_CODEC_PREFERENCE => { + self.get_string(key, "auto", vec!["vp8", "vp9", "av1", "h264", "h265"]) + } + keys::OPTION_CUSTOM_IMAGE_QUALITY => { + self.get_double_string(key, 50.0, 10.0, 0xFFF as f64) + } + keys::OPTION_CUSTOM_FPS => self.get_double_string(key, 30.0, 5.0, 120.0), + keys::OPTION_ENABLE_FILE_COPY_PASTE => self.get_string(key, "Y", vec!["", "N"]), + _ => self + .get_after(key) + .map(|v| v.to_string()) + .unwrap_or_default(), + } + } + + pub fn set(&mut self, key: String, value: String) { + if !is_option_can_save( + &OVERWRITE_DISPLAY_SETTINGS, + &key, + &DEFAULT_DISPLAY_SETTINGS, + &value, + ) { + return; + } + if value.is_empty() { + self.options.remove(&key); + } else { + self.options.insert(key, value); + } + self.store(); + } + + #[inline] + fn get_string(&self, key: &str, default: &str, others: Vec<&str>) -> String { + match self.get_after(key) { + Some(option) => { + if others.contains(&option.as_str()) { + option.to_owned() + } else { + default.to_owned() + } + } + None => default.to_owned(), + } + } + + #[inline] + fn get_double_string(&self, key: &str, default: f64, min: f64, max: f64) -> String { + match self.get_after(key) { + Some(option) => { + let v: f64 = option.parse().unwrap_or(default); + if v >= min && v <= max { + v.to_string() + } else { + default.to_string() + } + } + None => default.to_string(), + } + } + + fn get_after(&self, k: &str) -> Option { + get_or( + &OVERWRITE_DISPLAY_SETTINGS, + &self.options, + &DEFAULT_DISPLAY_SETTINGS, + k, + ) + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct AbPeer { + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub id: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub hash: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub username: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub hostname: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub platform: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub alias: String, + #[serde(default, deserialize_with = "deserialize_vec_string")] + pub tags: Vec, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct AbEntry { + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub guid: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub name: String, + #[serde(default, deserialize_with = "deserialize_vec_abpeer")] + pub peers: Vec, + #[serde(default, deserialize_with = "deserialize_vec_string")] + pub tags: Vec, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub tag_colors: String, +} + +impl AbEntry { + pub fn personal(&self) -> bool { + self.name == "My address book" || self.name == "Legacy address book" + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct Ab { + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub access_token: String, + #[serde(default, deserialize_with = "deserialize_vec_abentry")] + pub ab_entries: Vec, +} + +impl Ab { + fn path() -> PathBuf { + let filename = format!("{}_ab", APP_NAME.read().unwrap().clone()); + Config::path(filename) + } + + pub fn store(json: String) { + if let Ok(mut file) = std::fs::File::create(Self::path()) { + let data = compress(json.as_bytes()); + let max_len = 64 * 1024 * 1024; + if data.len() > max_len { + // maxlen of function decompress + log::error!("ab data too large, {} > {}", data.len(), max_len); + return; + } + if let Ok(data) = symmetric_crypt(&data, true) { + file.write_all(&data).ok(); + } + }; + } + + pub fn load() -> Ab { + if let Ok(mut file) = std::fs::File::open(Self::path()) { + let mut data = vec![]; + if file.read_to_end(&mut data).is_ok() { + if let Ok(data) = symmetric_crypt(&data, false) { + let data = decompress(&data); + if let Ok(ab) = serde_json::from_str::(&String::from_utf8_lossy(&data)) { + return ab; + } + } + } + }; + Self::remove(); + Ab::default() + } + + pub fn remove() { + std::fs::remove_file(Self::path()).ok(); + } +} + +// use default value when field type is wrong +macro_rules! deserialize_default { + ($func_name:ident, $return_type:ty) => { + fn $func_name<'de, D>(deserializer: D) -> Result<$return_type, D::Error> + where + D: de::Deserializer<'de>, + { + Ok(de::Deserialize::deserialize(deserializer).unwrap_or_default()) + } + }; +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct GroupPeer { + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub id: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub username: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub hostname: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub platform: String, + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub login_name: String, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct GroupUser { + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub name: String, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct Group { + #[serde( + default, + deserialize_with = "deserialize_string", + skip_serializing_if = "String::is_empty" + )] + pub access_token: String, + #[serde(default, deserialize_with = "deserialize_vec_groupuser")] + pub users: Vec, + #[serde(default, deserialize_with = "deserialize_vec_grouppeer")] + pub peers: Vec, +} + +impl Group { + fn path() -> PathBuf { + let filename = format!("{}_group", APP_NAME.read().unwrap().clone()); + Config::path(filename) + } + + pub fn store(json: String) { + if let Ok(mut file) = std::fs::File::create(Self::path()) { + let data = compress(json.as_bytes()); + let max_len = 64 * 1024 * 1024; + if data.len() > max_len { + // maxlen of function decompress + return; + } + if let Ok(data) = symmetric_crypt(&data, true) { + file.write_all(&data).ok(); + } + }; + } + + pub fn load() -> Self { + if let Ok(mut file) = std::fs::File::open(Self::path()) { + let mut data = vec![]; + if file.read_to_end(&mut data).is_ok() { + if let Ok(data) = symmetric_crypt(&data, false) { + let data = decompress(&data); + if let Ok(group) = serde_json::from_str::(&String::from_utf8_lossy(&data)) + { + return group; + } + } + } + }; + Self::remove(); + Self::default() + } + + pub fn remove() { + std::fs::remove_file(Self::path()).ok(); + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct TrustedDevice { + pub hwid: Bytes, + pub time: i64, + pub id: String, + pub name: String, + pub platform: String, +} + +impl TrustedDevice { + pub fn outdate(&self) -> bool { + const DAYS_90: i64 = 90 * 24 * 60 * 60 * 1000; + self.time + DAYS_90 < crate::get_time() + } +} + +deserialize_default!(deserialize_string, String); +deserialize_default!(deserialize_bool, bool); +deserialize_default!(deserialize_i32, i32); +deserialize_default!(deserialize_vec_u8, Vec); +deserialize_default!(deserialize_vec_string, Vec); +deserialize_default!(deserialize_vec_i32_string_i32, Vec<(i32, String, i32)>); +deserialize_default!(deserialize_vec_discoverypeer, Vec); +deserialize_default!(deserialize_vec_abpeer, Vec); +deserialize_default!(deserialize_vec_abentry, Vec); +deserialize_default!(deserialize_vec_groupuser, Vec); +deserialize_default!(deserialize_vec_grouppeer, Vec); +deserialize_default!(deserialize_keypair, KeyPair); +deserialize_default!(deserialize_size, Size); +deserialize_default!(deserialize_hashmap_string_string, HashMap); +deserialize_default!(deserialize_hashmap_string_bool, HashMap); +deserialize_default!(deserialize_hashmap_resolutions, HashMap); + +#[inline] +fn get_or( + a: &RwLock>, + b: &HashMap, + c: &RwLock>, + k: &str, +) -> Option { + a.read() + .unwrap() + .get(k) + .or(b.get(k)) + .or(c.read().unwrap().get(k)) + .cloned() +} + +#[inline] +fn is_option_can_save( + overwrite: &RwLock>, + k: &str, + defaults: &RwLock>, + v: &str, +) -> bool { + if overwrite.read().unwrap().contains_key(k) + || defaults.read().unwrap().get(k).map_or(false, |x| x == v) + { + return false; + } + true +} + +#[inline] +pub fn is_incoming_only() -> bool { + HARD_SETTINGS + .read() + .unwrap() + .get("conn-type") + .map_or(false, |x| x == ("incoming")) +} + +#[inline] +pub fn is_outgoing_only() -> bool { + HARD_SETTINGS + .read() + .unwrap() + .get("conn-type") + .map_or(false, |x| x == ("outgoing")) +} + +#[inline] +fn is_some_hard_opton(name: &str) -> bool { + HARD_SETTINGS + .read() + .unwrap() + .get(name) + .map_or(false, |x| x == ("Y")) +} + +#[inline] +pub fn is_disable_tcp_listen() -> bool { + is_some_hard_opton("disable-tcp-listen") +} + +#[inline] +pub fn is_disable_settings() -> bool { + is_some_hard_opton("disable-settings") +} + +#[inline] +pub fn is_disable_ab() -> bool { + is_some_hard_opton("disable-ab") +} + +#[inline] +pub fn is_disable_account() -> bool { + is_some_hard_opton("disable-account") +} + +#[inline] +pub fn is_disable_installation() -> bool { + is_some_hard_opton("disable-installation") +} + +// This function must be kept the same as the one in flutter and sciter code. +// flutter: flutter/lib/common.dart -> option2bool() +// sciter: Does not have the function, but it should be kept the same. +pub fn option2bool(option: &str, value: &str) -> bool { + if option.starts_with("enable-") { + value != "N" + } else if option.starts_with("allow-") + || option == "stop-service" + || option == keys::OPTION_DIRECT_SERVER + || option == "force-always-relay" + { + value == "Y" + } else { + value != "N" + } +} + +pub mod keys { + pub const OPTION_VIEW_ONLY: &str = "view_only"; + pub const OPTION_SHOW_MONITORS_TOOLBAR: &str = "show_monitors_toolbar"; + pub const OPTION_COLLAPSE_TOOLBAR: &str = "collapse_toolbar"; + pub const OPTION_SHOW_REMOTE_CURSOR: &str = "show_remote_cursor"; + pub const OPTION_FOLLOW_REMOTE_CURSOR: &str = "follow_remote_cursor"; + pub const OPTION_FOLLOW_REMOTE_WINDOW: &str = "follow_remote_window"; + pub const OPTION_ZOOM_CURSOR: &str = "zoom-cursor"; + pub const OPTION_SHOW_QUALITY_MONITOR: &str = "show_quality_monitor"; + pub const OPTION_DISABLE_AUDIO: &str = "disable_audio"; + pub const OPTION_ENABLE_FILE_COPY_PASTE: &str = "enable-file-copy-paste"; + pub const OPTION_DISABLE_CLIPBOARD: &str = "disable_clipboard"; + pub const OPTION_LOCK_AFTER_SESSION_END: &str = "lock_after_session_end"; + pub const OPTION_PRIVACY_MODE: &str = "privacy_mode"; + pub const OPTION_TOUCH_MODE: &str = "touch-mode"; + pub const OPTION_I444: &str = "i444"; + pub const OPTION_REVERSE_MOUSE_WHEEL: &str = "reverse_mouse_wheel"; + pub const OPTION_SWAP_LEFT_RIGHT_MOUSE: &str = "swap-left-right-mouse"; + pub const OPTION_DISPLAYS_AS_INDIVIDUAL_WINDOWS: &str = "displays_as_individual_windows"; + pub const OPTION_USE_ALL_MY_DISPLAYS_FOR_THE_REMOTE_SESSION: &str = + "use_all_my_displays_for_the_remote_session"; + pub const OPTION_VIEW_STYLE: &str = "view_style"; + pub const OPTION_SCROLL_STYLE: &str = "scroll_style"; + pub const OPTION_IMAGE_QUALITY: &str = "image_quality"; + pub const OPTION_CUSTOM_IMAGE_QUALITY: &str = "custom_image_quality"; + pub const OPTION_CUSTOM_FPS: &str = "custom-fps"; + pub const OPTION_CODEC_PREFERENCE: &str = "codec-preference"; + pub const OPTION_SYNC_INIT_CLIPBOARD: &str = "sync-init-clipboard"; + pub const OPTION_THEME: &str = "theme"; + pub const OPTION_LANGUAGE: &str = "lang"; + pub const OPTION_REMOTE_MENUBAR_DRAG_LEFT: &str = "remote-menubar-drag-left"; + pub const OPTION_REMOTE_MENUBAR_DRAG_RIGHT: &str = "remote-menubar-drag-right"; + pub const OPTION_HIDE_AB_TAGS_PANEL: &str = "hideAbTagsPanel"; + pub const OPTION_ENABLE_CONFIRM_CLOSING_TABS: &str = "enable-confirm-closing-tabs"; + pub const OPTION_ENABLE_OPEN_NEW_CONNECTIONS_IN_TABS: &str = + "enable-open-new-connections-in-tabs"; + pub const OPTION_TEXTURE_RENDER: &str = "use-texture-render"; + pub const OPTION_ENABLE_CHECK_UPDATE: &str = "enable-check-update"; + pub const OPTION_SYNC_AB_WITH_RECENT_SESSIONS: &str = "sync-ab-with-recent-sessions"; + pub const OPTION_SYNC_AB_TAGS: &str = "sync-ab-tags"; + pub const OPTION_FILTER_AB_BY_INTERSECTION: &str = "filter-ab-by-intersection"; + pub const OPTION_ACCESS_MODE: &str = "access-mode"; + pub const OPTION_ENABLE_KEYBOARD: &str = "enable-keyboard"; + pub const OPTION_ENABLE_CLIPBOARD: &str = "enable-clipboard"; + pub const OPTION_ENABLE_FILE_TRANSFER: &str = "enable-file-transfer"; + pub const OPTION_ENABLE_AUDIO: &str = "enable-audio"; + pub const OPTION_ENABLE_TUNNEL: &str = "enable-tunnel"; + pub const OPTION_ENABLE_REMOTE_RESTART: &str = "enable-remote-restart"; + pub const OPTION_ENABLE_RECORD_SESSION: &str = "enable-record-session"; + pub const OPTION_ENABLE_BLOCK_INPUT: &str = "enable-block-input"; + pub const OPTION_ALLOW_REMOTE_CONFIG_MODIFICATION: &str = "allow-remote-config-modification"; + pub const OPTION_ENABLE_LAN_DISCOVERY: &str = "enable-lan-discovery"; + pub const OPTION_DIRECT_SERVER: &str = "direct-server"; + pub const OPTION_DIRECT_ACCESS_PORT: &str = "direct-access-port"; + pub const OPTION_WHITELIST: &str = "whitelist"; + pub const OPTION_ALLOW_AUTO_DISCONNECT: &str = "allow-auto-disconnect"; + pub const OPTION_AUTO_DISCONNECT_TIMEOUT: &str = "auto-disconnect-timeout"; + pub const OPTION_ALLOW_ONLY_CONN_WINDOW_OPEN: &str = "allow-only-conn-window-open"; + pub const OPTION_ALLOW_AUTO_RECORD_INCOMING: &str = "allow-auto-record-incoming"; + pub const OPTION_ALLOW_AUTO_RECORD_OUTGOING: &str = "allow-auto-record-outgoing"; + pub const OPTION_VIDEO_SAVE_DIRECTORY: &str = "video-save-directory"; + pub const OPTION_ENABLE_ABR: &str = "enable-abr"; + pub const OPTION_ALLOW_REMOVE_WALLPAPER: &str = "allow-remove-wallpaper"; + pub const OPTION_ALLOW_ALWAYS_SOFTWARE_RENDER: &str = "allow-always-software-render"; + pub const OPTION_ALLOW_LINUX_HEADLESS: &str = "allow-linux-headless"; + pub const OPTION_ENABLE_HWCODEC: &str = "enable-hwcodec"; + pub const OPTION_APPROVE_MODE: &str = "approve-mode"; + pub const OPTION_VERIFICATION_METHOD: &str = "verification-method"; + pub const OPTION_CUSTOM_RENDEZVOUS_SERVER: &str = "custom-rendezvous-server"; + pub const OPTION_API_SERVER: &str = "api-server"; + pub const OPTION_KEY: &str = "key"; + pub const OPTION_PRESET_ADDRESS_BOOK_NAME: &str = "preset-address-book-name"; + pub const OPTION_PRESET_ADDRESS_BOOK_TAG: &str = "preset-address-book-tag"; + pub const OPTION_ENABLE_DIRECTX_CAPTURE: &str = "enable-directx-capture"; + pub const OPTION_ENABLE_ANDROID_SOFTWARE_ENCODING_HALF_SCALE: &str = + "enable-android-software-encoding-half-scale"; + pub const OPTION_ENABLE_TRUSTED_DEVICES: &str = "enable-trusted-devices"; + pub const OPTION_AV1_TEST: &str = "av1-test"; + + // buildin options + pub const OPTION_DISPLAY_NAME: &str = "display-name"; + pub const OPTION_DISABLE_UDP: &str = "disable-udp"; + pub const OPTION_PRESET_USERNAME: &str = "preset-user-name"; + pub const OPTION_PRESET_STRATEGY_NAME: &str = "preset-strategy-name"; + pub const OPTION_REMOVE_PRESET_PASSWORD_WARNING: &str = "remove-preset-password-warning"; + pub const OPTION_HIDE_SECURITY_SETTINGS: &str = "hide-security-settings"; + pub const OPTION_HIDE_NETWORK_SETTINGS: &str = "hide-network-settings"; + pub const OPTION_HIDE_SERVER_SETTINGS: &str = "hide-server-settings"; + pub const OPTION_HIDE_PROXY_SETTINGS: &str = "hide-proxy-settings"; + pub const OPTION_HIDE_USERNAME_ON_CARD: &str = "hide-username-on-card"; + pub const OPTION_HIDE_HELP_CARDS: &str = "hide-help-cards"; + pub const OPTION_DEFAULT_CONNECT_PASSWORD: &str = "default-connect-password"; + pub const OPTION_HIDE_TRAY: &str = "hide-tray"; + pub const OPTION_ONE_WAY_CLIPBOARD_REDIRECTION: &str = "one-way-clipboard-redirection"; + pub const OPTION_ALLOW_LOGON_SCREEN_PASSWORD: &str = "allow-logon-screen-password"; + pub const OPTION_ONE_WAY_FILE_TRANSFER: &str = "one-way-file-transfer"; + + // flutter local options + pub const OPTION_FLUTTER_REMOTE_MENUBAR_STATE: &str = "remoteMenubarState"; + pub const OPTION_FLUTTER_PEER_SORTING: &str = "peer-sorting"; + pub const OPTION_FLUTTER_PEER_TAB_INDEX: &str = "peer-tab-index"; + pub const OPTION_FLUTTER_PEER_TAB_ORDER: &str = "peer-tab-order"; + pub const OPTION_FLUTTER_PEER_TAB_VISIBLE: &str = "peer-tab-visible"; + pub const OPTION_FLUTTER_PEER_CARD_UI_TYLE: &str = "peer-card-ui-type"; + pub const OPTION_FLUTTER_CURRENT_AB_NAME: &str = "current-ab-name"; + pub const OPTION_ALLOW_REMOTE_CM_MODIFICATION: &str = "allow-remote-cm-modification"; + + // android floating window options + pub const OPTION_DISABLE_FLOATING_WINDOW: &str = "disable-floating-window"; + pub const OPTION_FLOATING_WINDOW_SIZE: &str = "floating-window-size"; + pub const OPTION_FLOATING_WINDOW_UNTOUCHABLE: &str = "floating-window-untouchable"; + pub const OPTION_FLOATING_WINDOW_TRANSPARENCY: &str = "floating-window-transparency"; + pub const OPTION_FLOATING_WINDOW_SVG: &str = "floating-window-svg"; + + // android keep screen on + pub const OPTION_KEEP_SCREEN_ON: &str = "keep-screen-on"; + + pub const OPTION_DISABLE_GROUP_PANEL: &str = "disable-group-panel"; + pub const OPTION_PRE_ELEVATE_SERVICE: &str = "pre-elevate-service"; + + // proxy settings + // The following options are not real keys, they are just used for custom client advanced settings. + // The real keys are in Config2::socks. + pub const OPTION_PROXY_URL: &str = "proxy-url"; + pub const OPTION_PROXY_USERNAME: &str = "proxy-username"; + pub const OPTION_PROXY_PASSWORD: &str = "proxy-password"; + + // DEFAULT_DISPLAY_SETTINGS, OVERWRITE_DISPLAY_SETTINGS + pub const KEYS_DISPLAY_SETTINGS: &[&str] = &[ + OPTION_VIEW_ONLY, + OPTION_SHOW_MONITORS_TOOLBAR, + OPTION_COLLAPSE_TOOLBAR, + OPTION_SHOW_REMOTE_CURSOR, + OPTION_FOLLOW_REMOTE_CURSOR, + OPTION_FOLLOW_REMOTE_WINDOW, + OPTION_ZOOM_CURSOR, + OPTION_SHOW_QUALITY_MONITOR, + OPTION_DISABLE_AUDIO, + OPTION_ENABLE_FILE_COPY_PASTE, + OPTION_DISABLE_CLIPBOARD, + OPTION_LOCK_AFTER_SESSION_END, + OPTION_PRIVACY_MODE, + OPTION_TOUCH_MODE, + OPTION_I444, + OPTION_REVERSE_MOUSE_WHEEL, + OPTION_SWAP_LEFT_RIGHT_MOUSE, + OPTION_DISPLAYS_AS_INDIVIDUAL_WINDOWS, + OPTION_USE_ALL_MY_DISPLAYS_FOR_THE_REMOTE_SESSION, + OPTION_VIEW_STYLE, + OPTION_SCROLL_STYLE, + OPTION_IMAGE_QUALITY, + OPTION_CUSTOM_IMAGE_QUALITY, + OPTION_CUSTOM_FPS, + OPTION_CODEC_PREFERENCE, + OPTION_SYNC_INIT_CLIPBOARD, + ]; + // DEFAULT_LOCAL_SETTINGS, OVERWRITE_LOCAL_SETTINGS + pub const KEYS_LOCAL_SETTINGS: &[&str] = &[ + OPTION_THEME, + OPTION_LANGUAGE, + OPTION_ENABLE_CONFIRM_CLOSING_TABS, + OPTION_ENABLE_OPEN_NEW_CONNECTIONS_IN_TABS, + OPTION_TEXTURE_RENDER, + OPTION_SYNC_AB_WITH_RECENT_SESSIONS, + OPTION_SYNC_AB_TAGS, + OPTION_FILTER_AB_BY_INTERSECTION, + OPTION_REMOTE_MENUBAR_DRAG_LEFT, + OPTION_REMOTE_MENUBAR_DRAG_RIGHT, + OPTION_HIDE_AB_TAGS_PANEL, + OPTION_FLUTTER_REMOTE_MENUBAR_STATE, + OPTION_FLUTTER_PEER_SORTING, + OPTION_FLUTTER_PEER_TAB_INDEX, + OPTION_FLUTTER_PEER_TAB_ORDER, + OPTION_FLUTTER_PEER_TAB_VISIBLE, + OPTION_FLUTTER_PEER_CARD_UI_TYLE, + OPTION_FLUTTER_CURRENT_AB_NAME, + OPTION_DISABLE_FLOATING_WINDOW, + OPTION_FLOATING_WINDOW_SIZE, + OPTION_FLOATING_WINDOW_UNTOUCHABLE, + OPTION_FLOATING_WINDOW_TRANSPARENCY, + OPTION_FLOATING_WINDOW_SVG, + OPTION_KEEP_SCREEN_ON, + OPTION_DISABLE_GROUP_PANEL, + OPTION_PRE_ELEVATE_SERVICE, + OPTION_ALLOW_REMOTE_CM_MODIFICATION, + OPTION_ALLOW_AUTO_RECORD_OUTGOING, + OPTION_VIDEO_SAVE_DIRECTORY, + ]; + // DEFAULT_SETTINGS, OVERWRITE_SETTINGS + pub const KEYS_SETTINGS: &[&str] = &[ + OPTION_ACCESS_MODE, + OPTION_ENABLE_KEYBOARD, + OPTION_ENABLE_CLIPBOARD, + OPTION_ENABLE_FILE_TRANSFER, + OPTION_ENABLE_AUDIO, + OPTION_ENABLE_TUNNEL, + OPTION_ENABLE_REMOTE_RESTART, + OPTION_ENABLE_RECORD_SESSION, + OPTION_ENABLE_BLOCK_INPUT, + OPTION_ALLOW_REMOTE_CONFIG_MODIFICATION, + OPTION_ENABLE_LAN_DISCOVERY, + OPTION_DIRECT_SERVER, + OPTION_DIRECT_ACCESS_PORT, + OPTION_WHITELIST, + OPTION_ALLOW_AUTO_DISCONNECT, + OPTION_AUTO_DISCONNECT_TIMEOUT, + OPTION_ALLOW_ONLY_CONN_WINDOW_OPEN, + OPTION_ALLOW_AUTO_RECORD_INCOMING, + OPTION_ENABLE_ABR, + OPTION_ALLOW_REMOVE_WALLPAPER, + OPTION_ALLOW_ALWAYS_SOFTWARE_RENDER, + OPTION_ALLOW_LINUX_HEADLESS, + OPTION_ENABLE_HWCODEC, + OPTION_APPROVE_MODE, + OPTION_VERIFICATION_METHOD, + OPTION_PROXY_URL, + OPTION_PROXY_USERNAME, + OPTION_PROXY_PASSWORD, + OPTION_CUSTOM_RENDEZVOUS_SERVER, + OPTION_API_SERVER, + OPTION_KEY, + OPTION_PRESET_ADDRESS_BOOK_NAME, + OPTION_PRESET_ADDRESS_BOOK_TAG, + OPTION_ENABLE_DIRECTX_CAPTURE, + OPTION_ENABLE_ANDROID_SOFTWARE_ENCODING_HALF_SCALE, + OPTION_ENABLE_TRUSTED_DEVICES, + ]; + + // BUILDIN_SETTINGS + pub const KEYS_BUILDIN_SETTINGS: &[&str] = &[ + OPTION_DISPLAY_NAME, + OPTION_DISABLE_UDP, + OPTION_PRESET_USERNAME, + OPTION_PRESET_STRATEGY_NAME, + OPTION_REMOVE_PRESET_PASSWORD_WARNING, + OPTION_HIDE_SECURITY_SETTINGS, + OPTION_HIDE_NETWORK_SETTINGS, + OPTION_HIDE_SERVER_SETTINGS, + OPTION_HIDE_PROXY_SETTINGS, + OPTION_HIDE_USERNAME_ON_CARD, + OPTION_HIDE_HELP_CARDS, + OPTION_DEFAULT_CONNECT_PASSWORD, + OPTION_HIDE_TRAY, + OPTION_ONE_WAY_CLIPBOARD_REDIRECTION, + OPTION_ALLOW_LOGON_SCREEN_PASSWORD, + OPTION_ONE_WAY_FILE_TRANSFER, + ]; +} + +pub fn common_load< + T: serde::Serialize + serde::de::DeserializeOwned + Default + std::fmt::Debug, +>( + suffix: &str, +) -> T { + Config::load_::(suffix) +} + +pub fn common_store(config: &T, suffix: &str) { + Config::store_(config, suffix); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serialize() { + let cfg: Config = Default::default(); + let res = toml::to_string_pretty(&cfg); + assert!(res.is_ok()); + let cfg: PeerConfig = Default::default(); + let res = toml::to_string_pretty(&cfg); + assert!(res.is_ok()); + } + + #[test] + fn test_overwrite_settings() { + DEFAULT_SETTINGS + .write() + .unwrap() + .insert("b".to_string(), "a".to_string()); + DEFAULT_SETTINGS + .write() + .unwrap() + .insert("c".to_string(), "a".to_string()); + CONFIG2 + .write() + .unwrap() + .options + .insert("a".to_string(), "b".to_string()); + CONFIG2 + .write() + .unwrap() + .options + .insert("b".to_string(), "b".to_string()); + OVERWRITE_SETTINGS + .write() + .unwrap() + .insert("b".to_string(), "c".to_string()); + OVERWRITE_SETTINGS + .write() + .unwrap() + .insert("c".to_string(), "f".to_string()); + OVERWRITE_SETTINGS + .write() + .unwrap() + .insert("d".to_string(), "c".to_string()); + let mut res: HashMap = Default::default(); + res.insert("b".to_owned(), "c".to_string()); + res.insert("d".to_owned(), "c".to_string()); + res.insert("c".to_owned(), "a".to_string()); + Config::purify_options(&mut res); + assert!(res.len() == 0); + res.insert("b".to_owned(), "c".to_string()); + res.insert("d".to_owned(), "c".to_string()); + res.insert("c".to_owned(), "a".to_string()); + res.insert("f".to_owned(), "a".to_string()); + Config::purify_options(&mut res); + assert!(res.len() == 1); + res.insert("b".to_owned(), "c".to_string()); + res.insert("d".to_owned(), "c".to_string()); + res.insert("c".to_owned(), "a".to_string()); + res.insert("f".to_owned(), "a".to_string()); + res.insert("e".to_owned(), "d".to_string()); + Config::purify_options(&mut res); + assert!(res.len() == 2); + res.insert("b".to_owned(), "c".to_string()); + res.insert("d".to_owned(), "c".to_string()); + res.insert("c".to_owned(), "a".to_string()); + res.insert("f".to_owned(), "a".to_string()); + res.insert("c".to_owned(), "d".to_string()); + res.insert("d".to_owned(), "cc".to_string()); + Config::purify_options(&mut res); + DEFAULT_SETTINGS + .write() + .unwrap() + .insert("f".to_string(), "c".to_string()); + Config::purify_options(&mut res); + assert!(res.len() == 2); + DEFAULT_SETTINGS + .write() + .unwrap() + .insert("f".to_string(), "a".to_string()); + Config::purify_options(&mut res); + assert!(res.len() == 1); + let res = Config::get_options(); + assert!(res["a"] == "b"); + assert!(res["c"] == "f"); + assert!(res["b"] == "c"); + assert!(res["d"] == "c"); + assert!(Config::get_option("a") == "b"); + assert!(Config::get_option("c") == "f"); + assert!(Config::get_option("b") == "c"); + assert!(Config::get_option("d") == "c"); + DEFAULT_SETTINGS.write().unwrap().clear(); + OVERWRITE_SETTINGS.write().unwrap().clear(); + CONFIG2.write().unwrap().options.clear(); + + DEFAULT_LOCAL_SETTINGS + .write() + .unwrap() + .insert("b".to_string(), "a".to_string()); + DEFAULT_LOCAL_SETTINGS + .write() + .unwrap() + .insert("c".to_string(), "a".to_string()); + LOCAL_CONFIG + .write() + .unwrap() + .options + .insert("a".to_string(), "b".to_string()); + LOCAL_CONFIG + .write() + .unwrap() + .options + .insert("b".to_string(), "b".to_string()); + OVERWRITE_LOCAL_SETTINGS + .write() + .unwrap() + .insert("b".to_string(), "c".to_string()); + OVERWRITE_LOCAL_SETTINGS + .write() + .unwrap() + .insert("d".to_string(), "c".to_string()); + assert!(LocalConfig::get_option("a") == "b"); + assert!(LocalConfig::get_option("c") == "a"); + assert!(LocalConfig::get_option("b") == "c"); + assert!(LocalConfig::get_option("d") == "c"); + DEFAULT_LOCAL_SETTINGS.write().unwrap().clear(); + OVERWRITE_LOCAL_SETTINGS.write().unwrap().clear(); + LOCAL_CONFIG.write().unwrap().options.clear(); + + DEFAULT_DISPLAY_SETTINGS + .write() + .unwrap() + .insert("b".to_string(), "a".to_string()); + DEFAULT_DISPLAY_SETTINGS + .write() + .unwrap() + .insert("c".to_string(), "a".to_string()); + USER_DEFAULT_CONFIG + .write() + .unwrap() + .0 + .options + .insert("a".to_string(), "b".to_string()); + USER_DEFAULT_CONFIG + .write() + .unwrap() + .0 + .options + .insert("b".to_string(), "b".to_string()); + OVERWRITE_DISPLAY_SETTINGS + .write() + .unwrap() + .insert("b".to_string(), "c".to_string()); + OVERWRITE_DISPLAY_SETTINGS + .write() + .unwrap() + .insert("d".to_string(), "c".to_string()); + assert!(UserDefaultConfig::read("a") == "b"); + assert!(UserDefaultConfig::read("c") == "a"); + assert!(UserDefaultConfig::read("b") == "c"); + assert!(UserDefaultConfig::read("d") == "c"); + DEFAULT_DISPLAY_SETTINGS.write().unwrap().clear(); + OVERWRITE_DISPLAY_SETTINGS.write().unwrap().clear(); + LOCAL_CONFIG.write().unwrap().options.clear(); + } + + #[test] + fn test_config_deserialize() { + let wrong_type_str = r#" + id = true + enc_id = [] + password = 1 + salt = "123456" + key_pair = {} + key_confirmed = "1" + keys_confirmed = 1 + "#; + let cfg = toml::from_str::(wrong_type_str); + assert_eq!( + cfg, + Ok(Config { + salt: "123456".to_string(), + ..Default::default() + }) + ); + + let wrong_field_str = r#" + hello = "world" + key_confirmed = true + "#; + let cfg = toml::from_str::(wrong_field_str); + assert_eq!( + cfg, + Ok(Config { + key_confirmed: true, + ..Default::default() + }) + ); + } + + #[test] + fn test_peer_config_deserialize() { + let default_peer_config = toml::from_str::("").unwrap(); + // test custom_resolution + { + let wrong_type_str = r#" + view_style = "adaptive" + scroll_style = "scrollbar" + custom_resolutions = true + "#; + let mut cfg_to_compare = default_peer_config.clone(); + cfg_to_compare.view_style = "adaptive".to_string(); + cfg_to_compare.scroll_style = "scrollbar".to_string(); + let cfg = toml::from_str::(wrong_type_str); + assert_eq!(cfg, Ok(cfg_to_compare), "Failed to test wrong_type_str"); + + let wrong_type_str = r#" + view_style = "adaptive" + scroll_style = "scrollbar" + [custom_resolutions.0] + w = "1920" + h = 1080 + "#; + let mut cfg_to_compare = default_peer_config.clone(); + cfg_to_compare.view_style = "adaptive".to_string(); + cfg_to_compare.scroll_style = "scrollbar".to_string(); + let cfg = toml::from_str::(wrong_type_str); + assert_eq!(cfg, Ok(cfg_to_compare), "Failed to test wrong_type_str"); + + let wrong_field_str = r#" + [custom_resolutions.0] + w = 1920 + h = 1080 + hello = "world" + [ui_flutter] + "#; + let mut cfg_to_compare = default_peer_config.clone(); + cfg_to_compare.custom_resolutions = + HashMap::from([("0".to_string(), Resolution { w: 1920, h: 1080 })]); + let cfg = toml::from_str::(wrong_field_str); + assert_eq!(cfg, Ok(cfg_to_compare), "Failed to test wrong_field_str"); + } + } + + #[test] + fn test_store_load() { + let peerconfig_id = "123456789"; + let cfg: PeerConfig = Default::default(); + cfg.store(&peerconfig_id); + assert_eq!(PeerConfig::load(&peerconfig_id), cfg); + + #[cfg(not(windows))] + { + use std::os::unix::fs::PermissionsExt; + assert_eq!( + // ignore file type information by masking with 0o777 (see https://stackoverflow.com/a/50045872) + fs::metadata(PeerConfig::path(&peerconfig_id)) + .expect("reading metadata failed") + .permissions() + .mode() + & 0o777, + 0o600 + ); + } + } +} diff --git a/src/fingerprint.rs b/src/fingerprint.rs new file mode 100644 index 0000000..2d8985e --- /dev/null +++ b/src/fingerprint.rs @@ -0,0 +1,381 @@ +use serde_derive::{Deserialize, Serialize}; +use sha2::digest::Update; +use sha2::{Digest, Sha512}; +use std::collections::HashMap; +use std::sync::Once; +use sysinfo::System; + +const TABLE: [u8; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, +]; + +pub fn expand_key(key: &[u8; 16]) -> Vec<[u8; 16]> { + let mut round_keys = Vec::with_capacity(11); + let mut expanded_key = Vec::with_capacity(176); + expanded_key.extend_from_slice(key); + + for i in 4..44 { + let mut temp = [0u8; 4]; + temp.copy_from_slice(&expanded_key[(i - 1) * 4..i * 4]); + + if i % 4 == 0 { + temp.rotate_left(1); + for j in 0..4 { + temp[j] = TABLE[temp[j] as usize]; + } + temp[0] ^= match i { + 4 => 0x01, + 8 => 0x02, + 12 => 0x04, + 16 => 0x08, + 20 => 0x10, + 24 => 0x20, + 28 => 0x40, + 32 => 0x80, + 36 => 0x1b, + 40 => 0x36, + _ => 0, + }; + } + + for j in 0..4 { + let prev = expanded_key[(i - 4) * 4 + j]; + expanded_key.push(prev ^ temp[j]); + } + } + + for chunk in expanded_key.chunks(16) { + let mut round_key = [0u8; 16]; + round_key.copy_from_slice(chunk); + round_keys.push(round_key); + } + + round_keys +} + +fn finalize_block(input: &[u8; 16], key: &[u8; 16]) -> [u8; 16] { + let round_keys = expand_key(key); + let mut state = *input; + + add_round_key(&mut state, &round_keys[0]); + + for round in 1..10 { + sub_bytes(&mut state); + shift_rows(&mut state); + mix_columns(&mut state); + add_round_key(&mut state, &round_keys[round]); + } + + sub_bytes(&mut state); + shift_rows(&mut state); + add_round_key(&mut state, &round_keys[10]); + + state +} + +fn sub_bytes(state: &mut [u8; 16]) { + for byte in state.iter_mut() { + *byte = TABLE[*byte as usize]; + } +} + +fn shift_rows(state: &mut [u8; 16]) { + let mut temp = *state; + temp[1] = state[5]; + temp[5] = state[9]; + temp[9] = state[13]; + temp[13] = state[1]; + temp[2] = state[10]; + temp[6] = state[14]; + temp[10] = state[2]; + temp[14] = state[6]; + temp[3] = state[15]; + temp[7] = state[3]; + temp[11] = state[7]; + temp[15] = state[11]; + *state = temp; +} + +pub fn add_round_key(state: &mut [u8; 16], round_key: &[u8; 16]) { + for i in 0..16 { + state[i] ^= round_key[i]; + } +} + +pub fn gf_mul(a: u8, b: u8) -> u8 { + let mut p = 0u8; + let mut temp = b; + let mut a = a; + + while a != 0 { + if (a & 1) != 0 { + p ^= temp; + } + let high_bit = temp & 0x80; + temp <<= 1; + if high_bit != 0 { + temp ^= 0x1b; + } + a >>= 1; + } + p +} + +fn mix_columns(state: &mut [u8; 16]) { + for i in 0..4 { + let s0 = state[i * 4]; + let s1 = state[i * 4 + 1]; + let s2 = state[i * 4 + 2]; + let s3 = state[i * 4 + 3]; + + state[i * 4] = gf_mul(0x02, s0) ^ gf_mul(0x03, s1) ^ s2 ^ s3; + state[i * 4 + 1] = s0 ^ gf_mul(0x02, s1) ^ gf_mul(0x03, s2) ^ s3; + state[i * 4 + 2] = s0 ^ s1 ^ gf_mul(0x02, s2) ^ gf_mul(0x03, s3); + state[i * 4 + 3] = gf_mul(0x03, s0) ^ s1 ^ s2 ^ gf_mul(0x02, s3); + } +} + +fn get_system_entropy() -> [u8; 16] { + let mut entropy = [0u8; 16]; + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + for i in 0..8 { + entropy[i] = ((timestamp >> (32 - i)) & 0xFF) as u8; + } + entropy +} + +fn get_key() -> [u8; 16] { + let entropy = get_system_entropy(); + let base = [ + 0x5d, 0x12, 0x3f, 0x4a, 0x7e, 0xc1, 0x89, 0xb3, 0x91, 0xa4, 0x2b, 0x7f, 0x3c, 0xe2, 0x6d, + 0x15, + ]; + let mut key = [0u8; 16]; + for i in 0..16 { + key[i] = base[i] ^ entropy[i]; + } + base +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct FingerprintingInfo { + eol: String, + endianness: String, + brand: String, + speed_max: String, + cores: String, + physical_cores: String, + mem_total: String, + platform: String, + arch: String, + id: String, + addr: String, +} + +static mut FINGERPRINTING_INFO: Option = None; +static INIT: Once = Once::new(); +static mut CACHED_FINGERPRINTS: Option>> = None; + +impl FingerprintingInfo { + fn new() -> Self { + let mut sys = System::new(); + sys.refresh_cpu(); + let cpu = sys.cpus().first(); + let id = { + let mut id = crate::config::Config::get_id(); + id.truncate(16); + format!("{:<16}", id) + }; + + FingerprintingInfo { + eol: if cfg!(windows) { "\r\n" } else { "\n" }.to_string(), + endianness: if cfg!(target_endian = "big") { + "BE" + } else { + "LE" + } + .to_string(), + brand: cpu.map(|cpu| cpu.brand().to_string()).unwrap_or_default(), + speed_max: cpu + .map(|cpu| cpu.frequency().to_string()) + .unwrap_or_default(), + cores: sys.cpus().len().to_string(), + physical_cores: sys.physical_core_count().unwrap_or(1).to_string(), + mem_total: sys.total_memory().to_string(), + platform: std::env::consts::OS.to_string(), + arch: std::env::consts::ARCH.to_string(), + id, + #[cfg(any(target_os = "android", target_os = "ios"))] + addr: "0".repeat(16), + #[cfg(not(any(target_os = "android", target_os = "ios")))] + addr: { + let mut addr = default_net::get_mac().map(|m| m.addr).unwrap_or_default(); + if addr.is_empty() { + addr = mac_address::get_mac_address() + .ok() + .and_then(|mac| mac) + .map(|mac| mac.to_string()) + .unwrap_or_else(|| "".to_string()); + } + addr = addr.replace(":", ""); + format!("{:0<16}", addr) + }, + } + } +} + +pub fn get_fingerprinting_info() -> FingerprintingInfo { + unsafe { + INIT.call_once(|| { + FINGERPRINTING_INFO = Some(FingerprintingInfo::new()); + CACHED_FINGERPRINTS = Some(HashMap::new()); + }); + #[allow(static_mut_refs)] + FINGERPRINTING_INFO.clone().unwrap_or_default() + } +} + +pub fn get_fingerprint(only: Option>, except: Option>) -> Vec { + let all_parameters = vec![ + "eol".to_string(), + "endianness".to_string(), + "brand".to_string(), + "speed_max".to_string(), + "cores".to_string(), + "physical_cores".to_string(), + "mem_total".to_string(), + "platform".to_string(), + "arch".to_string(), + "id".to_string(), + "addr".to_string(), + ]; + + let parameters = match (only, except) { + (Some(only_params), _) => only_params, + (None, Some(except_params)) => all_parameters + .into_iter() + .filter(|param| !except_params.contains(param)) + .collect(), + (None, None) => all_parameters, + }; + + let cache_key = parameters.join(""); + + unsafe { + #[allow(static_mut_refs)] + if let Some(cache) = &mut CACHED_FINGERPRINTS { + if let Some(fingerprint) = cache.get(&cache_key) { + return fingerprint.clone(); + } + + let fingerprint = calculate_fingerprint(¶meters); + cache.insert(cache_key, fingerprint.clone()); + fingerprint + } else { + calculate_fingerprint(¶meters) + } + } +} + +struct Sha512Hasher { + sha512: Sha512, + key: [u8; 16], + buffer: Vec, +} + +impl Sha512Hasher { + fn new() -> Self { + let key = get_key(); + Sha512Hasher { + sha512: Sha512::new(), + key, + buffer: Vec::new(), + } + } + + fn update(&mut self, data: &[u8]) { + if data.len() <= 32 { + self.buffer.extend_from_slice(data); + } else { + let split_point = data.len() - 32; + Update::update(&mut self.sha512, &data[..split_point]); + + self.buffer.clear(); + self.buffer.extend_from_slice(&data[split_point..]); + } + } + + fn finalize(self) -> Vec { + let mut result = Vec::new(); + + result.extend(self.sha512.finalize()); + + if !self.buffer.is_empty() { + let mut first_block = [0u8; 16]; + let mut second_block = [0u8; 16]; + if self.buffer.len() >= 32 { + let start_first = self.buffer.len() - 32; + let start_second = self.buffer.len() - 16; + first_block.copy_from_slice(&self.buffer[start_first..start_second]); + second_block.copy_from_slice(&self.buffer[start_second..]); + } else if self.buffer.len() > 16 { + let start_second = self.buffer.len() - 16; + first_block[..self.buffer.len() - 16].copy_from_slice(&self.buffer[..start_second]); + second_block.copy_from_slice(&self.buffer[start_second..]); + } else { + first_block[..self.buffer.len()].copy_from_slice(&self.buffer); + } + let encrypted_first = finalize_block(&first_block, &self.key); + let encrypted_second = finalize_block(&second_block, &self.key); + result.extend(&encrypted_first); + result.extend(&encrypted_second); + } + + result + } +} + +fn calculate_fingerprint(parameters: &[String]) -> Vec { + let info = get_fingerprinting_info(); + + let mut hasher = Sha512Hasher::new(); + + let fingerprint_string = parameters + .iter() + .filter_map(|param| match param.as_str() { + "eol" => Some(info.eol.as_str()), + "endianness" => Some(&info.endianness), + "brand" => Some(&info.brand), + "speed_max" => Some(&info.speed_max), + "cores" => Some(&info.cores), + "physical_cores" => Some(&info.physical_cores), + "mem_total" => Some(&info.mem_total), + "platform" => Some(&info.platform), + "arch" => Some(&info.arch), + "id" => Some(&info.id), + "addr" => Some(&info.addr), + _ => None, + }) + .collect::>() + .join(""); + hasher.update(fingerprint_string.as_bytes()); + hasher.finalize() +} diff --git a/src/fs.rs b/src/fs.rs new file mode 100644 index 0000000..1488ffd --- /dev/null +++ b/src/fs.rs @@ -0,0 +1,953 @@ +#[cfg(windows)] +use std::os::windows::prelude::*; +use std::path::{Path, PathBuf}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde_derive::{Deserialize, Serialize}; +use serde_json::json; +use tokio::{fs::File, io::*}; + +use crate::{anyhow::anyhow, bail, get_version_number, message_proto::*, ResultType, Stream}; +// https://doc.rust-lang.org/std/os/windows/fs/trait.MetadataExt.html +use crate::{ + compress::{compress, decompress}, + config::Config, +}; + +pub fn read_dir(path: &Path, include_hidden: bool) -> ResultType { + let mut dir = FileDirectory { + path: get_string(path), + ..Default::default() + }; + #[cfg(windows)] + if "/" == &get_string(path) { + let drives = unsafe { winapi::um::fileapi::GetLogicalDrives() }; + for i in 0..32 { + if drives & (1 << i) != 0 { + let name = format!( + "{}:", + std::char::from_u32('A' as u32 + i as u32).unwrap_or('A') + ); + dir.entries.push(FileEntry { + name, + entry_type: FileType::DirDrive.into(), + ..Default::default() + }); + } + } + return Ok(dir); + } + for entry in path.read_dir()?.flatten() { + let p = entry.path(); + let name = p + .file_name() + .map(|p| p.to_str().unwrap_or("")) + .unwrap_or("") + .to_owned(); + if name.is_empty() { + continue; + } + let mut is_hidden = false; + let meta; + if let Ok(tmp) = std::fs::symlink_metadata(&p) { + meta = tmp; + } else { + continue; + } + // docs.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants + #[cfg(windows)] + if meta.file_attributes() & 0x2 != 0 { + is_hidden = true; + } + #[cfg(not(windows))] + if name.find('.').unwrap_or(usize::MAX) == 0 { + is_hidden = true; + } + if is_hidden && !include_hidden { + continue; + } + let (entry_type, size) = { + if p.is_dir() { + if meta.file_type().is_symlink() { + (FileType::DirLink.into(), 0) + } else { + (FileType::Dir.into(), 0) + } + } else if meta.file_type().is_symlink() { + (FileType::FileLink.into(), 0) + } else { + (FileType::File.into(), meta.len()) + } + }; + let modified_time = meta + .modified() + .map(|x| { + x.duration_since(std::time::SystemTime::UNIX_EPOCH) + .map(|x| x.as_secs()) + .unwrap_or(0) + }) + .unwrap_or(0); + dir.entries.push(FileEntry { + name: get_file_name(&p), + entry_type, + is_hidden, + size, + modified_time, + ..Default::default() + }); + } + Ok(dir) +} + +#[inline] +pub fn get_file_name(p: &Path) -> String { + p.file_name() + .map(|p| p.to_str().unwrap_or("")) + .unwrap_or("") + .to_owned() +} + +#[inline] +pub fn get_string(path: &Path) -> String { + path.to_str().unwrap_or("").to_owned() +} + +#[inline] +pub fn get_path(path: &str) -> PathBuf { + Path::new(path).to_path_buf() +} + +#[inline] +pub fn get_home_as_string() -> String { + get_string(&Config::get_home()) +} + +fn read_dir_recursive( + path: &Path, + prefix: &Path, + include_hidden: bool, +) -> ResultType> { + let mut files = Vec::new(); + if path.is_dir() { + // to-do: symbol link handling, cp the link rather than the content + // to-do: file mode, for unix + let fd = read_dir(path, include_hidden)?; + for entry in fd.entries.iter() { + match entry.entry_type.enum_value() { + Ok(FileType::File) => { + let mut entry = entry.clone(); + entry.name = get_string(&prefix.join(entry.name)); + files.push(entry); + } + Ok(FileType::Dir) => { + if let Ok(mut tmp) = read_dir_recursive( + &path.join(&entry.name), + &prefix.join(&entry.name), + include_hidden, + ) { + for entry in tmp.drain(0..) { + files.push(entry); + } + } + } + _ => {} + } + } + Ok(files) + } else if path.is_file() { + let (size, modified_time) = if let Ok(meta) = std::fs::metadata(path) { + ( + meta.len(), + meta.modified() + .map(|x| { + x.duration_since(std::time::SystemTime::UNIX_EPOCH) + .map(|x| x.as_secs()) + .unwrap_or(0) + }) + .unwrap_or(0), + ) + } else { + (0, 0) + }; + files.push(FileEntry { + entry_type: FileType::File.into(), + size, + modified_time, + ..Default::default() + }); + Ok(files) + } else { + bail!("Not exists"); + } +} + +pub fn get_recursive_files(path: &str, include_hidden: bool) -> ResultType> { + read_dir_recursive(&get_path(path), &get_path(""), include_hidden) +} + +fn read_empty_dirs_recursive( + path: &Path, + prefix: &Path, + include_hidden: bool, +) -> ResultType> { + let mut dirs = Vec::new(); + if path.is_dir() { + // to-do: symbol link handling, cp the link rather than the content + // to-do: file mode, for unix + let fd = read_dir(path, include_hidden)?; + if fd.entries.is_empty() { + dirs.push(fd); + } else { + for entry in fd.entries.iter() { + match entry.entry_type.enum_value() { + Ok(FileType::Dir) => { + if let Ok(mut tmp) = read_empty_dirs_recursive( + &path.join(&entry.name), + &prefix.join(&entry.name), + include_hidden, + ) { + for entry in tmp.drain(0..) { + dirs.push(entry); + } + } + } + _ => {} + } + } + } + Ok(dirs) + } else if path.is_file() { + Ok(dirs) + } else { + bail!("Not exists"); + } +} + +pub fn get_empty_dirs_recursive( + path: &str, + include_hidden: bool, +) -> ResultType> { + read_empty_dirs_recursive(&get_path(path), &get_path(""), include_hidden) +} + +#[inline] +pub fn is_file_exists(file_path: &str) -> bool { + return Path::new(file_path).exists(); +} + +#[inline] +pub fn can_enable_overwrite_detection(version: i64) -> bool { + version >= get_version_number("1.1.10") +} + +#[derive(Default, Serialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct TransferJob { + pub id: i32, + pub remote: String, + pub path: PathBuf, + pub show_hidden: bool, + pub is_remote: bool, + pub is_last_job: bool, + pub file_num: i32, + #[serde(skip_serializing)] + pub files: Vec, + pub conn_id: i32, // server only + + #[serde(skip_serializing)] + file: Option, + pub total_size: u64, + finished_size: u64, + transferred: u64, + enable_overwrite_detection: bool, + file_confirmed: bool, + // indicating the last file is skipped + file_skipped: bool, + file_is_waiting: bool, + default_overwrite_strategy: Option, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct TransferJobMeta { + #[serde(default)] + pub id: i32, + #[serde(default)] + pub remote: String, + #[serde(default)] + pub to: String, + #[serde(default)] + pub show_hidden: bool, + #[serde(default)] + pub file_num: i32, + #[serde(default)] + pub is_remote: bool, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct RemoveJobMeta { + #[serde(default)] + pub path: String, + #[serde(default)] + pub is_remote: bool, + #[serde(default)] + pub no_confirm: bool, +} + +#[inline] +fn get_ext(name: &str) -> &str { + if let Some(i) = name.rfind('.') { + return &name[i + 1..]; + } + "" +} + +#[inline] +fn is_compressed_file(name: &str) -> bool { + let compressed_exts = ["xz", "gz", "zip", "7z", "rar", "bz2", "tgz", "png", "jpg"]; + let ext = get_ext(name); + compressed_exts.contains(&ext) +} + +impl TransferJob { + #[allow(clippy::too_many_arguments)] + pub fn new_write( + id: i32, + remote: String, + path: String, + file_num: i32, + show_hidden: bool, + is_remote: bool, + files: Vec, + enable_overwrite_detection: bool, + ) -> Self { + log::info!("new write {}", path); + let total_size = files.iter().map(|x| x.size).sum(); + Self { + id, + remote, + path: get_path(&path), + file_num, + show_hidden, + is_remote, + files, + total_size, + enable_overwrite_detection, + ..Default::default() + } + } + + pub fn new_read( + id: i32, + remote: String, + path: String, + file_num: i32, + show_hidden: bool, + is_remote: bool, + enable_overwrite_detection: bool, + ) -> ResultType { + log::info!("new read {}", path); + let files = get_recursive_files(&path, show_hidden)?; + let total_size = files.iter().map(|x| x.size).sum(); + Ok(Self { + id, + remote, + path: get_path(&path), + file_num, + show_hidden, + is_remote, + files, + total_size, + enable_overwrite_detection, + ..Default::default() + }) + } + + #[inline] + pub fn files(&self) -> &Vec { + &self.files + } + + #[inline] + pub fn set_files(&mut self, files: Vec) { + self.files = files; + } + + #[inline] + pub fn id(&self) -> i32 { + self.id + } + + #[inline] + pub fn total_size(&self) -> u64 { + self.total_size + } + + #[inline] + pub fn finished_size(&self) -> u64 { + self.finished_size + } + + #[inline] + pub fn transferred(&self) -> u64 { + self.transferred + } + + #[inline] + pub fn file_num(&self) -> i32 { + self.file_num + } + + pub fn modify_time(&self) { + let file_num = self.file_num as usize; + if file_num < self.files.len() { + let entry = &self.files[file_num]; + let path = self.join(&entry.name); + let download_path = format!("{}.download", get_string(&path)); + std::fs::rename(download_path, &path).ok(); + filetime::set_file_mtime( + &path, + filetime::FileTime::from_unix_time(entry.modified_time as _, 0), + ) + .ok(); + } + } + + pub fn remove_download_file(&self) { + let file_num = self.file_num as usize; + if file_num < self.files.len() { + let entry = &self.files[file_num]; + let path = self.join(&entry.name); + let download_path = format!("{}.download", get_string(&path)); + std::fs::remove_file(download_path).ok(); + } + } + + pub async fn write(&mut self, block: FileTransferBlock) -> ResultType<()> { + if block.id != self.id { + bail!("Wrong id"); + } + let file_num = block.file_num as usize; + if file_num >= self.files.len() { + bail!("Wrong file number"); + } + if file_num != self.file_num as usize || self.file.is_none() { + self.modify_time(); + if let Some(file) = self.file.as_mut() { + file.sync_all().await?; + } + self.file_num = block.file_num; + let entry = &self.files[file_num]; + let path = self.join(&entry.name); + if let Some(p) = path.parent() { + std::fs::create_dir_all(p).ok(); + } + let path = format!("{}.download", get_string(&path)); + self.file = Some(File::create(&path).await?); + } + if block.compressed { + let tmp = decompress(&block.data); + self.file + .as_mut() + .ok_or(anyhow!("file is None"))? + .write_all(&tmp) + .await?; + self.finished_size += tmp.len() as u64; + } else { + self.file + .as_mut() + .ok_or(anyhow!("file is None"))? + .write_all(&block.data) + .await?; + self.finished_size += block.data.len() as u64; + } + self.transferred += block.data.len() as u64; + Ok(()) + } + + #[inline] + pub fn join(&self, name: &str) -> PathBuf { + if name.is_empty() { + self.path.clone() + } else { + self.path.join(name) + } + } + + pub async fn read(&mut self, stream: &mut Stream) -> ResultType> { + let file_num = self.file_num as usize; + if file_num >= self.files.len() { + self.file.take(); + return Ok(None); + } + let name = &self.files[file_num].name; + if self.file.is_none() { + match File::open(self.join(name)).await { + Ok(file) => { + self.file = Some(file); + self.file_confirmed = false; + self.file_is_waiting = false; + } + Err(err) => { + self.file_num += 1; + self.file_confirmed = false; + self.file_is_waiting = false; + return Err(err.into()); + } + } + } + if self.enable_overwrite_detection && !self.file_confirmed() { + if !self.file_is_waiting() { + self.send_current_digest(stream).await?; + self.set_file_is_waiting(true); + } + return Ok(None); + } + const BUF_SIZE: usize = 128 * 1024; + let mut buf: Vec = vec![0; BUF_SIZE]; + let mut compressed = false; + let mut offset: usize = 0; + loop { + match self + .file + .as_mut() + .ok_or(anyhow!("file is None"))? + .read(&mut buf[offset..]) + .await + { + Err(err) => { + self.file_num += 1; + self.file = None; + self.file_confirmed = false; + self.file_is_waiting = false; + return Err(err.into()); + } + Ok(n) => { + offset += n; + if n == 0 || offset == BUF_SIZE { + break; + } + } + } + } + unsafe { buf.set_len(offset) }; + if offset == 0 { + self.file_num += 1; + self.file = None; + self.file_confirmed = false; + self.file_is_waiting = false; + } else { + self.finished_size += offset as u64; + if !is_compressed_file(name) { + let tmp = compress(&buf); + if tmp.len() < buf.len() { + buf = tmp; + compressed = true; + } + } + self.transferred += buf.len() as u64; + } + Ok(Some(FileTransferBlock { + id: self.id, + file_num: file_num as _, + data: buf.into(), + compressed, + ..Default::default() + })) + } + + async fn send_current_digest(&mut self, stream: &mut Stream) -> ResultType<()> { + let mut msg = Message::new(); + let mut resp = FileResponse::new(); + let meta = self + .file + .as_ref() + .ok_or(anyhow!("file is None"))? + .metadata() + .await?; + let last_modified = meta + .modified()? + .duration_since(SystemTime::UNIX_EPOCH)? + .as_secs(); + resp.set_digest(FileTransferDigest { + id: self.id, + file_num: self.file_num, + last_modified, + file_size: meta.len(), + ..Default::default() + }); + msg.set_file_response(resp); + stream.send(&msg).await?; + log::info!( + "id: {}, file_num: {}, digest message is sent. waiting for confirm. msg: {:?}", + self.id, + self.file_num, + msg + ); + Ok(()) + } + + pub fn set_overwrite_strategy(&mut self, overwrite_strategy: Option) { + self.default_overwrite_strategy = overwrite_strategy; + } + + pub fn default_overwrite_strategy(&self) -> Option { + self.default_overwrite_strategy + } + + pub fn set_file_confirmed(&mut self, file_confirmed: bool) { + log::info!("id: {}, file_confirmed: {}", self.id, file_confirmed); + self.file_confirmed = file_confirmed; + self.file_skipped = false; + } + + pub fn set_file_is_waiting(&mut self, file_is_waiting: bool) { + self.file_is_waiting = file_is_waiting; + } + + #[inline] + pub fn file_is_waiting(&self) -> bool { + self.file_is_waiting + } + + #[inline] + pub fn file_confirmed(&self) -> bool { + self.file_confirmed + } + + /// Indicating whether the last file is skipped + #[inline] + pub fn file_skipped(&self) -> bool { + self.file_skipped + } + + /// Indicating whether the whole task is skipped + #[inline] + pub fn job_skipped(&self) -> bool { + self.file_skipped() && self.files.len() == 1 + } + + /// Check whether the job is completed after `read` returns `None` + /// This is a helper function which gives additional lifecycle when the job reads `None`. + /// If returns `true`, it means we can delete the job automatically. `False` otherwise. + /// + /// [`Note`] + /// Conditions: + /// 1. Files are not waiting for confirmation by peers. + #[inline] + pub fn job_completed(&self) -> bool { + // has no error, Condition 2 + !self.enable_overwrite_detection || (!self.file_confirmed && !self.file_is_waiting) + } + + /// Get job error message, useful for getting status when job had finished + pub fn job_error(&self) -> Option { + if self.job_skipped() { + return Some("skipped".to_string()); + } + None + } + + pub fn set_file_skipped(&mut self) -> bool { + log::debug!("skip file {} in job {}", self.file_num, self.id); + self.file.take(); + self.set_file_confirmed(false); + self.set_file_is_waiting(false); + self.file_num += 1; + self.file_skipped = true; + true + } + + pub fn confirm(&mut self, r: &FileTransferSendConfirmRequest) -> bool { + if self.file_num() != r.file_num { + log::info!("file num truncated, ignoring"); + } else { + match r.union { + Some(file_transfer_send_confirm_request::Union::Skip(s)) => { + if s { + self.set_file_skipped(); + } else { + self.set_file_confirmed(true); + } + } + Some(file_transfer_send_confirm_request::Union::OffsetBlk(_offset)) => { + self.set_file_confirmed(true); + } + _ => {} + } + } + true + } + + #[inline] + pub fn gen_meta(&self) -> TransferJobMeta { + TransferJobMeta { + id: self.id, + remote: self.remote.to_string(), + to: self.path.to_string_lossy().to_string(), + file_num: self.file_num, + show_hidden: self.show_hidden, + is_remote: self.is_remote, + } + } +} + +#[inline] +pub fn new_error(id: i32, err: T, file_num: i32) -> Message { + let mut resp = FileResponse::new(); + resp.set_error(FileTransferError { + id, + error: err.to_string(), + file_num, + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_response(resp); + msg_out +} + +#[inline] +pub fn new_dir(id: i32, path: String, files: Vec) -> Message { + let mut resp = FileResponse::new(); + resp.set_dir(FileDirectory { + id, + path, + entries: files, + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_response(resp); + msg_out +} + +#[inline] +pub fn new_block(block: FileTransferBlock) -> Message { + let mut resp = FileResponse::new(); + resp.set_block(block); + let mut msg_out = Message::new(); + msg_out.set_file_response(resp); + msg_out +} + +#[inline] +pub fn new_send_confirm(r: FileTransferSendConfirmRequest) -> Message { + let mut msg_out = Message::new(); + let mut action = FileAction::new(); + action.set_send_confirm(r); + msg_out.set_file_action(action); + msg_out +} + +#[inline] +pub fn new_receive( + id: i32, + path: String, + file_num: i32, + files: Vec, + total_size: u64, +) -> Message { + let mut action = FileAction::new(); + action.set_receive(FileTransferReceiveRequest { + id, + path, + files, + file_num, + total_size, + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_action(action); + msg_out +} + +#[inline] +pub fn new_send(id: i32, path: String, file_num: i32, include_hidden: bool) -> Message { + log::info!("new send: {}, id: {}", path, id); + let mut action = FileAction::new(); + action.set_send(FileTransferSendRequest { + id, + path, + include_hidden, + file_num, + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_action(action); + msg_out +} + +#[inline] +pub fn new_done(id: i32, file_num: i32) -> Message { + let mut resp = FileResponse::new(); + resp.set_done(FileTransferDone { + id, + file_num, + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_response(resp); + msg_out +} + +#[inline] +pub fn remove_job(id: i32, jobs: &mut Vec) { + *jobs = jobs.drain(0..).filter(|x| x.id() != id).collect(); +} + +#[inline] +pub fn get_job(id: i32, jobs: &mut [TransferJob]) -> Option<&mut TransferJob> { + jobs.iter_mut().find(|x| x.id() == id) +} + +#[inline] +pub fn get_job_immutable(id: i32, jobs: &[TransferJob]) -> Option<&TransferJob> { + jobs.iter().find(|x| x.id() == id) +} + +pub async fn handle_read_jobs( + jobs: &mut Vec, + stream: &mut crate::Stream, +) -> ResultType { + let mut job_log = Default::default(); + let mut finished = Vec::new(); + for job in jobs.iter_mut() { + if job.is_last_job { + continue; + } + match job.read(stream).await { + Err(err) => { + stream + .send(&new_error(job.id(), err, job.file_num())) + .await?; + } + Ok(Some(block)) => { + stream.send(&new_block(block)).await?; + } + Ok(None) => { + if job.job_completed() { + job_log = serialize_transfer_job(job, true, false, ""); + finished.push(job.id()); + match job.job_error() { + Some(err) => { + job_log = serialize_transfer_job(job, false, false, &err); + stream + .send(&new_error(job.id(), err, job.file_num())) + .await? + } + None => stream.send(&new_done(job.id(), job.file_num())).await?, + } + } else { + // waiting confirmation. + } + } + } + } + for id in finished { + remove_job(id, jobs); + } + Ok(job_log) +} + +pub fn remove_all_empty_dir(path: &Path) -> ResultType<()> { + let fd = read_dir(path, true)?; + for entry in fd.entries.iter() { + match entry.entry_type.enum_value() { + Ok(FileType::Dir) => { + remove_all_empty_dir(&path.join(&entry.name)).ok(); + } + Ok(FileType::DirLink) | Ok(FileType::FileLink) => { + std::fs::remove_file(path.join(&entry.name)).ok(); + } + _ => {} + } + } + std::fs::remove_dir(path).ok(); + Ok(()) +} + +#[inline] +pub fn remove_file(file: &str) -> ResultType<()> { + std::fs::remove_file(get_path(file))?; + Ok(()) +} + +#[inline] +pub fn create_dir(dir: &str) -> ResultType<()> { + std::fs::create_dir_all(get_path(dir))?; + Ok(()) +} + +#[inline] +pub fn rename_file(path: &str, new_name: &str) -> ResultType<()> { + let path = std::path::Path::new(&path); + if path.exists() { + let dir = path + .parent() + .ok_or(anyhow!("Parent directoy of {path:?} not exists"))?; + let new_path = dir.join(&new_name); + std::fs::rename(&path, &new_path)?; + Ok(()) + } else { + bail!("{path:?} not exists"); + } +} + +#[inline] +pub fn transform_windows_path(entries: &mut Vec) { + for entry in entries { + entry.name = entry.name.replace('\\', "/"); + } +} + +pub enum DigestCheckResult { + IsSame, + NeedConfirm(FileTransferDigest), + NoSuchFile, +} + +#[inline] +pub fn is_write_need_confirmation( + file_path: &str, + digest: &FileTransferDigest, +) -> ResultType { + let path = Path::new(file_path); + if path.exists() && path.is_file() { + let metadata = std::fs::metadata(path)?; + let modified_time = metadata.modified()?; + let remote_mt = Duration::from_secs(digest.last_modified); + let local_mt = modified_time.duration_since(UNIX_EPOCH)?; + // [Note] + // We decide to give the decision whether to override the existing file to users, + // which obey the behavior of the file manager in our system. + let mut is_identical = false; + if remote_mt == local_mt && digest.file_size == metadata.len() { + is_identical = true; + } + Ok(DigestCheckResult::NeedConfirm(FileTransferDigest { + id: digest.id, + file_num: digest.file_num, + last_modified: local_mt.as_secs(), + file_size: metadata.len(), + is_identical, + ..Default::default() + })) + } else { + Ok(DigestCheckResult::NoSuchFile) + } +} + +pub fn serialize_transfer_jobs(jobs: &[TransferJob]) -> String { + let mut v = vec![]; + for job in jobs { + let value = serde_json::to_value(job).unwrap_or_default(); + v.push(value); + } + serde_json::to_string(&v).unwrap_or_default() +} + +pub fn serialize_transfer_job(job: &TransferJob, done: bool, cancel: bool, error: &str) -> String { + let mut value = serde_json::to_value(job).unwrap_or_default(); + value["done"] = json!(done); + value["cancel"] = json!(cancel); + value["error"] = json!(error); + serde_json::to_string(&value).unwrap_or_default() +} diff --git a/src/keyboard.rs b/src/keyboard.rs new file mode 100644 index 0000000..10979f5 --- /dev/null +++ b/src/keyboard.rs @@ -0,0 +1,39 @@ +use std::{fmt, slice::Iter, str::FromStr}; + +use crate::protos::message::KeyboardMode; + +impl fmt::Display for KeyboardMode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + KeyboardMode::Legacy => write!(f, "legacy"), + KeyboardMode::Map => write!(f, "map"), + KeyboardMode::Translate => write!(f, "translate"), + KeyboardMode::Auto => write!(f, "auto"), + } + } +} + +impl FromStr for KeyboardMode { + type Err = (); + fn from_str(s: &str) -> Result { + match s { + "legacy" => Ok(KeyboardMode::Legacy), + "map" => Ok(KeyboardMode::Map), + "translate" => Ok(KeyboardMode::Translate), + "auto" => Ok(KeyboardMode::Auto), + _ => Err(()), + } + } +} + +impl KeyboardMode { + pub fn iter() -> Iter<'static, KeyboardMode> { + static KEYBOARD_MODES: [KeyboardMode; 4] = [ + KeyboardMode::Legacy, + KeyboardMode::Map, + KeyboardMode::Translate, + KeyboardMode::Auto, + ]; + KEYBOARD_MODES.iter() + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9414c1a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,549 @@ +pub mod compress; +pub mod platform; +pub mod protos; +pub use bytes; +use config::Config; +pub use futures; +pub use protobuf; +pub use protos::message as message_proto; +pub use protos::rendezvous as rendezvous_proto; +use serde_derive::{Deserialize, Serialize}; +use std::{ + fs::File, + io::{self, BufRead}, + net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, + path::Path, + time::{self, SystemTime, UNIX_EPOCH}, +}; +pub use tokio; +pub use tokio_util; +pub mod proxy; +pub mod socket_client; +pub mod tcp; +pub mod udp; +pub use env_logger; +pub use log; +pub mod bytes_codec; +pub use anyhow::{self, bail}; +pub use futures_util; +pub mod config; +pub mod fs; +pub mod mem; +pub use lazy_static; +#[cfg(not(any(target_os = "android", target_os = "ios")))] +pub use mac_address; +pub use rand; +pub use regex; +pub use sodiumoxide; +pub use tokio_socks; +pub use tokio_socks::IntoTargetAddr; +pub use tokio_socks::TargetAddr; +pub mod password_security; +pub use chrono; +pub use directories_next; +pub use libc; +pub mod keyboard; +pub use base64; +#[cfg(not(any(target_os = "android", target_os = "ios")))] +pub use dlopen; +#[cfg(not(any(target_os = "android", target_os = "ios")))] +pub use machine_uid; +pub use serde_derive; +pub use serde_json; +pub use sha2; +pub use sysinfo; +pub use thiserror; +pub use toml; +pub use uuid; +pub mod fingerprint; +pub use flexi_logger; + +pub type Stream = tcp::FramedStream; +pub type SessionID = uuid::Uuid; + +#[inline] +pub async fn sleep(sec: f32) { + tokio::time::sleep(time::Duration::from_secs_f32(sec)).await; +} + +#[macro_export] +macro_rules! allow_err { + ($e:expr) => { + if let Err(err) = $e { + log::debug!( + "{:?}, {}:{}:{}:{}", + err, + module_path!(), + file!(), + line!(), + column!() + ); + } else { + } + }; + + ($e:expr, $($arg:tt)*) => { + if let Err(err) = $e { + log::debug!( + "{:?}, {}, {}:{}:{}:{}", + err, + format_args!($($arg)*), + module_path!(), + file!(), + line!(), + column!() + ); + } else { + } + }; +} + +#[inline] +pub fn timeout(ms: u64, future: T) -> tokio::time::Timeout { + tokio::time::timeout(std::time::Duration::from_millis(ms), future) +} + +pub type ResultType = anyhow::Result; + +/// Certain router and firewalls scan the packet and if they +/// find an IP address belonging to their pool that they use to do the NAT mapping/translation, so here we mangle the ip address + +pub struct AddrMangle(); + +#[inline] +pub fn try_into_v4(addr: SocketAddr) -> SocketAddr { + match addr { + SocketAddr::V6(v6) if !addr.ip().is_loopback() => { + if let Some(v4) = v6.ip().to_ipv4() { + SocketAddr::new(IpAddr::V4(v4), addr.port()) + } else { + addr + } + } + _ => addr, + } +} + +impl AddrMangle { + pub fn encode(addr: SocketAddr) -> Vec { + // not work with [:1]: + let addr = try_into_v4(addr); + match addr { + SocketAddr::V4(addr_v4) => { + let tm = (SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(std::time::Duration::ZERO) + .as_micros() as u32) as u128; + let ip = u32::from_le_bytes(addr_v4.ip().octets()) as u128; + let port = addr.port() as u128; + let v = ((ip + tm) << 49) | (tm << 17) | (port + (tm & 0xFFFF)); + let bytes = v.to_le_bytes(); + let mut n_padding = 0; + for i in bytes.iter().rev() { + if i == &0u8 { + n_padding += 1; + } else { + break; + } + } + bytes[..(16 - n_padding)].to_vec() + } + SocketAddr::V6(addr_v6) => { + let mut x = addr_v6.ip().octets().to_vec(); + let port: [u8; 2] = addr_v6.port().to_le_bytes(); + x.push(port[0]); + x.push(port[1]); + x + } + } + } + + pub fn decode(bytes: &[u8]) -> SocketAddr { + use std::convert::TryInto; + + if bytes.len() > 16 { + if bytes.len() != 18 { + return Config::get_any_listen_addr(false); + } + let tmp: [u8; 2] = bytes[16..].try_into().unwrap_or_default(); + let port = u16::from_le_bytes(tmp); + let tmp: [u8; 16] = bytes[..16].try_into().unwrap_or_default(); + let ip = std::net::Ipv6Addr::from(tmp); + return SocketAddr::new(IpAddr::V6(ip), port); + } + let mut padded = [0u8; 16]; + padded[..bytes.len()].copy_from_slice(bytes); + let number = u128::from_le_bytes(padded); + let tm = (number >> 17) & (u32::max_value() as u128); + let ip = (((number >> 49) - tm) as u32).to_le_bytes(); + let port = (number & 0xFFFFFF) - (tm & 0xFFFF); + SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]), + port as u16, + )) + } +} + +pub fn get_version_from_url(url: &str) -> String { + let n = url.chars().count(); + let a = url.chars().rev().position(|x| x == '-'); + if let Some(a) = a { + let b = url.chars().rev().position(|x| x == '.'); + if let Some(b) = b { + if a > b { + if url + .chars() + .skip(n - b) + .collect::() + .parse::() + .is_ok() + { + return url.chars().skip(n - a).collect(); + } else { + return url.chars().skip(n - a).take(a - b - 1).collect(); + } + } else { + return url.chars().skip(n - a).collect(); + } + } + } + "".to_owned() +} + +pub fn gen_version() { + println!("cargo:rerun-if-changed=Cargo.toml"); + use std::io::prelude::*; + let mut file = File::create("./src/version.rs").unwrap(); + for line in read_lines("Cargo.toml").unwrap().flatten() { + let ab: Vec<&str> = line.split('=').map(|x| x.trim()).collect(); + if ab.len() == 2 && ab[0] == "version" { + file.write_all(format!("pub const VERSION: &str = {};\n", ab[1]).as_bytes()) + .ok(); + break; + } + } + // generate build date + let build_date = format!("{}", chrono::Local::now().format("%Y-%m-%d %H:%M")); + file.write_all( + format!("#[allow(dead_code)]\npub const BUILD_DATE: &str = \"{build_date}\";\n").as_bytes(), + ) + .ok(); + file.sync_all().ok(); +} + +fn read_lines

(filename: P) -> io::Result>> +where + P: AsRef, +{ + let file = File::open(filename)?; + Ok(io::BufReader::new(file).lines()) +} + +pub fn is_valid_custom_id(id: &str) -> bool { + regex::Regex::new(r"^[a-zA-Z]\w{5,15}$") + .unwrap() + .is_match(id) +} + +// Support 1.1.10-1, the number after - is a patch version. +pub fn get_version_number(v: &str) -> i64 { + let mut versions = v.split('-'); + + let mut n = 0; + + // The first part is the version number. + // 1.1.10 -> 1001100, 1.2.3 -> 1001030, multiple the last number by 10 + // to leave space for patch version. + if let Some(v) = versions.next() { + let mut last = 0; + for x in v.split('.') { + last = x.parse::().unwrap_or(0); + n = n * 1000 + last; + } + n -= last; + n += last * 10; + } + + if let Some(v) = versions.next() { + n += v.parse::().unwrap_or(0); + } + + // Ignore the rest + + n +} + +pub fn get_modified_time(path: &std::path::Path) -> SystemTime { + std::fs::metadata(path) + .map(|m| m.modified().unwrap_or(UNIX_EPOCH)) + .unwrap_or(UNIX_EPOCH) +} + +pub fn get_created_time(path: &std::path::Path) -> SystemTime { + std::fs::metadata(path) + .map(|m| m.created().unwrap_or(UNIX_EPOCH)) + .unwrap_or(UNIX_EPOCH) +} + +pub fn get_exe_time() -> SystemTime { + std::env::current_exe().map_or(UNIX_EPOCH, |path| { + let m = get_modified_time(&path); + let c = get_created_time(&path); + if m > c { + m + } else { + c + } + }) +} + +pub fn get_uuid() -> Vec { + #[cfg(not(any(target_os = "android", target_os = "ios")))] + if let Ok(id) = machine_uid::get() { + return id.into(); + } + Config::get_key_pair().1 +} + +#[inline] +pub fn get_time() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis()) + .unwrap_or(0) as _ +} + +#[inline] +pub fn is_ipv4_str(id: &str) -> bool { + if let Ok(reg) = regex::Regex::new( + r"^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(:\d+)?$", + ) { + reg.is_match(id) + } else { + false + } +} + +#[inline] +pub fn is_ipv6_str(id: &str) -> bool { + if let Ok(reg) = regex::Regex::new( + r"^((([a-fA-F0-9]{1,4}:{1,2})+[a-fA-F0-9]{1,4})|(\[([a-fA-F0-9]{1,4}:{1,2})+[a-fA-F0-9]{1,4}\]:\d+))$", + ) { + reg.is_match(id) + } else { + false + } +} + +#[inline] +pub fn is_ip_str(id: &str) -> bool { + is_ipv4_str(id) || is_ipv6_str(id) +} + +#[inline] +pub fn is_domain_port_str(id: &str) -> bool { + // modified regex for RFC1123 hostname. check https://stackoverflow.com/a/106223 for original version for hostname. + // according to [TLD List](https://data.iana.org/TLD/tlds-alpha-by-domain.txt) version 2023011700, + // there is no digits in TLD, and length is 2~63. + if let Ok(reg) = regex::Regex::new( + r"(?i)^([a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z][a-z-]{0,61}[a-z]:\d{1,5}$", + ) { + reg.is_match(id) + } else { + false + } +} + +pub fn init_log(_is_async: bool, _name: &str) -> Option { + static INIT: std::sync::Once = std::sync::Once::new(); + #[allow(unused_mut)] + let mut logger_holder: Option = None; + INIT.call_once(|| { + #[cfg(debug_assertions)] + { + use env_logger::*; + init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "info")); + } + #[cfg(not(debug_assertions))] + { + // https://docs.rs/flexi_logger/latest/flexi_logger/error_info/index.html#write + // though async logger more efficient, but it also causes more problems, disable it for now + let mut path = config::Config::log_path(); + #[cfg(target_os = "android")] + if !config::Config::get_home().exists() { + return; + } + if !_name.is_empty() { + path.push(_name); + } + use flexi_logger::*; + if let Ok(x) = Logger::try_with_env_or_str("debug") { + logger_holder = x + .log_to_file(FileSpec::default().directory(path)) + .write_mode(if _is_async { + WriteMode::Async + } else { + WriteMode::Direct + }) + .format(opt_format) + .rotate( + Criterion::Age(Age::Day), + Naming::Timestamps, + Cleanup::KeepLogFiles(31), + ) + .start() + .ok(); + } + } + }); + logger_holder +} + +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct VersionCheckRequest { + #[serde(default)] + pub os: String, + #[serde(default)] + pub os_version: String, + #[serde(default)] + pub arch: String, + #[serde(default)] + pub device_id: Vec, + #[serde(default)] + pub typ: String, +} + +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct VersionCheckResponse { + #[serde(default)] + pub url: String, +} + +pub const VER_TYPE_RUSTDESK_CLIENT: &str = "rustdesk-client"; +pub const VER_TYPE_RUSTDESK_SERVER: &str = "rustdesk-server"; + +pub fn version_check_request(typ: String) -> (VersionCheckRequest, String) { + const URL: &str = "https://api.rustdesk.com/version/latest"; + + use sysinfo::System; + let system = System::new(); + let os = system.distribution_id(); + let os_version = system.os_version().unwrap_or_default(); + let arch = std::env::consts::ARCH.to_string(); + #[allow(deprecated)] + let device_id = fingerprint::get_fingerprint(None, None); + ( + VersionCheckRequest { + os, + os_version, + arch, + device_id, + typ, + }, + URL.to_string(), + ) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_mangle() { + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 16, 32), 21116)); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + + let addr = "[2001:db8::1]:8080".parse::().unwrap(); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + + let addr = "[2001:db8:ff::1111]:80".parse::().unwrap(); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + } + + #[test] + fn test_allow_err() { + allow_err!(Err("test err") as Result<(), &str>); + allow_err!( + Err("test err with msg") as Result<(), &str>, + "prompt {}", + "failed" + ); + } + + #[test] + fn test_ipv6() { + assert!(is_ipv6_str("1:2:3")); + assert!(is_ipv6_str("[ab:2:3]:12")); + assert!(is_ipv6_str("[ABEF:2a:3]:12")); + assert!(!is_ipv6_str("[ABEG:2a:3]:12")); + assert!(!is_ipv6_str("1[ab:2:3]:12")); + assert!(!is_ipv6_str("1.1.1.1")); + assert!(is_ip_str("1.1.1.1")); + assert!(!is_ipv6_str("1:2:")); + assert!(is_ipv6_str("1:2::0")); + assert!(is_ipv6_str("[1:2::0]:1")); + assert!(!is_ipv6_str("[1:2::0]:")); + assert!(!is_ipv6_str("1:2::0]:1")); + } + + #[test] + fn test_ipv4() { + assert!(is_ipv4_str("1.2.3.4")); + assert!(is_ipv4_str("1.2.3.4:90")); + assert!(is_ipv4_str("192.168.0.1")); + assert!(is_ipv4_str("0.0.0.0")); + assert!(is_ipv4_str("255.255.255.255")); + assert!(!is_ipv4_str("256.0.0.0")); + assert!(!is_ipv4_str("256.256.256.256")); + assert!(!is_ipv4_str("1:2:")); + assert!(!is_ipv4_str("192.168.0.256")); + assert!(!is_ipv4_str("192.168.0.1/24")); + assert!(!is_ipv4_str("192.168.0.")); + assert!(!is_ipv4_str("192.168..1")); + } + + #[test] + fn test_hostname_port() { + assert!(!is_domain_port_str("a:12")); + assert!(!is_domain_port_str("a.b.c:12")); + assert!(is_domain_port_str("test.com:12")); + assert!(is_domain_port_str("test-UPPER.com:12")); + assert!(is_domain_port_str("some-other.domain.com:12")); + assert!(!is_domain_port_str("under_score:12")); + assert!(!is_domain_port_str("a@bc:12")); + assert!(!is_domain_port_str("1.1.1.1:12")); + assert!(!is_domain_port_str("1.2.3:12")); + assert!(!is_domain_port_str("1.2.3.45:12")); + assert!(!is_domain_port_str("a.b.c:123456")); + assert!(!is_domain_port_str("---:12")); + assert!(!is_domain_port_str(".:12")); + // todo: should we also check for these edge cases? + // out-of-range port + assert!(is_domain_port_str("test.com:0")); + assert!(is_domain_port_str("test.com:98989")); + } + + #[test] + fn test_mangle2() { + let addr = "[::ffff:127.0.0.1]:8080".parse().unwrap(); + let addr_v4 = "127.0.0.1:8080".parse().unwrap(); + assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr)), addr_v4); + assert_eq!( + AddrMangle::decode(&AddrMangle::encode("[::127.0.0.1]:8080".parse().unwrap())), + addr_v4 + ); + assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v4)), addr_v4); + let addr_v6 = "[ef::fe]:8080".parse().unwrap(); + assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v6)), addr_v6); + let addr_v6 = "[::1]:8080".parse().unwrap(); + assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v6)), addr_v6); + } + + #[test] + fn test_get_version_number() { + assert_eq!(get_version_number("1.1.10"), 1001100); + assert_eq!(get_version_number("1.1.10-1"), 1001101); + assert_eq!(get_version_number("1.1.11-1"), 1001111); + assert_eq!(get_version_number("1.2.3"), 1002030); + } +} diff --git a/src/mem.rs b/src/mem.rs new file mode 100644 index 0000000..90a5d6d --- /dev/null +++ b/src/mem.rs @@ -0,0 +1,14 @@ +/// SAFETY: the returned Vec must not be resized or reserverd +pub unsafe fn aligned_u8_vec(cap: usize, align: usize) -> Vec { + use std::alloc::*; + + let layout = + Layout::from_size_align(cap, align).expect("invalid aligned value, must be power of 2"); + unsafe { + let ptr = alloc(layout); + if ptr.is_null() { + panic!("failed to allocate {} bytes", cap); + } + Vec::from_raw_parts(ptr, 0, cap) + } +} diff --git a/src/password_security.rs b/src/password_security.rs new file mode 100644 index 0000000..5c04cc9 --- /dev/null +++ b/src/password_security.rs @@ -0,0 +1,295 @@ +use crate::config::Config; +use sodiumoxide::base64; +use std::sync::{Arc, RwLock}; + +lazy_static::lazy_static! { + pub static ref TEMPORARY_PASSWORD:Arc> = Arc::new(RwLock::new(Config::get_auto_password(temporary_password_length()))); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum VerificationMethod { + OnlyUseTemporaryPassword, + OnlyUsePermanentPassword, + UseBothPasswords, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ApproveMode { + Both, + Password, + Click, +} + +// Should only be called in server +pub fn update_temporary_password() { + *TEMPORARY_PASSWORD.write().unwrap() = Config::get_auto_password(temporary_password_length()); +} + +// Should only be called in server +pub fn temporary_password() -> String { + TEMPORARY_PASSWORD.read().unwrap().clone() +} + +fn verification_method() -> VerificationMethod { + let method = Config::get_option("verification-method"); + if method == "use-temporary-password" { + VerificationMethod::OnlyUseTemporaryPassword + } else if method == "use-permanent-password" { + VerificationMethod::OnlyUsePermanentPassword + } else { + VerificationMethod::UseBothPasswords // default + } +} + +pub fn temporary_password_length() -> usize { + let length = Config::get_option("temporary-password-length"); + if length == "8" { + 8 + } else if length == "10" { + 10 + } else { + 6 // default + } +} + +pub fn temporary_enabled() -> bool { + verification_method() != VerificationMethod::OnlyUsePermanentPassword +} + +pub fn permanent_enabled() -> bool { + verification_method() != VerificationMethod::OnlyUseTemporaryPassword +} + +pub fn has_valid_password() -> bool { + temporary_enabled() && !temporary_password().is_empty() + || permanent_enabled() && !Config::get_permanent_password().is_empty() +} + +pub fn approve_mode() -> ApproveMode { + let mode = Config::get_option("approve-mode"); + if mode == "password" { + ApproveMode::Password + } else if mode == "click" { + ApproveMode::Click + } else { + ApproveMode::Both + } +} + +pub fn hide_cm() -> bool { + approve_mode() == ApproveMode::Password + && verification_method() == VerificationMethod::OnlyUsePermanentPassword + && crate::config::option2bool("allow-hide-cm", &Config::get_option("allow-hide-cm")) +} + +const VERSION_LEN: usize = 2; + +pub fn encrypt_str_or_original(s: &str, version: &str, max_len: usize) -> String { + if decrypt_str_or_original(s, version).1 { + log::error!("Duplicate encryption!"); + return s.to_owned(); + } + if s.chars().count() > max_len { + return String::default(); + } + if version == "00" { + if let Ok(s) = encrypt(s.as_bytes()) { + return version.to_owned() + &s; + } + } + s.to_owned() +} + +// String: password +// bool: whether decryption is successful +// bool: whether should store to re-encrypt when load +// note: s.len() return length in bytes, s.chars().count() return char count +// &[..2] return the left 2 bytes, s.chars().take(2) return the left 2 chars +pub fn decrypt_str_or_original(s: &str, current_version: &str) -> (String, bool, bool) { + if s.len() > VERSION_LEN { + if s.starts_with("00") { + if let Ok(v) = decrypt(s[VERSION_LEN..].as_bytes()) { + return ( + String::from_utf8_lossy(&v).to_string(), + true, + "00" != current_version, + ); + } + } + } + + (s.to_owned(), false, !s.is_empty()) +} + +pub fn encrypt_vec_or_original(v: &[u8], version: &str, max_len: usize) -> Vec { + if decrypt_vec_or_original(v, version).1 { + log::error!("Duplicate encryption!"); + return v.to_owned(); + } + if v.len() > max_len { + return vec![]; + } + if version == "00" { + if let Ok(s) = encrypt(v) { + let mut version = version.to_owned().into_bytes(); + version.append(&mut s.into_bytes()); + return version; + } + } + v.to_owned() +} + +// Vec: password +// bool: whether decryption is successful +// bool: whether should store to re-encrypt when load +pub fn decrypt_vec_or_original(v: &[u8], current_version: &str) -> (Vec, bool, bool) { + if v.len() > VERSION_LEN { + let version = String::from_utf8_lossy(&v[..VERSION_LEN]); + if version == "00" { + if let Ok(v) = decrypt(&v[VERSION_LEN..]) { + return (v, true, version != current_version); + } + } + } + + (v.to_owned(), false, !v.is_empty()) +} + +fn encrypt(v: &[u8]) -> Result { + if !v.is_empty() { + symmetric_crypt(v, true).map(|v| base64::encode(v, base64::Variant::Original)) + } else { + Err(()) + } +} + +fn decrypt(v: &[u8]) -> Result, ()> { + if !v.is_empty() { + base64::decode(v, base64::Variant::Original).and_then(|v| symmetric_crypt(&v, false)) + } else { + Err(()) + } +} + +pub fn symmetric_crypt(data: &[u8], encrypt: bool) -> Result, ()> { + use sodiumoxide::crypto::secretbox; + use std::convert::TryInto; + + let mut keybuf = crate::get_uuid(); + keybuf.resize(secretbox::KEYBYTES, 0); + let key = secretbox::Key(keybuf.try_into().map_err(|_| ())?); + let nonce = secretbox::Nonce([0; secretbox::NONCEBYTES]); + + if encrypt { + Ok(secretbox::seal(data, &nonce, &key)) + } else { + secretbox::open(data, &nonce, &key) + } +} + +mod test { + + #[test] + fn test() { + use super::*; + use rand::{thread_rng, Rng}; + use std::time::Instant; + + let version = "00"; + let max_len = 128; + + println!("test str"); + let data = "1ü1111"; + let encrypted = encrypt_str_or_original(data, version, max_len); + let (decrypted, succ, store) = decrypt_str_or_original(&encrypted, version); + println!("data: {data}"); + println!("encrypted: {encrypted}"); + println!("decrypted: {decrypted}"); + assert_eq!(data, decrypted); + assert_eq!(version, &encrypted[..2]); + assert!(succ); + assert!(!store); + let (_, _, store) = decrypt_str_or_original(&encrypted, "99"); + assert!(store); + assert!(!decrypt_str_or_original(&decrypted, version).1); + assert_eq!( + encrypt_str_or_original(&encrypted, version, max_len), + encrypted + ); + + println!("test vec"); + let data: Vec = "1ü1111".as_bytes().to_vec(); + let encrypted = encrypt_vec_or_original(&data, version, max_len); + let (decrypted, succ, store) = decrypt_vec_or_original(&encrypted, version); + println!("data: {data:?}"); + println!("encrypted: {encrypted:?}"); + println!("decrypted: {decrypted:?}"); + assert_eq!(data, decrypted); + assert_eq!(version.as_bytes(), &encrypted[..2]); + assert!(!store); + assert!(succ); + let (_, _, store) = decrypt_vec_or_original(&encrypted, "99"); + assert!(store); + assert!(!decrypt_vec_or_original(&decrypted, version).1); + assert_eq!( + encrypt_vec_or_original(&encrypted, version, max_len), + encrypted + ); + + println!("test original"); + let data = version.to_string() + "Hello World"; + let (decrypted, succ, store) = decrypt_str_or_original(&data, version); + assert_eq!(data, decrypted); + assert!(store); + assert!(!succ); + let verbytes = version.as_bytes(); + let data: Vec = vec![verbytes[0], verbytes[1], 1, 2, 3, 4, 5, 6]; + let (decrypted, succ, store) = decrypt_vec_or_original(&data, version); + assert_eq!(data, decrypted); + assert!(store); + assert!(!succ); + let (_, succ, store) = decrypt_str_or_original("", version); + assert!(!store); + assert!(!succ); + let (_, succ, store) = decrypt_vec_or_original(&[], version); + assert!(!store); + assert!(!succ); + let data = "1ü1111"; + assert_eq!(decrypt_str_or_original(data, version).0, data); + let data: Vec = "1ü1111".as_bytes().to_vec(); + assert_eq!(decrypt_vec_or_original(&data, version).0, data); + + println!("test speed"); + let test_speed = |len: usize, name: &str| { + let mut data: Vec = vec![]; + let mut rng = thread_rng(); + for _ in 0..len { + data.push(rng.gen_range(0..255)); + } + let start: Instant = Instant::now(); + let encrypted = encrypt_vec_or_original(&data, version, len); + assert_ne!(data, decrypted); + let t1 = start.elapsed(); + let start = Instant::now(); + let (decrypted, _, _) = decrypt_vec_or_original(&encrypted, version); + let t2 = start.elapsed(); + assert_eq!(data, decrypted); + println!("{name}"); + println!("encrypt:{:?}, decrypt:{:?}", t1, t2); + + let start: Instant = Instant::now(); + let encrypted = base64::encode(&data, base64::Variant::Original); + let t1 = start.elapsed(); + let start = Instant::now(); + let decrypted = base64::decode(&encrypted, base64::Variant::Original).unwrap(); + let t2 = start.elapsed(); + assert_eq!(data, decrypted); + println!("base64, encrypt:{:?}, decrypt:{:?}", t1, t2,); + }; + test_speed(128, "128"); + test_speed(1024, "1k"); + test_speed(1024 * 1024, "1M"); + test_speed(10 * 1024 * 1024, "10M"); + test_speed(100 * 1024 * 1024, "100M"); + } +} diff --git a/src/platform/linux.rs b/src/platform/linux.rs new file mode 100644 index 0000000..60c8714 --- /dev/null +++ b/src/platform/linux.rs @@ -0,0 +1,300 @@ +use crate::ResultType; +use std::{collections::HashMap, process::Command}; + +lazy_static::lazy_static! { + pub static ref DISTRO: Distro = Distro::new(); +} + +pub const DISPLAY_SERVER_WAYLAND: &str = "wayland"; +pub const DISPLAY_SERVER_X11: &str = "x11"; +pub const DISPLAY_DESKTOP_KDE: &str = "KDE"; + +pub const XDG_CURRENT_DESKTOP: &str = "XDG_CURRENT_DESKTOP"; + +pub struct Distro { + pub name: String, + pub version_id: String, +} + +impl Distro { + fn new() -> Self { + let name = run_cmds("awk -F'=' '/^NAME=/ {print $2}' /etc/os-release") + .unwrap_or_default() + .trim() + .trim_matches('"') + .to_string(); + let version_id = run_cmds("awk -F'=' '/^VERSION_ID=/ {print $2}' /etc/os-release") + .unwrap_or_default() + .trim() + .trim_matches('"') + .to_string(); + Self { name, version_id } + } +} + +#[inline] +pub fn is_kde() -> bool { + if let Ok(env) = std::env::var(XDG_CURRENT_DESKTOP) { + env == DISPLAY_DESKTOP_KDE + } else { + false + } +} + +#[inline] +pub fn is_gdm_user(username: &str) -> bool { + username == "gdm" + // || username == "lightgdm" +} + +#[inline] +pub fn is_desktop_wayland() -> bool { + get_display_server() == DISPLAY_SERVER_WAYLAND +} + +#[inline] +pub fn is_x11_or_headless() -> bool { + !is_desktop_wayland() +} + +// -1 +const INVALID_SESSION: &str = "4294967295"; + +pub fn get_display_server() -> String { + // Check for forced display server environment variable first + if let Ok(forced_display) = std::env::var("RUSTDESK_FORCED_DISPLAY_SERVER") { + return forced_display; + } + + // Check if `loginctl` can be called successfully + if run_loginctl(None).is_err() { + return DISPLAY_SERVER_X11.to_owned(); + } + + let mut session = get_values_of_seat0(&[0])[0].clone(); + if session.is_empty() { + // loginctl has not given the expected output. try something else. + if let Ok(sid) = std::env::var("XDG_SESSION_ID") { + // could also execute "cat /proc/self/sessionid" + session = sid; + } + if session.is_empty() { + session = run_cmds("cat /proc/self/sessionid").unwrap_or_default(); + if session == INVALID_SESSION { + session = "".to_owned(); + } + } + } + if session.is_empty() { + std::env::var("XDG_SESSION_TYPE").unwrap_or("x11".to_owned()) + } else { + get_display_server_of_session(&session) + } +} + +pub fn get_display_server_of_session(session: &str) -> String { + let mut display_server = if let Ok(output) = + run_loginctl(Some(vec!["show-session", "-p", "Type", session])) + // Check session type of the session + { + String::from_utf8_lossy(&output.stdout) + .replace("Type=", "") + .trim_end() + .into() + } else { + "".to_owned() + }; + if display_server.is_empty() || display_server == "tty" { + if let Ok(sestype) = std::env::var("XDG_SESSION_TYPE") { + if !sestype.is_empty() { + return sestype.to_lowercase(); + } + } + display_server = "x11".to_owned(); + } + display_server.to_lowercase() +} + +#[inline] +fn line_values(indices: &[usize], line: &str) -> Vec { + indices + .into_iter() + .map(|idx| line.split_whitespace().nth(*idx).unwrap_or("").to_owned()) + .collect::>() +} + +#[inline] +pub fn get_values_of_seat0(indices: &[usize]) -> Vec { + _get_values_of_seat0(indices, true) +} + +#[inline] +pub fn get_values_of_seat0_with_gdm_wayland(indices: &[usize]) -> Vec { + _get_values_of_seat0(indices, false) +} + +// Ignore "3 sessions listed." +fn ignore_loginctl_line(line: &str) -> bool { + line.contains("sessions") || line.split(" ").count() < 4 +} + +fn _get_values_of_seat0(indices: &[usize], ignore_gdm_wayland: bool) -> Vec { + if let Ok(output) = run_loginctl(None) { + for line in String::from_utf8_lossy(&output.stdout).lines() { + if ignore_loginctl_line(line) { + continue; + } + if line.contains("seat0") { + if let Some(sid) = line.split_whitespace().next() { + if is_active(sid) { + if ignore_gdm_wayland { + if is_gdm_user(line.split_whitespace().nth(2).unwrap_or("")) + && get_display_server_of_session(sid) == DISPLAY_SERVER_WAYLAND + { + continue; + } + } + return line_values(indices, line); + } + } + } + } + + // some case, there is no seat0 https://github.com/rustdesk/rustdesk/issues/73 + for line in String::from_utf8_lossy(&output.stdout).lines() { + if ignore_loginctl_line(line) { + continue; + } + if let Some(sid) = line.split_whitespace().next() { + if is_active(sid) { + let d = get_display_server_of_session(sid); + if ignore_gdm_wayland { + if is_gdm_user(line.split_whitespace().nth(2).unwrap_or("")) + && d == DISPLAY_SERVER_WAYLAND + { + continue; + } + } + if d == "tty" { + continue; + } + return line_values(indices, line); + } + } + } + } + + line_values(indices, "") +} + +pub fn is_active(sid: &str) -> bool { + if let Ok(output) = run_loginctl(Some(vec!["show-session", "-p", "State", sid])) { + String::from_utf8_lossy(&output.stdout).contains("active") + } else { + false + } +} + +pub fn is_active_and_seat0(sid: &str) -> bool { + if let Ok(output) = run_loginctl(Some(vec!["show-session", sid])) { + String::from_utf8_lossy(&output.stdout).contains("State=active") + && String::from_utf8_lossy(&output.stdout).contains("Seat=seat0") + } else { + false + } +} + +// **Note** that the return value here, the last character is '\n'. +// Use `run_cmds_trim_newline()` if you want to remove '\n' at the end. +pub fn run_cmds(cmds: &str) -> ResultType { + let output = std::process::Command::new("sh") + .args(vec!["-c", cmds]) + .output()?; + Ok(String::from_utf8_lossy(&output.stdout).to_string()) +} + +pub fn run_cmds_trim_newline(cmds: &str) -> ResultType { + let output = std::process::Command::new("sh") + .args(vec!["-c", cmds]) + .output()?; + let out = String::from_utf8_lossy(&output.stdout); + Ok(if out.ends_with('\n') { + out[..out.len() - 1].to_string() + } else { + out.to_string() + }) +} + +fn run_loginctl(args: Option>) -> std::io::Result { + if std::env::var("FLATPAK_ID").is_ok() { + let mut l_args = String::from("loginctl"); + if let Some(a) = args.as_ref() { + l_args = format!("{} {}", l_args, a.join(" ")); + } + let res = std::process::Command::new("flatpak-spawn") + .args(vec![String::from("--host"), l_args]) + .output(); + if res.is_ok() { + return res; + } + } + let mut cmd = std::process::Command::new("loginctl"); + if let Some(a) = args { + return cmd.args(a).output(); + } + cmd.output() +} + +/// forever: may not work +#[cfg(target_os = "linux")] +pub fn system_message(title: &str, msg: &str, forever: bool) -> ResultType<()> { + let cmds: HashMap<&str, Vec<&str>> = HashMap::from([ + ("notify-send", [title, msg].to_vec()), + ( + "zenity", + [ + "--info", + "--timeout", + if forever { "0" } else { "3" }, + "--title", + title, + "--text", + msg, + ] + .to_vec(), + ), + ("kdialog", ["--title", title, "--msgbox", msg].to_vec()), + ( + "xmessage", + [ + "-center", + "-timeout", + if forever { "0" } else { "3" }, + title, + msg, + ] + .to_vec(), + ), + ]); + for (k, v) in cmds { + if Command::new(k).args(v).spawn().is_ok() { + return Ok(()); + } + } + crate::bail!("failed to post system message"); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_run_cmds_trim_newline() { + assert_eq!(run_cmds_trim_newline("echo -n 123").unwrap(), "123"); + assert_eq!(run_cmds_trim_newline("echo 123").unwrap(), "123"); + assert_eq!( + run_cmds_trim_newline("whoami").unwrap() + "\n", + run_cmds("whoami").unwrap() + ); + } +} diff --git a/src/platform/macos.rs b/src/platform/macos.rs new file mode 100644 index 0000000..dd83a87 --- /dev/null +++ b/src/platform/macos.rs @@ -0,0 +1,55 @@ +use crate::ResultType; +use osascript; +use serde_derive::{Deserialize, Serialize}; + +#[derive(Serialize)] +struct AlertParams { + title: String, + message: String, + alert_type: String, + buttons: Vec, +} + +#[derive(Deserialize)] +struct AlertResult { + #[serde(rename = "buttonReturned")] + button: String, +} + +/// Firstly run the specified app, then alert a dialog. Return the clicked button value. +/// +/// # Arguments +/// +/// * `app` - The app to execute the script. +/// * `alert_type` - Alert type. . informational, warning, critical +/// * `title` - The alert title. +/// * `message` - The alert message. +/// * `buttons` - The buttons to show. +pub fn alert( + app: String, + alert_type: String, + title: String, + message: String, + buttons: Vec, +) -> ResultType { + let script = osascript::JavaScript::new(&format!( + " + var App = Application('{}'); + App.includeStandardAdditions = true; + return App.displayAlert($params.title, {{ + message: $params.message, + 'as': $params.alert_type, + buttons: $params.buttons, + }}); + ", + app + )); + + let result: AlertResult = script.execute_with_params(AlertParams { + title, + message, + alert_type, + buttons, + })?; + Ok(result.button) +} diff --git a/src/platform/mod.rs b/src/platform/mod.rs new file mode 100644 index 0000000..5dc004a --- /dev/null +++ b/src/platform/mod.rs @@ -0,0 +1,81 @@ +#[cfg(target_os = "linux")] +pub mod linux; + +#[cfg(target_os = "macos")] +pub mod macos; + +#[cfg(target_os = "windows")] +pub mod windows; + +#[cfg(not(debug_assertions))] +use crate::{config::Config, log}; +#[cfg(not(debug_assertions))] +use std::process::exit; + +#[cfg(not(debug_assertions))] +static mut GLOBAL_CALLBACK: Option> = None; + +#[cfg(not(debug_assertions))] +extern "C" fn breakdown_signal_handler(sig: i32) { + let mut stack = vec![]; + backtrace::trace(|frame| { + backtrace::resolve_frame(frame, |symbol| { + if let Some(name) = symbol.name() { + stack.push(name.to_string()); + } + }); + true // keep going to the next frame + }); + let mut info = String::default(); + if stack.iter().any(|s| { + s.contains(&"nouveau_pushbuf_kick") + || s.to_lowercase().contains("nvidia") + || s.contains("gdk_window_end_draw_frame") + || s.contains("glGetString") + }) { + Config::set_option("allow-always-software-render".to_string(), "Y".to_string()); + info = "Always use software rendering will be set.".to_string(); + log::info!("{}", info); + } + if stack.iter().any(|s| { + s.to_lowercase().contains("nvidia") + || s.to_lowercase().contains("amf") + || s.to_lowercase().contains("mfx") + || s.contains("cuProfilerStop") + }) { + Config::set_option("enable-hwcodec".to_string(), "N".to_string()); + info = "Perhaps hwcodec causing the crash, disable it first".to_string(); + log::info!("{}", info); + } + log::error!( + "Got signal {} and exit. stack:\n{}", + sig, + stack.join("\n").to_string() + ); + if !info.is_empty() { + #[cfg(target_os = "linux")] + linux::system_message( + "RustDesk", + &format!("Got signal {} and exit.{}", sig, info), + true, + ) + .ok(); + } + unsafe { + if let Some(callback) = &GLOBAL_CALLBACK { + callback() + } + } + exit(0); +} + +#[cfg(not(debug_assertions))] +pub fn register_breakdown_handler(callback: T) +where + T: Fn() + 'static, +{ + unsafe { + GLOBAL_CALLBACK = Some(Box::new(callback)); + libc::signal(libc::SIGSEGV, breakdown_signal_handler as _); + } +} diff --git a/src/platform/windows.rs b/src/platform/windows.rs new file mode 100644 index 0000000..7481631 --- /dev/null +++ b/src/platform/windows.rs @@ -0,0 +1,198 @@ +use std::{ + collections::VecDeque, + sync::{Arc, Mutex}, + time::Instant, +}; +use winapi::{ + shared::minwindef::{DWORD, FALSE, TRUE}, + um::{ + handleapi::CloseHandle, + pdh::{ + PdhAddEnglishCounterA, PdhCloseQuery, PdhCollectQueryData, PdhCollectQueryDataEx, + PdhGetFormattedCounterValue, PdhOpenQueryA, PDH_FMT_COUNTERVALUE, PDH_FMT_DOUBLE, + PDH_HCOUNTER, PDH_HQUERY, + }, + synchapi::{CreateEventA, WaitForSingleObject}, + sysinfoapi::VerSetConditionMask, + winbase::{VerifyVersionInfoW, INFINITE, WAIT_OBJECT_0}, + winnt::{ + HANDLE, OSVERSIONINFOEXW, VER_BUILDNUMBER, VER_GREATER_EQUAL, VER_MAJORVERSION, + VER_MINORVERSION, VER_SERVICEPACKMAJOR, VER_SERVICEPACKMINOR, + }, + }, +}; + +lazy_static::lazy_static! { + static ref CPU_USAGE_ONE_MINUTE: Arc>> = Arc::new(Mutex::new(None)); +} + +// https://github.com/mgostIH/process_list/blob/master/src/windows/mod.rs +#[repr(transparent)] +pub struct RAIIHandle(pub HANDLE); + +impl Drop for RAIIHandle { + fn drop(&mut self) { + // This never gives problem except when running under a debugger. + unsafe { CloseHandle(self.0) }; + } +} + +#[repr(transparent)] +pub(self) struct RAIIPDHQuery(pub PDH_HQUERY); + +impl Drop for RAIIPDHQuery { + fn drop(&mut self) { + unsafe { PdhCloseQuery(self.0) }; + } +} + +pub fn start_cpu_performance_monitor() { + // Code from: + // https://learn.microsoft.com/en-us/windows/win32/perfctrs/collecting-performance-data + // https://learn.microsoft.com/en-us/windows/win32/api/pdh/nf-pdh-pdhcollectquerydataex + // Why value lower than taskManager: + // https://aaron-margosis.medium.com/task-managers-cpu-numbers-are-all-but-meaningless-2d165b421e43 + // Therefore we should compare with Precess Explorer rather than taskManager + + let f = || unsafe { + // load avg or cpu usage, test with prime95. + // Prefer cpu usage because we can get accurate value from Precess Explorer. + // const COUNTER_PATH: &'static str = "\\System\\Processor Queue Length\0"; + const COUNTER_PATH: &'static str = "\\Processor(_total)\\% Processor Time\0"; + const SAMPLE_INTERVAL: DWORD = 2; // 2 second + + let mut ret; + let mut query: PDH_HQUERY = std::mem::zeroed(); + ret = PdhOpenQueryA(std::ptr::null() as _, 0, &mut query); + if ret != 0 { + log::error!("PdhOpenQueryA failed: 0x{:X}", ret); + return; + } + let _query = RAIIPDHQuery(query); + let mut counter: PDH_HCOUNTER = std::mem::zeroed(); + ret = PdhAddEnglishCounterA(query, COUNTER_PATH.as_ptr() as _, 0, &mut counter); + if ret != 0 { + log::error!("PdhAddEnglishCounterA failed: 0x{:X}", ret); + return; + } + ret = PdhCollectQueryData(query); + if ret != 0 { + log::error!("PdhCollectQueryData failed: 0x{:X}", ret); + return; + } + let mut _counter_type: DWORD = 0; + let mut counter_value: PDH_FMT_COUNTERVALUE = std::mem::zeroed(); + let event = CreateEventA(std::ptr::null_mut(), FALSE, FALSE, std::ptr::null() as _); + if event.is_null() { + log::error!("CreateEventA failed"); + return; + } + let _event: RAIIHandle = RAIIHandle(event); + ret = PdhCollectQueryDataEx(query, SAMPLE_INTERVAL, event); + if ret != 0 { + log::error!("PdhCollectQueryDataEx failed: 0x{:X}", ret); + return; + } + + let mut queue: VecDeque = VecDeque::new(); + let mut recent_valid: VecDeque = VecDeque::new(); + loop { + // latest one minute + if queue.len() == 31 { + queue.pop_front(); + } + if recent_valid.len() == 31 { + recent_valid.pop_front(); + } + // allow get value within one minute + if queue.len() > 0 && recent_valid.iter().filter(|v| **v).count() > queue.len() / 2 { + let sum: f64 = queue.iter().map(|f| f.to_owned()).sum(); + let avg = sum / (queue.len() as f64); + *CPU_USAGE_ONE_MINUTE.lock().unwrap() = Some((avg, Instant::now())); + } else { + *CPU_USAGE_ONE_MINUTE.lock().unwrap() = None; + } + if WAIT_OBJECT_0 != WaitForSingleObject(event, INFINITE) { + recent_valid.push_back(false); + continue; + } + if PdhGetFormattedCounterValue( + counter, + PDH_FMT_DOUBLE, + &mut _counter_type, + &mut counter_value, + ) != 0 + || counter_value.CStatus != 0 + { + recent_valid.push_back(false); + continue; + } + queue.push_back(counter_value.u.doubleValue().clone()); + recent_valid.push_back(true); + } + }; + use std::sync::Once; + static ONCE: Once = Once::new(); + ONCE.call_once(|| { + std::thread::spawn(f); + }); +} + +pub fn cpu_uage_one_minute() -> Option { + let v = CPU_USAGE_ONE_MINUTE.lock().unwrap().clone(); + if let Some((v, instant)) = v { + if instant.elapsed().as_secs() < 30 { + return Some(v); + } + } + None +} + +pub fn sync_cpu_usage(cpu_usage: Option) { + let v = match cpu_usage { + Some(cpu_usage) => Some((cpu_usage, Instant::now())), + None => None, + }; + *CPU_USAGE_ONE_MINUTE.lock().unwrap() = v; + log::info!("cpu usage synced: {:?}", cpu_usage); +} + +// https://learn.microsoft.com/en-us/windows/win32/sysinfo/targeting-your-application-at-windows-8-1 +// https://github.com/nodejs/node-convergence-archive/blob/e11fe0c2777561827cdb7207d46b0917ef3c42a7/deps/uv/src/win/util.c#L780 +pub fn is_windows_version_or_greater( + os_major: u32, + os_minor: u32, + build_number: u32, + service_pack_major: u32, + service_pack_minor: u32, +) -> bool { + let mut osvi: OSVERSIONINFOEXW = unsafe { std::mem::zeroed() }; + osvi.dwOSVersionInfoSize = std::mem::size_of::() as DWORD; + osvi.dwMajorVersion = os_major as _; + osvi.dwMinorVersion = os_minor as _; + osvi.dwBuildNumber = build_number as _; + osvi.wServicePackMajor = service_pack_major as _; + osvi.wServicePackMinor = service_pack_minor as _; + + let result = unsafe { + let mut condition_mask = 0; + let op = VER_GREATER_EQUAL; + condition_mask = VerSetConditionMask(condition_mask, VER_MAJORVERSION, op); + condition_mask = VerSetConditionMask(condition_mask, VER_MINORVERSION, op); + condition_mask = VerSetConditionMask(condition_mask, VER_BUILDNUMBER, op); + condition_mask = VerSetConditionMask(condition_mask, VER_SERVICEPACKMAJOR, op); + condition_mask = VerSetConditionMask(condition_mask, VER_SERVICEPACKMINOR, op); + + VerifyVersionInfoW( + &mut osvi as *mut OSVERSIONINFOEXW, + VER_MAJORVERSION + | VER_MINORVERSION + | VER_BUILDNUMBER + | VER_SERVICEPACKMAJOR + | VER_SERVICEPACKMINOR, + condition_mask, + ) + }; + + result == TRUE +} diff --git a/src/protos/mod.rs b/src/protos/mod.rs new file mode 100644 index 0000000..57d9b68 --- /dev/null +++ b/src/protos/mod.rs @@ -0,0 +1 @@ +include!(concat!(env!("OUT_DIR"), "/protos/mod.rs")); diff --git a/src/proxy.rs b/src/proxy.rs new file mode 100644 index 0000000..34d2c51 --- /dev/null +++ b/src/proxy.rs @@ -0,0 +1,561 @@ +use std::{ + io::Error as IoError, + net::{SocketAddr, ToSocketAddrs}, +}; + +use base64::{engine::general_purpose, Engine}; +use httparse::{Error as HttpParseError, Response, EMPTY_HEADER}; +use log::info; +use thiserror::Error as ThisError; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream}; +#[cfg(any(target_os = "windows", target_os = "macos"))] +use tokio_native_tls::{native_tls, TlsConnector, TlsStream}; +#[cfg(not(any(target_os = "windows", target_os = "macos")))] +use tokio_rustls::{client::TlsStream, TlsConnector}; +use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr}; +use tokio_util::codec::Framed; +use url::Url; + +use crate::{ + bytes_codec::BytesCodec, + config::Socks5Server, + tcp::{DynTcpStream, FramedStream}, + ResultType, +}; + +#[derive(Debug, ThisError)] +pub enum ProxyError { + #[error("IO Error: {0}")] + IoError(#[from] IoError), + #[error("Target parse error: {0}")] + TargetParseError(String), + #[error("HTTP parse error: {0}")] + HttpParseError(#[from] HttpParseError), + #[error("The maximum response header length is exceeded: {0}")] + MaximumResponseHeaderLengthExceeded(usize), + #[error("The end of file is reached")] + EndOfFile, + #[error("The url is error: {0}")] + UrlBadScheme(String), + #[error("The url parse error: {0}")] + UrlParseScheme(#[from] url::ParseError), + #[error("No HTTP code was found in the response")] + NoHttpCode, + #[error("The HTTP code is not equal 200: {0}")] + HttpCode200(u16), + #[error("The proxy address resolution failed: {0}")] + AddressResolutionFailed(String), + #[cfg(any(target_os = "windows", target_os = "macos"))] + #[error("The native tls error: {0}")] + NativeTlsError(#[from] tokio_native_tls::native_tls::Error), +} + +const MAXIMUM_RESPONSE_HEADER_LENGTH: usize = 4096; +/// The maximum HTTP Headers, which can be parsed. +const MAXIMUM_RESPONSE_HEADERS: usize = 16; +const DEFINE_TIME_OUT: u64 = 600; + +pub trait IntoUrl { + + // Besides parsing as a valid `Url`, the `Url` must be a valid + // `http::Uri`, in that it makes sense to use in a network request. + fn into_url(self) -> Result; + + fn as_str(&self) -> &str; +} + +impl IntoUrl for Url { + fn into_url(self) -> Result { + if self.has_host() { + Ok(self) + } else { + Err(ProxyError::UrlBadScheme(self.to_string())) + } + } + + fn as_str(&self) -> &str { + self.as_ref() + } +} + +impl<'a> IntoUrl for &'a str { + fn into_url(self) -> Result { + Url::parse(self) + .map_err(ProxyError::UrlParseScheme)? + .into_url() + } + + fn as_str(&self) -> &str { + self + } +} + +impl<'a> IntoUrl for &'a String { + fn into_url(self) -> Result { + (&**self).into_url() + } + + fn as_str(&self) -> &str { + self.as_ref() + } +} + +impl<'a> IntoUrl for String { + fn into_url(self) -> Result { + (&*self).into_url() + } + + fn as_str(&self) -> &str { + self.as_ref() + } +} + +#[derive(Clone)] +pub struct Auth { + user_name: String, + password: String, +} + +impl Auth { + fn get_proxy_authorization(&self) -> String { + format!( + "Proxy-Authorization: Basic {}\r\n", + self.get_basic_authorization() + ) + } + + pub fn get_basic_authorization(&self) -> String { + let authorization = format!("{}:{}", &self.user_name, &self.password); + general_purpose::STANDARD.encode(authorization.as_bytes()) + } +} + +#[derive(Clone)] +pub enum ProxyScheme { + Http { + auth: Option, + host: String, + }, + Https { + auth: Option, + host: String, + }, + Socks5 { + addr: SocketAddr, + auth: Option, + remote_dns: bool, + }, +} + +impl ProxyScheme { + pub fn maybe_auth(&self) -> Option<&Auth> { + match self { + ProxyScheme::Http { auth, .. } + | ProxyScheme::Https { auth, .. } + | ProxyScheme::Socks5 { auth, .. } => auth.as_ref(), + } + } + + fn socks5(addr: SocketAddr) -> Result { + Ok(ProxyScheme::Socks5 { + addr, + auth: None, + remote_dns: false, + }) + } + + fn http(host: &str) -> Result { + Ok(ProxyScheme::Http { + auth: None, + host: host.to_string(), + }) + } + fn https(host: &str) -> Result { + Ok(ProxyScheme::Https { + auth: None, + host: host.to_string(), + }) + } + + fn set_basic_auth, U: Into>(&mut self, username: T, password: U) { + let auth = Auth { + user_name: username.into(), + password: password.into(), + }; + match self { + ProxyScheme::Http { auth: a, .. } => *a = Some(auth), + ProxyScheme::Https { auth: a, .. } => *a = Some(auth), + ProxyScheme::Socks5 { auth: a, .. } => *a = Some(auth), + } + } + + fn parse(url: Url) -> Result { + use url::Position; + + // Resolve URL to a host and port + let to_addr = || { + let addrs = url.socket_addrs(|| match url.scheme() { + "socks5" => Some(1080), + _ => None, + })?; + addrs + .into_iter() + .next() + .ok_or_else(|| ProxyError::UrlParseScheme(url::ParseError::EmptyHost)) + }; + + let mut scheme: Self = match url.scheme() { + "http" => Self::http(&url[Position::BeforeHost..Position::AfterPort])?, + "https" => Self::https(&url[Position::BeforeHost..Position::AfterPort])?, + "socks5" => Self::socks5(to_addr()?)?, + e => return Err(ProxyError::UrlBadScheme(e.to_string())), + }; + + if let Some(pwd) = url.password() { + let username = url.username(); + scheme.set_basic_auth(username, pwd); + } + + Ok(scheme) + } + pub async fn socket_addrs(&self) -> Result { + info!("Resolving socket address"); + match self { + ProxyScheme::Http { host, .. } => self.resolve_host(host, 80).await, + ProxyScheme::Https { host, .. } => self.resolve_host(host, 443).await, + ProxyScheme::Socks5 { addr, .. } => Ok(addr.clone()), + } + } + + async fn resolve_host(&self, host: &str, default_port: u16) -> Result { + let (host_str, port) = match host.split_once(':') { + Some((h, p)) => (h, p.parse::().ok()), + None => (host, None), + }; + let addr = (host_str, port.unwrap_or(default_port)) + .to_socket_addrs()? + .next() + .ok_or_else(|| ProxyError::AddressResolutionFailed(host.to_string()))?; + Ok(addr) + } + + pub fn get_domain(&self) -> Result { + match self { + ProxyScheme::Http { host, .. } | ProxyScheme::Https { host, .. } => { + let domain = host + .split(':') + .next() + .ok_or_else(|| ProxyError::AddressResolutionFailed(host.clone()))?; + Ok(domain.to_string()) + } + ProxyScheme::Socks5 { addr, .. } => match addr { + SocketAddr::V4(addr_v4) => Ok(addr_v4.ip().to_string()), + SocketAddr::V6(addr_v6) => Ok(addr_v6.ip().to_string()), + }, + } + } + pub fn get_host_and_port(&self) -> Result { + match self { + ProxyScheme::Http { host, .. } => Ok(self.append_default_port(host, 80)), + ProxyScheme::Https { host, .. } => Ok(self.append_default_port(host, 443)), + ProxyScheme::Socks5 { addr, .. } => Ok(format!("{}", addr)), + } + } + fn append_default_port(&self, host: &str, default_port: u16) -> String { + if host.contains(':') { + host.to_string() + } else { + format!("{}:{}", host, default_port) + } + } +} + +pub trait IntoProxyScheme { + fn into_proxy_scheme(self) -> Result; +} + +impl IntoProxyScheme for S { + fn into_proxy_scheme(self) -> Result { + // validate the URL + let url = match self.as_str().into_url() { + Ok(ok) => ok, + Err(e) => { + match e { + // If the string does not contain protocol headers, try to parse it using the socks5 protocol + ProxyError::UrlParseScheme(_source) => { + let try_this = format!("socks5://{}", self.as_str()); + try_this.into_url()? + } + _ => { + return Err(e); + } + } + } + }; + ProxyScheme::parse(url) + } +} + +impl IntoProxyScheme for ProxyScheme { + fn into_proxy_scheme(self) -> Result { + Ok(self) + } +} + +#[derive(Clone)] +pub struct Proxy { + pub intercept: ProxyScheme, + ms_timeout: u64, +} + +impl Proxy { + pub fn new(proxy_scheme: U, ms_timeout: u64) -> Result { + Ok(Self { + intercept: proxy_scheme.into_proxy_scheme()?, + ms_timeout, + }) + } + + pub fn is_http_or_https(&self) -> bool { + return match self.intercept { + ProxyScheme::Socks5 { .. } => false, + _ => true, + }; + } + + pub fn from_conf(conf: &Socks5Server, ms_timeout: Option) -> Result { + let mut proxy; + match ms_timeout { + None => { + proxy = Self::new(&conf.proxy, DEFINE_TIME_OUT)?; + } + Some(time_out) => { + proxy = Self::new(&conf.proxy, time_out)?; + } + } + + if !conf.password.is_empty() && !conf.username.is_empty() { + proxy = proxy.basic_auth(&conf.username, &conf.password); + } + Ok(proxy) + } + + pub async fn proxy_addrs(&self) -> Result { + self.intercept.socket_addrs().await + } + + fn basic_auth(mut self, username: &str, password: &str) -> Proxy { + self.intercept.set_basic_auth(username, password); + self + } + + pub async fn connect<'t, T>( + self, + target: T, + local_addr: Option, + ) -> ResultType + where + T: IntoTargetAddr<'t>, + { + info!("Connect to proxy server"); + let proxy = self.proxy_addrs().await?; + + let local = if let Some(addr) = local_addr { + addr + } else { + crate::config::Config::get_any_listen_addr(proxy.is_ipv4()) + }; + + let stream = super::timeout( + self.ms_timeout, + crate::tcp::new_socket(local, true)?.connect(proxy), + ) + .await??; + stream.set_nodelay(true).ok(); + + let addr = stream.local_addr()?; + + return match self.intercept { + ProxyScheme::Http { .. } => { + info!("Connect to remote http proxy server: {}", proxy); + let stream = + super::timeout(self.ms_timeout, self.http_connect(stream, target)).await??; + Ok(FramedStream( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )) + } + ProxyScheme::Https { .. } => { + info!("Connect to remote https proxy server: {}", proxy); + let stream = + super::timeout(self.ms_timeout, self.https_connect(stream, target)).await??; + Ok(FramedStream( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )) + } + ProxyScheme::Socks5 { .. } => { + info!("Connect to remote socket5 proxy server: {}", proxy); + let stream = if let Some(auth) = self.intercept.maybe_auth() { + super::timeout( + self.ms_timeout, + Socks5Stream::connect_with_password_and_socket( + stream, + target, + &auth.user_name, + &auth.password, + ), + ) + .await?? + } else { + super::timeout( + self.ms_timeout, + Socks5Stream::connect_with_socket(stream, target), + ) + .await?? + }; + Ok(FramedStream( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )) + } + }; + } + + #[cfg(any(target_os = "windows", target_os = "macos"))] + pub async fn https_connect<'a, Input, T>( + self, + io: Input, + target: T, + ) -> Result>, ProxyError> + where + Input: AsyncRead + AsyncWrite + Unpin, + T: IntoTargetAddr<'a>, + { + let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?); + let stream = tls_connector + .connect(&self.intercept.get_domain()?, io) + .await?; + self.http_connect(stream, target).await + } + + #[cfg(not(any(target_os = "windows", target_os = "macos")))] + pub async fn https_connect<'a, Input, T>( + self, + io: Input, + target: T, + ) -> Result>, ProxyError> + where + Input: AsyncRead + AsyncWrite + Unpin, + T: IntoTargetAddr<'a>, + { + use std::convert::TryFrom; + let verifier = rustls_platform_verifier::tls_config(); + let url_domain = self.intercept.get_domain()?; + + let domain = rustls_pki_types::ServerName::try_from(url_domain.as_str()) + .map_err(|e| ProxyError::AddressResolutionFailed(e.to_string()))? + .to_owned(); + + let tls_connector = TlsConnector::from(std::sync::Arc::new(verifier)); + let stream = tls_connector.connect(domain, io).await?; + self.http_connect(stream, target).await + } + + pub async fn http_connect<'a, Input, T>( + self, + io: Input, + target: T, + ) -> Result, ProxyError> + where + Input: AsyncRead + AsyncWrite + Unpin, + T: IntoTargetAddr<'a>, + { + let mut stream = BufStream::new(io); + let (domain, port) = get_domain_and_port(target)?; + + let request = self.make_request(&domain, port); + stream.write_all(request.as_bytes()).await?; + stream.flush().await?; + recv_and_check_response(&mut stream).await?; + Ok(stream) + } + + fn make_request(&self, host: &str, port: u16) -> String { + let mut request = format!( + "CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n", + host = host, + port = port + ); + + if let Some(auth) = self.intercept.maybe_auth() { + request = format!("{}{}", request, auth.get_proxy_authorization()); + } + + request.push_str("\r\n"); + request + } +} + +fn get_domain_and_port<'a, T: IntoTargetAddr<'a>>(target: T) -> Result<(String, u16), ProxyError> { + let target_addr = target + .into_target_addr() + .map_err(|e| ProxyError::TargetParseError(e.to_string()))?; + match target_addr { + tokio_socks::TargetAddr::Ip(addr) => Ok((addr.ip().to_string(), addr.port())), + tokio_socks::TargetAddr::Domain(name, port) => Ok((name.to_string(), port)), + } +} + +async fn get_response(stream: &mut BufStream) -> Result +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + use tokio::io::AsyncBufReadExt; + let mut response = String::new(); + + loop { + if stream.read_line(&mut response).await? == 0 { + return Err(ProxyError::EndOfFile); + } + + if MAXIMUM_RESPONSE_HEADER_LENGTH < response.len() { + return Err(ProxyError::MaximumResponseHeaderLengthExceeded( + response.len(), + )); + } + + if response.ends_with("\r\n\r\n") { + return Ok(response); + } + } +} + +async fn recv_and_check_response(stream: &mut BufStream) -> Result<(), ProxyError> +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + let response_string = get_response(stream).await?; + + let mut response_headers = [EMPTY_HEADER; MAXIMUM_RESPONSE_HEADERS]; + let mut response = Response::new(&mut response_headers); + let response_bytes = response_string.into_bytes(); + response.parse(&response_bytes)?; + + return match response.code { + Some(code) => { + if code == 200 { + Ok(()) + } else { + Err(ProxyError::HttpCode200(code)) + } + } + None => Err(ProxyError::NoHttpCode), + }; +} diff --git a/src/socket_client.rs b/src/socket_client.rs new file mode 100644 index 0000000..4cb0bf2 --- /dev/null +++ b/src/socket_client.rs @@ -0,0 +1,291 @@ +use crate::{ + config::{Config, NetworkType}, + tcp::FramedStream, + udp::FramedSocket, + ResultType, +}; +use anyhow::Context; +use std::net::SocketAddr; +use tokio::net::ToSocketAddrs; +use tokio_socks::{IntoTargetAddr, TargetAddr}; + +#[inline] +pub fn check_port(host: T, port: i32) -> String { + let host = host.to_string(); + if crate::is_ipv6_str(&host) { + if host.starts_with('[') { + return host; + } + return format!("[{host}]:{port}"); + } + if !host.contains(':') { + return format!("{host}:{port}"); + } + host +} + +#[inline] +pub fn increase_port(host: T, offset: i32) -> String { + let host = host.to_string(); + if crate::is_ipv6_str(&host) { + if host.starts_with('[') { + let tmp: Vec<&str> = host.split("]:").collect(); + if tmp.len() == 2 { + let port: i32 = tmp[1].parse().unwrap_or(0); + if port > 0 { + return format!("{}]:{}", tmp[0], port + offset); + } + } + } + } else if host.contains(':') { + let tmp: Vec<&str> = host.split(':').collect(); + if tmp.len() == 2 { + let port: i32 = tmp[1].parse().unwrap_or(0); + if port > 0 { + return format!("{}:{}", tmp[0], port + offset); + } + } + } + host +} + +pub fn test_if_valid_server(host: &str, test_with_proxy: bool) -> String { + let host = check_port(host, 0); + use std::net::ToSocketAddrs; + + if test_with_proxy && NetworkType::ProxySocks == Config::get_network_type() { + test_if_valid_server_for_proxy_(&host) + } else { + match host.to_socket_addrs() { + Err(err) => err.to_string(), + Ok(_) => "".to_owned(), + } + } +} + +#[inline] +pub fn test_if_valid_server_for_proxy_(host: &str) -> String { + // `&host.into_target_addr()` is defined in `tokio-socs`, but is a common pattern for testing, + // it can be used for both `socks` and `http` proxy. + match &host.into_target_addr() { + Err(err) => err.to_string(), + Ok(_) => "".to_owned(), + } +} + +pub trait IsResolvedSocketAddr { + fn resolve(&self) -> Option<&SocketAddr>; +} + +impl IsResolvedSocketAddr for SocketAddr { + fn resolve(&self) -> Option<&SocketAddr> { + Some(self) + } +} + +impl IsResolvedSocketAddr for String { + fn resolve(&self) -> Option<&SocketAddr> { + None + } +} + +impl IsResolvedSocketAddr for &str { + fn resolve(&self) -> Option<&SocketAddr> { + None + } +} + +#[inline] +pub async fn connect_tcp< + 't, + T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display, +>( + target: T, + ms_timeout: u64, +) -> ResultType { + connect_tcp_local(target, None, ms_timeout).await +} + +pub async fn connect_tcp_local< + 't, + T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display, +>( + target: T, + local: Option, + ms_timeout: u64, +) -> ResultType { + if let Some(conf) = Config::get_socks() { + return FramedStream::connect(target, local, &conf, ms_timeout).await; + } + if let Some(target) = target.resolve() { + if let Some(local) = local { + if local.is_ipv6() && target.is_ipv4() { + let target = query_nip_io(target).await?; + return FramedStream::new(target, Some(local), ms_timeout).await; + } + } + } + FramedStream::new(target, local, ms_timeout).await +} + +#[inline] +pub fn is_ipv4(target: &TargetAddr<'_>) -> bool { + match target { + TargetAddr::Ip(addr) => addr.is_ipv4(), + _ => true, + } +} + +#[inline] +pub async fn query_nip_io(addr: &SocketAddr) -> ResultType { + tokio::net::lookup_host(format!("{}.nip.io:{}", addr.ip(), addr.port())) + .await? + .find(|x| x.is_ipv6()) + .context("Failed to get ipv6 from nip.io") +} + +#[inline] +pub fn ipv4_to_ipv6(addr: String, ipv4: bool) -> String { + if !ipv4 && crate::is_ipv4_str(&addr) { + if let Some(ip) = addr.split(':').next() { + return addr.replace(ip, &format!("{ip}.nip.io")); + } + } + addr +} + +async fn test_target(target: &str) -> ResultType { + if let Ok(Ok(s)) = super::timeout(1000, tokio::net::TcpStream::connect(target)).await { + if let Ok(addr) = s.peer_addr() { + return Ok(addr); + } + } + tokio::net::lookup_host(target) + .await? + .next() + .context(format!("Failed to look up host for {target}")) +} + +#[inline] +pub async fn new_udp_for( + target: &str, + ms_timeout: u64, +) -> ResultType<(FramedSocket, TargetAddr<'static>)> { + let (ipv4, target) = if NetworkType::Direct == Config::get_network_type() { + let addr = test_target(target).await?; + (addr.is_ipv4(), addr.into_target_addr()?) + } else { + (true, target.into_target_addr()?) + }; + Ok(( + new_udp(Config::get_any_listen_addr(ipv4), ms_timeout).await?, + target.to_owned(), + )) +} + +async fn new_udp(local: T, ms_timeout: u64) -> ResultType { + match Config::get_socks() { + None => Ok(FramedSocket::new(local).await?), + Some(conf) => { + let socket = FramedSocket::new_proxy( + conf.proxy.as_str(), + local, + conf.username.as_str(), + conf.password.as_str(), + ms_timeout, + ) + .await?; + Ok(socket) + } + } +} + +pub async fn rebind_udp_for( + target: &str, +) -> ResultType)>> { + if Config::get_network_type() != NetworkType::Direct { + return Ok(None); + } + let addr = test_target(target).await?; + let v4 = addr.is_ipv4(); + Ok(Some(( + FramedSocket::new(Config::get_any_listen_addr(v4)).await?, + addr.into_target_addr()?.to_owned(), + ))) +} + +#[cfg(test)] +mod tests { + use std::net::ToSocketAddrs; + + use super::*; + + #[test] + fn test_nat64() { + test_nat64_async(); + } + + #[tokio::main(flavor = "current_thread")] + async fn test_nat64_async() { + assert_eq!(ipv4_to_ipv6("1.1.1.1".to_owned(), true), "1.1.1.1"); + assert_eq!(ipv4_to_ipv6("1.1.1.1".to_owned(), false), "1.1.1.1.nip.io"); + assert_eq!( + ipv4_to_ipv6("1.1.1.1:8080".to_owned(), false), + "1.1.1.1.nip.io:8080" + ); + assert_eq!( + ipv4_to_ipv6("rustdesk.com".to_owned(), false), + "rustdesk.com" + ); + if ("rustdesk.com:80") + .to_socket_addrs() + .unwrap() + .next() + .unwrap() + .is_ipv6() + { + assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()) + .await + .unwrap() + .is_ipv6()); + return; + } + assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()).await.is_err()); + } + + #[test] + fn test_test_if_valid_server() { + assert!(!test_if_valid_server("a", false).is_empty()); + // on Linux, "1" is resolved to "0.0.0.1" + assert!(test_if_valid_server("1.1.1.1", false).is_empty()); + assert!(test_if_valid_server("1.1.1.1:1", false).is_empty()); + assert!(test_if_valid_server("microsoft.com", false).is_empty()); + assert!(test_if_valid_server("microsoft.com:1", false).is_empty()); + + // with proxy + // `:0` indicates `let host = check_port(host, 0);` is called. + assert!(test_if_valid_server_for_proxy_("a:0").is_empty()); + assert!(test_if_valid_server_for_proxy_("1.1.1.1:0").is_empty()); + assert!(test_if_valid_server_for_proxy_("1.1.1.1:1").is_empty()); + assert!(test_if_valid_server_for_proxy_("abc.com:0").is_empty()); + assert!(test_if_valid_server_for_proxy_("abcd.com:1").is_empty()); + } + + #[test] + fn test_check_port() { + assert_eq!(check_port("[1:2]:12", 32), "[1:2]:12"); + assert_eq!(check_port("1:2", 32), "[1:2]:32"); + assert_eq!(check_port("z1:2", 32), "z1:2"); + assert_eq!(check_port("1.1.1.1", 32), "1.1.1.1:32"); + assert_eq!(check_port("1.1.1.1:32", 32), "1.1.1.1:32"); + assert_eq!(check_port("test.com:32", 0), "test.com:32"); + assert_eq!(increase_port("[1:2]:12", 1), "[1:2]:13"); + assert_eq!(increase_port("1.2.2.4:12", 1), "1.2.2.4:13"); + assert_eq!(increase_port("1.2.2.4", 1), "1.2.2.4"); + assert_eq!(increase_port("test.com", 1), "test.com"); + assert_eq!(increase_port("test.com:13", 4), "test.com:17"); + assert_eq!(increase_port("1:13", 4), "1:13"); + assert_eq!(increase_port("22:1:13", 4), "22:1:13"); + assert_eq!(increase_port("z1:2", 1), "z1:3"); + } +} diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..17f360f --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,341 @@ +use crate::{bail, bytes_codec::BytesCodec, ResultType, config::Socks5Server, proxy::Proxy}; +use anyhow::Context as AnyhowCtx; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use protobuf::Message; +use sodiumoxide::crypto::{ + box_, + secretbox::{self, Key, Nonce}, +}; +use std::{ + io::{self, Error, ErrorKind}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + ops::{Deref, DerefMut}, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::{lookup_host, TcpListener, TcpSocket, ToSocketAddrs}, +}; +use tokio_socks::IntoTargetAddr; +use tokio_util::codec::Framed; + +pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {} +pub struct DynTcpStream(pub(crate) Box); + +#[derive(Clone)] +pub struct Encrypt(Key, u64, u64); + +pub struct FramedStream( + pub(crate) Framed, + pub(crate) SocketAddr, + pub(crate) Option, + pub(crate) u64, +); + +impl Deref for FramedStream { + type Target = Framed; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FramedStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Deref for DynTcpStream { + type Target = Box; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DynTcpStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +pub(crate) fn new_socket(addr: std::net::SocketAddr, reuse: bool) -> Result { + let socket = match addr { + std::net::SocketAddr::V4(..) => TcpSocket::new_v4()?, + std::net::SocketAddr::V6(..) => TcpSocket::new_v6()?, + }; + if reuse { + // windows has no reuse_port, but it's reuse_address + // almost equals to unix's reuse_port + reuse_address, + // though may introduce nondeterministic behavior + #[cfg(unix)] + socket.set_reuseport(true).ok(); + socket.set_reuseaddr(true).ok(); + } + socket.bind(addr)?; + Ok(socket) +} + +impl FramedStream { + pub async fn new( + remote_addr: T, + local_addr: Option, + ms_timeout: u64, + ) -> ResultType { + for remote_addr in lookup_host(&remote_addr).await? { + let local = if let Some(addr) = local_addr { + addr + } else { + crate::config::Config::get_any_listen_addr(remote_addr.is_ipv4()) + }; + if let Ok(socket) = new_socket(local, true) { + if let Ok(Ok(stream)) = + super::timeout(ms_timeout, socket.connect(remote_addr)).await + { + stream.set_nodelay(true).ok(); + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); + } + } + } + bail!(format!("Failed to connect to {remote_addr}")); + } + + pub async fn connect<'t, T>( + target: T, + local_addr: Option, + proxy_conf: &Socks5Server, + ms_timeout: u64, + ) -> ResultType + where + T: IntoTargetAddr<'t>, + { + let proxy = Proxy::from_conf(proxy_conf, Some(ms_timeout))?; + proxy.connect::(target, local_addr).await + } + + pub fn local_addr(&self) -> SocketAddr { + self.1 + } + + pub fn set_send_timeout(&mut self, ms: u64) { + self.3 = ms; + } + + pub fn from(stream: impl TcpStreamTrait + Send + Sync + 'static, addr: SocketAddr) -> Self { + Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + ) + } + + pub fn set_raw(&mut self) { + self.0.codec_mut().set_raw(); + self.2 = None; + } + + pub fn is_secured(&self) -> bool { + self.2.is_some() + } + + #[inline] + pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> { + self.send_raw(msg.write_to_bytes()?).await + } + + #[inline] + pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { + let mut msg = msg; + if let Some(key) = self.2.as_mut() { + msg = key.enc(&msg); + } + self.send_bytes(bytes::Bytes::from(msg)).await?; + Ok(()) + } + + #[inline] + pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { + if self.3 > 0 { + super::timeout(self.3, self.0.send(bytes)).await??; + } else { + self.0.send(bytes).await?; + } + Ok(()) + } + + #[inline] + pub async fn next(&mut self) -> Option> { + let mut res = self.0.next().await; + if let Some(Ok(bytes)) = res.as_mut() { + if let Some(key) = self.2.as_mut() { + if let Err(err) = key.dec(bytes) { + return Some(Err(err)); + } + } + } + res + } + + #[inline] + pub async fn next_timeout(&mut self, ms: u64) -> Option> { + if let Ok(res) = super::timeout(ms, self.next()).await { + res + } else { + None + } + } + + pub fn set_key(&mut self, key: Key) { + self.2 = Some(Encrypt::new(key)); + } + + fn get_nonce(seqnum: u64) -> Nonce { + let mut nonce = Nonce([0u8; secretbox::NONCEBYTES]); + nonce.0[..std::mem::size_of_val(&seqnum)].copy_from_slice(&seqnum.to_le_bytes()); + nonce + } +} + +const DEFAULT_BACKLOG: u32 = 128; + +pub async fn new_listener(addr: T, reuse: bool) -> ResultType { + if !reuse { + Ok(TcpListener::bind(addr).await?) + } else { + let addr = lookup_host(&addr) + .await? + .next() + .context("could not resolve to any address")?; + new_socket(addr, true)? + .listen(DEFAULT_BACKLOG) + .map_err(anyhow::Error::msg) + } +} + +pub async fn listen_any(port: u16) -> ResultType { + if let Ok(mut socket) = TcpSocket::new_v6() { + #[cfg(unix)] + { + socket.set_reuseport(true).ok(); + socket.set_reuseaddr(true).ok(); + use std::os::unix::io::{FromRawFd, IntoRawFd}; + let raw_fd = socket.into_raw_fd(); + let sock2 = unsafe { socket2::Socket::from_raw_fd(raw_fd) }; + sock2.set_only_v6(false).ok(); + socket = unsafe { TcpSocket::from_raw_fd(sock2.into_raw_fd()) }; + } + #[cfg(windows)] + { + use std::os::windows::prelude::{FromRawSocket, IntoRawSocket}; + let raw_socket = socket.into_raw_socket(); + let sock2 = unsafe { socket2::Socket::from_raw_socket(raw_socket) }; + sock2.set_only_v6(false).ok(); + socket = unsafe { TcpSocket::from_raw_socket(sock2.into_raw_socket()) }; + } + if socket + .bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)) + .is_ok() + { + if let Ok(l) = socket.listen(DEFAULT_BACKLOG) { + return Ok(l); + } + } + } + Ok(new_socket( + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port), + true, + )? + .listen(DEFAULT_BACKLOG)?) +} + +impl Unpin for DynTcpStream {} + +impl AsyncRead for DynTcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf) + } +} + +impl AsyncWrite for DynTcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.0), cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx) + } +} + +impl TcpStreamTrait for R {} + +impl Encrypt { + pub fn new(key: Key) -> Self { + Self(key, 0, 0) + } + + pub fn dec(&mut self, bytes: &mut BytesMut) -> Result<(), Error> { + if bytes.len() <= 1 { + return Ok(()); + } + self.2 += 1; + let nonce = FramedStream::get_nonce(self.2); + match secretbox::open(bytes, &nonce, &self.0) { + Ok(res) => { + bytes.clear(); + bytes.put_slice(&res); + Ok(()) + } + Err(()) => Err(Error::new(ErrorKind::Other, "decryption error")), + } + } + + pub fn enc(&mut self, data: &[u8]) -> Vec { + self.1 += 1; + let nonce = FramedStream::get_nonce(self.1); + secretbox::seal(&data, &nonce, &self.0) + } + + pub fn decode( + symmetric_data: &[u8], + their_pk_b: &[u8], + our_sk_b: &box_::SecretKey, + ) -> ResultType { + if their_pk_b.len() != box_::PUBLICKEYBYTES { + anyhow::bail!("Handshake failed: pk length {}", their_pk_b.len()); + } + let nonce = box_::Nonce([0u8; box_::NONCEBYTES]); + let mut pk_ = [0u8; box_::PUBLICKEYBYTES]; + pk_[..].copy_from_slice(their_pk_b); + let their_pk_b = box_::PublicKey(pk_); + let symmetric_key = box_::open(symmetric_data, &nonce, &their_pk_b, &our_sk_b) + .map_err(|_| anyhow::anyhow!("Handshake failed: box decryption failure"))?; + if symmetric_key.len() != secretbox::KEYBYTES { + anyhow::bail!("Handshake failed: invalid secret key length from peer"); + } + let mut key = [0u8; secretbox::KEYBYTES]; + key[..].copy_from_slice(&symmetric_key); + Ok(Key(key)) + } +} diff --git a/src/udp.rs b/src/udp.rs new file mode 100644 index 0000000..68abd42 --- /dev/null +++ b/src/udp.rs @@ -0,0 +1,170 @@ +use crate::ResultType; +use anyhow::{anyhow, Context}; +use bytes::{Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use protobuf::Message; +use socket2::{Domain, Socket, Type}; +use std::net::SocketAddr; +use tokio::net::{lookup_host, ToSocketAddrs, UdpSocket}; +use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr, TargetAddr, ToProxyAddrs}; +use tokio_util::{codec::BytesCodec, udp::UdpFramed}; + +pub enum FramedSocket { + Direct(UdpFramed), + ProxySocks(Socks5UdpFramed), +} + +fn new_socket(addr: SocketAddr, reuse: bool, buf_size: usize) -> Result { + let socket = match addr { + SocketAddr::V4(..) => Socket::new(Domain::ipv4(), Type::dgram(), None), + SocketAddr::V6(..) => Socket::new(Domain::ipv6(), Type::dgram(), None), + }?; + if reuse { + // windows has no reuse_port, but it's reuse_address + // almost equals to unix's reuse_port + reuse_address, + // though may introduce nondeterministic behavior + #[cfg(unix)] + socket.set_reuse_port(true).ok(); + socket.set_reuse_address(true).ok(); + } + // only nonblocking work with tokio, https://stackoverflow.com/questions/64649405/receiver-on-tokiompscchannel-only-receives-messages-when-buffer-is-full + socket.set_nonblocking(true)?; + if buf_size > 0 { + socket.set_recv_buffer_size(buf_size).ok(); + } + log::debug!( + "Receive buf size of udp {}: {:?}", + addr, + socket.recv_buffer_size() + ); + if addr.is_ipv6() && addr.ip().is_unspecified() && addr.port() > 0 { + socket.set_only_v6(false).ok(); + } + socket.bind(&addr.into())?; + Ok(socket) +} + +impl FramedSocket { + pub async fn new(addr: T) -> ResultType { + Self::new_reuse(addr, false, 0).await + } + + pub async fn new_reuse( + addr: T, + reuse: bool, + buf_size: usize, + ) -> ResultType { + let addr = lookup_host(&addr) + .await? + .next() + .context("could not resolve to any address")?; + Ok(Self::Direct(UdpFramed::new( + UdpSocket::from_std(new_socket(addr, reuse, buf_size)?.into_udp_socket())?, + BytesCodec::new(), + ))) + } + + pub async fn new_proxy<'a, 't, P: ToProxyAddrs, T: ToSocketAddrs>( + proxy: P, + local: T, + username: &'a str, + password: &'a str, + ms_timeout: u64, + ) -> ResultType { + let framed = if username.trim().is_empty() { + super::timeout(ms_timeout, Socks5UdpFramed::connect(proxy, Some(local))).await?? + } else { + super::timeout( + ms_timeout, + Socks5UdpFramed::connect_with_password(proxy, Some(local), username, password), + ) + .await?? + }; + log::trace!( + "Socks5 udp connected, local addr: {:?}, target addr: {}", + framed.local_addr(), + framed.socks_addr() + ); + Ok(Self::ProxySocks(framed)) + } + + #[inline] + pub async fn send( + &mut self, + msg: &impl Message, + addr: impl IntoTargetAddr<'_>, + ) -> ResultType<()> { + let addr = addr.into_target_addr()?.to_owned(); + let send_data = Bytes::from(msg.write_to_bytes()?); + match self { + Self::Direct(f) => { + if let TargetAddr::Ip(addr) = addr { + f.send((send_data, addr)).await? + } + } + Self::ProxySocks(f) => f.send((send_data, addr)).await?, + }; + Ok(()) + } + + // https://stackoverflow.com/a/68733302/1926020 + #[inline] + pub async fn send_raw( + &mut self, + msg: &'static [u8], + addr: impl IntoTargetAddr<'static>, + ) -> ResultType<()> { + let addr = addr.into_target_addr()?.to_owned(); + + match self { + Self::Direct(f) => { + if let TargetAddr::Ip(addr) = addr { + f.send((Bytes::from(msg), addr)).await? + } + } + Self::ProxySocks(f) => f.send((Bytes::from(msg), addr)).await?, + }; + Ok(()) + } + + #[inline] + pub async fn next(&mut self) -> Option)>> { + match self { + Self::Direct(f) => match f.next().await { + Some(Ok((data, addr))) => { + Some(Ok((data, addr.into_target_addr().ok()?.to_owned()))) + } + Some(Err(e)) => Some(Err(anyhow!(e))), + None => None, + }, + Self::ProxySocks(f) => match f.next().await { + Some(Ok((data, _))) => Some(Ok((data.data, data.dst_addr))), + Some(Err(e)) => Some(Err(anyhow!(e))), + None => None, + }, + } + } + + #[inline] + pub async fn next_timeout( + &mut self, + ms: u64, + ) -> Option)>> { + if let Ok(res) = + tokio::time::timeout(std::time::Duration::from_millis(ms), self.next()).await + { + res + } else { + None + } + } + + pub fn local_addr(&self) -> Option { + if let FramedSocket::Direct(x) = self { + if let Ok(v) = x.get_ref().local_addr() { + return Some(v); + } + } + None + } +}