diff --git a/include/Common.h b/include/Common.h index abf758a..313df2b 100644 --- a/include/Common.h +++ b/include/Common.h @@ -135,8 +135,9 @@ std::string ThreadName(bool DebugModeOverride = false); void RegisterThread(const std::string& str); #define RegisterThreadAuto() RegisterThread(__func__) -#define KB 1024 -#define MB (KB * 1024) +#define KB 1024llu +#define MB (KB * 1024llu) +#define GB (MB * 1024llu) #define SSU_UNRAW SECRET_SENTRY_URL #define _file_basename std::filesystem::path(__FILE__).filename().string() diff --git a/include/TNetwork.h b/include/TNetwork.h index 528aef4..87e4235 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -48,4 +48,5 @@ private: void SendFile(TClient& c, const std::string& Name); static bool TCPSendRaw(TClient& C, SOCKET socket, char* Data, int32_t Size); static void SplitLoad(TClient& c, size_t Sent, size_t Size, bool D, const std::string& Name); + static uint8_t* SendSplit(TClient& c, SOCKET Socket, uint8_t* DataPtr, size_t Size); }; diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index dfa8776..9c0e748 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -761,14 +761,63 @@ void TNetwork::SendFile(TClient& c, const std::string& UnsafeName) { } } +static std::pair SplitIntoChunks(size_t FullSize, size_t ChunkSize) { + if (FullSize < ChunkSize) { + return { 0, FullSize }; + } + size_t Count = FullSize / (FullSize / ChunkSize); + size_t LastChunkSize = FullSize - (Count * ChunkSize); + return { Count, LastChunkSize }; +} + +TEST_CASE("SplitIntoChunks") { + size_t FullSize; + size_t ChunkSize; + SUBCASE("Normal case") { + FullSize = 1234567; + ChunkSize = 1234; + } + SUBCASE("Zero original size") { + FullSize = 0; + ChunkSize = 100; + } + SUBCASE("Equal full size and chunk size") { + FullSize = 125; + ChunkSize = 125; + } + SUBCASE("Even split") { + FullSize = 10000; + ChunkSize = 100; + } + SUBCASE("Odd split") { + FullSize = 13; + ChunkSize = 2; + } + SUBCASE("Large sizes") { + FullSize = 10 * GB; + ChunkSize = 125 * MB; + } + auto [Count, LastSize] = SplitIntoChunks(FullSize, ChunkSize); + CHECK((Count * ChunkSize) + LastSize == FullSize); +} + +uint8_t* /* end ptr */ TNetwork::SendSplit(TClient& c, SOCKET Socket, uint8_t* DataPtr, size_t Size) { + if (TCPSendRaw(c, Socket, reinterpret_cast(DataPtr), Size)) { + return DataPtr + Size; + } else { + return nullptr; + } +} + void TNetwork::SplitLoad(TClient& c, size_t Sent, size_t Size, bool D, const std::string& Name) { std::ifstream f(Name.c_str(), std::ios::binary); + auto Buf = f.rdbuf(); uint32_t Split = 125 * MB; - char* Data; + std::vector Data; if (Size > Split) - Data = new char[Split]; + Data.resize(Split); else - Data = new char[Size]; + Data.resize(Size); SOCKET TCPSock; if (D) TCPSock = c.GetDownSock(); @@ -779,8 +828,8 @@ void TNetwork::SplitLoad(TClient& c, size_t Sent, size_t Size, bool D, const std size_t Diff = Size - Sent; if (Diff > Split) { f.seekg(Sent, std::ios_base::beg); - f.read(Data, Split); - if (!TCPSendRaw(c, TCPSock, Data, Split)) { + f.read(reinterpret_cast(Data.data()), Split); + if (!TCPSendRaw(c, TCPSock, reinterpret_cast(Data.data()), Split)) { if (c.GetStatus() > -1) c.SetStatus(-1); break; @@ -788,8 +837,8 @@ void TNetwork::SplitLoad(TClient& c, size_t Sent, size_t Size, bool D, const std Sent += Split; } else { f.seekg(Sent, std::ios_base::beg); - f.read(Data, Diff); - if (!TCPSendRaw(c, TCPSock, Data, int32_t(Diff))) { + f.read(reinterpret_cast(Data.data()), Diff); + if (!TCPSendRaw(c, TCPSock, reinterpret_cast(Data.data()), int32_t(Diff))) { if (c.GetStatus() > -1) c.SetStatus(-1); break; @@ -797,8 +846,6 @@ void TNetwork::SplitLoad(TClient& c, size_t Sent, size_t Size, bool D, const std Sent += Diff; } } - delete[] Data; - f.close(); } bool TNetwork::TCPSendRaw(TClient& C, SOCKET socket, char* Data, int32_t Size) {