Stay writable after partial socket writes.
This CL fixes an issue where the "writable" flag didn't stay set after
::send or ::sendto only sent a partial buffer.
Also SocketTest::TcpInternal has been updated to use rtc::Buffer instead
of manually allocating data.
BUG=webrtc:4898
Review URL: https://codereview.webrtc.org/1616153007
Cr-Commit-Position: refs/heads/master@{#11480}
diff --git a/webrtc/base/physicalsocketserver.cc b/webrtc/base/physicalsocketserver.cc
index 67fbea0..4cea040 100644
--- a/webrtc/base/physicalsocketserver.cc
+++ b/webrtc/base/physicalsocketserver.cc
@@ -271,7 +271,8 @@
}
int PhysicalSocket::Send(const void* pv, size_t cb) {
- int sent = ::send(s_, reinterpret_cast<const char *>(pv), (int)cb,
+ int sent = DoSend(s_, reinterpret_cast<const char *>(pv),
+ static_cast<int>(cb),
#if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID)
// Suppress SIGPIPE. Without this, attempting to send on a socket whose
// other end is closed will result in a SIGPIPE signal being raised to
@@ -287,7 +288,8 @@
MaybeRemapSendError();
// We have seen minidumps where this may be false.
ASSERT(sent <= static_cast<int>(cb));
- if ((sent < 0) && IsBlockingError(GetError())) {
+ if ((sent > 0 && sent < static_cast<int>(cb)) ||
+ (sent < 0 && IsBlockingError(GetError()))) {
enabled_events_ |= DE_WRITE;
}
return sent;
@@ -298,7 +300,7 @@
const SocketAddress& addr) {
sockaddr_storage saddr;
size_t len = addr.ToSockAddrStorage(&saddr);
- int sent = ::sendto(
+ int sent = DoSendTo(
s_, static_cast<const char *>(buffer), static_cast<int>(length),
#if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID)
// Suppress SIGPIPE. See above for explanation.
@@ -311,7 +313,8 @@
MaybeRemapSendError();
// We have seen minidumps where this may be false.
ASSERT(sent <= static_cast<int>(length));
- if ((sent < 0) && IsBlockingError(GetError())) {
+ if ((sent > 0 && sent < static_cast<int>(length)) ||
+ (sent < 0 && IsBlockingError(GetError()))) {
enabled_events_ |= DE_WRITE;
}
return sent;
@@ -474,13 +477,25 @@
#endif
}
-
SOCKET PhysicalSocket::DoAccept(SOCKET socket,
sockaddr* addr,
socklen_t* addrlen) {
return ::accept(socket, addr, addrlen);
}
+int PhysicalSocket::DoSend(SOCKET socket, const char* buf, int len, int flags) {
+ return ::send(socket, buf, len, flags);
+}
+
+int PhysicalSocket::DoSendTo(SOCKET socket,
+ const char* buf,
+ int len,
+ int flags,
+ const struct sockaddr* dest_addr,
+ socklen_t addrlen) {
+ return ::sendto(socket, buf, len, flags, dest_addr, addrlen);
+}
+
void PhysicalSocket::OnResolveResult(AsyncResolverInterface* resolver) {
if (resolver != resolver_) {
return;
diff --git a/webrtc/base/physicalsocketserver.h b/webrtc/base/physicalsocketserver.h
index cbe6580..583306c 100644
--- a/webrtc/base/physicalsocketserver.h
+++ b/webrtc/base/physicalsocketserver.h
@@ -68,8 +68,8 @@
AsyncSocket* CreateAsyncSocket(int type) override;
AsyncSocket* CreateAsyncSocket(int family, int type) override;
- // Internal Factory for Accept
- AsyncSocket* WrapSocket(SOCKET s);
+ // Internal Factory for Accept (virtual so it can be overwritten in tests).
+ virtual AsyncSocket* WrapSocket(SOCKET s);
// SocketServer:
bool Wait(int cms, bool process_io) override;
@@ -161,6 +161,13 @@
// Make virtual so ::accept can be overwritten in tests.
virtual SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen);
+ // Make virtual so ::send can be overwritten in tests.
+ virtual int DoSend(SOCKET socket, const char* buf, int len, int flags);
+
+ // Make virtual so ::sendto can be overwritten in tests.
+ virtual int DoSendTo(SOCKET socket, const char* buf, int len, int flags,
+ const struct sockaddr* dest_addr, socklen_t addrlen);
+
void OnResolveResult(AsyncResolverInterface* resolver);
void UpdateLastError();
diff --git a/webrtc/base/physicalsocketserver_unittest.cc b/webrtc/base/physicalsocketserver_unittest.cc
index a2fde80..c53441d 100644
--- a/webrtc/base/physicalsocketserver_unittest.cc
+++ b/webrtc/base/physicalsocketserver_unittest.cc
@@ -29,8 +29,15 @@
: SocketDispatcher(ss) {
}
+ FakeSocketDispatcher(SOCKET s, PhysicalSocketServer* ss)
+ : SocketDispatcher(s, ss) {
+ }
+
protected:
SOCKET DoAccept(SOCKET socket, sockaddr* addr, socklen_t* addrlen) override;
+ int DoSend(SOCKET socket, const char* buf, int len, int flags) override;
+ int DoSendTo(SOCKET socket, const char* buf, int len, int flags,
+ const struct sockaddr* dest_addr, socklen_t addrlen) override;
};
class FakePhysicalSocketServer : public PhysicalSocketServer {
@@ -41,22 +48,29 @@
AsyncSocket* CreateAsyncSocket(int type) override {
SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
- if (dispatcher->Create(type)) {
- return dispatcher;
- } else {
+ if (!dispatcher->Create(type)) {
delete dispatcher;
return nullptr;
}
+ return dispatcher;
}
AsyncSocket* CreateAsyncSocket(int family, int type) override {
SocketDispatcher* dispatcher = new FakeSocketDispatcher(this);
- if (dispatcher->Create(family, type)) {
- return dispatcher;
- } else {
+ if (!dispatcher->Create(family, type)) {
delete dispatcher;
return nullptr;
}
+ return dispatcher;
+ }
+
+ AsyncSocket* WrapSocket(SOCKET s) override {
+ SocketDispatcher* dispatcher = new FakeSocketDispatcher(s, this);
+ if (!dispatcher->Initialize()) {
+ delete dispatcher;
+ return nullptr;
+ }
+ return dispatcher;
}
PhysicalSocketTest* GetTest() const { return test_; }
@@ -71,18 +85,25 @@
void SetFailAccept(bool fail) { fail_accept_ = fail; }
bool FailAccept() const { return fail_accept_; }
+ // Maximum size to ::send to a socket. Set to < 0 to disable limiting.
+ void SetMaxSendSize(int max_size) { max_send_size_ = max_size; }
+ int MaxSendSize() const { return max_send_size_; }
+
protected:
PhysicalSocketTest()
: server_(new FakePhysicalSocketServer(this)),
scope_(server_.get()),
- fail_accept_(false) {
+ fail_accept_(false),
+ max_send_size_(-1) {
}
void ConnectInternalAcceptError(const IPAddress& loopback);
+ void WritableAfterPartialWrite(const IPAddress& loopback);
rtc::scoped_ptr<FakePhysicalSocketServer> server_;
SocketServerScope scope_;
bool fail_accept_;
+ int max_send_size_;
};
SOCKET FakeSocketDispatcher::DoAccept(SOCKET socket,
@@ -97,6 +118,29 @@
return SocketDispatcher::DoAccept(socket, addr, addrlen);
}
+int FakeSocketDispatcher::DoSend(SOCKET socket, const char* buf, int len,
+ int flags) {
+ FakePhysicalSocketServer* ss =
+ static_cast<FakePhysicalSocketServer*>(socketserver());
+ if (ss->GetTest()->MaxSendSize() >= 0) {
+ len = std::min(len, ss->GetTest()->MaxSendSize());
+ }
+
+ return SocketDispatcher::DoSend(socket, buf, len, flags);
+}
+
+int FakeSocketDispatcher::DoSendTo(SOCKET socket, const char* buf, int len,
+ int flags, const struct sockaddr* dest_addr, socklen_t addrlen) {
+ FakePhysicalSocketServer* ss =
+ static_cast<FakePhysicalSocketServer*>(socketserver());
+ if (ss->GetTest()->MaxSendSize() >= 0) {
+ len = std::min(len, ss->GetTest()->MaxSendSize());
+ }
+
+ return SocketDispatcher::DoSendTo(socket, buf, len, flags, dest_addr,
+ addrlen);
+}
+
TEST_F(PhysicalSocketTest, TestConnectIPv4) {
SocketTest::TestConnectIPv4();
}
@@ -209,6 +253,33 @@
ConnectInternalAcceptError(kIPv6Loopback);
}
+void PhysicalSocketTest::WritableAfterPartialWrite(const IPAddress& loopback) {
+ // Simulate a really small maximum send size.
+ const int kMaxSendSize = 128;
+ SetMaxSendSize(kMaxSendSize);
+
+ // Run the default send/receive socket tests with a smaller amount of data
+ // to avoid long running times due to the small maximum send size.
+ const size_t kDataSize = 128 * 1024;
+ TcpInternal(loopback, kDataSize, kMaxSendSize);
+}
+
+TEST_F(PhysicalSocketTest, TestWritableAfterPartialWriteIPv4) {
+ WritableAfterPartialWrite(kIPv4Loopback);
+}
+
+// Crashes on Linux. See webrtc:4923.
+#if defined(WEBRTC_LINUX)
+#define MAYBE_TestWritableAfterPartialWriteIPv6 \
+ DISABLED_TestWritableAfterPartialWriteIPv6
+#else
+#define MAYBE_TestWritableAfterPartialWriteIPv6 \
+ TestWritableAfterPartialWriteIPv6
+#endif
+TEST_F(PhysicalSocketTest, MAYBE_TestWritableAfterPartialWriteIPv6) {
+ WritableAfterPartialWrite(kIPv6Loopback);
+}
+
// Crashes on Linux. See webrtc:4923.
#if defined(WEBRTC_LINUX)
#define MAYBE_TestConnectFailIPv6 DISABLED_TestConnectFailIPv6
diff --git a/webrtc/base/socket_unittest.cc b/webrtc/base/socket_unittest.cc
index 8143823..d1369e2 100644
--- a/webrtc/base/socket_unittest.cc
+++ b/webrtc/base/socket_unittest.cc
@@ -11,6 +11,7 @@
#include "webrtc/base/socket_unittest.h"
#include "webrtc/base/arraysize.h"
+#include "webrtc/base/buffer.h"
#include "webrtc/base/asyncudpsocket.h"
#include "webrtc/base/gunit.h"
#include "webrtc/base/nethelpers.h"
@@ -21,6 +22,9 @@
namespace rtc {
+// Data size to be used in TcpInternal tests.
+static const size_t kTcpInternalDataSize = 1024 * 1024; // bytes
+
#define MAYBE_SKIP_IPV6 \
if (!HasIPv6Enabled()) { \
LOG(LS_INFO) << "No IPv6... skipping"; \
@@ -129,12 +133,12 @@
}
void SocketTest::TestTcpIPv4() {
- TcpInternal(kIPv4Loopback);
+ TcpInternal(kIPv4Loopback, kTcpInternalDataSize, -1);
}
void SocketTest::TestTcpIPv6() {
MAYBE_SKIP_IPV6;
- TcpInternal(kIPv6Loopback);
+ TcpInternal(kIPv6Loopback, kTcpInternalDataSize, -1);
}
void SocketTest::TestSingleFlowControlCallbackIPv4() {
@@ -671,24 +675,15 @@
EXPECT_LT(0, accepted->Recv(buf, 1024));
}
-void SocketTest::TcpInternal(const IPAddress& loopback) {
+void SocketTest::TcpInternal(const IPAddress& loopback, size_t data_size,
+ ssize_t max_send_size) {
testing::StreamSink sink;
SocketAddress accept_addr;
- // Create test data.
- const size_t kDataSize = 1024 * 1024;
- scoped_ptr<char[]> send_buffer(new char[kDataSize]);
- scoped_ptr<char[]> recv_buffer(new char[kDataSize]);
- size_t send_pos = 0, recv_pos = 0;
- for (size_t i = 0; i < kDataSize; ++i) {
- send_buffer[i] = static_cast<char>(i % 256);
- recv_buffer[i] = 0;
- }
-
- // Create client.
- scoped_ptr<AsyncSocket> client(
+ // Create receiving client.
+ scoped_ptr<AsyncSocket> receiver(
ss_->CreateAsyncSocket(loopback.family(), SOCK_STREAM));
- sink.Monitor(client.get());
+ sink.Monitor(receiver.get());
// Create server and listen.
scoped_ptr<AsyncSocket> server(
@@ -698,97 +693,115 @@
EXPECT_EQ(0, server->Listen(5));
// Attempt connection.
- EXPECT_EQ(0, client->Connect(server->GetLocalAddress()));
+ EXPECT_EQ(0, receiver->Connect(server->GetLocalAddress()));
- // Accept connection.
+ // Accept connection which will be used for sending.
EXPECT_TRUE_WAIT((sink.Check(server.get(), testing::SSE_READ)), kTimeout);
- scoped_ptr<AsyncSocket> accepted(server->Accept(&accept_addr));
- ASSERT_TRUE(accepted);
- sink.Monitor(accepted.get());
+ scoped_ptr<AsyncSocket> sender(server->Accept(&accept_addr));
+ ASSERT_TRUE(sender);
+ sink.Monitor(sender.get());
// Both sides are now connected.
- EXPECT_EQ_WAIT(AsyncSocket::CS_CONNECTED, client->GetState(), kTimeout);
- EXPECT_TRUE(sink.Check(client.get(), testing::SSE_OPEN));
- EXPECT_EQ(client->GetRemoteAddress(), accepted->GetLocalAddress());
- EXPECT_EQ(accepted->GetRemoteAddress(), client->GetLocalAddress());
+ EXPECT_EQ_WAIT(AsyncSocket::CS_CONNECTED, receiver->GetState(), kTimeout);
+ EXPECT_TRUE(sink.Check(receiver.get(), testing::SSE_OPEN));
+ EXPECT_EQ(receiver->GetRemoteAddress(), sender->GetLocalAddress());
+ EXPECT_EQ(sender->GetRemoteAddress(), receiver->GetLocalAddress());
+
+ // Create test data.
+ rtc::Buffer send_buffer(0, data_size);
+ rtc::Buffer recv_buffer(0, data_size);
+ for (size_t i = 0; i < data_size; ++i) {
+ char ch = static_cast<char>(i % 256);
+ send_buffer.AppendData(&ch, sizeof(ch));
+ }
// Send and receive a bunch of data.
- bool send_waiting_for_writability = false;
- bool send_expect_success = true;
- bool recv_waiting_for_readability = true;
- bool recv_expect_success = false;
- int data_in_flight = 0;
- while (recv_pos < kDataSize) {
- // Send as much as we can if we've been cleared to send.
- while (!send_waiting_for_writability && send_pos < kDataSize) {
- int tosend = static_cast<int>(kDataSize - send_pos);
- int sent = accepted->Send(send_buffer.get() + send_pos, tosend);
- if (send_expect_success) {
+ size_t sent_size = 0;
+ bool writable = true;
+ bool send_called = false;
+ bool readable = false;
+ bool recv_called = false;
+ while (recv_buffer.size() < send_buffer.size()) {
+ // Send as much as we can while we're cleared to send.
+ while (writable && sent_size < send_buffer.size()) {
+ int unsent_size = static_cast<int>(send_buffer.size() - sent_size);
+ int sent = sender->Send(send_buffer.data() + sent_size, unsent_size);
+ if (!send_called) {
// The first Send() after connecting or getting writability should
// succeed and send some data.
EXPECT_GT(sent, 0);
- send_expect_success = false;
+ send_called = true;
}
if (sent >= 0) {
- EXPECT_LE(sent, tosend);
- send_pos += sent;
- data_in_flight += sent;
+ EXPECT_LE(sent, unsent_size);
+ sent_size += sent;
+ if (max_send_size >= 0) {
+ EXPECT_LE(static_cast<ssize_t>(sent), max_send_size);
+ if (sent < unsent_size) {
+ // If max_send_size is limiting the amount to send per call such
+ // that the sent amount is less than the unsent amount, we simulate
+ // that the socket is no longer writable.
+ writable = false;
+ }
+ }
} else {
- ASSERT_TRUE(accepted->IsBlocking());
- send_waiting_for_writability = true;
+ ASSERT_TRUE(sender->IsBlocking());
+ writable = false;
}
}
// Read all the sent data.
- while (data_in_flight > 0) {
- if (recv_waiting_for_readability) {
+ while (recv_buffer.size() < sent_size) {
+ if (!readable) {
// Wait until data is available.
- EXPECT_TRUE_WAIT(sink.Check(client.get(), testing::SSE_READ), kTimeout);
- recv_waiting_for_readability = false;
- recv_expect_success = true;
+ EXPECT_TRUE_WAIT(sink.Check(receiver.get(), testing::SSE_READ),
+ kTimeout);
+ readable = true;
+ recv_called = false;
}
// Receive as much as we can get in a single recv call.
- int rcvd = client->Recv(recv_buffer.get() + recv_pos,
- kDataSize - recv_pos);
+ char recved_data[data_size];
+ int recved_size = receiver->Recv(recved_data, data_size);
- if (recv_expect_success) {
+ if (!recv_called) {
// The first Recv() after getting readability should succeed and receive
// some data.
// TODO: The following line is disabled due to flakey pulse
// builds. Re-enable if/when possible.
- // EXPECT_GT(rcvd, 0);
- recv_expect_success = false;
+ // EXPECT_GT(recved_size, 0);
+ recv_called = true;
}
- if (rcvd >= 0) {
- EXPECT_LE(rcvd, data_in_flight);
- recv_pos += rcvd;
- data_in_flight -= rcvd;
+ if (recved_size >= 0) {
+ EXPECT_LE(static_cast<size_t>(recved_size),
+ sent_size - recv_buffer.size());
+ recv_buffer.AppendData(recved_data, recved_size);
} else {
- ASSERT_TRUE(client->IsBlocking());
- recv_waiting_for_readability = true;
+ ASSERT_TRUE(receiver->IsBlocking());
+ readable = false;
}
}
- // Once all that we've sent has been rcvd, expect to be able to send again.
- if (send_waiting_for_writability) {
- EXPECT_TRUE_WAIT(sink.Check(accepted.get(), testing::SSE_WRITE),
+ // Once all that we've sent has been received, expect to be able to send
+ // again.
+ if (!writable) {
+ EXPECT_TRUE_WAIT(sink.Check(sender.get(), testing::SSE_WRITE),
kTimeout);
- send_waiting_for_writability = false;
- send_expect_success = true;
+ writable = true;
+ send_called = false;
}
}
// The received data matches the sent data.
- EXPECT_EQ(kDataSize, send_pos);
- EXPECT_EQ(kDataSize, recv_pos);
- EXPECT_EQ(0, memcmp(recv_buffer.get(), send_buffer.get(), kDataSize));
+ EXPECT_EQ(data_size, sent_size);
+ EXPECT_EQ(data_size, recv_buffer.size());
+ EXPECT_EQ(recv_buffer, send_buffer);
// Close down.
- accepted->Close();
- EXPECT_EQ_WAIT(AsyncSocket::CS_CLOSED, client->GetState(), kTimeout);
- EXPECT_TRUE(sink.Check(client.get(), testing::SSE_CLOSE));
- client->Close();
+ sender->Close();
+ EXPECT_EQ_WAIT(AsyncSocket::CS_CLOSED, receiver->GetState(), kTimeout);
+ EXPECT_TRUE(sink.Check(receiver.get(), testing::SSE_CLOSE));
+ receiver->Close();
}
void SocketTest::SingleFlowControlCallbackInternal(const IPAddress& loopback) {
diff --git a/webrtc/base/socket_unittest.h b/webrtc/base/socket_unittest.h
index e4a6b32..adc69f1 100644
--- a/webrtc/base/socket_unittest.h
+++ b/webrtc/base/socket_unittest.h
@@ -62,6 +62,10 @@
const IPAddress kIPv4Loopback;
const IPAddress kIPv6Loopback;
+ protected:
+ void TcpInternal(const IPAddress& loopback, size_t data_size,
+ ssize_t max_send_size);
+
private:
void ConnectInternal(const IPAddress& loopback);
void ConnectWithDnsLookupInternal(const IPAddress& loopback,
@@ -76,7 +80,6 @@
void ServerCloseInternal(const IPAddress& loopback);
void CloseInClosedCallbackInternal(const IPAddress& loopback);
void SocketServerWaitInternal(const IPAddress& loopback);
- void TcpInternal(const IPAddress& loopback);
void SingleFlowControlCallbackInternal(const IPAddress& loopback);
void UdpInternal(const IPAddress& loopback);
void UdpReadyToSend(const IPAddress& loopback);