From 83530a6c15d8fd3d6f1e617ffd4e67baab3ad427 Mon Sep 17 00:00:00 2001 From: DrSocalkwe3n Date: Tue, 6 Jan 2026 18:20:58 +0600 Subject: [PATCH] =?UTF-8?q?codex-5.2:=20=D0=BD=D0=BE=D0=B2=D0=B0=D1=8F=20?= =?UTF-8?q?=D0=BC=D0=BE=D0=B4=D0=B5=D0=BB=D1=8C=20=D0=BE=D0=B1=D1=8A=D0=B5?= =?UTF-8?q?=D0=BA=D1=82=D0=BE=D0=B2=20=D0=B2=D0=B7=D0=B0=D0=B8=D0=BC=D0=BE?= =?UTF-8?q?=D0=B4=D0=B5=D0=B9=D1=81=D1=82=D0=B2=D0=B8=D1=8F=20=D1=81=20?= =?UTF-8?q?=D1=81=D0=B5=D1=82=D1=8C=D1=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Src/Common/Net2.cpp | 483 ++++++++++++++++++++++++++++++++++++++++++++ Src/Common/Net2.hpp | 227 +++++++++++++++++++++ 2 files changed, 710 insertions(+) create mode 100644 Src/Common/Net2.cpp create mode 100644 Src/Common/Net2.hpp diff --git a/Src/Common/Net2.cpp b/Src/Common/Net2.cpp new file mode 100644 index 0000000..8eb2ab1 --- /dev/null +++ b/Src/Common/Net2.cpp @@ -0,0 +1,483 @@ +#include "Net2.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace LV::Net2 { + +using namespace TOS; + +namespace { + +struct HeaderFields { + uint32_t size = 0; + uint16_t type = 0; + Priority priority = Priority::Normal; + FrameFlags flags = FrameFlags::None; + uint32_t streamId = 0; +}; + +std::array encodeHeader(const HeaderFields &h) { + std::array out{}; + uint32_t sizeNet = detail::toNetwork(h.size); + uint16_t typeNet = detail::toNetwork(h.type); + uint32_t streamNet = detail::toNetwork(h.streamId); + + std::memcpy(out.data(), &sizeNet, sizeof(sizeNet)); + std::memcpy(out.data() + 4, &typeNet, sizeof(typeNet)); + out[6] = std::byte(static_cast(h.priority)); + out[7] = std::byte(static_cast(h.flags)); + std::memcpy(out.data() + 8, &streamNet, sizeof(streamNet)); + return out; +} + +HeaderFields decodeHeader(const std::array &in) { + HeaderFields h{}; + std::memcpy(&h.size, in.data(), sizeof(h.size)); + std::memcpy(&h.type, in.data() + 4, sizeof(h.type)); + h.priority = static_cast(std::to_integer(in[6])); + h.flags = static_cast(std::to_integer(in[7])); + std::memcpy(&h.streamId, in.data() + 8, sizeof(h.streamId)); + + h.size = detail::fromNetwork(h.size); + h.type = detail::fromNetwork(h.type); + h.streamId = detail::fromNetwork(h.streamId); + return h; +} + +} // namespace + +PacketWriter& PacketWriter::writeBytes(std::span data) { + Buffer.insert(Buffer.end(), data.begin(), data.end()); + return *this; +} + +PacketWriter& PacketWriter::writeString(std::string_view str) { + write(static_cast(str.size())); + auto bytes = std::as_bytes(std::span(str.data(), str.size())); + Buffer.insert(Buffer.end(), bytes.begin(), bytes.end()); + return *this; +} + +std::vector PacketWriter::release() { + std::vector out = std::move(Buffer); + Buffer.clear(); + return out; +} + +void PacketWriter::clear() { + Buffer.clear(); +} + +PacketReader::PacketReader(std::span data) + : Data(data) +{ +} + +void PacketReader::readBytes(std::span out) { + require(out.size()); + std::memcpy(out.data(), Data.data() + Pos, out.size()); + Pos += out.size(); +} + +std::string PacketReader::readString() { + uint32_t size = read(); + require(size); + std::string out(size, '\0'); + std::memcpy(out.data(), Data.data() + Pos, size); + Pos += size; + return out; +} + +void PacketReader::require(size_t size) { + if(Data.size() - Pos < size) + MAKE_ERROR("Net2::PacketReader: not enough data"); +} + +SocketServer::SocketServer(asio::io_context &ioc, std::function(tcp::socket)> &&onConnect, uint16_t port) + : AsyncObject(ioc), Acceptor(ioc, tcp::endpoint(tcp::v4(), port)) +{ + assert(onConnect); + co_spawn(run(std::move(onConnect))); +} + +bool SocketServer::isStopped() const { + return !Acceptor.is_open(); +} + +uint16_t SocketServer::getPort() const { + return Acceptor.local_endpoint().port(); +} + +coro SocketServer::run(std::function(tcp::socket)> onConnect) { + while(true) { + try { + co_spawn(onConnect(co_await Acceptor.async_accept())); + } catch(const std::exception &exc) { + if(const boost::system::system_error *errc = dynamic_cast(&exc); + errc && (errc->code() == asio::error::operation_aborted || errc->code() == asio::error::bad_descriptor)) + break; + } + } +} + +AsyncSocket::SendQueue::SendQueue(asio::io_context &ioc) + : semaphore(ioc) +{ + semaphore.expires_at(std::chrono::steady_clock::time_point::max()); +} + +bool AsyncSocket::SendQueue::empty() const { + for(const auto &queue : queues) { + if(!queue.empty()) + return false; + } + return true; +} + +AsyncSocket::AsyncSocket(asio::io_context &ioc, tcp::socket &&socket, Limits limits) + : AsyncObject(ioc), LimitsCfg(limits), Socket(std::move(socket)), Outgoing(ioc) +{ + Context = std::make_shared(); + + boost::asio::socket_base::linger optionLinger(true, 4); + Socket.set_option(optionLinger); + boost::asio::ip::tcp::no_delay optionNoDelay(true); + Socket.set_option(optionNoDelay); + + co_spawn(sendLoop()); +} + +AsyncSocket::~AsyncSocket() { + if(Context) + Context->needShutdown.store(true); + + { + boost::lock_guard lock(Outgoing.mtx); + Outgoing.semaphore.cancel(); + WorkDeadline.cancel(); + } + + if(Socket.is_open()) + try { Socket.close(); } catch(...) {} +} + +void AsyncSocket::enqueue(OutgoingMessage &&msg) { + if(msg.payload.size() > LimitsCfg.maxMessageSize) { + setError("Net2::AsyncSocket: message too large"); + close(); + return; + } + + boost::unique_lock lock(Outgoing.mtx); + const size_t msgSize = msg.payload.size(); + const size_t lowIndex = static_cast(Priority::Low); + + if(msg.priority == Priority::Low) { + while(Outgoing.bytesInLow + msgSize > LimitsCfg.maxLowPriorityBytes && !Outgoing.queues[lowIndex].empty()) { + Outgoing.bytesInQueue -= Outgoing.queues[lowIndex].front().payload.size(); + Outgoing.bytesInLow -= Outgoing.queues[lowIndex].front().payload.size(); + Outgoing.queues[lowIndex].pop_front(); + } + if(Outgoing.bytesInLow + msgSize > LimitsCfg.maxLowPriorityBytes) { + return; + } + } + + if(Outgoing.bytesInQueue + msgSize > LimitsCfg.maxQueueBytes) { + dropLow(msgSize); + if(Outgoing.bytesInQueue + msgSize > LimitsCfg.maxQueueBytes) { + if(msg.dropIfOverloaded) + return; + setError("Net2::AsyncSocket: send queue overflow"); + close(); + return; + } + } + + const size_t idx = static_cast(msg.priority); + Outgoing.bytesInQueue += msgSize; + if(msg.priority == Priority::Low) + Outgoing.bytesInLow += msgSize; + Outgoing.queues[idx].push_back(std::move(msg)); + + if(Outgoing.waiting) { + Outgoing.waiting = false; + Outgoing.semaphore.cancel(); + Outgoing.semaphore.expires_at(std::chrono::steady_clock::time_point::max()); + } +} + +coro AsyncSocket::readMessage() { + while(true) { + std::array headerBytes{}; + co_await readExact(headerBytes.data(), headerBytes.size()); + HeaderFields header = decodeHeader(headerBytes); + + if(header.size > LimitsCfg.maxFrameSize) + MAKE_ERROR("Net2::AsyncSocket: frame too large"); + + std::vector chunk(header.size); + if(header.size) + co_await readExact(chunk.data(), chunk.size()); + + if(header.streamId != 0) { + if(Fragments.size() >= LimitsCfg.maxOpenStreams && !Fragments.contains(header.streamId)) + MAKE_ERROR("Net2::AsyncSocket: too many open streams"); + + FragmentState &state = Fragments[header.streamId]; + if(state.data.empty()) { + state.type = header.type; + state.priority = header.priority; + } + + if(state.data.size() + chunk.size() > LimitsCfg.maxMessageSize) + MAKE_ERROR("Net2::AsyncSocket: reassembled message too large"); + + state.data.insert(state.data.end(), chunk.begin(), chunk.end()); + + if(!hasFlag(header.flags, FrameFlags::HasMore)) { + IncomingMessage msg{state.type, state.priority, std::move(state.data)}; + Fragments.erase(header.streamId); + co_return msg; + } + + continue; + } + + if(hasFlag(header.flags, FrameFlags::HasMore)) + MAKE_ERROR("Net2::AsyncSocket: stream id missing for fragmented frame"); + + IncomingMessage msg{header.type, header.priority, std::move(chunk)}; + co_return msg; + } +} + +coro<> AsyncSocket::readLoop(std::function(IncomingMessage&&)> onMessage) { + while(isAlive()) { + IncomingMessage msg = co_await readMessage(); + co_await onMessage(std::move(msg)); + } +} + +void AsyncSocket::closeRead() { + if(Socket.is_open() && !Context->readClosed.exchange(true)) { + try { Socket.shutdown(boost::asio::socket_base::shutdown_receive); } catch(...) {} + } +} + +void AsyncSocket::close() { + if(Context) + Context->needShutdown.store(true); + if(Socket.is_open()) + try { Socket.close(); } catch(...) {} +} + +bool AsyncSocket::isAlive() const { + return Context && !Context->needShutdown.load() && !Context->senderStopped.load() && Socket.is_open(); +} + +std::string AsyncSocket::getError() const { + boost::lock_guard lock(Context->errorMtx); + return Context->error; +} + +coro<> AsyncSocket::sendLoop() { + try { + while(!Context->needShutdown.load()) { + OutgoingMessage msg; + { + boost::unique_lock lock(Outgoing.mtx); + if(Outgoing.empty()) { + Outgoing.waiting = true; + auto coroutine = Outgoing.semaphore.async_wait(); + lock.unlock(); + try { co_await std::move(coroutine); } catch(...) {} + continue; + } + + if(!popNext(msg)) + continue; + } + + co_await sendMessage(std::move(msg)); + } + } catch(const std::exception &exc) { + setError(exc.what()); + } catch(...) { + setError("Net2::AsyncSocket: send loop stopped"); + } + + Context->senderStopped.store(true); +} + +coro<> AsyncSocket::sendMessage(OutgoingMessage &&msg) { + const size_t total = msg.payload.size(); + if(total <= LimitsCfg.maxFrameSize) { + co_await sendFrame(msg.type, msg.priority, FrameFlags::None, 0, msg.payload); + co_return; + } + + if(!msg.allowFragment) { + setError("Net2::AsyncSocket: message requires fragmentation"); + close(); + co_return; + } + + uint32_t streamId = NextStreamId++; + if(streamId == 0) + streamId = NextStreamId++; + + size_t offset = 0; + while(offset < total) { + const size_t chunk = std::min(LimitsCfg.maxFrameSize, total - offset); + const bool more = (offset + chunk) < total; + FrameFlags flags = more ? FrameFlags::HasMore : FrameFlags::None; + std::span view(msg.payload.data() + offset, chunk); + co_await sendFrame(msg.type, msg.priority, flags, streamId, view); + offset += chunk; + } +} + +coro<> AsyncSocket::sendFrame(uint16_t type, Priority priority, FrameFlags flags, uint32_t streamId, + std::span payload) { + HeaderFields header{ + .size = static_cast(payload.size()), + .type = type, + .priority = priority, + .flags = flags, + .streamId = streamId + }; + auto headerBytes = encodeHeader(header); + std::array buffers{ + asio::buffer(headerBytes), + asio::buffer(payload.data(), payload.size()) + }; + if(payload.empty()) + co_await asio::async_write(Socket, asio::buffer(headerBytes)); + else + co_await asio::async_write(Socket, buffers); +} + +coro<> AsyncSocket::readExact(std::byte *data, size_t size) { + if(size == 0) + co_return; + co_await asio::async_read(Socket, asio::buffer(data, size)); +} + +bool AsyncSocket::popNext(OutgoingMessage &out) { + static constexpr int kWeights[4] = {8, 4, 2, 1}; + + for(int attempt = 0; attempt < 4; ++attempt) { + const uint8_t idx = static_cast((Outgoing.nextIndex + attempt) % 4); + auto &queue = Outgoing.queues[idx]; + if(queue.empty()) + continue; + + if(Outgoing.credits[idx] <= 0) + Outgoing.credits[idx] = kWeights[idx]; + + if(Outgoing.credits[idx] <= 0) + continue; + + out = std::move(queue.front()); + queue.pop_front(); + Outgoing.credits[idx]--; + Outgoing.nextIndex = idx; + + const size_t msgSize = out.payload.size(); + Outgoing.bytesInQueue -= msgSize; + if(idx == static_cast(Priority::Low)) + Outgoing.bytesInLow -= msgSize; + return true; + } + + for(int i = 0; i < 4; ++i) + Outgoing.credits[i] = kWeights[i]; + return false; +} + +void AsyncSocket::dropLow(size_t needBytes) { + const size_t lowIndex = static_cast(Priority::Low); + while(Outgoing.bytesInQueue + needBytes > LimitsCfg.maxQueueBytes && !Outgoing.queues[lowIndex].empty()) { + const size_t size = Outgoing.queues[lowIndex].front().payload.size(); + Outgoing.bytesInQueue -= size; + Outgoing.bytesInLow -= size; + Outgoing.queues[lowIndex].pop_front(); + } +} + +void AsyncSocket::setError(const std::string &msg) { + if(!Context) + return; + boost::lock_guard lock(Context->errorMtx); + Context->error = msg; +} + +coro asyncConnectTo(const std::string &address, + std::function onProgress) { + std::string progress; + auto addLog = [&](const std::string &msg) { + progress += '\n'; + progress += msg; + if(onProgress) + onProgress('\n' + msg); + }; + + auto ioc = co_await asio::this_coro::executor; + + addLog("Parsing address " + address); + auto re = Str::match(address, "((?:\\[[\\d\\w:]+\\])|(?:[\\d\\.]+))(?:\\:(\\d+))?"); + + std::vector> eps; + + if(!re) { + re = Str::match(address, "([-_\\.\\w\\d]+)(?:\\:(\\d+))?"); + if(!re) + MAKE_ERROR("Failed to parse address"); + + tcp::resolver resv{ioc}; + tcp::resolver::results_type result; + + addLog("Resolving name..."); + result = co_await resv.async_resolve(*re->at(1), re->at(2) ? *re->at(2) : "7890"); + + addLog("Got " + std::to_string(result.size()) + " endpoints"); + for(auto iter : result) { + std::string addr = iter.endpoint().address().to_string() + ':' + std::to_string(iter.endpoint().port()); + std::string hostname = iter.host_name(); + if(hostname == addr) + addLog("ep: " + addr); + else + addLog("ep: " + hostname + " (" + addr + ')'); + + eps.emplace_back(iter.endpoint(), iter.host_name()); + } + } else { + eps.emplace_back(tcp::endpoint{asio::ip::make_address(*re->at(1)), + static_cast(re->at(2) ? Str::toVal(*re->at(2)) : 7890)}, + *re->at(1)); + } + + for(auto [ep, hostname] : eps) { + addLog("Connecting to " + hostname + " (" + ep.address().to_string() + ':' + + std::to_string(ep.port()) + ")"); + try { + tcp::socket sock{ioc}; + co_await sock.async_connect(ep); + addLog("Connected"); + co_return sock; + } catch(const std::exception &exc) { + addLog(std::string("Connect failed: ") + exc.what()); + } + } + + MAKE_ERROR("Unable to connect to server"); +} + +} // namespace LV::Net2 diff --git a/Src/Common/Net2.hpp b/Src/Common/Net2.hpp new file mode 100644 index 0000000..8934d14 --- /dev/null +++ b/Src/Common/Net2.hpp @@ -0,0 +1,227 @@ +#pragma once + +#include "Async.hpp" +#include "TOSLib.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace LV::Net2 { + +namespace detail { + +constexpr bool kLittleEndian = (std::endian::native == std::endian::little); + +template +requires std::is_integral_v +inline T toNetwork(T value) { + if constexpr (kLittleEndian && sizeof(T) > 1) + return std::byteswap(value); + return value; +} + +template +requires std::is_floating_point_v +inline T toNetwork(T value) { + using U = std::conditional_t; + U u = std::bit_cast(value); + u = toNetwork(u); + return std::bit_cast(u); +} + +template +inline T fromNetwork(T value) { + return toNetwork(value); +} + +} // namespace detail + +enum class Priority : uint8_t { + Realtime = 0, + High = 1, + Normal = 2, + Low = 3 +}; + +enum class FrameFlags : uint8_t { + None = 0, + HasMore = 1 +}; + +inline FrameFlags operator|(FrameFlags a, FrameFlags b) { + return static_cast(static_cast(a) | static_cast(b)); +} + +inline bool hasFlag(FrameFlags value, FrameFlags flag) { + return (static_cast(value) & static_cast(flag)) != 0; +} + +struct Limits { + size_t maxFrameSize = 1 << 24; + size_t maxMessageSize = 1 << 26; + size_t maxQueueBytes = 1 << 27; + size_t maxLowPriorityBytes = 1 << 26; + size_t maxOpenStreams = 64; +}; + +struct OutgoingMessage { + uint16_t type = 0; + Priority priority = Priority::Normal; + bool dropIfOverloaded = false; + bool allowFragment = true; + std::vector payload; +}; + +struct IncomingMessage { + uint16_t type = 0; + Priority priority = Priority::Normal; + std::vector payload; +}; + +class PacketWriter { +public: + PacketWriter& writeBytes(std::span data); + + template + requires (std::is_integral_v || std::is_floating_point_v) + PacketWriter& write(T value) { + T net = detail::toNetwork(value); + std::array bytes{}; + std::memcpy(bytes.data(), &net, sizeof(T)); + Buffer.insert(Buffer.end(), bytes.begin(), bytes.end()); + return *this; + } + + PacketWriter& writeString(std::string_view str); + + const std::vector& data() const { return Buffer; } + std::vector release(); + void clear(); + +private: + std::vector Buffer; +}; + +class PacketReader { +public: + explicit PacketReader(std::span data); + + template + requires (std::is_integral_v || std::is_floating_point_v) + T read() { + require(sizeof(T)); + T net{}; + std::memcpy(&net, Data.data() + Pos, sizeof(T)); + Pos += sizeof(T); + return detail::fromNetwork(net); + } + + void readBytes(std::span out); + std::string readString(); + bool empty() const { return Pos >= Data.size(); } + size_t remaining() const { return Data.size() - Pos; } + +private: + void require(size_t size); + + size_t Pos = 0; + std::span Data; +}; + +class SocketServer : public AsyncObject { +public: + SocketServer(asio::io_context &ioc, std::function(tcp::socket)> &&onConnect, uint16_t port = 0); + bool isStopped() const; + uint16_t getPort() const; + +private: + coro run(std::function(tcp::socket)> onConnect); + + tcp::acceptor Acceptor; +}; + +class AsyncSocket : public AsyncObject { +public: + static constexpr size_t kHeaderSize = 12; + + AsyncSocket(asio::io_context &ioc, tcp::socket &&socket, Limits limits = {}); + ~AsyncSocket(); + + void enqueue(OutgoingMessage &&msg); + coro readMessage(); + coro<> readLoop(std::function(IncomingMessage&&)> onMessage); + + void closeRead(); + void close(); + bool isAlive() const; + std::string getError() const; + +private: + struct FragmentState { + uint16_t type = 0; + Priority priority = Priority::Normal; + std::vector data; + }; + + struct AsyncContext { + std::atomic_bool needShutdown{false}; + std::atomic_bool senderStopped{false}; + std::atomic_bool readClosed{false}; + boost::mutex errorMtx; + std::string error; + }; + + struct SendQueue { + boost::mutex mtx; + bool waiting = false; + asio::steady_timer semaphore; + std::deque queues[4]; + size_t bytesInQueue = 0; + size_t bytesInLow = 0; + uint8_t nextIndex = 0; + int credits[4] = {8, 4, 2, 1}; + + explicit SendQueue(asio::io_context &ioc); + bool empty() const; + }; + + coro<> sendLoop(); + coro<> sendMessage(OutgoingMessage &&msg); + coro<> sendFrame(uint16_t type, Priority priority, FrameFlags flags, uint32_t streamId, + std::span payload); + + coro<> readExact(std::byte *data, size_t size); + + bool popNext(OutgoingMessage &out); + void dropLow(size_t needBytes); + void setError(const std::string &msg); + + Limits LimitsCfg; + tcp::socket Socket; + SendQueue Outgoing; + std::shared_ptr Context; + std::unordered_map Fragments; + uint32_t NextStreamId = 1; +}; + +coro asyncConnectTo(const std::string &address, + std::function onProgress = nullptr); + +} // namespace LV::Net2