Reapply "dcsctp: Negotiate zero checksum"

The handover state has been added with
commit daaa6ab5a8c74b87d9d6ded07342d8a2c50c73f7.

This reverts commit 014cbed9d2377ec0a0b15f2c48b21a562f770366.

Bug: webrtc:14997
Change-Id: Ie84f3184f3ea67aaa6438481634046ba18b497a6
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/320941
Reviewed-by: Florent Castelli <orphis@webrtc.org>
Reviewed-by: Jeremy Leconte <jleconte@webrtc.org>
Commit-Queue: Victor Boivie <boivie@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#40794}
diff --git a/net/dcsctp/public/dcsctp_socket.h b/net/dcsctp/public/dcsctp_socket.h
index 9fda56a..3cfb805 100644
--- a/net/dcsctp/public/dcsctp_socket.h
+++ b/net/dcsctp/public/dcsctp_socket.h
@@ -255,6 +255,10 @@
   // peers.
   bool uses_message_interleaving = false;
 
+  // Indicates if draft-tuexen-tsvwg-sctp-zero-checksum-00 has been negotiated
+  // by both peers.
+  bool uses_zero_checksum = false;
+
   // The number of negotiated incoming and outgoing streams, which is configured
   // locally as `DcSctpOptions::announced_maximum_incoming_streams` and
   // `DcSctpOptions::announced_maximum_outgoing_streams`, and which will be
diff --git a/net/dcsctp/socket/capabilities.h b/net/dcsctp/socket/capabilities.h
index fa3be37..286509a 100644
--- a/net/dcsctp/socket/capabilities.h
+++ b/net/dcsctp/socket/capabilities.h
@@ -21,6 +21,8 @@
   bool message_interleaving = false;
   // RFC6525 Stream Reconfiguration
   bool reconfig = false;
+  // https://datatracker.ietf.org/doc/draft-tuexen-tsvwg-sctp-zero-checksum/
+  bool zero_checksum = false;
   // Negotiated maximum incoming and outgoing stream count.
   uint16_t negotiated_maximum_incoming_streams = 0;
   uint16_t negotiated_maximum_outgoing_streams = 0;
diff --git a/net/dcsctp/socket/dcsctp_socket.cc b/net/dcsctp/socket/dcsctp_socket.cc
index 712fcea..a6845e3 100644
--- a/net/dcsctp/socket/dcsctp_socket.cc
+++ b/net/dcsctp/socket/dcsctp_socket.cc
@@ -56,6 +56,7 @@
 #include "net/dcsctp/packet/parameter/parameter.h"
 #include "net/dcsctp/packet/parameter/state_cookie_parameter.h"
 #include "net/dcsctp/packet/parameter/supported_extensions_parameter.h"
+#include "net/dcsctp/packet/parameter/zero_checksum_acceptable_chunk_parameter.h"
 #include "net/dcsctp/packet/sctp_packet.h"
 #include "net/dcsctp/packet/tlv_trait.h"
 #include "net/dcsctp/public/dcsctp_message.h"
@@ -117,6 +118,11 @@
     capabilities.reconfig = true;
   }
 
+  if (options.enable_zero_checksum &&
+      parameters.get<ZeroChecksumAcceptableChunkParameter>().has_value()) {
+    capabilities.zero_checksum = true;
+  }
+
   capabilities.negotiated_maximum_incoming_streams = std::min(
       options.announced_maximum_incoming_streams, peer_nbr_outbound_streams);
   capabilities.negotiated_maximum_outgoing_streams = std::min(
@@ -137,6 +143,9 @@
     chunk_types.push_back(IDataChunk::kType);
     chunk_types.push_back(IForwardTsnChunk::kType);
   }
+  if (options.enable_zero_checksum) {
+    builder.Add(ZeroChecksumAcceptableChunkParameter());
+  }
   builder.Add(SupportedExtensionsParameter(std::move(chunk_types)));
 }
 
@@ -279,7 +288,10 @@
                  connect_params_.initial_tsn, params_builder.Build());
   SctpPacket::Builder b(VerificationTag(0), options_);
   b.Add(init);
-  packet_sender_.Send(b);
+  // https://www.ietf.org/archive/id/draft-tuexen-tsvwg-sctp-zero-checksum-01.html#section-4.2
+  // "When an end point sends a packet containing an INIT chunk, it MUST include
+  // a correct CRC32c checksum in the packet containing the INIT chunk."
+  packet_sender_.Send(b, /*write_checksum=*/true);
 }
 
 void DcSctpSocket::MakeConnectionParameters() {
@@ -320,6 +332,7 @@
     size_t a_rwnd,
     TieTag tie_tag) {
   metrics_.uses_message_interleaving = capabilities.message_interleaving;
+  metrics_.uses_zero_checksum = capabilities.zero_checksum;
   metrics_.negotiated_maximum_incoming_streams =
       capabilities.negotiated_maximum_incoming_streams;
   metrics_.negotiated_maximum_outgoing_streams =
@@ -351,6 +364,7 @@
       capabilities.message_interleaving =
           state.capabilities.message_interleaving;
       capabilities.reconfig = state.capabilities.reconfig;
+      capabilities.zero_checksum = state.capabilities.zero_checksum;
       capabilities.negotiated_maximum_incoming_streams =
           state.capabilities.negotiated_maximum_incoming_streams;
       capabilities.negotiated_maximum_outgoing_streams =
@@ -1211,7 +1225,9 @@
                         options_.announced_maximum_incoming_streams,
                         connect_params_.initial_tsn, params_builder.Build());
   b.Add(init_ack);
-  packet_sender_.Send(b);
+  // If the peer has signaled that it supports zero checksum, INIT-ACK can then
+  // have its checksum as zero.
+  packet_sender_.Send(b, /*write_checksum=*/!capabilities.zero_checksum);
 }
 
 void DcSctpSocket::HandleInitAck(
diff --git a/net/dcsctp/socket/dcsctp_socket_test.cc b/net/dcsctp/socket/dcsctp_socket_test.cc
index c31c048..a7e2c7a 100644
--- a/net/dcsctp/socket/dcsctp_socket_test.cc
+++ b/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -24,6 +24,7 @@
 #include "net/dcsctp/common/handover_testing.h"
 #include "net/dcsctp/common/math.h"
 #include "net/dcsctp/packet/chunk/chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
 #include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
 #include "net/dcsctp/packet/chunk/data_chunk.h"
 #include "net/dcsctp/packet/chunk/data_common.h"
@@ -31,6 +32,7 @@
 #include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
 #include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
 #include "net/dcsctp/packet/chunk/idata_chunk.h"
+#include "net/dcsctp/packet/chunk/init_ack_chunk.h"
 #include "net/dcsctp/packet/chunk/init_chunk.h"
 #include "net/dcsctp/packet/chunk/sack_chunk.h"
 #include "net/dcsctp/packet/chunk/shutdown_chunk.h"
@@ -59,8 +61,10 @@
 using ::testing::_;
 using ::testing::AllOf;
 using ::testing::ElementsAre;
+using ::testing::Eq;
 using ::testing::HasSubstr;
 using ::testing::IsEmpty;
+using ::testing::Not;
 using ::testing::SizeIs;
 using ::testing::UnorderedElementsAre;
 
@@ -2903,5 +2907,156 @@
   EXPECT_EQ(msg2->payload().size(), kSmallMessageSize);
 }
 
+TEST_P(DcSctpSocketParametrizedTest, ZeroChecksumMetricsAreSet) {
+  std::vector<std::pair<bool, bool>> combinations = {
+      {false, false}, {false, true}, {true, false}, {true, true}};
+  for (const auto& [a_enable, z_enable] : combinations) {
+    DcSctpOptions a_options = {.enable_zero_checksum = a_enable};
+    DcSctpOptions z_options = {.enable_zero_checksum = z_enable};
+
+    SocketUnderTest a("A", a_options);
+    auto z = std::make_unique<SocketUnderTest>("Z", z_options);
+
+    ConnectSockets(a, *z);
+    z = MaybeHandoverSocket(std::move(z));
+
+    EXPECT_EQ(a.socket.GetMetrics()->uses_zero_checksum, a_enable && z_enable);
+    EXPECT_EQ(z->socket.GetMetrics()->uses_zero_checksum, a_enable && z_enable);
+  }
+}
+
+TEST(DcSctpSocketTest, AlwaysSendsInitWithNonZeroChecksum) {
+  DcSctpOptions options = {.enable_zero_checksum = true};
+  SocketUnderTest a("A", options);
+
+  a.socket.Connect();
+  std::vector<uint8_t> data = a.cb.ConsumeSentPacket();
+  ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+                              SctpPacket::Parse(data, options));
+  EXPECT_THAT(packet.descriptors(),
+              ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+                                         InitChunk::kType)));
+  EXPECT_THAT(packet.common_header().checksum, Not(Eq(0u)));
+}
+
+TEST(DcSctpSocketTest, MaySendInitAckWithZeroChecksum) {
+  DcSctpOptions options = {.enable_zero_checksum = true};
+  SocketUnderTest a("A", options);
+  SocketUnderTest z("Z", options);
+
+  a.socket.Connect();
+  z.socket.ReceivePacket(a.cb.ConsumeSentPacket());  // INIT
+
+  std::vector<uint8_t> data = z.cb.ConsumeSentPacket();
+  ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+                              SctpPacket::Parse(data, options));
+  EXPECT_THAT(packet.descriptors(),
+              ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+                                         InitAckChunk::kType)));
+  EXPECT_THAT(packet.common_header().checksum, 0u);
+}
+
+TEST(DcSctpSocketTest, AlwaysSendsCookieEchoWithNonZeroChecksum) {
+  DcSctpOptions options = {.enable_zero_checksum = true};
+  SocketUnderTest a("A", options);
+  SocketUnderTest z("Z", options);
+
+  a.socket.Connect();
+  z.socket.ReceivePacket(a.cb.ConsumeSentPacket());  // INIT
+  a.socket.ReceivePacket(z.cb.ConsumeSentPacket());  // INIT-ACK
+
+  std::vector<uint8_t> data = a.cb.ConsumeSentPacket();
+  ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+                              SctpPacket::Parse(data, options));
+  EXPECT_THAT(packet.descriptors(),
+              ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+                                         CookieEchoChunk::kType)));
+  EXPECT_THAT(packet.common_header().checksum, Not(Eq(0u)));
+}
+
+TEST(DcSctpSocketTest, SendsCookieAckWithZeroChecksum) {
+  DcSctpOptions options = {.enable_zero_checksum = true};
+  SocketUnderTest a("A", options);
+  SocketUnderTest z("Z", options);
+
+  a.socket.Connect();
+  z.socket.ReceivePacket(a.cb.ConsumeSentPacket());  // INIT
+  a.socket.ReceivePacket(z.cb.ConsumeSentPacket());  // INIT-ACK
+  z.socket.ReceivePacket(a.cb.ConsumeSentPacket());  // COOKIE-ECHO
+
+  std::vector<uint8_t> data = z.cb.ConsumeSentPacket();
+  ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+                              SctpPacket::Parse(data, options));
+  EXPECT_THAT(packet.descriptors(),
+              ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+                                         CookieAckChunk::kType)));
+  EXPECT_THAT(packet.common_header().checksum, 0u);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendsDataWithZeroChecksum) {
+  DcSctpOptions options = {.enable_zero_checksum = true};
+  SocketUnderTest a("A", options);
+  auto z = std::make_unique<SocketUnderTest>("Z", options);
+
+  ConnectSockets(a, *z);
+  z = MaybeHandoverSocket(std::move(z));
+
+  std::vector<uint8_t> payload(a.options.mtu - 100);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+  std::vector<uint8_t> data = a.cb.ConsumeSentPacket();
+  z->socket.ReceivePacket(data);
+  ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+                              SctpPacket::Parse(data, options));
+  EXPECT_THAT(packet.descriptors(),
+              ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+                                         DataChunk::kType)));
+  EXPECT_THAT(packet.common_header().checksum, 0u);
+
+  MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, AllPacketsAfterConnectHaveZeroChecksum) {
+  DcSctpOptions options = {.enable_zero_checksum = true};
+  SocketUnderTest a("A", options);
+  auto z = std::make_unique<SocketUnderTest>("Z", options);
+
+  ConnectSockets(a, *z);
+  z = MaybeHandoverSocket(std::move(z));
+
+  // Send large messages in both directions, and verify that they arrive and
+  // that every packet has zero checksum.
+  std::vector<uint8_t> payload(kLargeMessageSize);
+  a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+  z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+  for (;;) {
+    if (auto data = a.cb.ConsumeSentPacket(); !data.empty()) {
+      ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+                                  SctpPacket::Parse(data, options));
+      EXPECT_THAT(packet.common_header().checksum, 0u);
+      z->socket.ReceivePacket(std::move(data));
+
+    } else if (auto data = z->cb.ConsumeSentPacket(); !data.empty()) {
+      ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+                                  SctpPacket::Parse(data, options));
+      EXPECT_THAT(packet.common_header().checksum, 0u);
+      a.socket.ReceivePacket(std::move(data));
+
+    } else {
+      break;
+    }
+  }
+
+  absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+  ASSERT_TRUE(msg1.has_value());
+  EXPECT_THAT(msg1->payload(), SizeIs(kLargeMessageSize));
+
+  absl::optional<DcSctpMessage> msg2 = a.cb.ConsumeReceivedMessage();
+  ASSERT_TRUE(msg2.has_value());
+  EXPECT_THAT(msg2->payload(), SizeIs(kLargeMessageSize));
+
+  MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
 }  // namespace
 }  // namespace dcsctp
diff --git a/net/dcsctp/socket/packet_sender.cc b/net/dcsctp/socket/packet_sender.cc
index 85392e2..f0134ee 100644
--- a/net/dcsctp/socket/packet_sender.cc
+++ b/net/dcsctp/socket/packet_sender.cc
@@ -21,12 +21,12 @@
                                               SendPacketStatus)> on_sent_packet)
     : callbacks_(callbacks), on_sent_packet_(std::move(on_sent_packet)) {}
 
-bool PacketSender::Send(SctpPacket::Builder& builder) {
+bool PacketSender::Send(SctpPacket::Builder& builder, bool write_checksum) {
   if (builder.empty()) {
     return false;
   }
 
-  std::vector<uint8_t> payload = builder.Build();
+  std::vector<uint8_t> payload = builder.Build(write_checksum);
 
   SendPacketStatus status = callbacks_.SendPacketWithStatus(payload);
   on_sent_packet_(payload, status);
diff --git a/net/dcsctp/socket/packet_sender.h b/net/dcsctp/socket/packet_sender.h
index 7af4d3c..395c2ef 100644
--- a/net/dcsctp/socket/packet_sender.h
+++ b/net/dcsctp/socket/packet_sender.h
@@ -25,7 +25,7 @@
                                   SendPacketStatus)> on_sent_packet);
 
   // Sends the packet, and returns true if it was sent successfully.
-  bool Send(SctpPacket::Builder& builder);
+  bool Send(SctpPacket::Builder& builder, bool write_checksum = true);
 
  private:
   DcSctpSocketCallbacks& callbacks_;
diff --git a/net/dcsctp/socket/state_cookie.cc b/net/dcsctp/socket/state_cookie.cc
index 86be77a..624d783 100644
--- a/net/dcsctp/socket/state_cookie.cc
+++ b/net/dcsctp/socket/state_cookie.cc
@@ -42,6 +42,7 @@
   buffer.Store8<30>(capabilities_.reconfig);
   buffer.Store16<32>(capabilities_.negotiated_maximum_incoming_streams);
   buffer.Store16<34>(capabilities_.negotiated_maximum_outgoing_streams);
+  buffer.Store8<36>(capabilities_.zero_checksum);
   return cookie;
 }
 
@@ -74,6 +75,7 @@
   capabilities.reconfig = buffer.Load8<30>() != 0;
   capabilities.negotiated_maximum_incoming_streams = buffer.Load16<32>();
   capabilities.negotiated_maximum_outgoing_streams = buffer.Load16<34>();
+  capabilities.zero_checksum = buffer.Load8<36>() != 0;
 
   return StateCookie(verification_tag, initial_tsn, a_rwnd, tie_tag,
                      capabilities);
diff --git a/net/dcsctp/socket/state_cookie.h b/net/dcsctp/socket/state_cookie.h
index a26dbf8..34cd6d3 100644
--- a/net/dcsctp/socket/state_cookie.h
+++ b/net/dcsctp/socket/state_cookie.h
@@ -27,7 +27,7 @@
 // Do not trust anything in it; no pointers or anything like that.
 class StateCookie {
  public:
-  static constexpr size_t kCookieSize = 36;
+  static constexpr size_t kCookieSize = 37;
 
   StateCookie(VerificationTag initiate_tag,
               TSN initial_tsn,
diff --git a/net/dcsctp/socket/state_cookie_test.cc b/net/dcsctp/socket/state_cookie_test.cc
index 7d8e133..19be71a 100644
--- a/net/dcsctp/socket/state_cookie_test.cc
+++ b/net/dcsctp/socket/state_cookie_test.cc
@@ -21,6 +21,7 @@
   Capabilities capabilities = {.partial_reliability = true,
                                .message_interleaving = false,
                                .reconfig = true,
+                               .zero_checksum = true,
                                .negotiated_maximum_incoming_streams = 123,
                                .negotiated_maximum_outgoing_streams = 234};
   StateCookie cookie(VerificationTag(123), TSN(456),
@@ -36,6 +37,7 @@
   EXPECT_TRUE(deserialized.capabilities().partial_reliability);
   EXPECT_FALSE(deserialized.capabilities().message_interleaving);
   EXPECT_TRUE(deserialized.capabilities().reconfig);
+  EXPECT_TRUE(deserialized.capabilities().zero_checksum);
   EXPECT_EQ(deserialized.capabilities().negotiated_maximum_incoming_streams,
             123);
   EXPECT_EQ(deserialized.capabilities().negotiated_maximum_outgoing_streams,
diff --git a/net/dcsctp/socket/transmission_control_block.cc b/net/dcsctp/socket/transmission_control_block.cc
index 1dcf394..8bb1e8b 100644
--- a/net/dcsctp/socket/transmission_control_block.cc
+++ b/net/dcsctp/socket/transmission_control_block.cc
@@ -163,7 +163,7 @@
     } else {
       builder.Add(retransmission_queue_.CreateForwardTsn());
     }
-    packet_sender_.Send(builder);
+    Send(builder);
     // https://datatracker.ietf.org/doc/html/rfc3758
     // "IMPLEMENTATION NOTE: An implementation may wish to limit the number of
     // duplicate FORWARD TSN chunks it sends by ... waiting a full RTT before
@@ -198,7 +198,7 @@
       builder.Add(DataChunk(tsn, std::move(data), false));
     }
   }
-  packet_sender_.Send(builder);
+  Send(builder);
 }
 
 void TransmissionControlBlock::SendBufferedPackets(SctpPacket::Builder& builder,
@@ -245,7 +245,13 @@
       }
     }
 
-    if (!packet_sender_.Send(builder)) {
+    // https://www.ietf.org/archive/id/draft-tuexen-tsvwg-sctp-zero-checksum-02.html#section-4.2
+    // "When an end point sends a packet containing a COOKIE ECHO chunk, it MUST
+    // include a correct CRC32c checksum in the packet containing the COOKIE
+    // ECHO chunk."
+    bool write_checksum =
+        !capabilities_.zero_checksum || cookie_echo_chunk_.has_value();
+    if (!packet_sender_.Send(builder, write_checksum)) {
       break;
     }
 
@@ -274,6 +280,9 @@
   if (capabilities_.reconfig) {
     sb << "Reconfig,";
   }
+  if (capabilities_.zero_checksum) {
+    sb << "ZeroChecksum,";
+  }
   sb << " max_in=" << capabilities_.negotiated_maximum_incoming_streams;
   sb << " max_out=" << capabilities_.negotiated_maximum_outgoing_streams;
 
@@ -294,6 +303,7 @@
   state.capabilities.partial_reliability = capabilities_.partial_reliability;
   state.capabilities.message_interleaving = capabilities_.message_interleaving;
   state.capabilities.reconfig = capabilities_.reconfig;
+  state.capabilities.zero_checksum = capabilities_.zero_checksum;
   state.capabilities.negotiated_maximum_incoming_streams =
       capabilities_.negotiated_maximum_incoming_streams;
   state.capabilities.negotiated_maximum_outgoing_streams =
diff --git a/net/dcsctp/socket/transmission_control_block.h b/net/dcsctp/socket/transmission_control_block.h
index fc66fcc..46a39d5 100644
--- a/net/dcsctp/socket/transmission_control_block.h
+++ b/net/dcsctp/socket/transmission_control_block.h
@@ -80,7 +80,8 @@
     return tx_error_counter_.IsExhausted();
   }
   void Send(SctpPacket::Builder& builder) override {
-    packet_sender_.Send(builder);
+    packet_sender_.Send(builder,
+                        /*write_checksum=*/!capabilities_.zero_checksum);
   }
 
   // Other accessors
diff --git a/net/dcsctp/socket/transmission_control_block_test.cc b/net/dcsctp/socket/transmission_control_block_test.cc
index 40aea58..6106fbb 100644
--- a/net/dcsctp/socket/transmission_control_block_test.cc
+++ b/net/dcsctp/socket/transmission_control_block_test.cc
@@ -92,6 +92,7 @@
   capabilities_.negotiated_maximum_outgoing_streams = 2000;
   capabilities_.message_interleaving = true;
   capabilities_.partial_reliability = true;
+  capabilities_.zero_checksum = true;
   capabilities_.reconfig = true;
 
   TransmissionControlBlock tcb(
@@ -99,9 +100,10 @@
       kMyVerificationTag, kMyInitialTsn, kPeerVerificationTag, kPeerInitialTsn,
       kArwnd, kTieTag, sender_, on_connection_established.AsStdFunction());
 
-  EXPECT_EQ(tcb.ToString(),
-            "verification_tag=000001c8, last_cumulative_ack=999, "
-            "capabilities=PR,IL,Reconfig, max_in=1000 max_out=2000");
+  EXPECT_EQ(
+      tcb.ToString(),
+      "verification_tag=000001c8, last_cumulative_ack=999, "
+      "capabilities=PR,IL,Reconfig,ZeroChecksum, max_in=1000 max_out=2000");
 }
 
 TEST_F(TransmissionControlBlockTest, IsInitiallyHandoverReady) {