diff --git a/src/fs.rs b/src/fs.rs index ef1e4a2f6..301d7efd3 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -398,7 +398,7 @@ pub struct TransferJob { pub is_resume: bool, pub file_num: i32, #[serde(skip_serializing)] - pub files: Vec, + files: Vec, pub conn_id: i32, // server only #[serde(skip_serializing)] @@ -457,25 +457,15 @@ fn is_compressed_file(name: &str) -> bool { compressed_exts.contains(&ext) } -#[inline] -fn validate_file_name_no_traversal(name: &str) -> ResultType<()> { +pub fn validate_file_name_no_traversal(name: &str) -> ResultType<()> { if name.bytes().any(|b| b == 0) { bail!("file name contains null bytes"); } - #[cfg(windows)] - if name - .split(|c| c == '/' || c == '\\') + let has_traversal = name + .split(|c: char| c == '/' || (cfg!(windows) && c == '\\')) .filter(|s| !s.is_empty()) - .any(|component| component == "..") - { - bail!("path traversal detected in file name"); - } - #[cfg(not(windows))] - if name - .split('/') - .filter(|s| !s.is_empty()) - .any(|component| component == "..") - { + .any(|component| component == ".."); + if has_traversal { bail!("path traversal detected in file name"); } #[cfg(windows)] @@ -497,8 +487,9 @@ fn validate_file_name_no_traversal(name: &str) -> ResultType<()> { Ok(()) } -#[inline] fn validate_transfer_file_names(files: &[FileEntry]) -> ResultType<()> { + // Single-file transfer may use an empty relative name, because + // the destination file path is carried by transfer metadata. if files.len() == 1 && files.first().map_or(false, |f| f.name.is_empty()) { return Ok(()); } @@ -511,7 +502,6 @@ fn validate_transfer_file_names(files: &[FileEntry]) -> ResultType<()> { Ok(()) } -#[inline] fn validate_no_symlink_components(base: &PathBuf, name: &str) -> ResultType<()> { if name.is_empty() { return Ok(()); @@ -521,6 +511,8 @@ fn validate_no_symlink_components(base: &PathBuf, name: &str) -> ResultType<()> match component { std::path::Component::Normal(seg) => { current.push(seg); + // Best-effort guard: path-based checks are inherently TOCTOU-prone + // if local filesystem state changes between validation and write. if let Ok(meta) = std::fs::symlink_metadata(¤t) { if meta.file_type().is_symlink() { bail!("symlink path component is not allowed"); @@ -536,7 +528,6 @@ fn validate_no_symlink_components(base: &PathBuf, name: &str) -> ResultType<()> Ok(()) } -#[inline] fn join_validated_path(base: &PathBuf, name: &str) -> ResultType { validate_file_name_no_traversal(name)?; validate_no_symlink_components(base, name)?; @@ -553,11 +544,9 @@ impl TransferJob { file_num: i32, show_hidden: bool, is_remote: bool, - files: Vec, enable_overwrite_detection: bool, ) -> Self { log::info!("new write {}", data_source); - let total_size = files.iter().map(|x| x.size).sum(); Self { id, r#type, @@ -566,13 +555,18 @@ impl TransferJob { file_num, show_hidden, is_remote, - files, - total_size, + files: Vec::new(), + total_size: 0, enable_overwrite_detection, ..Default::default() } } + pub fn with_files(mut self, files: Vec) -> ResultType { + self.set_files(files)?; + Ok(self) + } + pub fn new_read( id: i32, r#type: JobType, @@ -631,6 +625,7 @@ impl TransferJob { validate_no_symlink_components(base, &file.name)?; } } + self.total_size = files.iter().map(|x| x.size).sum(); self.files = files; Ok(()) } @@ -666,6 +661,20 @@ impl TransferJob { self.file_num } + fn resolve_entry_path(&self, base: &PathBuf, name: &str) -> Option { + if self.r#type == JobType::Generic { + match join_validated_path(base, name) { + Ok(path) => Some(path), + Err(err) => { + log::error!("Invalid file name in transfer job {}: {}", self.id, err); + None + } + } + } else { + Some(Self::join(base, name)) + } + } + pub fn modify_time(&self) { if self.r#type == JobType::Printer { return; @@ -674,16 +683,8 @@ impl TransferJob { let file_num = self.file_num as usize; if file_num < self.files.len() { let entry = &self.files[file_num]; - let path = if self.r#type == JobType::Generic { - match join_validated_path(p, &entry.name) { - Ok(path) => path, - Err(err) => { - log::error!("Invalid file name in transfer job {}: {}", self.id, err); - return; - } - } - } else { - Self::join(p, &entry.name) + let Some(path) = self.resolve_entry_path(p, &entry.name) else { + return; }; let download_path = format!("{}.download", get_string(&path)); let digest_path = format!("{}.digest", get_string(&path)); @@ -706,16 +707,8 @@ impl TransferJob { let file_num = self.file_num as usize; if file_num < self.files.len() { let entry = &self.files[file_num]; - let path = if self.r#type == JobType::Generic { - match join_validated_path(p, &entry.name) { - Ok(path) => path, - Err(err) => { - log::error!("Invalid file name in transfer job {}: {}", self.id, err); - return; - } - } - } else { - Self::join(p, &entry.name) + let Some(path) = self.resolve_entry_path(p, &entry.name) else { + return; }; let download_path = format!("{}.download", get_string(&path)); let digest_path = format!("{}.digest", get_string(&path)); @@ -1082,16 +1075,8 @@ impl TransferJob { async fn set_stream_offset(&mut self, file_num: usize, offset: u64) { if let DataSource::FilePath(p) = &self.data_source { let entry = &self.files[file_num]; - let path = if self.r#type == JobType::Generic { - match join_validated_path(p, &entry.name) { - Ok(path) => path, - Err(err) => { - log::error!("Invalid file name in transfer job {}: {}", self.id, err); - return; - } - } - } else { - Self::join(p, &entry.name) + let Some(path) = self.resolve_entry_path(p, &entry.name) else { + return; }; let file_path = get_string(&path); let download_path = format!("{}.download", &file_path); @@ -1529,13 +1514,12 @@ mod tests { 0, false, true, - Vec::new(), false, ) } - fn new_write_job(id: i32, download_dir: PathBuf, name: &str) -> TransferJob { - TransferJob::new_write( + fn new_write_job(id: i32, download_dir: PathBuf, name: &str) -> ResultType { + let job = TransferJob::new_write( id, JobType::Generic, "/fake/remote".to_string(), @@ -1543,17 +1527,10 @@ mod tests { 0, false, true, - vec![new_file_entry(name)], false, ) - } - - fn make_test_block(id: i32, payload: &[u8]) -> FileTransferBlock { - let mut block = FileTransferBlock::new(); - block.id = id; - block.file_num = 0; - block.data = payload.to_vec().into(); - block + .with_files(vec![new_file_entry(name)])?; + Ok(job) } fn assert_err_contains(err: anyhow::Error, expected: &str) { @@ -1565,17 +1542,13 @@ mod tests { ); } - #[tokio::test] - async fn path_traversal_e2e_write_rejects_relative_escape() { + #[test] + fn path_traversal_e2e_write_rejects_relative_escape() { let tmp_root = unique_temp_dir("rustdesk_e2e_relative"); let downloads = tmp_root.join("downloads"); std::fs::create_dir_all(&downloads).expect("create downloads dir"); - let mut job = new_write_job(1, downloads, "../traversal_proof.txt"); - let block = make_test_block(1, b"malicious payload"); - let err = job - .write(block) - .await + let err = new_write_job(1, downloads, "../traversal_proof.txt") .expect_err("relative path traversal must be rejected"); assert_err_contains(err, "path traversal"); assert!(!tmp_root.join("traversal_proof.txt").exists()); @@ -1583,18 +1556,14 @@ mod tests { let _ = std::fs::remove_dir_all(&tmp_root); } - #[tokio::test] - async fn path_traversal_e2e_write_rejects_absolute_path() { + #[test] + fn path_traversal_e2e_write_rejects_absolute_path() { let tmp_root = unique_temp_dir("rustdesk_e2e_absolute"); let downloads = tmp_root.join("downloads"); let absolute_target = tmp_root.join("fake_ssh").join("authorized_keys"); std::fs::create_dir_all(&downloads).expect("create downloads dir"); - let mut job = new_write_job(2, downloads, &absolute_target.to_string_lossy()); - let block = make_test_block(2, b"ssh key payload"); - let err = job - .write(block) - .await + let err = new_write_job(2, downloads, &absolute_target.to_string_lossy()) .expect_err("absolute path must be rejected"); assert_err_contains(err, "absolute path"); assert!(!absolute_target.exists()); @@ -1602,8 +1571,8 @@ mod tests { let _ = std::fs::remove_dir_all(&tmp_root); } - #[tokio::test] - async fn path_traversal_e2e_write_rejects_symlink_escape() { + #[test] + fn path_traversal_e2e_write_rejects_symlink_escape() { let tmp_root = unique_temp_dir("rustdesk_e2e_symlink"); let downloads = tmp_root.join("downloads"); let outside = tmp_root.join("outside"); @@ -1633,11 +1602,7 @@ mod tests { } } - let mut job = new_write_job(3, downloads, "link/escape.txt"); - let block = make_test_block(3, b"symlink escape payload"); - let err = job - .write(block) - .await + let err = new_write_job(3, downloads, "link/escape.txt") .expect_err("symlink traversal must be rejected"); assert_err_contains(err, "symlink"); assert!(!escaped_target.exists()); @@ -1744,7 +1709,6 @@ mod tests { 0, false, true, - Vec::new(), false, ); let err = job