blob: 1aa996c4a89e228dafdaeb9f9a80bd75a23f8843 [file] [log] [blame]
/*
* Copyright 2018 The WebRTC Project Authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "p2p/base/mdns_message.h"
#include "rtc_base/logging.h"
#include "rtc_base/net_helpers.h"
#include "rtc_base/string_encode.h"
namespace webrtc {
namespace {
// RFC 1035, Section 4.1.1.
//
// QR bit.
constexpr uint16_t kMdnsFlagMaskQueryOrResponse = 0x8000;
// AA bit.
constexpr uint16_t kMdnsFlagMaskAuthoritative = 0x0400;
// RFC 1035, Section 4.1.2, QCLASS and RFC 6762, Section 18.12, repurposing of
// top bit of QCLASS as the unicast response bit.
constexpr uint16_t kMdnsQClassMaskUnicastResponse = 0x8000;
constexpr size_t kMdnsHeaderSizeBytes = 12;
bool ReadDomainName(MessageBufferReader* buf, std::string* name) {
size_t name_start_pos = buf->CurrentOffset();
uint8_t label_length;
if (!buf->ReadUInt8(&label_length)) {
return false;
}
// RFC 1035, Section 4.1.4.
//
// If the first two bits of the length octet are ones, the name is compressed
// and the rest six bits with the next octet denotes its position in the
// message by the offset from the start of the message.
auto is_pointer = [](uint8_t octet) {
return (octet & 0x80) && (octet & 0x40);
};
while (label_length && !is_pointer(label_length)) {
// RFC 1035, Section 2.3.1, labels are restricted to 63 octets or less.
if (label_length > 63) {
return false;
}
std::string label;
if (!buf->ReadString(&label, label_length)) {
return false;
}
(*name) += label + ".";
if (!buf->ReadUInt8(&label_length)) {
return false;
}
}
if (is_pointer(label_length)) {
uint8_t next_octet;
if (!buf->ReadUInt8(&next_octet)) {
return false;
}
size_t pos_jump_to = ((label_length & 0x3f) << 8) | next_octet;
// A legitimate pointer only refers to a prior occurrence of the same name,
// and we should only move strictly backward to a prior name field after the
// header.
if (pos_jump_to >= name_start_pos || pos_jump_to < kMdnsHeaderSizeBytes) {
return false;
}
MessageBufferReader new_buf(buf->MessageData(), buf->MessageLength());
if (!new_buf.Consume(pos_jump_to)) {
return false;
}
return ReadDomainName(&new_buf, name);
}
return true;
}
void WriteDomainName(rtc::ByteBufferWriter* buf, const std::string& name) {
std::vector<std::string> labels;
rtc::tokenize(name, '.', &labels);
for (const auto& label : labels) {
buf->WriteUInt8(label.length());
buf->WriteString(label);
}
buf->WriteUInt8(0);
}
} // namespace
void MdnsHeader::SetQueryOrResponse(bool is_query) {
if (is_query) {
flags &= ~kMdnsFlagMaskQueryOrResponse;
} else {
flags |= kMdnsFlagMaskQueryOrResponse;
}
}
void MdnsHeader::SetAuthoritative(bool is_authoritative) {
if (is_authoritative) {
flags |= kMdnsFlagMaskAuthoritative;
} else {
flags &= ~kMdnsFlagMaskAuthoritative;
}
}
bool MdnsHeader::IsAuthoritative() const {
return flags & kMdnsFlagMaskAuthoritative;
}
bool MdnsHeader::Read(MessageBufferReader* buf) {
if (!buf->ReadUInt16(&id) || !buf->ReadUInt16(&flags) ||
!buf->ReadUInt16(&qdcount) || !buf->ReadUInt16(&ancount) ||
!buf->ReadUInt16(&nscount) || !buf->ReadUInt16(&arcount)) {
RTC_LOG(LS_ERROR) << "Invalid mDNS header.";
return false;
}
return true;
}
void MdnsHeader::Write(rtc::ByteBufferWriter* buf) const {
buf->WriteUInt16(id);
buf->WriteUInt16(flags);
buf->WriteUInt16(qdcount);
buf->WriteUInt16(ancount);
buf->WriteUInt16(nscount);
buf->WriteUInt16(arcount);
}
bool MdnsHeader::IsQuery() const {
return !(flags & kMdnsFlagMaskQueryOrResponse);
}
MdnsSectionEntry::MdnsSectionEntry() = default;
MdnsSectionEntry::~MdnsSectionEntry() = default;
MdnsSectionEntry::MdnsSectionEntry(const MdnsSectionEntry& other) = default;
void MdnsSectionEntry::SetType(SectionEntryType type) {
switch (type) {
case SectionEntryType::kA:
type_ = 1;
return;
case SectionEntryType::kAAAA:
type_ = 28;
return;
default:
RTC_NOTREACHED();
}
}
SectionEntryType MdnsSectionEntry::GetType() const {
switch (type_) {
case 1:
return SectionEntryType::kA;
case 28:
return SectionEntryType::kAAAA;
default:
return SectionEntryType::kUnsupported;
}
}
void MdnsSectionEntry::SetClass(SectionEntryClass cls) {
switch (cls) {
case SectionEntryClass::kIN:
class_ = 1;
return;
default:
RTC_NOTREACHED();
}
}
SectionEntryClass MdnsSectionEntry::GetClass() const {
switch (class_) {
case 1:
return SectionEntryClass::kIN;
default:
return SectionEntryClass::kUnsupported;
}
}
MdnsQuestion::MdnsQuestion() = default;
MdnsQuestion::MdnsQuestion(const MdnsQuestion& other) = default;
MdnsQuestion::~MdnsQuestion() = default;
bool MdnsQuestion::Read(MessageBufferReader* buf) {
if (!ReadDomainName(buf, &name_)) {
RTC_LOG(LS_ERROR) << "Invalid name.";
return false;
}
if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_)) {
RTC_LOG(LS_ERROR) << "Invalid type and class.";
return false;
}
return true;
}
bool MdnsQuestion::Write(rtc::ByteBufferWriter* buf) const {
WriteDomainName(buf, name_);
buf->WriteUInt16(type_);
buf->WriteUInt16(class_);
return true;
}
void MdnsQuestion::SetUnicastResponse(bool should_unicast) {
if (should_unicast) {
class_ |= kMdnsQClassMaskUnicastResponse;
} else {
class_ &= ~kMdnsQClassMaskUnicastResponse;
}
}
bool MdnsQuestion::ShouldUnicastResponse() const {
return class_ & kMdnsQClassMaskUnicastResponse;
}
MdnsResourceRecord::MdnsResourceRecord() = default;
MdnsResourceRecord::MdnsResourceRecord(const MdnsResourceRecord& other) =
default;
MdnsResourceRecord::~MdnsResourceRecord() = default;
bool MdnsResourceRecord::Read(MessageBufferReader* buf) {
if (!ReadDomainName(buf, &name_)) {
return false;
}
if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_) ||
!buf->ReadUInt32(&ttl_seconds_) || !buf->ReadUInt16(&rdlength_)) {
return false;
}
switch (GetType()) {
case SectionEntryType::kA:
return ReadARData(buf);
case SectionEntryType::kAAAA:
return ReadQuadARData(buf);
case SectionEntryType::kUnsupported:
return false;
default:
RTC_NOTREACHED();
}
return false;
}
bool MdnsResourceRecord::ReadARData(MessageBufferReader* buf) {
// A RDATA contains a 32-bit IPv4 address.
return buf->ReadString(&rdata_, 4);
}
bool MdnsResourceRecord::ReadQuadARData(MessageBufferReader* buf) {
// AAAA RDATA contains a 128-bit IPv6 address.
return buf->ReadString(&rdata_, 16);
}
bool MdnsResourceRecord::Write(rtc::ByteBufferWriter* buf) const {
WriteDomainName(buf, name_);
buf->WriteUInt16(type_);
buf->WriteUInt16(class_);
buf->WriteUInt32(ttl_seconds_);
buf->WriteUInt16(rdlength_);
switch (GetType()) {
case SectionEntryType::kA:
WriteARData(buf);
return true;
case SectionEntryType::kAAAA:
WriteQuadARData(buf);
return true;
case SectionEntryType::kUnsupported:
return false;
default:
RTC_NOTREACHED();
}
return true;
}
void MdnsResourceRecord::WriteARData(rtc::ByteBufferWriter* buf) const {
buf->WriteString(rdata_);
}
void MdnsResourceRecord::WriteQuadARData(rtc::ByteBufferWriter* buf) const {
buf->WriteString(rdata_);
}
bool MdnsResourceRecord::SetIPAddressInRecordData(
const rtc::IPAddress& address) {
int af = address.family();
if (af != AF_INET && af != AF_INET6) {
return false;
}
char out[16] = {0};
if (!rtc::inet_pton(af, address.ToString().c_str(), out)) {
return false;
}
rdlength_ = (af == AF_INET) ? 4 : 16;
rdata_ = std::string(out, rdlength_);
return true;
}
bool MdnsResourceRecord::GetIPAddressFromRecordData(
rtc::IPAddress* address) const {
if (GetType() != SectionEntryType::kA &&
GetType() != SectionEntryType::kAAAA) {
return false;
}
if (rdata_.size() != 4 && rdata_.size() != 16) {
return false;
}
char out[INET6_ADDRSTRLEN] = {0};
int af = (GetType() == SectionEntryType::kA) ? AF_INET : AF_INET6;
if (!rtc::inet_ntop(af, rdata_.data(), out, sizeof(out))) {
return false;
}
return rtc::IPFromString(std::string(out), address);
}
MdnsMessage::MdnsMessage() = default;
MdnsMessage::~MdnsMessage() = default;
bool MdnsMessage::Read(MessageBufferReader* buf) {
RTC_DCHECK_EQ(0u, buf->CurrentOffset());
if (!header_.Read(buf)) {
return false;
}
auto read_question = [&buf](std::vector<MdnsQuestion>* section,
uint16_t count) {
section->resize(count);
for (auto& question : (*section)) {
if (!question.Read(buf)) {
return false;
}
}
return true;
};
auto read_rr = [&buf](std::vector<MdnsResourceRecord>* section,
uint16_t count) {
section->resize(count);
for (auto& rr : (*section)) {
if (!rr.Read(buf)) {
return false;
}
}
return true;
};
if (!read_question(&question_section_, header_.qdcount) ||
!read_rr(&answer_section_, header_.ancount) ||
!read_rr(&authority_section_, header_.nscount) ||
!read_rr(&additional_section_, header_.arcount)) {
return false;
}
return true;
}
bool MdnsMessage::Write(rtc::ByteBufferWriter* buf) const {
header_.Write(buf);
auto write_rr = [&buf](const std::vector<MdnsResourceRecord>& section) {
for (const auto& rr : section) {
if (!rr.Write(buf)) {
return false;
}
}
return true;
};
for (const auto& question : question_section_) {
if (!question.Write(buf)) {
return false;
}
}
if (!write_rr(answer_section_) || !write_rr(authority_section_) ||
!write_rr(additional_section_)) {
return false;
}
return true;
}
bool MdnsMessage::ShouldUnicastResponse() const {
bool should_unicast = false;
for (const auto& question : question_section_) {
should_unicast |= question.ShouldUnicastResponse();
}
return should_unicast;
}
void MdnsMessage::AddQuestion(const MdnsQuestion& question) {
question_section_.push_back(question);
header_.qdcount = question_section_.size();
}
void MdnsMessage::AddAnswerRecord(const MdnsResourceRecord& answer) {
answer_section_.push_back(answer);
header_.ancount = answer_section_.size();
}
} // namespace webrtc