blob: 61af6f39f29c898e3c2a0db21ac5097c7381c60b [file] [log] [blame]
/*
* Copyright 2018 The WebRTC Project Authors. All rights reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "p2p/base/mdns_message.h"
#include "rtc_base/logging.h"
#include "rtc_base/nethelpers.h"
#include "rtc_base/stringencode.h"
namespace webrtc {
namespace {
// RFC 1035, Section 4.1.1.
//
// QR bit.
constexpr uint16_t kMDnsFlagMaskQueryOrResponse = 0x8000;
// AA bit.
constexpr uint16_t kMDnsFlagMaskAuthoritative = 0x0400;
// RFC 1035, Section 4.1.2, QCLASS and RFC 6762, Section 18.12, repurposing of
// top bit of QCLASS as the unicast response bit.
constexpr uint16_t kMDnsQClassMaskUnicastResponse = 0x8000;
constexpr size_t kMDnsHeaderSizeBytes = 12;
bool ReadDomainName(MessageBufferReader* buf, std::string* name) {
size_t name_start_pos = buf->CurrentOffset();
uint8_t label_length;
if (!buf->ReadUInt8(&label_length)) {
return false;
}
// RFC 1035, Section 4.1.4.
//
// If the first two bits of the length octet are ones, the name is compressed
// and the rest six bits with the next octet denotes its position in the
// message by the offset from the start of the message.
auto is_pointer = [](uint8_t octet) {
return (octet & 0x80) && (octet & 0x40);
};
while (label_length && !is_pointer(label_length)) {
// RFC 1035, Section 2.3.1, labels are restricted to 63 octets or less.
if (label_length > 63) {
return false;
}
std::string label;
if (!buf->ReadString(&label, label_length)) {
return false;
}
(*name) += label + ".";
if (!buf->ReadUInt8(&label_length)) {
return false;
}
}
if (is_pointer(label_length)) {
uint8_t next_octet;
if (!buf->ReadUInt8(&next_octet)) {
return false;
}
size_t pos_jump_to = ((label_length & 0x3f) << 8) | next_octet;
// A legitimate pointer only refers to a prior occurrence of the same name,
// and we should only move strictly backward to a prior name field after the
// header.
if (pos_jump_to >= name_start_pos || pos_jump_to < kMDnsHeaderSizeBytes) {
return false;
}
MessageBufferReader new_buf(buf->MessageData(), buf->MessageLength());
if (!new_buf.Consume(pos_jump_to)) {
return false;
}
return ReadDomainName(&new_buf, name);
}
return true;
}
void WriteDomainName(rtc::ByteBufferWriter* buf, const std::string& name) {
std::vector<std::string> labels;
rtc::tokenize(name, '.', &labels);
for (const auto& label : labels) {
buf->WriteUInt8(label.length());
buf->WriteString(label);
}
buf->WriteUInt8(0);
}
} // namespace
void MDnsHeader::SetQueryOrResponse(bool is_query) {
if (is_query) {
flags &= ~kMDnsFlagMaskQueryOrResponse;
} else {
flags |= kMDnsFlagMaskQueryOrResponse;
}
}
void MDnsHeader::SetAuthoritative(bool is_authoritative) {
if (is_authoritative) {
flags |= kMDnsFlagMaskAuthoritative;
} else {
flags &= ~kMDnsFlagMaskAuthoritative;
}
}
bool MDnsHeader::IsAuthoritative() const {
return flags & kMDnsFlagMaskAuthoritative;
}
bool MDnsHeader::Read(MessageBufferReader* buf) {
if (!buf->ReadUInt16(&id) || !buf->ReadUInt16(&flags) ||
!buf->ReadUInt16(&qdcount) || !buf->ReadUInt16(&ancount) ||
!buf->ReadUInt16(&nscount) || !buf->ReadUInt16(&arcount)) {
RTC_LOG(LS_ERROR) << "Invalid mDNS header.";
return false;
}
return true;
}
void MDnsHeader::Write(rtc::ByteBufferWriter* buf) const {
buf->WriteUInt16(id);
buf->WriteUInt16(flags);
buf->WriteUInt16(qdcount);
buf->WriteUInt16(ancount);
buf->WriteUInt16(nscount);
buf->WriteUInt16(arcount);
}
bool MDnsHeader::IsQuery() const {
return !(flags & kMDnsFlagMaskQueryOrResponse);
}
MDnsSectionEntry::MDnsSectionEntry() = default;
MDnsSectionEntry::~MDnsSectionEntry() = default;
MDnsSectionEntry::MDnsSectionEntry(const MDnsSectionEntry& other) = default;
void MDnsSectionEntry::SetType(SectionEntryType type) {
switch (type) {
case SectionEntryType::kA:
type_ = 1;
return;
case SectionEntryType::kAAAA:
type_ = 28;
return;
default:
RTC_NOTREACHED();
}
}
SectionEntryType MDnsSectionEntry::GetType() const {
switch (type_) {
case 1:
return SectionEntryType::kA;
case 28:
return SectionEntryType::kAAAA;
default:
return SectionEntryType::kUnsupported;
}
}
void MDnsSectionEntry::SetClass(SectionEntryClass cls) {
switch (cls) {
case SectionEntryClass::kIN:
class_ = 1;
return;
default:
RTC_NOTREACHED();
}
}
SectionEntryClass MDnsSectionEntry::GetClass() const {
switch (class_) {
case 1:
return SectionEntryClass::kIN;
default:
return SectionEntryClass::kUnsupported;
}
}
MDnsQuestion::MDnsQuestion() = default;
MDnsQuestion::MDnsQuestion(const MDnsQuestion& other) = default;
MDnsQuestion::~MDnsQuestion() = default;
bool MDnsQuestion::Read(MessageBufferReader* buf) {
if (!ReadDomainName(buf, &name_)) {
RTC_LOG(LS_ERROR) << "Invalid name.";
return false;
}
if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_)) {
RTC_LOG(LS_ERROR) << "Invalid type and class.";
return false;
}
return true;
}
bool MDnsQuestion::Write(rtc::ByteBufferWriter* buf) const {
WriteDomainName(buf, name_);
buf->WriteUInt16(type_);
buf->WriteUInt16(class_);
return true;
}
void MDnsQuestion::SetUnicastResponse(bool should_unicast) {
if (should_unicast) {
class_ |= kMDnsQClassMaskUnicastResponse;
} else {
class_ &= ~kMDnsQClassMaskUnicastResponse;
}
}
bool MDnsQuestion::ShouldUnicastResponse() const {
return class_ & kMDnsQClassMaskUnicastResponse;
}
MDnsResourceRecord::MDnsResourceRecord() = default;
MDnsResourceRecord::MDnsResourceRecord(const MDnsResourceRecord& other) =
default;
MDnsResourceRecord::~MDnsResourceRecord() = default;
bool MDnsResourceRecord::Read(MessageBufferReader* buf) {
if (!ReadDomainName(buf, &name_)) {
return false;
}
if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_) ||
!buf->ReadUInt32(&ttl_seconds_) || !buf->ReadUInt16(&rdlength_)) {
return false;
}
switch (GetType()) {
case SectionEntryType::kA:
return ReadARData(buf);
case SectionEntryType::kAAAA:
return ReadQuadARData(buf);
case SectionEntryType::kUnsupported:
return false;
default:
RTC_NOTREACHED();
}
return false;
}
bool MDnsResourceRecord::ReadARData(MessageBufferReader* buf) {
// A RDATA contains a 32-bit IPv4 address.
return buf->ReadString(&rdata_, 4);
}
bool MDnsResourceRecord::ReadQuadARData(MessageBufferReader* buf) {
// AAAA RDATA contains a 128-bit IPv6 address.
return buf->ReadString(&rdata_, 16);
}
bool MDnsResourceRecord::Write(rtc::ByteBufferWriter* buf) const {
WriteDomainName(buf, name_);
buf->WriteUInt16(type_);
buf->WriteUInt16(class_);
buf->WriteUInt32(ttl_seconds_);
buf->WriteUInt16(rdlength_);
switch (GetType()) {
case SectionEntryType::kA:
WriteARData(buf);
return true;
case SectionEntryType::kAAAA:
WriteQuadARData(buf);
return true;
case SectionEntryType::kUnsupported:
return false;
default:
RTC_NOTREACHED();
}
return true;
}
void MDnsResourceRecord::WriteARData(rtc::ByteBufferWriter* buf) const {
buf->WriteString(rdata_);
}
void MDnsResourceRecord::WriteQuadARData(rtc::ByteBufferWriter* buf) const {
buf->WriteString(rdata_);
}
bool MDnsResourceRecord::SetIPAddressInRecordData(
const rtc::IPAddress& address) {
int af = address.family();
if (af != AF_INET && af != AF_INET6) {
return false;
}
char out[16] = {0};
if (!rtc::inet_pton(af, address.ToString().c_str(), out)) {
return false;
}
rdlength_ = (af == AF_INET) ? 4 : 16;
rdata_ = std::string(out, rdlength_);
return true;
}
bool MDnsResourceRecord::GetIPAddressFromRecordData(
rtc::IPAddress* address) const {
if (GetType() != SectionEntryType::kA &&
GetType() != SectionEntryType::kAAAA) {
return false;
}
if (rdata_.size() != 4 && rdata_.size() != 16) {
return false;
}
char out[INET6_ADDRSTRLEN] = {0};
int af = (GetType() == SectionEntryType::kA) ? AF_INET : AF_INET6;
if (!rtc::inet_ntop(af, rdata_.data(), out, sizeof(out))) {
return false;
}
return rtc::IPFromString(std::string(out), address);
}
MDnsMessage::MDnsMessage() = default;
MDnsMessage::~MDnsMessage() = default;
bool MDnsMessage::Read(MessageBufferReader* buf) {
RTC_DCHECK_EQ(0u, buf->CurrentOffset());
if (!header_.Read(buf)) {
return false;
}
auto read_question = [&buf](std::vector<MDnsQuestion>* section,
uint16_t count) {
section->resize(count);
for (auto& question : (*section)) {
if (!question.Read(buf)) {
return false;
}
}
return true;
};
auto read_rr = [&buf](std::vector<MDnsResourceRecord>* section,
uint16_t count) {
section->resize(count);
for (auto& rr : (*section)) {
if (!rr.Read(buf)) {
return false;
}
}
return true;
};
if (!read_question(&question_section_, header_.qdcount) ||
!read_rr(&answer_section_, header_.ancount) ||
!read_rr(&authority_section_, header_.nscount) ||
!read_rr(&additional_section_, header_.arcount)) {
return false;
}
return true;
}
bool MDnsMessage::Write(rtc::ByteBufferWriter* buf) const {
header_.Write(buf);
auto write_rr = [&buf](const std::vector<MDnsResourceRecord>& section) {
for (auto rr : section) {
if (!rr.Write(buf)) {
return false;
}
}
return true;
};
for (auto question : question_section_) {
if (!question.Write(buf)) {
return false;
}
}
if (!write_rr(answer_section_) || !write_rr(authority_section_) ||
!write_rr(additional_section_)) {
return false;
}
return true;
}
bool MDnsMessage::ShouldUnicastResponse() const {
bool should_unicast = false;
for (const auto& question : question_section_) {
should_unicast |= question.ShouldUnicastResponse();
}
return should_unicast;
}
void MDnsMessage::AddQuestion(const MDnsQuestion& question) {
question_section_.push_back(question);
header_.qdcount = question_section_.size();
}
void MDnsMessage::AddAnswerRecord(const MDnsResourceRecord& answer) {
answer_section_.push_back(answer);
header_.ancount = answer_section_.size();
}
} // namespace webrtc