dcsctp: Ensure packet size doesn't exceed MTU

Due to a previous refactoring, the SCTP packet header is only added when
the first chunk is written. This wasn't reflected in the
`bytes_remaining`, which made it add more than could fit within the MTU.

Additionally, the maximum packet size must be even divisible by four as
padding will be added to chunks that are not even divisble by four (up
to three bytes of padding). So compensate for that.

Bug: webrtc:12614
Change-Id: I6b57dfbf88d1fcfcbf443038915dd180e796191a
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/215145
Reviewed-by: Tommi <tommi@webrtc.org>
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33760}
diff --git a/net/dcsctp/common/math.h b/net/dcsctp/common/math.h
index ee161d2..12f690e 100644
--- a/net/dcsctp/common/math.h
+++ b/net/dcsctp/common/math.h
@@ -16,7 +16,19 @@
 // used to e.g. pad chunks or parameters to an even 32-bit offset.
 template <typename IntType>
 IntType RoundUpTo4(IntType val) {
-  return (val + 3) & -4;
+  return (val + 3) & ~3;
+}
+
+// Similarly, rounds down `val` to the nearest value that is divisible by four.
+template <typename IntType>
+IntType RoundDownTo4(IntType val) {
+  return val & ~3;
+}
+
+// Returns true if `val` is divisible by four.
+template <typename IntType>
+bool IsDivisibleBy4(IntType val) {
+  return (val & 3) == 0;
 }
 
 }  // namespace dcsctp
diff --git a/net/dcsctp/common/math_test.cc b/net/dcsctp/common/math_test.cc
index 902aefa..f95dfbd 100644
--- a/net/dcsctp/common/math_test.cc
+++ b/net/dcsctp/common/math_test.cc
@@ -15,17 +15,101 @@
 namespace {
 
 TEST(MathUtilTest, CanRoundUpTo4) {
-  EXPECT_EQ(RoundUpTo4(0), 0);
-  EXPECT_EQ(RoundUpTo4(1), 4);
-  EXPECT_EQ(RoundUpTo4(2), 4);
-  EXPECT_EQ(RoundUpTo4(3), 4);
-  EXPECT_EQ(RoundUpTo4(4), 4);
-  EXPECT_EQ(RoundUpTo4(5), 8);
-  EXPECT_EQ(RoundUpTo4(6), 8);
-  EXPECT_EQ(RoundUpTo4(7), 8);
-  EXPECT_EQ(RoundUpTo4(8), 8);
-  EXPECT_EQ(RoundUpTo4(10000000000), 10000000000);
-  EXPECT_EQ(RoundUpTo4(10000000001), 10000000004);
+  // Signed numbers
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(-5)), -4);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(-4)), -4);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(-3)), 0);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(-2)), 0);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(-1)), 0);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(0)), 0);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(1)), 4);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(2)), 4);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(3)), 4);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(4)), 4);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(5)), 8);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(6)), 8);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(7)), 8);
+  EXPECT_EQ(RoundUpTo4(static_cast<int>(8)), 8);
+  EXPECT_EQ(RoundUpTo4(static_cast<int64_t>(10000000000)), 10000000000);
+  EXPECT_EQ(RoundUpTo4(static_cast<int64_t>(10000000001)), 10000000004);
+
+  // Unsigned numbers
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(0)), 0u);
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(1)), 4u);
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(2)), 4u);
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(3)), 4u);
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(4)), 4u);
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(5)), 8u);
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(6)), 8u);
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(7)), 8u);
+  EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(8)), 8u);
+  EXPECT_EQ(RoundUpTo4(static_cast<uint64_t>(10000000000)), 10000000000u);
+  EXPECT_EQ(RoundUpTo4(static_cast<uint64_t>(10000000001)), 10000000004u);
+}
+
+TEST(MathUtilTest, CanRoundDownTo4) {
+  // Signed numbers
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(-5)), -8);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(-4)), -4);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(-3)), -4);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(-2)), -4);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(-1)), -4);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(0)), 0);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(1)), 0);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(2)), 0);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(3)), 0);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(4)), 4);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(5)), 4);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(6)), 4);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(7)), 4);
+  EXPECT_EQ(RoundDownTo4(static_cast<int>(8)), 8);
+  EXPECT_EQ(RoundDownTo4(static_cast<int64_t>(10000000000)), 10000000000);
+  EXPECT_EQ(RoundDownTo4(static_cast<int64_t>(10000000001)), 10000000000);
+
+  // Unsigned numbers
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(0)), 0u);
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(1)), 0u);
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(2)), 0u);
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(3)), 0u);
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(4)), 4u);
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(5)), 4u);
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(6)), 4u);
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(7)), 4u);
+  EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(8)), 8u);
+  EXPECT_EQ(RoundDownTo4(static_cast<uint64_t>(10000000000)), 10000000000u);
+  EXPECT_EQ(RoundDownTo4(static_cast<uint64_t>(10000000001)), 10000000000u);
+}
+
+TEST(MathUtilTest, IsDivisibleBy4) {
+  // Signed numbers
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(-4)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(-3)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(-2)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(-1)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(0)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(1)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(2)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(3)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(4)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(5)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(6)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(7)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int>(8)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int64_t>(10000000000)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<int64_t>(10000000001)), false);
+
+  // Unsigned numbers
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(0)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(1)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(2)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(3)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(4)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(5)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(6)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(7)), false);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(8)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<uint64_t>(10000000000)), true);
+  EXPECT_EQ(IsDivisibleBy4(static_cast<uint64_t>(10000000001)), false);
 }
 
 }  // namespace
diff --git a/net/dcsctp/packet/sctp_packet.cc b/net/dcsctp/packet/sctp_packet.cc
index 1e12367..da06ccf 100644
--- a/net/dcsctp/packet/sctp_packet.cc
+++ b/net/dcsctp/packet/sctp_packet.cc
@@ -52,11 +52,11 @@
     : verification_tag_(verification_tag),
       source_port_(options.local_port),
       dest_port_(options.remote_port),
-      max_mtu_(options.mtu) {}
+      max_packet_size_(RoundDownTo4(options.mtu)) {}
 
 SctpPacket::Builder& SctpPacket::Builder::Add(const Chunk& chunk) {
   if (out_.empty()) {
-    out_.reserve(max_mtu_);
+    out_.reserve(max_packet_size_);
     out_.resize(SctpPacket::kHeaderSize);
     BoundedByteWriter<kHeaderSize> buffer(out_);
     buffer.Store16<0>(source_port_);
@@ -64,14 +64,31 @@
     buffer.Store32<4>(*verification_tag_);
     // Checksum is at offset 8 - written when calling Build();
   }
+  RTC_DCHECK(IsDivisibleBy4(out_.size()));
+
   chunk.SerializeTo(out_);
   if (out_.size() % 4 != 0) {
     out_.resize(RoundUpTo4(out_.size()));
   }
 
+  RTC_DCHECK(out_.size() <= max_packet_size_)
+      << "Exceeded max size, data=" << out_.size()
+      << ", max_size=" << max_packet_size_;
   return *this;
 }
 
+size_t SctpPacket::Builder::bytes_remaining() const {
+  if (out_.empty()) {
+    // The packet header (CommonHeader) hasn't been written yet:
+    return max_packet_size_ - kHeaderSize;
+  } else if (out_.size() > max_packet_size_) {
+    RTC_DCHECK(false) << "Exceeded max size, data=" << out_.size()
+                      << ", max_size=" << max_packet_size_;
+    return 0;
+  }
+  return max_packet_size_ - out_.size();
+}
+
 std::vector<uint8_t> SctpPacket::Builder::Build() {
   std::vector<uint8_t> out;
   out_.swap(out);
@@ -80,6 +97,11 @@
     uint32_t crc = GenerateCrc32C(out);
     BoundedByteWriter<kHeaderSize>(out).Store32<8>(crc);
   }
+
+  RTC_DCHECK(out.size() <= max_packet_size_)
+      << "Exceeded max size, data=" << out.size()
+      << ", max_size=" << max_packet_size_;
+
   return out;
 }
 
diff --git a/net/dcsctp/packet/sctp_packet.h b/net/dcsctp/packet/sctp_packet.h
index 927b8db..2600caf 100644
--- a/net/dcsctp/packet/sctp_packet.h
+++ b/net/dcsctp/packet/sctp_packet.h
@@ -65,10 +65,9 @@
     // Adds a chunk to the to-be-built SCTP packet.
     Builder& Add(const Chunk& chunk);
 
-    // The number of bytes remaining in the packet, until the MTU is reached.
-    size_t bytes_remaining() const {
-      return out_.size() >= max_mtu_ ? 0 : max_mtu_ - out_.size();
-    }
+    // The number of bytes remaining in the packet for chunk storage until the
+    // packet reaches its maximum size.
+    size_t bytes_remaining() const;
 
     // Indicates if any packets have been added to the builder.
     bool empty() const { return out_.empty(); }
@@ -82,7 +81,9 @@
     VerificationTag verification_tag_;
     uint16_t source_port_;
     uint16_t dest_port_;
-    size_t max_mtu_;
+    // The maximum packet size is always even divisible by four, as chunks are
+    // always padded to a size even divisible by four.
+    size_t max_packet_size_;
     std::vector<uint8_t> out_;
   };
 
diff --git a/net/dcsctp/packet/sctp_packet_test.cc b/net/dcsctp/packet/sctp_packet_test.cc
index ece1b7b..7438315 100644
--- a/net/dcsctp/packet/sctp_packet_test.cc
+++ b/net/dcsctp/packet/sctp_packet_test.cc
@@ -15,6 +15,7 @@
 
 #include "api/array_view.h"
 #include "net/dcsctp/common/internal_types.h"
+#include "net/dcsctp/common/math.h"
 #include "net/dcsctp/packet/chunk/abort_chunk.h"
 #include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
 #include "net/dcsctp/packet/chunk/data_chunk.h"
@@ -24,6 +25,7 @@
 #include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h"
 #include "net/dcsctp/packet/parameter/parameter.h"
 #include "net/dcsctp/packet/tlv_trait.h"
+#include "net/dcsctp/public/dcsctp_options.h"
 #include "net/dcsctp/testing/testing_macros.h"
 #include "rtc_base/gunit.h"
 #include "test/gmock.h"
@@ -298,5 +300,43 @@
 
   EXPECT_FALSE(SctpPacket::Parse(data, true).has_value());
 }
+
+TEST(SctpPacketTest, ReturnsCorrectSpaceAvailableToStayWithinMTU) {
+  DcSctpOptions options;
+  options.mtu = 1191;
+
+  SctpPacket::Builder builder(VerificationTag(123), options);
+
+  // Chunks will be padded to an even 4 bytes, so the maximum packet size should
+  // be rounded down.
+  const size_t kMaxPacketSize = RoundDownTo4(options.mtu);
+  EXPECT_EQ(kMaxPacketSize, 1188u);
+
+  const size_t kSctpHeaderSize = 12;
+  EXPECT_EQ(builder.bytes_remaining(), kMaxPacketSize - kSctpHeaderSize);
+  EXPECT_EQ(builder.bytes_remaining(), 1176u);
+
+  // Add a smaller packet first.
+  DataChunk::Options data_options;
+
+  std::vector<uint8_t> payload1(183);
+  builder.Add(
+      DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload1, data_options));
+
+  size_t chunk1_size = RoundUpTo4(DataChunk::kHeaderSize + payload1.size());
+  EXPECT_EQ(builder.bytes_remaining(),
+            kMaxPacketSize - kSctpHeaderSize - chunk1_size);
+  EXPECT_EQ(builder.bytes_remaining(), 976u);  // Hand-calculated.
+
+  std::vector<uint8_t> payload2(957);
+  builder.Add(
+      DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload2, data_options));
+
+  size_t chunk2_size = RoundUpTo4(DataChunk::kHeaderSize + payload2.size());
+  EXPECT_EQ(builder.bytes_remaining(),
+            kMaxPacketSize - kSctpHeaderSize - chunk1_size - chunk2_size);
+  EXPECT_EQ(builder.bytes_remaining(), 0u);  // Hand-calculated.
+}
+
 }  // namespace
 }  // namespace dcsctp