File shm_transport.hpp

File List > astutedds > rtps > shm_transport.hpp

Go to the documentation of this file

//
// Copyright (c) 2026, Astute Systems PTY LTD
//
// This file is part of the Astute DDS developed by Astute Systems.
//
// See the commercial LICENSE file in the project root for full license details.
//
// @file shm_transport.hpp
// @brief Shared Memory Transport for high-performance local communication
//
// Implements zero-copy data transfer between participants on the same host
// using POSIX shared memory.
//

#ifndef ASTUTEDDS_RTPS_SHM_TRANSPORT_HPP
#define ASTUTEDDS_RTPS_SHM_TRANSPORT_HPP

#include <astutedds/rtps/rtps_types.hpp>

#include <atomic>
#include <chrono>
#include <condition_variable>
#include <cstdint>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

namespace astutedds::rtps
{

// ============================================================================
// Shared Memory Constants
// ============================================================================

constexpr int32_t LOCATOR_KIND_SHM = 16;

constexpr size_t DEFAULT_SHM_SEGMENT_SIZE = 16 * 1024 * 1024;

constexpr size_t DEFAULT_RING_BUFFER_SLOTS = 256;

constexpr size_t MAX_SHM_MESSAGE_SIZE = 1024 * 1024;

// ============================================================================
// Shared Memory Segment Header
// ============================================================================

struct ShmSegmentHeader
{
    uint32_t magic{0x534D4453};          
    uint32_t version{1};                 
    uint32_t segment_size{0};            
    uint32_t slot_count{0};              
    uint32_t slot_size{0};               
    std::atomic<uint64_t> write_idx{0};  
    std::atomic<uint64_t> read_idx{0};   
    uint8_t reserved[40]{};              

    static constexpr size_t SIZE = 64;
};

struct ShmSlotHeader
{
    std::atomic<uint32_t> state{0};  
    uint32_t message_size{0};        
    uint64_t sequence_number{0};     
    GuidPrefix_t sender_prefix{};    
    uint8_t reserved[8]{};           

    static constexpr size_t SIZE = 32;
    static constexpr uint32_t STATE_FREE = 0;
    static constexpr uint32_t STATE_WRITING = 1;
    static constexpr uint32_t STATE_READY = 2;
};

// ============================================================================
// Shared Memory Segment
// ============================================================================

class ShmSegment
{
public:
    ShmSegment(const std::string& name, size_t size, bool create);

    ~ShmSegment();

    // Non-copyable
    ShmSegment(const ShmSegment&) = delete;
    ShmSegment& operator=(const ShmSegment&) = delete;

    bool is_valid() const { return base_ptr_ != nullptr; }

    const std::string& name() const { return name_; }

    size_t size() const { return size_; }

    void* base_ptr() { return base_ptr_; }
    const void* base_ptr() const { return base_ptr_; }

    ShmSegmentHeader* header();
    const ShmSegmentHeader* header() const;

    static void unlink(const std::string& name);

private:
    std::string name_;
    size_t size_{0};
    int fd_{-1};
    void* base_ptr_{nullptr};
    bool owner_{false};
};

// ============================================================================
// Ring Buffer for Message Passing
// ============================================================================

class ShmRingBuffer
{
public:
    ShmRingBuffer(ShmSegment& segment, size_t slot_count);

    bool write(const GuidPrefix_t& sender, const std::vector<uint8_t>& data);

    bool read(GuidPrefix_t& sender, std::vector<uint8_t>& data, uint32_t reader_id);

    bool has_data(uint32_t reader_id) const;

    size_t slot_count() const { return slot_count_; }

    size_t slot_size() const { return slot_size_; }

private:
    ShmSlotHeader* get_slot(size_t index);
    uint8_t* get_slot_data(size_t index);

    ShmSegment& segment_;
    size_t slot_count_;
    size_t slot_size_;
    size_t data_offset_;
};

// ============================================================================
// Shared Memory Port
// ============================================================================

class ShmPort
{
public:
    ShmPort(uint32_t domain_id, uint32_t participant_id);

    ~ShmPort();

    // Non-copyable
    ShmPort(const ShmPort&) = delete;
    ShmPort& operator=(const ShmPort&) = delete;

    bool init();

    void close();

    bool is_open() const { return open_.load(); }

    uint32_t port_id() const { return port_id_; }

    Locator_t locator() const;

    bool send(uint32_t dest_port_id, const GuidPrefix_t& sender, const std::vector<uint8_t>& data);

    bool receive(GuidPrefix_t& sender, std::vector<uint8_t>& data);

    bool has_data() const;

private:
    std::string segment_name() const;

    uint32_t domain_id_;
    uint32_t participant_id_;
    uint32_t port_id_;
    uint32_t reader_id_;

    std::atomic<bool> open_{false};
    std::unique_ptr<ShmSegment> segment_;
    std::unique_ptr<ShmRingBuffer> ring_buffer_;

    mutable std::mutex mutex_;
};

// ============================================================================
// Shared Memory Transport
// ============================================================================

using ShmReceiveCallback = std::function<void(const GuidPrefix_t& sender, const std::vector<uint8_t>& data)>;

class ShmTransport
{
public:
    struct Config
    {
        size_t segment_size = DEFAULT_SHM_SEGMENT_SIZE;
        size_t slot_count = DEFAULT_RING_BUFFER_SLOTS;
        size_t max_message_size = MAX_SHM_MESSAGE_SIZE;
        bool cleanup_on_exit = true;
    };

    ShmTransport(uint32_t domain_id, uint32_t participant_id);

    ShmTransport(uint32_t domain_id, uint32_t participant_id, const Config& config);

    ~ShmTransport();

    // Non-copyable
    ShmTransport(const ShmTransport&) = delete;
    ShmTransport& operator=(const ShmTransport&) = delete;

    bool init();

    void shutdown();

    bool is_running() const { return running_.load(); }

    bool send(const Locator_t& locator, const std::vector<uint8_t>& data);

    bool send_to_participant(const GuidPrefix_t& dest_prefix, const std::vector<uint8_t>& data);

    size_t broadcast(const std::vector<uint8_t>& data);

    void set_receive_callback(ShmReceiveCallback callback);

    Locator_t get_local_locator() const;

    static bool is_shm_locator(const Locator_t& locator);

    struct Stats
    {
        uint64_t messages_sent{0};
        uint64_t messages_received{0};
        uint64_t bytes_sent{0};
        uint64_t bytes_received{0};
        uint64_t send_failures{0};
    };
    Stats get_stats() const;

    std::vector<uint32_t> discover_local_participants();

private:
    void receive_thread_func();
    uint32_t locator_to_port_id(const Locator_t& locator) const;
    Locator_t port_id_to_locator(uint32_t port_id) const;
    std::shared_ptr<ShmPort> get_or_create_send_port(uint32_t port_id);

    uint32_t domain_id_;
    uint32_t participant_id_;
    Config config_;
    GuidPrefix_t local_prefix_{};

    std::atomic<bool> running_{false};

    // Local receive port
    std::unique_ptr<ShmPort> receive_port_;

    // Send ports (one per remote participant)
    mutable std::mutex send_ports_mutex_;
    std::map<uint32_t, std::shared_ptr<ShmPort>> send_ports_;

    std::thread receive_thread_;
    ShmReceiveCallback receive_callback_;

    Stats stats_;
    mutable std::mutex stats_mutex_;
};

// ============================================================================
// Shared Memory Utilities
// ============================================================================

inline Locator_t make_shm_locator(uint32_t domain_id, uint32_t participant_id)
{
    Locator_t locator{};
    locator.kind = static_cast<LocatorKind_t>(LOCATOR_KIND_SHM);
    locator.port = participant_id;

    // Encode domain_id in address
    locator.address[0] = (domain_id >> 24) & 0xFF;
    locator.address[1] = (domain_id >> 16) & 0xFF;
    locator.address[2] = (domain_id >> 8) & 0xFF;
    locator.address[3] = domain_id & 0xFF;

    return locator;
}

inline void parse_shm_locator(const Locator_t& locator, uint32_t& domain_id, uint32_t& participant_id)
{
    domain_id = (static_cast<uint32_t>(locator.address[0]) << 24) | (static_cast<uint32_t>(locator.address[1]) << 16) |
                (static_cast<uint32_t>(locator.address[2]) << 8) | static_cast<uint32_t>(locator.address[3]);
    participant_id = locator.port;
}

void cleanup_shm_domain(uint32_t domain_id);

}  // namespace astutedds::rtps

#endif  // ASTUTEDDS_RTPS_SHM_TRANSPORT_HPP