| /* |
| * Copyright 2020 The WebRTC Project Authors. All rights reserved. |
| * |
| * Use of this source code is governed by a BSD-style license |
| * that can be found in the LICENSE file in the root of the source |
| * tree. An additional intellectual property rights grant can be found |
| * in the file PATENTS. All contributing project authors may |
| * be found in the AUTHORS file in the root of the source tree. |
| */ |
| |
| #include "p2p/base/stun_dictionary.h" |
| |
| #include <algorithm> |
| #include <deque> |
| #include <utility> |
| |
| #include "rtc_base/logging.h" |
| |
| namespace cricket { |
| |
| namespace { |
| |
| StunAttributeValueType GetStunAttributeValueType(int value_type) { |
| switch (value_type) { |
| case STUN_VALUE_ADDRESS: |
| return STUN_VALUE_ADDRESS; |
| case STUN_VALUE_XOR_ADDRESS: |
| return STUN_VALUE_XOR_ADDRESS; |
| case STUN_VALUE_UINT32: |
| return STUN_VALUE_UINT32; |
| case STUN_VALUE_UINT64: |
| return STUN_VALUE_UINT64; |
| case STUN_VALUE_BYTE_STRING: |
| return STUN_VALUE_BYTE_STRING; |
| case STUN_VALUE_ERROR_CODE: |
| return STUN_VALUE_ERROR_CODE; |
| case STUN_VALUE_UINT16_LIST: |
| return STUN_VALUE_UINT16_LIST; |
| default: |
| return STUN_VALUE_UNKNOWN; |
| } |
| } |
| |
| } // namespace |
| |
| const StunAddressAttribute* StunDictionaryView::GetAddress(int key) const { |
| const StunAttribute* attr = GetOrNull(key, STUN_VALUE_ADDRESS); |
| if (attr == nullptr) { |
| return nullptr; |
| } |
| return reinterpret_cast<const StunAddressAttribute*>(attr); |
| } |
| |
| const StunUInt32Attribute* StunDictionaryView::GetUInt32(int key) const { |
| const StunAttribute* attr = GetOrNull(key, STUN_VALUE_UINT32); |
| if (attr == nullptr) { |
| return nullptr; |
| } |
| return reinterpret_cast<const StunUInt32Attribute*>(attr); |
| } |
| |
| const StunUInt64Attribute* StunDictionaryView::GetUInt64(int key) const { |
| const StunAttribute* attr = GetOrNull(key, STUN_VALUE_UINT64); |
| if (attr == nullptr) { |
| return nullptr; |
| } |
| return reinterpret_cast<const StunUInt64Attribute*>(attr); |
| } |
| |
| const StunByteStringAttribute* StunDictionaryView::GetByteString( |
| int key) const { |
| const StunAttribute* attr = GetOrNull(key, STUN_VALUE_BYTE_STRING); |
| if (attr == nullptr) { |
| return nullptr; |
| } |
| return reinterpret_cast<const StunByteStringAttribute*>(attr); |
| } |
| |
| const StunUInt16ListAttribute* StunDictionaryView::GetUInt16List( |
| int key) const { |
| const StunAttribute* attr = GetOrNull(key, STUN_VALUE_UINT16_LIST); |
| if (attr == nullptr) { |
| return nullptr; |
| } |
| return reinterpret_cast<const StunUInt16ListAttribute*>(attr); |
| } |
| |
| const StunAttribute* StunDictionaryView::GetOrNull( |
| int key, |
| absl::optional<StunAttributeValueType> type) const { |
| const auto it = attrs_.find(key); |
| if (it == attrs_.end()) { |
| return nullptr; |
| } |
| |
| if (type && it->second->value_type() != *type) { |
| RTC_LOG(LS_WARNING) << "Get key: " << key << " with type: " << *type |
| << " found different type: " |
| << it->second->value_type(); |
| return nullptr; |
| } |
| return (*it).second.get(); |
| } |
| |
| webrtc::RTCErrorOr< |
| std::pair<uint64_t, std::deque<std::unique_ptr<StunAttribute>>>> |
| StunDictionaryView::ParseDelta(const StunByteStringAttribute& delta) { |
| rtc::ByteBufferReader buf(delta.array_view()); |
| uint16_t magic; |
| if (!buf.ReadUInt16(&magic)) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Failed to read magic number"); |
| } |
| if (magic != kDeltaMagic) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Invalid magic number"); |
| } |
| |
| uint16_t delta_version; |
| if (!buf.ReadUInt16(&delta_version)) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Failed to read version"); |
| } |
| |
| if (delta_version != kDeltaVersion) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Unsupported delta version"); |
| } |
| |
| // Now read all the attributes |
| std::deque<std::unique_ptr<StunAttribute>> attrs; |
| while (buf.Length()) { |
| uint16_t key, length, value_type; |
| if (!buf.ReadUInt16(&key)) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Failed to read attribute key"); |
| } |
| if (!buf.ReadUInt16(&length)) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Failed to read attribute length"); |
| } |
| if (!buf.ReadUInt16(&value_type)) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Failed to read value type"); |
| } |
| |
| StunAttributeValueType value_type_enum = |
| GetStunAttributeValueType(value_type); |
| std::unique_ptr<StunAttribute> attr( |
| StunAttribute::Create(value_type_enum, key, length, nullptr)); |
| if (!attr) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Failed to create attribute"); |
| } |
| if (attr->length() != length) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Inconsistent attribute length"); |
| } |
| if (!attr->Read(&buf)) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Failed to read attribute content"); |
| } |
| attrs.push_back(std::move(attr)); |
| } |
| |
| // The first attribute should be the version... |
| if (attrs.empty()) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Empty delta!"); |
| } |
| |
| if (attrs[0]->type() != kVersionKey || |
| attrs[0]->value_type() != STUN_VALUE_UINT64) { |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER, |
| "Missing version!"); |
| } |
| |
| uint64_t version_in_delta = |
| reinterpret_cast<const StunUInt64Attribute*>(attrs[0].get())->value(); |
| attrs.pop_front(); |
| |
| return std::make_pair(std::max(version_in_delta, version_in_delta), |
| std::move(attrs)); |
| } |
| |
| // Apply a delta return an StunUInt64Attribute to ack the update. |
| webrtc::RTCErrorOr< |
| std::pair<std::unique_ptr<StunUInt64Attribute>, std::vector<uint16_t>>> |
| StunDictionaryView::ApplyDelta(const StunByteStringAttribute& delta) { |
| auto parsed_delta = ParseDelta(delta); |
| if (!parsed_delta.ok()) { |
| return webrtc::RTCError(parsed_delta.error()); |
| } |
| |
| uint64_t version_in_delta = parsed_delta.value().first; |
| |
| // Check that update does not overflow max_bytes_stored_. |
| int new_bytes_stored = bytes_stored_; |
| for (auto& attr : parsed_delta.value().second) { |
| auto old_version = version_per_key_.find(attr->type()); |
| if (old_version == version_per_key_.end() || |
| version_in_delta > old_version->second) { |
| size_t new_length = attr->length(); |
| size_t old_length = GetLength(attr->type()); |
| if (old_version == version_per_key_.end()) { |
| new_length += sizeof(int64_t); |
| } |
| |
| new_bytes_stored = new_bytes_stored + new_length - old_length; |
| if (new_bytes_stored <= 0) { |
| RTC_LOG(LS_WARNING) |
| << "attr: " << attr->type() << " old_length: " << old_length |
| << " new_length: " << new_length |
| << " bytes_stored_: " << bytes_stored_ |
| << " new_bytes_stored: " << new_bytes_stored; |
| return webrtc::RTCError(webrtc::RTCErrorType::INVALID_PARAMETER); |
| } |
| if (new_bytes_stored > max_bytes_stored_) { |
| RTC_LOG(LS_INFO) << "attr: " << attr->type() |
| << " old_length: " << old_length |
| << " new_length: " << new_length |
| << " bytes_stored_: " << bytes_stored_ |
| << " new_bytes_stored: " << new_bytes_stored; |
| } |
| } |
| } |
| if (new_bytes_stored > max_bytes_stored_) { |
| RTC_LOG(LS_INFO) << " bytes_stored_: " << bytes_stored_ |
| << " new_bytes_stored: " << new_bytes_stored; |
| return webrtc::RTCError(webrtc::RTCErrorType::RESOURCE_EXHAUSTED); |
| } |
| |
| // Apply the update. |
| std::vector<uint16_t> keys; |
| for (auto& attr : parsed_delta.value().second) { |
| if (version_in_delta > version_per_key_[attr->type()]) { |
| version_per_key_[attr->type()] = version_in_delta; |
| keys.push_back(attr->type()); |
| if (attr->value_type() == STUN_VALUE_BYTE_STRING && attr->length() == 0) { |
| attrs_.erase(attr->type()); |
| } else { |
| int attribute_type = attr->type(); |
| attrs_[attribute_type] = std::move(attr); |
| } |
| } |
| } |
| bytes_stored_ = new_bytes_stored; |
| |
| return std::make_pair(std::make_unique<StunUInt64Attribute>( |
| STUN_ATTR_GOOG_DELTA_ACK, version_in_delta), |
| std::move(keys)); |
| } |
| |
| size_t StunDictionaryView::GetLength(int key) const { |
| auto attr = GetOrNull(key); |
| if (attr != nullptr) { |
| return attr->length(); |
| } |
| return 0; |
| } |
| |
| void StunDictionaryWriter::Disable() { |
| disabled_ = true; |
| } |
| |
| void StunDictionaryWriter::Delete(int key) { |
| if (disabled_) { |
| return; |
| } |
| |
| if (dictionary_) { |
| if (dictionary_->attrs_.find(key) == dictionary_->attrs_.end()) { |
| return; |
| } |
| } |
| |
| // remove any pending updates. |
| pending_.erase( |
| std::remove_if(pending_.begin(), pending_.end(), |
| [key](const auto& p) { return p.second->type() == key; }), |
| pending_.end()); |
| |
| // Create tombstone. |
| auto tombstone = std::make_unique<StunByteStringAttribute>(key); |
| |
| // add a pending entry. |
| pending_.push_back(std::make_pair(++version_, tombstone.get())); |
| |
| // store the tombstone. |
| tombstones_[key] = std::move(tombstone); |
| |
| if (dictionary_) { |
| // remove value |
| dictionary_->attrs_.erase(key); |
| } |
| } |
| |
| void StunDictionaryWriter::Set(std::unique_ptr<StunAttribute> attr) { |
| if (disabled_) { |
| return; |
| } |
| int key = attr->type(); |
| // remove any pending updates. |
| pending_.erase( |
| std::remove_if(pending_.begin(), pending_.end(), |
| [key](const auto& p) { return p.second->type() == key; }), |
| pending_.end()); |
| |
| // remove any existing key. |
| tombstones_.erase(key); |
| |
| // create pending entry. |
| pending_.push_back(std::make_pair(++version_, attr.get())); |
| |
| if (dictionary_) { |
| // store attribute. |
| dictionary_->attrs_[key] = std::move(attr); |
| } |
| } |
| |
| // Create an StunByteStringAttribute containing the pending (e.g not ack:ed) |
| // modifications. |
| std::unique_ptr<StunByteStringAttribute> StunDictionaryWriter::CreateDelta() { |
| if (disabled_) { |
| return nullptr; |
| } |
| if (pending_.empty()) { |
| return nullptr; |
| } |
| |
| rtc::ByteBufferWriter buf; |
| buf.WriteUInt16(StunDictionaryView::kDeltaMagic); // 0,1 |
| buf.WriteUInt16(StunDictionaryView::kDeltaVersion); // 2,3 |
| |
| // max version in Delta. |
| buf.WriteUInt16(StunDictionaryView::kVersionKey); // 4,5 |
| buf.WriteUInt16(8); // 6,7 |
| buf.WriteUInt16(STUN_VALUE_UINT64); // 8,9 |
| buf.WriteUInt64(pending_.back().first); // 10-17 |
| // attributes |
| for (const auto& attr : pending_) { |
| buf.WriteUInt16(attr.second->type()); |
| buf.WriteUInt16(static_cast<uint16_t>(attr.second->length())); |
| buf.WriteUInt16(attr.second->value_type()); |
| if (!attr.second->Write(&buf)) { |
| RTC_LOG(LS_ERROR) << "Failed to write key: " << attr.second->type(); |
| return nullptr; |
| } |
| } |
| return std::make_unique<StunByteStringAttribute>(STUN_ATTR_GOOG_DELTA, |
| buf.Data(), buf.Length()); |
| } |
| |
| // Apply a delta ack, i.e prune list of pending changes. |
| void StunDictionaryWriter::ApplyDeltaAck(const StunUInt64Attribute& ack) { |
| uint64_t acked_version = ack.value(); |
| auto entries_to_remove = std::remove_if( |
| pending_.begin(), pending_.end(), |
| [acked_version](const auto& p) { return p.first <= acked_version; }); |
| |
| // remove tombstones. |
| for (auto it = entries_to_remove; it != pending_.end(); ++it) { |
| tombstones_.erase((*it).second->type()); |
| } |
| pending_.erase(entries_to_remove, pending_.end()); |
| } |
| |
| // Check if a key has a pending change (i.e a change |
| // that has not been acked). |
| bool StunDictionaryWriter::Pending(int key) const { |
| for (const auto& attr : pending_) { |
| if (attr.second->type() == key) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| int StunDictionaryWriter::Pending() const { |
| return pending_.size(); |
| } |
| |
| } // namespace cricket |