Stay writable after partial socket writes.

This CL fixes an issue where the "writable" flag didn't stay set after
::send or ::sendto only sent a partial buffer.

Also SocketTest::TcpInternal has been updated to use rtc::Buffer instead
of manually allocating data.

BUG=webrtc:4898

Review URL: https://codereview.webrtc.org/1616153007

Cr-Commit-Position: refs/heads/master@{#11480}
diff --git a/webrtc/base/physicalsocketserver.cc b/webrtc/base/physicalsocketserver.cc
index 67fbea0..4cea040 100644
--- a/webrtc/base/physicalsocketserver.cc
+++ b/webrtc/base/physicalsocketserver.cc
@@ -271,7 +271,8 @@
 }
 
 int PhysicalSocket::Send(const void* pv, size_t cb) {
-  int sent = ::send(s_, reinterpret_cast<const char *>(pv), (int)cb,
+  int sent = DoSend(s_, reinterpret_cast<const char *>(pv),
+      static_cast<int>(cb),
 #if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID)
       // Suppress SIGPIPE. Without this, attempting to send on a socket whose
       // other end is closed will result in a SIGPIPE signal being raised to
@@ -287,7 +288,8 @@
   MaybeRemapSendError();
   // We have seen minidumps where this may be false.
   ASSERT(sent <= static_cast<int>(cb));
-  if ((sent < 0) && IsBlockingError(GetError())) {
+  if ((sent > 0 && sent < static_cast<int>(cb)) ||
+      (sent < 0 && IsBlockingError(GetError()))) {
     enabled_events_ |= DE_WRITE;
   }
   return sent;
@@ -298,7 +300,7 @@
                            const SocketAddress& addr) {
   sockaddr_storage saddr;
   size_t len = addr.ToSockAddrStorage(&saddr);
-  int sent = ::sendto(
+  int sent = DoSendTo(
       s_, static_cast<const char *>(buffer), static_cast<int>(length),
 #if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID)
       // Suppress SIGPIPE. See above for explanation.
@@ -311,7 +313,8 @@
   MaybeRemapSendError();
   // We have seen minidumps where this may be false.
   ASSERT(sent <= static_cast<int>(length));
-  if ((sent < 0) && IsBlockingError(GetError())) {
+  if ((sent > 0 && sent < static_cast<int>(length)) ||
+      (sent < 0 && IsBlockingError(GetError()))) {
     enabled_events_ |= DE_WRITE;
   }
   return sent;
@@ -474,13 +477,25 @@
 #endif
 }
 
-
 SOCKET PhysicalSocket::DoAccept(SOCKET socket,
                                 sockaddr* addr,
                                 socklen_t* addrlen) {
   return ::accept(socket, addr, addrlen);
 }
 
+int PhysicalSocket::DoSend(SOCKET socket, const char* buf, int len, int flags) {
+  return ::send(socket, buf, len, flags);
+}
+
+int PhysicalSocket::DoSendTo(SOCKET socket,
+                             const char* buf,
+                             int len,
+                             int flags,
+                             const struct sockaddr* dest_addr,
+                             socklen_t addrlen) {
+  return ::sendto(socket, buf, len, flags, dest_addr, addrlen);
+}
+
 void PhysicalSocket::OnResolveResult(AsyncResolverInterface* resolver) {
   if (resolver != resolver_) {
     return;
diff --git a/webrtc/base/physicalsocketserver.h b/webrtc/base/physicalsocketserver.h
index cbe6580..583306c 100644
--- a/webrtc/base/physicalsocketserver.h
+++ b/webrtc/base/physicalsocketserver.h
@@ -68,8 +68,8 @@
   AsyncSocket* CreateAsyncSocket(int type) override;
   AsyncSocket* CreateAsyncSocket(int family, int type) override;
 
-  // Internal Factory for Accept
-  AsyncSocket* WrapSocket(SOCKET s);
+  // Internal Factory for Accept (virtual so it can be overwritten in tests).
+  virtual AsyncSocket* WrapSocket(SOCKET s);
 
   // SocketServer:
   bool Wait(int cms, bool process_io) override;
@@ -161,6 +161,13 @@
   // Make virtual so ::accept can be overwritten in tests.
   virtual SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen);
 
+  // Make virtual so ::send can be overwritten in tests.
+  virtual int DoSend(SOCKET socket, const char* buf, int len, int flags);
+
+  // Make virtual so ::sendto can be overwritten in tests.
+  virtual int DoSendTo(SOCKET socket, const char* buf, int len, int flags,
+                       const struct sockaddr* dest_addr, socklen_t addrlen);
+
   void OnResolveResult(AsyncResolverInterface* resolver);
 
   void UpdateLastError();
diff --git a/webrtc/base/physicalsocketserver_unittest.cc b/webrtc/base/physicalsocketserver_unittest.cc
index a2fde80..c53441d 100644
--- a/webrtc/base/physicalsocketserver_unittest.cc
+++ b/webrtc/base/physicalsocketserver_unittest.cc
@@ -29,8 +29,15 @@
     : SocketDispatcher(ss) {
   }
 
+  FakeSocketDispatcher(SOCKET s, PhysicalSocketServer* ss)
+    : SocketDispatcher(s, ss) {
+  }
+
  protected:
   SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen) override;
+  int DoSend(SOCKET socket, const char* buf, int len, int flags) override;
+  int DoSendTo(SOCKET socket, const char* buf, int len, int flags,
+               const struct sockaddr* dest_addr, socklen_t addrlen) override;
 };
 
 class FakePhysicalSocketServer : public PhysicalSocketServer {
@@ -41,22 +48,29 @@
 
   AsyncSocket* CreateAsyncSocket(int type) override {
     SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
-    if (dispatcher->Create(type)) {
-      return dispatcher;
-    } else {
+    if (!dispatcher->Create(type)) {
       delete dispatcher;
       return nullptr;
     }
+    return dispatcher;
   }
 
   AsyncSocket* CreateAsyncSocket(int family, int type) override {
     SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
-    if (dispatcher->Create(family, type)) {
-      return dispatcher;
-    } else {
+    if (!dispatcher->Create(family, type)) {
       delete dispatcher;
       return nullptr;
     }
+    return dispatcher;
+  }
+
+  AsyncSocket* WrapSocket(SOCKET s) override {
+    SocketDispatcher* dispatcher = new FakeSocketDispatcher(s, this);
+    if (!dispatcher->Initialize()) {
+      delete dispatcher;
+      return nullptr;
+    }
+    return dispatcher;
   }
 
   PhysicalSocketTest* GetTest() const { return test_; }
@@ -71,18 +85,25 @@
   void SetFailAccept(bool fail) { fail_accept_ = fail; }
   bool FailAccept() const { return fail_accept_; }
 
+  // Maximum size to ::send to a socket. Set to < 0 to disable limiting.
+  void SetMaxSendSize(int max_size) { max_send_size_ = max_size; }
+  int MaxSendSize() const { return max_send_size_; }
+
  protected:
   PhysicalSocketTest()
     : server_(new FakePhysicalSocketServer(this)),
       scope_(server_.get()),
-      fail_accept_(false) {
+      fail_accept_(false),
+      max_send_size_(-1) {
   }
 
   void ConnectInternalAcceptError(const IPAddress& loopback);
+  void WritableAfterPartialWrite(const IPAddress& loopback);
 
   rtc::scoped_ptr<FakePhysicalSocketServer> server_;
   SocketServerScope scope_;
   bool fail_accept_;
+  int max_send_size_;
 };
 
 SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket,
@@ -97,6 +118,29 @@
   return SocketDispatcher::DoAccept(socket, addr, addrlen);
 }
 
+int FakeSocketDispatcher::DoSend(SOCKET socket, const char* buf, int len,
+    int flags) {
+  FakePhysicalSocketServer* ss =
+      static_cast<FakePhysicalSocketServer*>(socketserver());
+  if (ss->GetTest()->MaxSendSize() >= 0) {
+    len = std::min(len, ss->GetTest()->MaxSendSize());
+  }
+
+  return SocketDispatcher::DoSend(socket, buf, len, flags);
+}
+
+int FakeSocketDispatcher::DoSendTo(SOCKET socket, const char* buf, int len,
+    int flags, const struct sockaddr* dest_addr, socklen_t addrlen) {
+  FakePhysicalSocketServer* ss =
+      static_cast<FakePhysicalSocketServer*>(socketserver());
+  if (ss->GetTest()->MaxSendSize() >= 0) {
+    len = std::min(len, ss->GetTest()->MaxSendSize());
+  }
+
+  return SocketDispatcher::DoSendTo(socket, buf, len, flags, dest_addr,
+      addrlen);
+}
+
 TEST_F(PhysicalSocketTest, TestConnectIPv4) {
   SocketTest::TestConnectIPv4();
 }
@@ -209,6 +253,33 @@
   ConnectInternalAcceptError(kIPv6Loopback);
 }
 
+void PhysicalSocketTest::WritableAfterPartialWrite(const IPAddress& loopback) {
+  // Simulate a really small maximum send size.
+  const int kMaxSendSize = 128;
+  SetMaxSendSize(kMaxSendSize);
+
+  // Run the default send/receive socket tests with a smaller amount of data
+  // to avoid long running times due to the small maximum send size.
+  const size_t kDataSize = 128 * 1024;
+  TcpInternal(loopback, kDataSize, kMaxSendSize);
+}
+
+TEST_F(PhysicalSocketTest, TestWritableAfterPartialWriteIPv4) {
+  WritableAfterPartialWrite(kIPv4Loopback);
+}
+
+// Crashes on Linux. See webrtc:4923.
+#if defined(WEBRTC_LINUX)
+#define MAYBE_TestWritableAfterPartialWriteIPv6 \
+    DISABLED_TestWritableAfterPartialWriteIPv6
+#else
+#define MAYBE_TestWritableAfterPartialWriteIPv6 \
+    TestWritableAfterPartialWriteIPv6
+#endif
+TEST_F(PhysicalSocketTest, MAYBE_TestWritableAfterPartialWriteIPv6) {
+  WritableAfterPartialWrite(kIPv6Loopback);
+}
+
 // Crashes on Linux. See webrtc:4923.
 #if defined(WEBRTC_LINUX)
 #define MAYBE_TestConnectFailIPv6 DISABLED_TestConnectFailIPv6
diff --git a/webrtc/base/socket_unittest.cc b/webrtc/base/socket_unittest.cc
index 8143823..d1369e2 100644
--- a/webrtc/base/socket_unittest.cc
+++ b/webrtc/base/socket_unittest.cc
@@ -11,6 +11,7 @@
 #include "webrtc/base/socket_unittest.h"
 
 #include "webrtc/base/arraysize.h"
+#include "webrtc/base/buffer.h"
 #include "webrtc/base/asyncudpsocket.h"
 #include "webrtc/base/gunit.h"
 #include "webrtc/base/nethelpers.h"
@@ -21,6 +22,9 @@
 
 namespace rtc {
 
+// Data size to be used in TcpInternal tests.
+static const size_t kTcpInternalDataSize = 1024 * 1024;  // bytes
+
 #define MAYBE_SKIP_IPV6                             \
   if (!HasIPv6Enabled()) {                          \
     LOG(LS_INFO) << "No IPv6... skipping";          \
@@ -129,12 +133,12 @@
 }
 
 void SocketTest::TestTcpIPv4() {
-  TcpInternal(kIPv4Loopback);
+  TcpInternal(kIPv4Loopback, kTcpInternalDataSize, -1);
 }
 
 void SocketTest::TestTcpIPv6() {
   MAYBE_SKIP_IPV6;
-  TcpInternal(kIPv6Loopback);
+  TcpInternal(kIPv6Loopback, kTcpInternalDataSize, -1);
 }
 
 void SocketTest::TestSingleFlowControlCallbackIPv4() {
@@ -671,24 +675,15 @@
   EXPECT_LT(0, accepted->Recv(buf, 1024));
 }
 
-void SocketTest::TcpInternal(const IPAddress& loopback) {
+void SocketTest::TcpInternal(const IPAddress& loopback, size_t data_size,
+    ssize_t max_send_size) {
   testing::StreamSink sink;
   SocketAddress accept_addr;
 
-  // Create test data.
-  const size_t kDataSize = 1024 * 1024;
-  scoped_ptr<char[]> send_buffer(new char[kDataSize]);
-  scoped_ptr<char[]> recv_buffer(new char[kDataSize]);
-  size_t send_pos = 0, recv_pos = 0;
-  for (size_t i = 0; i < kDataSize; ++i) {
-    send_buffer[i] = static_cast<char>(i % 256);
-    recv_buffer[i] = 0;
-  }
-
-  // Create client.
-  scoped_ptr<AsyncSocket> client(
+  // Create receiving client.
+  scoped_ptr<AsyncSocket> receiver(
       ss_->CreateAsyncSocket(loopback.family(), SOCK_STREAM));
-  sink.Monitor(client.get());
+  sink.Monitor(receiver.get());
 
   // Create server and listen.
   scoped_ptr<AsyncSocket> server(
@@ -698,97 +693,115 @@
   EXPECT_EQ(0, server->Listen(5));
 
   // Attempt connection.
-  EXPECT_EQ(0, client->Connect(server->GetLocalAddress()));
+  EXPECT_EQ(0, receiver->Connect(server->GetLocalAddress()));
 
-  // Accept connection.
+  // Accept connection which will be used for sending.
   EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout);
-  scoped_ptr<AsyncSocket> accepted(server->Accept(&accept_addr));
-  ASSERT_TRUE(accepted);
-  sink.Monitor(accepted.get());
+  scoped_ptr<AsyncSocket> sender(server->Accept(&accept_addr));
+  ASSERT_TRUE(sender);
+  sink.Monitor(sender.get());
 
   // Both sides are now connected.
-  EXPECT_EQ_WAIT(AsyncSocket::CS_CONNECTED, client->GetState(), kTimeout);
-  EXPECT_TRUE(sink.Check(client.get(), testing::SSE_OPEN));
-  EXPECT_EQ(client->GetRemoteAddress(), accepted->GetLocalAddress());
-  EXPECT_EQ(accepted->GetRemoteAddress(), client->GetLocalAddress());
+  EXPECT_EQ_WAIT(AsyncSocket::CS_CONNECTED, receiver->GetState(), kTimeout);
+  EXPECT_TRUE(sink.Check(receiver.get(), testing::SSE_OPEN));
+  EXPECT_EQ(receiver->GetRemoteAddress(), sender->GetLocalAddress());
+  EXPECT_EQ(sender->GetRemoteAddress(), receiver->GetLocalAddress());
+
+  // Create test data.
+  rtc::Buffer send_buffer(0, data_size);
+  rtc::Buffer recv_buffer(0, data_size);
+  for (size_t i = 0; i < data_size; ++i) {
+    char ch = static_cast<char>(i % 256);
+    send_buffer.AppendData(&ch, sizeof(ch));
+  }
 
   // Send and receive a bunch of data.
-  bool send_waiting_for_writability = false;
-  bool send_expect_success = true;
-  bool recv_waiting_for_readability = true;
-  bool recv_expect_success = false;
-  int data_in_flight = 0;
-  while (recv_pos < kDataSize) {
-    // Send as much as we can if we've been cleared to send.
-    while (!send_waiting_for_writability && send_pos < kDataSize) {
-      int tosend = static_cast<int>(kDataSize - send_pos);
-      int sent = accepted->Send(send_buffer.get() + send_pos, tosend);
-      if (send_expect_success) {
+  size_t sent_size = 0;
+  bool writable = true;
+  bool send_called = false;
+  bool readable = false;
+  bool recv_called = false;
+  while (recv_buffer.size() < send_buffer.size()) {
+    // Send as much as we can while we're cleared to send.
+    while (writable && sent_size < send_buffer.size()) {
+      int unsent_size = static_cast<int>(send_buffer.size() - sent_size);
+      int sent = sender->Send(send_buffer.data() + sent_size, unsent_size);
+      if (!send_called) {
         // The first Send() after connecting or getting writability should
         // succeed and send some data.
         EXPECT_GT(sent, 0);
-        send_expect_success = false;
+        send_called = true;
       }
       if (sent >= 0) {
-        EXPECT_LE(sent, tosend);
-        send_pos += sent;
-        data_in_flight += sent;
+        EXPECT_LE(sent, unsent_size);
+        sent_size += sent;
+        if (max_send_size >= 0) {
+          EXPECT_LE(static_cast<ssize_t>(sent), max_send_size);
+          if (sent < unsent_size) {
+            // If max_send_size is limiting the amount to send per call such
+            // that the sent amount is less than the unsent amount, we simulate
+            // that the socket is no longer writable.
+            writable = false;
+          }
+        }
       } else {
-        ASSERT_TRUE(accepted->IsBlocking());
-        send_waiting_for_writability = true;
+        ASSERT_TRUE(sender->IsBlocking());
+        writable = false;
       }
     }
 
     // Read all the sent data.
-    while (data_in_flight > 0) {
-      if (recv_waiting_for_readability) {
+    while (recv_buffer.size() < sent_size) {
+      if (!readable) {
         // Wait until data is available.
-        EXPECT_TRUE_WAIT(sink.Check(client.get(), testing::SSE_READ), kTimeout);
-        recv_waiting_for_readability = false;
-        recv_expect_success = true;
+        EXPECT_TRUE_WAIT(sink.Check(receiver.get(), testing::SSE_READ),
+            kTimeout);
+        readable = true;
+        recv_called = false;
       }
 
       // Receive as much as we can get in a single recv call.
-      int rcvd = client->Recv(recv_buffer.get() + recv_pos,
-                              kDataSize - recv_pos);
+      char recved_data[data_size];
+      int recved_size = receiver->Recv(recved_data, data_size);
 
-      if (recv_expect_success) {
+      if (!recv_called) {
         // The first Recv() after getting readability should succeed and receive
         // some data.
         // TODO: The following line is disabled due to flakey pulse
         //     builds.  Re-enable if/when possible.
-        // EXPECT_GT(rcvd, 0);
-        recv_expect_success = false;
+        // EXPECT_GT(recved_size, 0);
+        recv_called = true;
       }
-      if (rcvd >= 0) {
-        EXPECT_LE(rcvd, data_in_flight);
-        recv_pos += rcvd;
-        data_in_flight -= rcvd;
+      if (recved_size >= 0) {
+        EXPECT_LE(static_cast<size_t>(recved_size),
+            sent_size - recv_buffer.size());
+        recv_buffer.AppendData(recved_data, recved_size);
       } else {
-        ASSERT_TRUE(client->IsBlocking());
-        recv_waiting_for_readability = true;
+        ASSERT_TRUE(receiver->IsBlocking());
+        readable = false;
       }
     }
 
-    // Once all that we've sent has been rcvd, expect to be able to send again.
-    if (send_waiting_for_writability) {
-      EXPECT_TRUE_WAIT(sink.Check(accepted.get(), testing::SSE_WRITE),
+    // Once all that we've sent has been received, expect to be able to send
+    // again.
+    if (!writable) {
+      EXPECT_TRUE_WAIT(sink.Check(sender.get(), testing::SSE_WRITE),
                        kTimeout);
-      send_waiting_for_writability = false;
-      send_expect_success = true;
+      writable = true;
+      send_called = false;
     }
   }
 
   // The received data matches the sent data.
-  EXPECT_EQ(kDataSize, send_pos);
-  EXPECT_EQ(kDataSize, recv_pos);
-  EXPECT_EQ(0, memcmp(recv_buffer.get(), send_buffer.get(), kDataSize));
+  EXPECT_EQ(data_size, sent_size);
+  EXPECT_EQ(data_size, recv_buffer.size());
+  EXPECT_EQ(recv_buffer, send_buffer);
 
   // Close down.
-  accepted->Close();
-  EXPECT_EQ_WAIT(AsyncSocket::CS_CLOSED, client->GetState(), kTimeout);
-  EXPECT_TRUE(sink.Check(client.get(), testing::SSE_CLOSE));
-  client->Close();
+  sender->Close();
+  EXPECT_EQ_WAIT(AsyncSocket::CS_CLOSED, receiver->GetState(), kTimeout);
+  EXPECT_TRUE(sink.Check(receiver.get(), testing::SSE_CLOSE));
+  receiver->Close();
 }
 
 void SocketTest::SingleFlowControlCallbackInternal(const IPAddress& loopback) {
diff --git a/webrtc/base/socket_unittest.h b/webrtc/base/socket_unittest.h
index e4a6b32..adc69f1 100644
--- a/webrtc/base/socket_unittest.h
+++ b/webrtc/base/socket_unittest.h
@@ -62,6 +62,10 @@
   const IPAddress kIPv4Loopback;
   const IPAddress kIPv6Loopback;
 
+ protected:
+  void TcpInternal(const IPAddress& loopback, size_t data_size,
+      ssize_t max_send_size);
+
  private:
   void ConnectInternal(const IPAddress& loopback);
   void ConnectWithDnsLookupInternal(const IPAddress& loopback,
@@ -76,7 +80,6 @@
   void ServerCloseInternal(const IPAddress& loopback);
   void CloseInClosedCallbackInternal(const IPAddress& loopback);
   void SocketServerWaitInternal(const IPAddress& loopback);
-  void TcpInternal(const IPAddress& loopback);
   void SingleFlowControlCallbackInternal(const IPAddress& loopback);
   void UdpInternal(const IPAddress& loopback);
   void UdpReadyToSend(const IPAddress& loopback);