Add a lock to NSSContext to fix data race
BUG=crbug/466784
R=juberti@webrtc.org, tommi@webrtc.org
Review URL: https://webrtc-codereview.appspot.com/44669005
Cr-Original-Commit-Position: refs/heads/master@{#8871}
Cr-Mirrored-From: https://chromium.googlesource.com/external/webrtc
Cr-Mirrored-Commit: bef8d2d0208768d006e639701658ad4c9e73a69f
diff --git a/base/BUILD.gn b/base/BUILD.gn
index fba03aa..47b05be 100644
--- a/base/BUILD.gn
+++ b/base/BUILD.gn
@@ -188,6 +188,7 @@
"cpumonitor.h",
"crc32.cc",
"crc32.h",
+ "criticalsection.cc",
"criticalsection.h",
"cryptstring.cc",
"cryptstring.h",
diff --git a/base/base.gyp b/base/base.gyp
index 5da97a6..ac23529 100644
--- a/base/base.gyp
+++ b/base/base.gyp
@@ -114,6 +114,7 @@
'cpumonitor.h',
'crc32.cc',
'crc32.h',
+ 'criticalsection.cc',
'criticalsection.h',
'cryptstring.cc',
'cryptstring.h',
diff --git a/base/criticalsection.cc b/base/criticalsection.cc
new file mode 100644
index 0000000..fcad5c3
--- /dev/null
+++ b/base/criticalsection.cc
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2015 The WebRTC Project Authors. All rights reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "webrtc/base/criticalsection.h"
+
+#include "webrtc/base/checks.h"
+#include "webrtc/base/thread.h"
+
+namespace rtc {
+
+void GlobalLockPod::Lock() {
+ while (AtomicOps::CompareAndSwap(&lock_acquired, 0, 1)) {
+ Thread::SleepMs(0);
+ }
+}
+
+void GlobalLockPod::Unlock() {
+ int old_value = AtomicOps::CompareAndSwap(&lock_acquired, 1, 0);
+ DCHECK_EQ(1, old_value) << "Unlock called without calling Lock first";
+}
+
+GlobalLock::GlobalLock() {
+ lock_acquired = 0;
+}
+
+} // namespace rtc
diff --git a/base/criticalsection.h b/base/criticalsection.h
index 8d6ddbe..db197d2 100644
--- a/base/criticalsection.h
+++ b/base/criticalsection.h
@@ -169,6 +169,11 @@
static void Store(volatile int* i, int value) {
*i = value;
}
+ static int CompareAndSwap(volatile int* i, int old_value, int new_value) {
+ return ::InterlockedCompareExchange(reinterpret_cast<volatile LONG*>(i),
+ new_value,
+ old_value);
+ }
#else
static int Increment(volatile int* i) {
return __sync_add_and_fetch(i, 1);
@@ -184,9 +189,29 @@
__sync_synchronize();
*i = value;
}
+ static int CompareAndSwap(volatile int* i, int old_value, int new_value) {
+ return __sync_val_compare_and_swap(i, old_value, new_value);
+ }
#endif
};
+
+// A POD lock used to protect global variables. Do NOT use for other purposes.
+// No custom constructor or private data member should be added.
+class LOCKABLE GlobalLockPod {
+ public:
+ void Lock() EXCLUSIVE_LOCK_FUNCTION();
+
+ void Unlock() UNLOCK_FUNCTION();
+
+ volatile int lock_acquired;
+};
+
+class GlobalLock : public GlobalLockPod {
+ public:
+ GlobalLock();
+};
+
} // namespace rtc
#endif // WEBRTC_BASE_CRITICALSECTION_H__
diff --git a/base/criticalsection_unittest.cc b/base/criticalsection_unittest.cc
index 8c36e3d..6f3c7e9 100644
--- a/base/criticalsection_unittest.cc
+++ b/base/criticalsection_unittest.cc
@@ -26,16 +26,54 @@
const int kNumThreads = 16;
const int kOperationsToRun = 1000;
-template <class T>
-class AtomicOpRunner : public MessageHandler {
+class UniqueValueVerifier {
public:
- explicit AtomicOpRunner(int initial_value)
- : value_(initial_value),
- threads_active_(0),
- start_event_(true, false),
- done_event_(true, false) {}
+ void Verify(const std::vector<int>& values) {
+ for (size_t i = 0; i < values.size(); ++i) {
+ std::pair<std::set<int>::iterator, bool> result =
+ all_values_.insert(values[i]);
+ // Each value should only be taken by one thread, so if this value
+ // has already been added, something went wrong.
+ EXPECT_TRUE(result.second)
+ << " Thread=" << Thread::Current() << " value=" << values[i];
+ }
+ }
- int value() const { return value_; }
+ void Finalize() {}
+
+ private:
+ std::set<int> all_values_;
+};
+
+class CompareAndSwapVerifier {
+ public:
+ CompareAndSwapVerifier() : zero_count_(0) {}
+
+ void Verify(const std::vector<int>& values) {
+ for (auto v : values) {
+ if (v == 0) {
+ EXPECT_EQ(0, zero_count_) << "Thread=" << Thread::Current();
+ ++zero_count_;
+ } else {
+ EXPECT_EQ(1, v) << " Thread=" << Thread::Current();
+ }
+ }
+ }
+
+ void Finalize() {
+ EXPECT_EQ(1, zero_count_);
+ }
+ private:
+ int zero_count_;
+};
+
+class RunnerBase : public MessageHandler {
+ public:
+ explicit RunnerBase(int value)
+ : threads_active_(0),
+ start_event_(true, false),
+ done_event_(true, false),
+ shared_value_(value) {}
bool Run() {
// Signal all threads to start.
@@ -49,43 +87,101 @@
threads_active_ = count;
}
- virtual void OnMessage(Message* msg) {
+ int shared_value() const { return shared_value_; }
+
+ protected:
+ // Derived classes must override OnMessage, and call BeforeStart and AfterEnd
+ // at the beginning and the end of OnMessage respectively.
+ void BeforeStart() {
+ ASSERT_TRUE(start_event_.Wait(kLongTime));
+ }
+
+ // Returns true if all threads have finished.
+ bool AfterEnd() {
+ if (AtomicOps::Decrement(&threads_active_) == 0) {
+ done_event_.Set();
+ return true;
+ }
+ return false;
+ }
+
+ int threads_active_;
+ Event start_event_;
+ Event done_event_;
+ int shared_value_;
+};
+
+class LOCKABLE CriticalSectionLock {
+ public:
+ void Lock() EXCLUSIVE_LOCK_FUNCTION() {
+ cs_.Enter();
+ }
+ void Unlock() UNLOCK_FUNCTION() {
+ cs_.Leave();
+ }
+
+ private:
+ CriticalSection cs_;
+};
+
+template <class Lock>
+class LockRunner : public RunnerBase {
+ public:
+ LockRunner() : RunnerBase(0) {}
+
+ void OnMessage(Message* msg) override {
+ BeforeStart();
+
+ lock_.Lock();
+
+ EXPECT_EQ(0, shared_value_);
+ int old = shared_value_;
+
+ // Use a loop to increase the chance of race.
+ for (int i = 0; i < kOperationsToRun; ++i) {
+ ++shared_value_;
+ }
+ EXPECT_EQ(old + kOperationsToRun, shared_value_);
+ shared_value_ = 0;
+
+ lock_.Unlock();
+
+ AfterEnd();
+ }
+
+ private:
+ Lock lock_;
+};
+
+template <class Op, class Verifier>
+class AtomicOpRunner : public RunnerBase {
+ public:
+ explicit AtomicOpRunner(int initial_value) : RunnerBase(initial_value) {}
+
+ void OnMessage(Message* msg) override {
+ BeforeStart();
+
std::vector<int> values;
values.reserve(kOperationsToRun);
- // Wait to start.
- ASSERT_TRUE(start_event_.Wait(kLongTime));
-
- // Generate a bunch of values by updating value_ atomically.
+ // Generate a bunch of values by updating shared_value_ atomically.
for (int i = 0; i < kOperationsToRun; ++i) {
- values.push_back(T::AtomicOp(&value_));
+ values.push_back(Op::AtomicOp(&shared_value_));
}
{ // Add them all to the set.
CritScope cs(&all_values_crit_);
- for (size_t i = 0; i < values.size(); ++i) {
- std::pair<std::set<int>::iterator, bool> result =
- all_values_.insert(values[i]);
- // Each value should only be taken by one thread, so if this value
- // has already been added, something went wrong.
- EXPECT_TRUE(result.second)
- << "Thread=" << Thread::Current() << " value=" << values[i];
- }
+ verifier_.Verify(values);
}
- // Signal that we're done.
- if (AtomicOps::Decrement(&threads_active_) == 0) {
- done_event_.Set();
+ if (AfterEnd()) {
+ verifier_.Finalize();
}
}
private:
- int value_;
- int threads_active_;
CriticalSection all_values_crit_;
- std::set<int> all_values_;
- Event start_event_;
- Event done_event_;
+ Verifier verifier_;
};
struct IncrementOp {
@@ -96,6 +192,10 @@
static int AtomicOp(int* i) { return AtomicOps::Decrement(i); }
};
+struct CompareAndSwapOp {
+ static int AtomicOp(int* i) { return AtomicOps::CompareAndSwap(i, 0, 1); }
+};
+
void StartThreads(ScopedPtrCollection<Thread>* threads,
MessageHandler* handler) {
for (int i = 0; i < kNumThreads; ++i) {
@@ -122,26 +222,63 @@
TEST(AtomicOpsTest, Increment) {
// Create and start lots of threads.
- AtomicOpRunner<IncrementOp> runner(0);
+ AtomicOpRunner<IncrementOp, UniqueValueVerifier> runner(0);
ScopedPtrCollection<Thread> threads;
StartThreads(&threads, &runner);
runner.SetExpectedThreadCount(kNumThreads);
// Release the hounds!
EXPECT_TRUE(runner.Run());
- EXPECT_EQ(kOperationsToRun * kNumThreads, runner.value());
+ EXPECT_EQ(kOperationsToRun * kNumThreads, runner.shared_value());
}
TEST(AtomicOpsTest, Decrement) {
// Create and start lots of threads.
- AtomicOpRunner<DecrementOp> runner(kOperationsToRun * kNumThreads);
+ AtomicOpRunner<DecrementOp, UniqueValueVerifier> runner(
+ kOperationsToRun * kNumThreads);
ScopedPtrCollection<Thread> threads;
StartThreads(&threads, &runner);
runner.SetExpectedThreadCount(kNumThreads);
// Release the hounds!
EXPECT_TRUE(runner.Run());
- EXPECT_EQ(0, runner.value());
+ EXPECT_EQ(0, runner.shared_value());
+}
+
+TEST(AtomicOpsTest, CompareAndSwap) {
+ // Create and start lots of threads.
+ AtomicOpRunner<CompareAndSwapOp, CompareAndSwapVerifier> runner(0);
+ ScopedPtrCollection<Thread> threads;
+ StartThreads(&threads, &runner);
+ runner.SetExpectedThreadCount(kNumThreads);
+
+ // Release the hounds!
+ EXPECT_TRUE(runner.Run());
+ EXPECT_EQ(1, runner.shared_value());
+}
+
+TEST(GlobalLockTest, Basic) {
+ // Create and start lots of threads.
+ LockRunner<GlobalLock> runner;
+ ScopedPtrCollection<Thread> threads;
+ StartThreads(&threads, &runner);
+ runner.SetExpectedThreadCount(kNumThreads);
+
+ // Release the hounds!
+ EXPECT_TRUE(runner.Run());
+ EXPECT_EQ(0, runner.shared_value());
+}
+
+TEST(CriticalSectionTest, Basic) {
+ // Create and start lots of threads.
+ LockRunner<CriticalSectionLock> runner;
+ ScopedPtrCollection<Thread> threads;
+ StartThreads(&threads, &runner);
+ runner.SetExpectedThreadCount(kNumThreads);
+
+ // Release the hounds!
+ EXPECT_TRUE(runner.Run());
+ EXPECT_EQ(0, runner.shared_value());
}
} // namespace rtc
diff --git a/base/nssstreamadapter.cc b/base/nssstreamadapter.cc
index 044d00b..fe1692c 100644
--- a/base/nssstreamadapter.cc
+++ b/base/nssstreamadapter.cc
@@ -978,25 +978,27 @@
}
-bool NSSContext::initialized;
+GlobalLockPod NSSContext::lock;
NSSContext *NSSContext::global_nss_context;
// Static initialization and shutdown
NSSContext *NSSContext::Instance() {
+ lock.Lock();
if (!global_nss_context) {
- scoped_ptr<NSSContext> new_ctx(new NSSContext());
- new_ctx->slot_ = PK11_GetInternalSlot();
+ scoped_ptr<NSSContext> new_ctx(new NSSContext(PK11_GetInternalSlot()));
if (new_ctx->slot_)
global_nss_context = new_ctx.release();
}
+ lock.Unlock();
+
return global_nss_context;
}
-
-
bool NSSContext::InitializeSSL(VerificationCallback callback) {
ASSERT(!callback);
+ static bool initialized = false;
+
if (!initialized) {
SECStatus rv;
diff --git a/base/nssstreamadapter.h b/base/nssstreamadapter.h
index 8b58885..fcacb95 100644
--- a/base/nssstreamadapter.h
+++ b/base/nssstreamadapter.h
@@ -19,6 +19,7 @@
#include "secmodt.h"
#include "webrtc/base/buffer.h"
+#include "webrtc/base/criticalsection.h"
#include "webrtc/base/nssidentity.h"
#include "webrtc/base/ssladapter.h"
#include "webrtc/base/sslstreamadapter.h"
@@ -29,7 +30,7 @@
// Singleton
class NSSContext {
public:
- NSSContext() {}
+ explicit NSSContext(PK11SlotInfo* slot) : slot_(slot) {}
~NSSContext() {
}
@@ -44,7 +45,7 @@
private:
PK11SlotInfo *slot_; // The PKCS-11 slot
- static bool initialized; // Was this initialized?
+ static GlobalLockPod lock; // To protect the global context
static NSSContext *global_nss_context; // The global context
};