Add a lock to NSSContext to fix data race
BUG=crbug/466784
R=juberti@webrtc.org, tommi@webrtc.org
Review URL: https://webrtc-codereview.appspot.com/44669005
Cr-Commit-Position: refs/heads/master@{#8871}
diff --git a/webrtc/base/criticalsection_unittest.cc b/webrtc/base/criticalsection_unittest.cc
index 8c36e3d..6f3c7e9 100644
--- a/webrtc/base/criticalsection_unittest.cc
+++ b/webrtc/base/criticalsection_unittest.cc
@@ -26,16 +26,54 @@
const int kNumThreads = 16;
const int kOperationsToRun = 1000;
-template <class T>
-class AtomicOpRunner : public MessageHandler {
+class UniqueValueVerifier {
public:
- explicit AtomicOpRunner(int initial_value)
- : value_(initial_value),
- threads_active_(0),
- start_event_(true, false),
- done_event_(true, false) {}
+ void Verify(const std::vector<int>& values) {
+ for (size_t i = 0; i < values.size(); ++i) {
+ std::pair<std::set<int>::iterator, bool> result =
+ all_values_.insert(values[i]);
+ // Each value should only be taken by one thread, so if this value
+ // has already been added, something went wrong.
+ EXPECT_TRUE(result.second)
+ << " Thread=" << Thread::Current() << " value=" << values[i];
+ }
+ }
- int value() const { return value_; }
+ void Finalize() {}
+
+ private:
+ std::set<int> all_values_;
+};
+
+class CompareAndSwapVerifier {
+ public:
+ CompareAndSwapVerifier() : zero_count_(0) {}
+
+ void Verify(const std::vector<int>& values) {
+ for (auto v : values) {
+ if (v == 0) {
+ EXPECT_EQ(0, zero_count_) << "Thread=" << Thread::Current();
+ ++zero_count_;
+ } else {
+ EXPECT_EQ(1, v) << " Thread=" << Thread::Current();
+ }
+ }
+ }
+
+ void Finalize() {
+ EXPECT_EQ(1, zero_count_);
+ }
+ private:
+ int zero_count_;
+};
+
+class RunnerBase : public MessageHandler {
+ public:
+ explicit RunnerBase(int value)
+ : threads_active_(0),
+ start_event_(true, false),
+ done_event_(true, false),
+ shared_value_(value) {}
bool Run() {
// Signal all threads to start.
@@ -49,43 +87,101 @@
threads_active_ = count;
}
- virtual void OnMessage(Message* msg) {
+ int shared_value() const { return shared_value_; }
+
+ protected:
+ // Derived classes must override OnMessage, and call BeforeStart and AfterEnd
+ // at the beginning and the end of OnMessage respectively.
+ void BeforeStart() {
+ ASSERT_TRUE(start_event_.Wait(kLongTime));
+ }
+
+ // Returns true if all threads have finished.
+ bool AfterEnd() {
+ if (AtomicOps::Decrement(&threads_active_) == 0) {
+ done_event_.Set();
+ return true;
+ }
+ return false;
+ }
+
+ int threads_active_;
+ Event start_event_;
+ Event done_event_;
+ int shared_value_;
+};
+
+class LOCKABLE CriticalSectionLock {
+ public:
+ void Lock() EXCLUSIVE_LOCK_FUNCTION() {
+ cs_.Enter();
+ }
+ void Unlock() UNLOCK_FUNCTION() {
+ cs_.Leave();
+ }
+
+ private:
+ CriticalSection cs_;
+};
+
+template <class Lock>
+class LockRunner : public RunnerBase {
+ public:
+ LockRunner() : RunnerBase(0) {}
+
+ void OnMessage(Message* msg) override {
+ BeforeStart();
+
+ lock_.Lock();
+
+ EXPECT_EQ(0, shared_value_);
+ int old = shared_value_;
+
+ // Use a loop to increase the chance of race.
+ for (int i = 0; i < kOperationsToRun; ++i) {
+ ++shared_value_;
+ }
+ EXPECT_EQ(old + kOperationsToRun, shared_value_);
+ shared_value_ = 0;
+
+ lock_.Unlock();
+
+ AfterEnd();
+ }
+
+ private:
+ Lock lock_;
+};
+
+template <class Op, class Verifier>
+class AtomicOpRunner : public RunnerBase {
+ public:
+ explicit AtomicOpRunner(int initial_value) : RunnerBase(initial_value) {}
+
+ void OnMessage(Message* msg) override {
+ BeforeStart();
+
std::vector<int> values;
values.reserve(kOperationsToRun);
- // Wait to start.
- ASSERT_TRUE(start_event_.Wait(kLongTime));
-
- // Generate a bunch of values by updating value_ atomically.
+ // Generate a bunch of values by updating shared_value_ atomically.
for (int i = 0; i < kOperationsToRun; ++i) {
- values.push_back(T::AtomicOp(&value_));
+ values.push_back(Op::AtomicOp(&shared_value_));
}
{ // Add them all to the set.
CritScope cs(&all_values_crit_);
- for (size_t i = 0; i < values.size(); ++i) {
- std::pair<std::set<int>::iterator, bool> result =
- all_values_.insert(values[i]);
- // Each value should only be taken by one thread, so if this value
- // has already been added, something went wrong.
- EXPECT_TRUE(result.second)
- << "Thread=" << Thread::Current() << " value=" << values[i];
- }
+ verifier_.Verify(values);
}
- // Signal that we're done.
- if (AtomicOps::Decrement(&threads_active_) == 0) {
- done_event_.Set();
+ if (AfterEnd()) {
+ verifier_.Finalize();
}
}
private:
- int value_;
- int threads_active_;
CriticalSection all_values_crit_;
- std::set<int> all_values_;
- Event start_event_;
- Event done_event_;
+ Verifier verifier_;
};
struct IncrementOp {
@@ -96,6 +192,10 @@
static int AtomicOp(int* i) { return AtomicOps::Decrement(i); }
};
+struct CompareAndSwapOp {
+ static int AtomicOp(int* i) { return AtomicOps::CompareAndSwap(i, 0, 1); }
+};
+
void StartThreads(ScopedPtrCollection<Thread>* threads,
MessageHandler* handler) {
for (int i = 0; i < kNumThreads; ++i) {
@@ -122,26 +222,63 @@
TEST(AtomicOpsTest, Increment) {
// Create and start lots of threads.
- AtomicOpRunner<IncrementOp> runner(0);
+ AtomicOpRunner<IncrementOp, UniqueValueVerifier> runner(0);
ScopedPtrCollection<Thread> threads;
StartThreads(&threads, &runner);
runner.SetExpectedThreadCount(kNumThreads);
// Release the hounds!
EXPECT_TRUE(runner.Run());
- EXPECT_EQ(kOperationsToRun * kNumThreads, runner.value());
+ EXPECT_EQ(kOperationsToRun * kNumThreads, runner.shared_value());
}
TEST(AtomicOpsTest, Decrement) {
// Create and start lots of threads.
- AtomicOpRunner<DecrementOp> runner(kOperationsToRun * kNumThreads);
+ AtomicOpRunner<DecrementOp, UniqueValueVerifier> runner(
+ kOperationsToRun * kNumThreads);
ScopedPtrCollection<Thread> threads;
StartThreads(&threads, &runner);
runner.SetExpectedThreadCount(kNumThreads);
// Release the hounds!
EXPECT_TRUE(runner.Run());
- EXPECT_EQ(0, runner.value());
+ EXPECT_EQ(0, runner.shared_value());
+}
+
+TEST(AtomicOpsTest, CompareAndSwap) {
+ // Create and start lots of threads.
+ AtomicOpRunner<CompareAndSwapOp, CompareAndSwapVerifier> runner(0);
+ ScopedPtrCollection<Thread> threads;
+ StartThreads(&threads, &runner);
+ runner.SetExpectedThreadCount(kNumThreads);
+
+ // Release the hounds!
+ EXPECT_TRUE(runner.Run());
+ EXPECT_EQ(1, runner.shared_value());
+}
+
+TEST(GlobalLockTest, Basic) {
+ // Create and start lots of threads.
+ LockRunner<GlobalLock> runner;
+ ScopedPtrCollection<Thread> threads;
+ StartThreads(&threads, &runner);
+ runner.SetExpectedThreadCount(kNumThreads);
+
+ // Release the hounds!
+ EXPECT_TRUE(runner.Run());
+ EXPECT_EQ(0, runner.shared_value());
+}
+
+TEST(CriticalSectionTest, Basic) {
+ // Create and start lots of threads.
+ LockRunner<CriticalSectionLock> runner;
+ ScopedPtrCollection<Thread> threads;
+ StartThreads(&threads, &runner);
+ runner.SetExpectedThreadCount(kNumThreads);
+
+ // Release the hounds!
+ EXPECT_TRUE(runner.Run());
+ EXPECT_EQ(0, runner.shared_value());
}
} // namespace rtc