File fragmentation.hpp

File List > astutedds > rtps > fragmentation.hpp

Go to the documentation of this file

//
// Copyright (c) 2026, Astute Systems PTY LTD
//
// This file is part of the Astute DDS developed by Astute Systems.
//
// See the commercial LICENSE file in the project root for full license details.
//
// @file fragmentation.hpp
// @brief RTPS message fragmentation support
//
// Implements DATA_FRAG, NACK_FRAG, and HEARTBEAT_FRAG submessages
// for handling large data that exceeds MTU size.
// Reference: DDSI-RTPS 2.5 Section 8.3.7
//

#ifndef ASTUTEDDS_RTPS_FRAGMENTATION_HPP
#define ASTUTEDDS_RTPS_FRAGMENTATION_HPP

#include <astutedds/rtps/rtps_types.hpp>

#include <algorithm>
#include <cstdint>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <vector>

namespace astutedds::rtps
{

// ============================================================================
// Fragmentation Constants
// ============================================================================

constexpr uint32_t DEFAULT_FRAGMENT_SIZE = 1344;  // 1500 - IP(20) - UDP(8) - RTPS(128)

constexpr uint32_t MAX_FRAGMENT_NUMBER = 0xFFFFFF;

// ============================================================================
// Fragment Number Handling
// ============================================================================

using FragmentNumber_t = uint32_t;

struct FragmentNumberSet
{
    FragmentNumber_t base{1};      
    uint32_t num_bits{0};          
    std::vector<uint32_t> bits{};  

    bool is_set(FragmentNumber_t frag_num) const
    {
        if (frag_num < base)
            return false;
        uint32_t offset = frag_num - base;
        if (offset >= num_bits)
            return false;

        uint32_t long_idx = offset / 32;
        uint32_t bit_idx = offset % 32;

        if (long_idx >= bits.size())
            return false;
        return (bits[long_idx] & (1u << bit_idx)) != 0;
    }

    void set(FragmentNumber_t frag_num)
    {
        if (frag_num < base)
            return;
        uint32_t offset = frag_num - base;
        if (offset >= num_bits)
        {
            num_bits = offset + 1;
        }

        uint32_t long_idx = offset / 32;
        uint32_t bit_idx = offset % 32;

        if (long_idx >= bits.size())
        {
            bits.resize(long_idx + 1, 0);
        }
        bits[long_idx] |= (1u << bit_idx);
    }

    std::vector<FragmentNumber_t> get_missing(FragmentNumber_t total_frags) const
    {
        std::vector<FragmentNumber_t> missing;
        for (FragmentNumber_t i = base; i <= total_frags; ++i)
        {
            if (!is_set(i))
            {
                missing.push_back(i);
            }
        }
        return missing;
    }
};

// ============================================================================
// DATA_FRAG Submessage (8.3.7.3)
// ============================================================================

struct DataFragSubmessage
{
    // Flags
    static constexpr uint8_t FLAG_ENDIANNESS = 0x01;
    static constexpr uint8_t FLAG_INLINE_QOS = 0x02;
    static constexpr uint8_t FLAG_KEY = 0x04;
    static constexpr uint8_t FLAG_NON_STANDARD_PAYLOAD = 0x08;

    EntityId_t reader_id{};
    EntityId_t writer_id{};
    SequenceNumber_t writer_sn{};
    FragmentNumber_t fragment_starting_num{1};
    uint16_t fragments_in_submessage{1};
    uint32_t fragment_size{DEFAULT_FRAGMENT_SIZE};
    uint32_t sample_size{0};  
    ParameterList inline_qos{};
    std::vector<uint8_t> serialized_payload{};
    bool key_flag{false};

    uint32_t total_fragments() const
    {
        if (fragment_size == 0)
            return 0;
        return (sample_size + fragment_size - 1) / fragment_size;
    }

    bool serialize(std::vector<uint8_t>& buffer) const;

    bool deserialize(const std::vector<uint8_t>& buffer, size_t& offset);
};

// ============================================================================
// NACK_FRAG Submessage (8.3.7.10)
// ============================================================================

struct NackFragSubmessage
{
    static constexpr uint8_t FLAG_ENDIANNESS = 0x01;

    EntityId_t reader_id{};
    EntityId_t writer_id{};
    SequenceNumber_t writer_sn{};
    FragmentNumberSet fragment_number_state{};
    Count_t count{0};

    bool serialize(std::vector<uint8_t>& buffer) const;

    bool deserialize(const std::vector<uint8_t>& buffer, size_t& offset);
};

// ============================================================================
// HEARTBEAT_FRAG Submessage (8.3.7.5)
// ============================================================================

struct HeartbeatFragSubmessage
{
    static constexpr uint8_t FLAG_ENDIANNESS = 0x01;

    EntityId_t reader_id{};
    EntityId_t writer_id{};
    SequenceNumber_t writer_sn{};
    FragmentNumber_t last_fragment_num{0};  
    Count_t count{0};

    bool serialize(std::vector<uint8_t>& buffer) const;

    bool deserialize(const std::vector<uint8_t>& buffer, size_t& offset);
};

// ============================================================================
// Fragment Assembly
// ============================================================================

class FragmentedSample
{
public:
    FragmentedSample(uint32_t sample_size, uint32_t fragment_size);

    bool add_fragment(FragmentNumber_t starting_num, const std::vector<uint8_t>& data);

    bool is_complete() const;

    const std::vector<uint8_t>& data() const { return assembled_data_; }

    std::vector<FragmentNumber_t> get_missing_fragments() const;

    uint32_t total_fragments() const { return total_fragments_; }

    uint32_t received_count() const { return received_count_; }

private:
    uint32_t sample_size_;
    uint32_t fragment_size_;
    uint32_t total_fragments_;
    uint32_t received_count_{0};
    FragmentNumberSet received_fragments_;
    std::vector<uint8_t> assembled_data_;
};

// ============================================================================
// Fragmentation Manager
// ============================================================================

using FragmentAssemblyCallback =
    std::function<void(const GUID_t& writer_guid, const SequenceNumber_t& seq_num, std::vector<uint8_t>&& data)>;

using FragmentNackCallback = std::function<void(const GUID_t& writer_guid, const SequenceNumber_t& seq_num,
                                                const std::vector<FragmentNumber_t>& missing)>;

struct FragmentKey
{
    GUID_t writer_guid;
    SequenceNumber_t sequence_number;

    bool operator<(const FragmentKey& other) const
    {
        if (writer_guid < other.writer_guid)
            return true;
        if (other.writer_guid < writer_guid)
            return false;
        if (sequence_number.high != other.sequence_number.high)
            return sequence_number.high < other.sequence_number.high;
        return sequence_number.low < other.sequence_number.low;
    }
};

class FragmentationManager
{
public:
    explicit FragmentationManager(size_t max_pending = 256, uint32_t nack_delay_ms = 100);

    ~FragmentationManager();

    // Non-copyable
    FragmentationManager(const FragmentationManager&) = delete;
    FragmentationManager& operator=(const FragmentationManager&) = delete;

    std::optional<std::vector<uint8_t>> process_data_frag(const GUID_t& writer_guid,
                                                          const DataFragSubmessage& data_frag);

    std::optional<std::vector<FragmentNumber_t>> process_heartbeat_frag(const GUID_t& writer_guid,
                                                                        const HeartbeatFragSubmessage& hb_frag);

    void set_assembly_callback(FragmentAssemblyCallback callback);

    void set_nack_callback(FragmentNackCallback callback);

    void check_pending_samples();

    struct Stats
    {
        uint64_t fragments_received{0};
        uint64_t samples_assembled{0};
        uint64_t nacks_sent{0};
        uint64_t duplicate_fragments{0};
    };
    Stats get_stats() const;

private:
    void cleanup_stale_samples();

    mutable std::mutex mutex_;
    size_t max_pending_;
    uint32_t nack_delay_ms_;

    std::map<FragmentKey, std::unique_ptr<FragmentedSample>> pending_samples_;
    std::map<FragmentKey, std::chrono::steady_clock::time_point> sample_timestamps_;

    FragmentAssemblyCallback assembly_callback_;
    FragmentNackCallback nack_callback_;

    Stats stats_;
};

// ============================================================================
// Fragmentation Utilities
// ============================================================================

std::vector<DataFragSubmessage> fragment_sample(const std::vector<uint8_t>& sample, const EntityId_t& writer_id,
                                                const EntityId_t& reader_id, const SequenceNumber_t& seq_num,
                                                uint32_t fragment_size = DEFAULT_FRAGMENT_SIZE);

inline bool needs_fragmentation(size_t data_size, size_t max_message_size = DEFAULT_FRAGMENT_SIZE)
{
    return data_size > max_message_size;
}

uint32_t calculate_fragment_size(uint32_t mtu = 1500, bool include_security_overhead = false);

}  // namespace astutedds::rtps

#endif  // ASTUTEDDS_RTPS_FRAGMENTATION_HPP