From c907ad8fb2085ccc654cf04782f85410527d1540 Mon Sep 17 00:00:00 2001 From: fufesou Date: Wed, 26 Mar 2025 14:34:24 +0800 Subject: [PATCH] refact: fs, buf stream Signed-off-by: fufesou --- src/fs.rs | 295 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 209 insertions(+), 86 deletions(-) diff --git a/src/fs.rs b/src/fs.rs index 135e617..d30e7bf 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -1,6 +1,8 @@ #[cfg(windows)] use std::os::windows::prelude::*; use std::{ + fmt::{Debug, Display}, + io::Cursor, path::{Path, PathBuf}, sync::atomic::{AtomicI32, Ordering}, time::{Duration, SystemTime, UNIX_EPOCH}, @@ -8,7 +10,10 @@ use std::{ use serde_derive::{Deserialize, Serialize}; use serde_json::json; -use tokio::{fs::File, io::*}; +use tokio::{ + fs::File, + io::{AsyncReadExt, AsyncWriteExt, BufStream as TokioBufStream}, +}; use crate::{anyhow::anyhow, bail, get_version_number, message_proto::*, ResultType, Stream}; // https://doc.rust-lang.org/std/os/windows/fs/trait.MetadataExt.html @@ -301,13 +306,86 @@ impl JobType { } } +#[derive(Debug)] +pub enum DataSource { + FilePath(PathBuf), + MemoryCursor(Cursor>), +} + +impl Default for DataSource { + fn default() -> Self { + DataSource::FilePath(PathBuf::new()) + } +} + +impl serde::Serialize for DataSource { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + match self { + DataSource::FilePath(p) => serializer.serialize_str(p.to_str().unwrap_or("")), + DataSource::MemoryCursor(_) => serializer.serialize_str(""), + } + } +} + +impl Display for DataSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DataSource::FilePath(p) => write!(f, "File: {}", p.to_string_lossy().to_string()), + DataSource::MemoryCursor(_) => write!(f, "Bytes"), + } + } +} + +impl DataSource { + fn to_meta(&self) -> String { + match self { + DataSource::FilePath(p) => p.to_string_lossy().to_string(), + DataSource::MemoryCursor(_) => "".to_string(), + } + } +} + +enum DataStream { + FileStream(File), + BufStream(TokioBufStream>>), +} + +impl Debug for DataStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DataStream::FileStream(fs) => write!(f, "{:?}", fs), + DataStream::BufStream(_) => write!(f, "BufStream"), + } + } +} + +impl DataStream { + async fn write_all(&mut self, buf: &[u8]) -> ResultType<()> { + match self { + DataStream::FileStream(fs) => fs.write_all(buf).await?, + DataStream::BufStream(bs) => bs.write_all(buf).await?, + } + Ok(()) + } + + async fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + DataStream::FileStream(fs) => fs.read(buf).await, + DataStream::BufStream(bs) => bs.read(buf).await, + } + } +} + #[derive(Default, Serialize, Debug)] #[serde(rename_all = "camelCase")] pub struct TransferJob { pub id: i32, pub r#type: JobType, pub remote: String, - pub path: PathBuf, + pub data_source: DataSource, pub show_hidden: bool, pub is_remote: bool, pub is_last_job: bool, @@ -317,7 +395,7 @@ pub struct TransferJob { pub conn_id: i32, // server only #[serde(skip_serializing)] - file: Option, + data_stream: Option, pub total_size: u64, finished_size: u64, transferred: u64, @@ -376,20 +454,20 @@ impl TransferJob { id: i32, r#type: JobType, remote: String, - path: String, + data_source: DataSource, file_num: i32, show_hidden: bool, is_remote: bool, files: Vec, enable_overwrite_detection: bool, ) -> Self { - log::info!("new write {}", path); + log::info!("new write {}", data_source); let total_size = files.iter().map(|x| x.size).sum(); Self { id, r#type, remote, - path: get_path(&path), + data_source, file_num, show_hidden, is_remote, @@ -404,20 +482,27 @@ impl TransferJob { id: i32, r#type: JobType, remote: String, - path: String, + data_source: DataSource, file_num: i32, show_hidden: bool, is_remote: bool, enable_overwrite_detection: bool, ) -> ResultType { - log::info!("new read {}", path); - let files = get_recursive_files(&path, show_hidden)?; - let total_size = files.iter().map(|x| x.size).sum(); + log::info!("new read {}", data_source); + let (files, total_size) = match &data_source { + DataSource::FilePath(p) => { + let p = p.to_str().ok_or(anyhow!("Invalid path"))?; + let files = get_recursive_files(p, show_hidden)?; + let total_size = files.iter().map(|x| x.size).sum(); + (files, total_size) + } + DataSource::MemoryCursor(c) => (Vec::new(), c.get_ref().len() as u64), + }; Ok(Self { id, r#type, remote, - path: get_path(&path), + data_source, file_num, show_hidden, is_remote, @@ -428,6 +513,13 @@ impl TransferJob { }) } + pub fn get_buf_data(self) -> Option> { + match self.data_stream { + Some(DataStream::BufStream(bs)) => Some(bs.into_inner().into_inner()), + _ => None, + } + } + #[inline] pub fn files(&self) -> &Vec { &self.files @@ -467,17 +559,19 @@ impl TransferJob { if self.r#type == JobType::Printer { return; } - let file_num = self.file_num as usize; - if file_num < self.files.len() { - let entry = &self.files[file_num]; - let path = self.join(&entry.name); - let download_path = format!("{}.download", get_string(&path)); - std::fs::rename(download_path, &path).ok(); - filetime::set_file_mtime( - &path, - filetime::FileTime::from_unix_time(entry.modified_time as _, 0), - ) - .ok(); + if let DataSource::FilePath(p) = &self.data_source { + let file_num = self.file_num as usize; + if file_num < self.files.len() { + let entry = &self.files[file_num]; + let path = Self::join(p, &entry.name); + let download_path = format!("{}.download", get_string(&path)); + std::fs::rename(download_path, &path).ok(); + filetime::set_file_mtime( + &path, + filetime::FileTime::from_unix_time(entry.modified_time as _, 0), + ) + .ok(); + } } } @@ -485,12 +579,14 @@ impl TransferJob { if self.r#type == JobType::Printer { return; } - let file_num = self.file_num as usize; - if file_num < self.files.len() { - let entry = &self.files[file_num]; - let path = self.join(&entry.name); - let download_path = format!("{}.download", get_string(&path)); - std::fs::remove_file(download_path).ok(); + if let DataSource::FilePath(p) = &self.data_source { + let file_num = self.file_num as usize; + if file_num < self.files.len() { + let entry = &self.files[file_num]; + let path = Self::join(p, &entry.name); + let download_path = format!("{}.download", get_string(&path)); + std::fs::remove_file(download_path).ok(); + } } } @@ -498,38 +594,47 @@ impl TransferJob { if block.id != self.id { bail!("Wrong id"); } - let file_num = block.file_num as usize; - if file_num >= self.files.len() { - bail!("Wrong file number"); - } - if file_num != self.file_num as usize || self.file.is_none() { - self.modify_time(); - if let Some(file) = self.file.as_mut() { - file.sync_all().await?; - } - self.file_num = block.file_num; - let entry = &self.files[file_num]; - let path = if self.r#type == JobType::Printer { - self.path.to_string_lossy().to_string() - } else { - let path = self.join(&entry.name); - if let Some(p) = path.parent() { - std::fs::create_dir_all(p).ok(); + match &self.data_source { + DataSource::FilePath(p) => { + let file_num = block.file_num as usize; + if file_num >= self.files.len() { + bail!("Wrong file number"); } - format!("{}.download", get_string(&path)) - }; - self.file = Some(File::create(&path).await?); + if file_num != self.file_num as usize || self.data_stream.is_none() { + self.modify_time(); + if let Some(DataStream::FileStream(file)) = self.data_stream.as_mut() { + file.sync_all().await?; + } + self.file_num = block.file_num; + let entry = &self.files[file_num]; + let path = if self.r#type == JobType::Printer { + p.to_string_lossy().to_string() + } else { + let path = Self::join(p, &entry.name); + if let Some(pp) = path.parent() { + std::fs::create_dir_all(pp).ok(); + } + format!("{}.download", get_string(&path)) + }; + self.data_stream = Some(DataStream::FileStream(File::create(&path).await?)); + } + } + DataSource::MemoryCursor(c) => { + if self.data_stream.is_none() { + self.data_stream = Some(DataStream::BufStream(TokioBufStream::new(c.clone()))); + } + } } if block.compressed { let tmp = decompress(&block.data); - self.file + self.data_stream .as_mut() - .ok_or(anyhow!("file is None"))? + .ok_or(anyhow!("data stream is None"))? .write_all(&tmp) .await?; self.finished_size += tmp.len() as u64; } else { - self.file + self.data_stream .as_mut() .ok_or(anyhow!("file is None"))? .write_all(&block.data) @@ -541,33 +646,46 @@ impl TransferJob { } #[inline] - pub fn join(&self, name: &str) -> PathBuf { + pub fn join(p: &PathBuf, name: &str) -> PathBuf { if name.is_empty() { - self.path.clone() + p.clone() } else { - self.path.join(name) + p.join(name) } } pub async fn read(&mut self, stream: &mut Stream) -> ResultType> { let file_num = self.file_num as usize; - if file_num >= self.files.len() { - self.file.take(); - return Ok(None); - } - let name = &self.files[file_num].name; - if self.file.is_none() { - match File::open(self.join(name)).await { - Ok(file) => { - self.file = Some(file); - self.file_confirmed = false; - self.file_is_waiting = false; + let name: &str; + match &mut self.data_source { + DataSource::FilePath(p) => { + if file_num >= self.files.len() { + self.data_stream.take(); + return Ok(None); + }; + name = &self.files[file_num].name; + if self.data_stream.is_none() { + match File::open(Self::join(p, name)).await { + Ok(file) => { + self.data_stream = Some(DataStream::FileStream(file)); + self.file_confirmed = false; + self.file_is_waiting = false; + } + Err(err) => { + self.file_num += 1; + self.file_confirmed = false; + self.file_is_waiting = false; + return Err(err.into()); + } + } } - Err(err) => { - self.file_num += 1; - self.file_confirmed = false; - self.file_is_waiting = false; - return Err(err.into()); + } + DataSource::MemoryCursor(c) => { + name = ""; + if self.data_stream.is_none() { + let mut t = std::io::Cursor::new(Vec::new()); + std::mem::swap(&mut t, c); + self.data_stream = Some(DataStream::BufStream(TokioBufStream::new(t))); } } } @@ -586,15 +704,15 @@ impl TransferJob { let mut offset: usize = 0; loop { match self - .file + .data_stream .as_mut() - .ok_or(anyhow!("file is None"))? + .ok_or(anyhow!("data stream is None"))? .read(&mut buf[offset..]) .await { Err(err) => { self.file_num += 1; - self.file = None; + self.data_stream = None; self.file_confirmed = false; self.file_is_waiting = false; return Err(err.into()); @@ -609,13 +727,17 @@ impl TransferJob { } unsafe { buf.set_len(offset) }; if offset == 0 { + if matches!(self.data_source, DataSource::MemoryCursor(_)) { + self.data_stream.take(); + return Ok(None); + } self.file_num += 1; - self.file = None; + self.data_stream = None; self.file_confirmed = false; self.file_is_waiting = false; } else { self.finished_size += offset as u64; - if !is_compressed_file(name) { + if matches!(self.data_source, DataSource::FilePath(_)) && !is_compressed_file(name) { let tmp = compress(&buf); if tmp.len() < buf.len() { buf = tmp; @@ -633,15 +755,14 @@ impl TransferJob { })) } + // Only for generic job and file stream async fn send_current_digest(&mut self, stream: &mut Stream) -> ResultType<()> { let mut msg = Message::new(); let mut resp = FileResponse::new(); - let meta = self - .file - .as_ref() - .ok_or(anyhow!("file is None"))? - .metadata() - .await?; + let meta = match self.data_stream.as_ref().ok_or(anyhow!("file is None"))? { + DataStream::FileStream(file) => file.metadata().await?, + DataStream::BufStream(_) => bail!("No need to send digest for buf stream"), + }; let last_modified = meta .modified()? .duration_since(SystemTime::UNIX_EPOCH)? @@ -727,7 +848,7 @@ impl TransferJob { pub fn set_file_skipped(&mut self) -> bool { log::debug!("skip file {} in job {}", self.file_num, self.id); - self.file.take(); + self.data_stream.take(); self.set_file_confirmed(false); self.set_file_is_waiting(false); self.file_num += 1; @@ -761,7 +882,7 @@ impl TransferJob { TransferJobMeta { id: self.id, remote: self.remote.to_string(), - to: self.path.to_string_lossy().to_string(), + to: self.data_source.to_meta(), file_num: self.file_num, show_hidden: self.show_hidden, is_remote: self.is_remote, @@ -875,8 +996,10 @@ pub fn new_done(id: i32, file_num: i32) -> Message { } #[inline] -pub fn remove_job(id: i32, jobs: &mut Vec) { - *jobs = jobs.drain(0..).filter(|x| x.id() != id).collect(); +pub fn remove_job(id: i32, jobs: &mut Vec) -> Option { + jobs.iter() + .position(|x| x.id() == id) + .map(|index| jobs.remove(index)) } #[inline] @@ -928,7 +1051,7 @@ pub async fn handle_read_jobs( } } for id in finished { - remove_job(id, jobs); + let _ = remove_job(id, jobs); } Ok(job_log) }