File tcp_transport.hpp

File List > astutedds > rtps > tcp_transport.hpp

Go to the documentation of this file

//
// Copyright (c) 2026, Astute Systems PTY LTD
//
// This file is part of the Astute DDS developed by Astute Systems.
//
// See the commercial LICENSE file in the project root for full license details.
//
// @file tcp_transport.hpp
// @brief RTPS over TCP transport
//
// Implements TCP transport for firewall traversal and reliable delivery.
// Reference: DDSI-RTPS 2.5 Section 9.6 (Transport Layer)
//

#ifndef ASTUTEDDS_RTPS_TCP_TRANSPORT_HPP
#define ASTUTEDDS_RTPS_TCP_TRANSPORT_HPP

#include <astutedds/platform.hpp>
#include <astutedds/rtps/rtps_types.hpp>

#include <atomic>
#include <chrono>
#include <condition_variable>
#include <cstdint>
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

namespace astutedds::rtps
{

// ============================================================================
// TCP Locator Extension
// ============================================================================

constexpr int32_t LOCATOR_KIND_TCPv4 = 4;
constexpr int32_t LOCATOR_KIND_TCPv6 = 5;

// ============================================================================
// TCP Connection State
// ============================================================================

enum class TCPConnectionState
{
    DISCONNECTED,
    CONNECTING,
    CONNECTED,
    CLOSING,
    ERRORED  // Note: 'ERROR' conflicts with Windows ERROR macro
};

enum class TCPConnectionMode
{
    CLIENT,  
    SERVER   
};

// ============================================================================
// TCP Message Framing
// ============================================================================

struct TCPMessageHeader
{
    uint32_t message_length{0};

    static constexpr size_t SIZE = 4;

    void serialize(std::vector<uint8_t>& buffer) const
    {
        buffer.push_back(message_length & 0xFF);
        buffer.push_back((message_length >> 8) & 0xFF);
        buffer.push_back((message_length >> 16) & 0xFF);
        buffer.push_back((message_length >> 24) & 0xFF);
    }

    bool deserialize(const uint8_t* data, size_t len)
    {
        if (len < SIZE)
            return false;
        message_length = data[0] | (data[1] << 8) | (data[2] << 16) | (data[3] << 24);
        return true;
    }
};

// ============================================================================
// TCP Connection
// ============================================================================

using TCPReceiveCallback = std::function<void(const std::vector<uint8_t>& data)>;

using TCPStateCallback = std::function<void(TCPConnectionState state)>;

class TCPConnection
{
public:
    explicit TCPConnection(TCPConnectionMode mode);

    ~TCPConnection();

    // Non-copyable
    TCPConnection(const TCPConnection&) = delete;
    TCPConnection& operator=(const TCPConnection&) = delete;

    bool connect(const std::string& address, uint16_t port,
                 std::chrono::milliseconds timeout = std::chrono::seconds(5));

    void adopt_socket(platform::socket_t socket_fd, const std::string& remote_addr, uint16_t remote_port);

    void close();

    bool send(const std::vector<uint8_t>& data);

    TCPConnectionState state() const { return state_.load(); }

    bool is_connected() const { return state_.load() == TCPConnectionState::CONNECTED; }

    Locator_t remote_locator() const;

    uint16_t local_port() const { return local_port_; }

    void set_receive_callback(TCPReceiveCallback callback);

    void set_state_callback(TCPStateCallback callback);

    void process_io();

    platform::socket_t socket_fd() const { return socket_fd_; }

    struct Stats
    {
        uint64_t bytes_sent{0};
        uint64_t bytes_received{0};
        uint64_t messages_sent{0};
        uint64_t messages_received{0};
        std::chrono::steady_clock::time_point connected_at;
        std::chrono::steady_clock::time_point last_activity;
    };
    Stats get_stats() const;

private:
    void set_state(TCPConnectionState new_state);
    bool send_pending();
    bool receive_data();
    bool process_receive_buffer();

    TCPConnectionMode mode_;
    std::atomic<TCPConnectionState> state_{TCPConnectionState::DISCONNECTED};

    platform::socket_t socket_fd_{platform::INVALID_SOCK};
    std::string remote_address_;
    uint16_t remote_port_{0};
    uint16_t local_port_{0};

    mutable std::mutex send_mutex_;
    std::deque<std::vector<uint8_t>> send_queue_;

    std::vector<uint8_t> receive_buffer_;
    size_t receive_offset_{0};

    TCPReceiveCallback receive_callback_;
    TCPStateCallback state_callback_;

    Stats stats_;
    mutable std::mutex stats_mutex_;
};

// ============================================================================
// TCP Server
// ============================================================================

using TCPAcceptCallback = std::function<void(std::shared_ptr<TCPConnection>)>;

class TCPServer
{
public:
    TCPServer();
    ~TCPServer();

    // Non-copyable
    TCPServer(const TCPServer&) = delete;
    TCPServer& operator=(const TCPServer&) = delete;

    bool listen(uint16_t port, int backlog = 16);

    void stop();

    bool is_listening() const { return listening_.load(); }

    uint16_t port() const { return port_; }

    void set_accept_callback(TCPAcceptCallback callback);

    void accept_pending();

    platform::socket_t socket_fd() const { return listen_fd_; }

private:
    std::atomic<bool> listening_{false};
    platform::socket_t listen_fd_{platform::INVALID_SOCK};
    uint16_t port_{0};
    TCPAcceptCallback accept_callback_;
    std::mutex mutex_;
};

// ============================================================================
// TCP Transport Manager
// ============================================================================

class TCPTransport
{
public:
    struct Config
    {
        uint16_t listen_port = 0;      
        bool enable_server = true;     
        bool enable_client = true;     
        size_t max_connections = 256;  
        std::chrono::milliseconds connect_timeout = std::chrono::seconds(5);
        std::chrono::milliseconds keepalive_interval = std::chrono::seconds(30);
        bool tcp_nodelay = true;          
        size_t send_buffer_size = 65536;  
        size_t recv_buffer_size = 65536;  
    };

    TCPTransport();
    explicit TCPTransport(const Config& config);
    ~TCPTransport();

    // Non-copyable
    TCPTransport(const TCPTransport&) = delete;
    TCPTransport& operator=(const TCPTransport&) = delete;

    bool init();

    void shutdown();

    bool is_running() const { return running_.load(); }

    std::shared_ptr<TCPConnection> get_or_connect(const Locator_t& locator);

    bool send(const Locator_t& locator, const std::vector<uint8_t>& data);

    void set_receive_callback(TCPReceiveCallback callback);

    Locator_t get_local_locator() const;

    std::vector<std::shared_ptr<TCPConnection>> get_connections() const;

    struct Stats
    {
        uint64_t total_bytes_sent{0};
        uint64_t total_bytes_received{0};
        uint64_t total_connections{0};
        uint64_t active_connections{0};
        uint64_t failed_connections{0};
    };
    Stats get_stats() const;

private:
    void io_thread_func();
    void cleanup_dead_connections();
    std::string locator_to_key(const Locator_t& locator) const;
    Locator_t key_to_locator(const std::string& key) const;

    Config config_;
    std::atomic<bool> running_{false};

    std::unique_ptr<TCPServer> server_;

    mutable std::mutex connections_mutex_;
    std::map<std::string, std::shared_ptr<TCPConnection>> connections_;

    std::thread io_thread_;
    std::condition_variable io_cv_;

    TCPReceiveCallback receive_callback_;

    Stats stats_;
    mutable std::mutex stats_mutex_;
};

// ============================================================================
// TCP Locator Utilities
// ============================================================================

inline Locator_t make_tcp_locator(const std::string& address, uint16_t port)
{
    Locator_t locator{};
    locator.kind = static_cast<LocatorKind_t>(LOCATOR_KIND_TCPv4);
    locator.port = port;

    // Parse IPv4 address
    uint32_t a, b, c, d;
    if (sscanf(address.c_str(), "%u.%u.%u.%u", &a, &b, &c, &d) == 4)
    {
        locator.address[12] = static_cast<uint8_t>(a);
        locator.address[13] = static_cast<uint8_t>(b);
        locator.address[14] = static_cast<uint8_t>(c);
        locator.address[15] = static_cast<uint8_t>(d);
    }

    return locator;
}

inline bool is_tcp_locator(const Locator_t& locator)
{
    int32_t kind = static_cast<int32_t>(locator.kind);
    return kind == LOCATOR_KIND_TCPv4 || kind == LOCATOR_KIND_TCPv6;
}

inline std::string locator_to_address(const Locator_t& locator)
{
    char buf[32];
    snprintf(buf, sizeof(buf), "%d.%d.%d.%d", locator.address[12], locator.address[13], locator.address[14],
             locator.address[15]);
    return buf;
}

}  // namespace astutedds::rtps

#endif  // ASTUTEDDS_RTPS_TCP_TRANSPORT_HPP