Allow setting a custom randomness source.

This is useful in environments where OpenSSL may not be available.

Bug: webrtc:15240
Change-Id: I7ba29e44bd1d25231df13ee79dacc74f260ded67
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/308600
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Sameer Vijaykar <samvi@google.com>
Cr-Commit-Position: refs/heads/main@{#40293}
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index 561cb9a..ac27ee8 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -1455,6 +1455,7 @@
     "../api/task_queue:pending_task_safety_flag",
     "../api/units:time_delta",
     "../system_wrappers:field_trial",
+    "synchronization:mutex",
     "system:rtc_export",
     "task_utils:repeating_task",
     "third_party/base64",
diff --git a/rtc_base/helpers.cc b/rtc_base/helpers.cc
index 3372398..84cbe5f 100644
--- a/rtc_base/helpers.cc
+++ b/rtc_base/helpers.cc
@@ -19,19 +19,14 @@
 #include "absl/strings/string_view.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
+#include "rtc_base/synchronization/mutex.h"
 
 // Protect against max macro inclusion.
 #undef max
 
 namespace rtc {
 
-// Base class for RNG implementations.
-class RandomGenerator {
- public:
-  virtual ~RandomGenerator() {}
-  virtual bool Init(const void* seed, size_t len) = 0;
-  virtual bool Generate(void* buf, size_t len) = 0;
-};
+namespace {
 
 // The OpenSSL RNG.
 class SecureRandomGenerator : public RandomGenerator {
@@ -64,8 +59,6 @@
   int seed_;
 };
 
-namespace {
-
 // TODO: Use Base64::Base64Table instead.
 static const char kBase64[64] = {
     'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
@@ -79,6 +72,13 @@
 
 static const char kUuidDigit17[4] = {'8', '9', 'a', 'b'};
 
+// Lock for the global random generator, only needed to serialize changing the
+// generator.
+webrtc::Mutex& GetRandomGeneratorLock() {
+  static webrtc::Mutex& mutex = *new webrtc::Mutex();
+  return mutex;
+}
+
 // This round about way of creating a global RNG is to safe-guard against
 // indeterminant static initialization order.
 std::unique_ptr<RandomGenerator>& GetGlobalRng() {
@@ -94,7 +94,18 @@
 
 }  // namespace
 
+void SetDefaultRandomGenerator() {
+  webrtc::MutexLock lock(&GetRandomGeneratorLock());
+  GetGlobalRng().reset(new SecureRandomGenerator());
+}
+
+void SetRandomGenerator(std::unique_ptr<RandomGenerator> generator) {
+  webrtc::MutexLock lock(&GetRandomGeneratorLock());
+  GetGlobalRng() = std::move(generator);
+}
+
 void SetRandomTestMode(bool test) {
+  webrtc::MutexLock lock(&GetRandomGeneratorLock());
   if (!test) {
     GetGlobalRng().reset(new SecureRandomGenerator());
   } else {
diff --git a/rtc_base/helpers.h b/rtc_base/helpers.h
index c214f52..51ca672 100644
--- a/rtc_base/helpers.h
+++ b/rtc_base/helpers.h
@@ -14,6 +14,7 @@
 #include <stddef.h>
 #include <stdint.h>
 
+#include <memory>
 #include <string>
 
 #include "absl/strings/string_view.h"
@@ -21,6 +22,23 @@
 
 namespace rtc {
 
+// Interface for RNG implementations.
+class RandomGenerator {
+ public:
+  virtual ~RandomGenerator() {}
+  virtual bool Init(const void* seed, size_t len) = 0;
+  virtual bool Generate(void* buf, size_t len) = 0;
+};
+
+// Sets the default random generator as the source of randomness. The default
+// source uses the OpenSSL RNG and provides cryptographically secure randomness.
+void SetDefaultRandomGenerator();
+
+// Set a custom random generator. Results produced by CreateRandomXyz
+// are cryptographically random iff the output of the supplied generator is
+// cryptographically random.
+void SetRandomGenerator(std::unique_ptr<RandomGenerator> generator);
+
 // For testing, we can return predictable data.
 void SetRandomTestMode(bool test);
 
diff --git a/rtc_base/helpers_unittest.cc b/rtc_base/helpers_unittest.cc
index b855872..015b4d0 100644
--- a/rtc_base/helpers_unittest.cc
+++ b/rtc_base/helpers_unittest.cc
@@ -12,20 +12,30 @@
 
 #include <string.h>
 
+#include <cstring>
 #include <string>
 
 #include "rtc_base/buffer.h"
+#include "test/gmock.h"
 #include "test/gtest.h"
 
 namespace rtc {
+namespace {
 
-class RandomTest : public ::testing::Test {};
+using ::testing::_;
+using ::testing::DoAll;
+using ::testing::Invoke;
+using ::testing::IsEmpty;
+using ::testing::Not;
+using ::testing::Return;
+using ::testing::WithArg;
+using ::testing::WithArgs;
 
-TEST_F(RandomTest, TestCreateRandomId) {
+TEST(RandomTest, TestCreateRandomId) {
   CreateRandomId();
 }
 
-TEST_F(RandomTest, TestCreateRandomDouble) {
+TEST(RandomTest, TestCreateRandomDouble) {
   for (int i = 0; i < 100; ++i) {
     double r = CreateRandomDouble();
     EXPECT_GE(r, 0.0);
@@ -33,11 +43,11 @@
   }
 }
 
-TEST_F(RandomTest, TestCreateNonZeroRandomId) {
+TEST(RandomTest, TestCreateNonZeroRandomId) {
   EXPECT_NE(0U, CreateRandomNonZeroId());
 }
 
-TEST_F(RandomTest, TestCreateRandomString) {
+TEST(RandomTest, TestCreateRandomString) {
   std::string random = CreateRandomString(256);
   EXPECT_EQ(256U, random.size());
   std::string random2;
@@ -46,7 +56,7 @@
   EXPECT_EQ(256U, random2.size());
 }
 
-TEST_F(RandomTest, TestCreateRandomData) {
+TEST(RandomTest, TestCreateRandomData) {
   static size_t kRandomDataLength = 32;
   std::string random1;
   std::string random2;
@@ -57,7 +67,7 @@
   EXPECT_NE(0, memcmp(random1.data(), random2.data(), kRandomDataLength));
 }
 
-TEST_F(RandomTest, TestCreateRandomStringEvenlyDivideTable) {
+TEST(RandomTest, TestCreateRandomStringEvenlyDivideTable) {
   static std::string kUnbiasedTable("01234567");
   std::string random;
   EXPECT_TRUE(CreateRandomString(256, kUnbiasedTable, &random));
@@ -68,12 +78,12 @@
   EXPECT_EQ(0U, random.size());
 }
 
-TEST_F(RandomTest, TestCreateRandomUuid) {
+TEST(RandomTest, TestCreateRandomUuid) {
   std::string random = CreateRandomUuid();
   EXPECT_EQ(36U, random.size());
 }
 
-TEST_F(RandomTest, TestCreateRandomForTest) {
+TEST(RandomTest, TestCreateRandomForTest) {
   // Make sure we get the output we expect.
   SetRandomTestMode(true);
   EXPECT_EQ(2154761789U, CreateRandomId());
@@ -112,4 +122,50 @@
   SetRandomTestMode(false);
 }
 
+class MockRandomGenerator : public RandomGenerator {
+ public:
+  MOCK_METHOD(void, Die, ());
+  ~MockRandomGenerator() override { Die(); }
+
+  MOCK_METHOD(bool, Init, (const void* seed, size_t len), (override));
+  MOCK_METHOD(bool, Generate, (void* buf, size_t len), (override));
+};
+
+TEST(RandomTest, TestSetRandomGenerator) {
+  std::unique_ptr<MockRandomGenerator> will_move =
+      std::make_unique<MockRandomGenerator>();
+  MockRandomGenerator* generator = will_move.get();
+  SetRandomGenerator(std::move(will_move));
+
+  EXPECT_CALL(*generator, Init(_, sizeof(int))).WillOnce(Return(true));
+  EXPECT_TRUE(InitRandom(5));
+
+  std::string seed = "seed";
+  EXPECT_CALL(*generator, Init(seed.data(), seed.size()))
+      .WillOnce(Return(true));
+  EXPECT_TRUE(InitRandom(seed.data(), seed.size()));
+
+  uint32_t id = 4658;
+  EXPECT_CALL(*generator, Generate(_, sizeof(uint32_t)))
+      .WillOnce(DoAll(WithArg<0>(Invoke([&id](void* p) {
+                        std::memcpy(p, &id, sizeof(uint32_t));
+                      })),
+                      Return(true)));
+  EXPECT_EQ(CreateRandomId(), id);
+
+  EXPECT_CALL(*generator, Generate)
+      .WillOnce(DoAll(
+          WithArgs<0, 1>([](void* p, size_t len) { std::memset(p, 0, len); }),
+          Return(true)));
+  EXPECT_THAT(CreateRandomUuid(), Not(IsEmpty()));
+
+  // Set the default random generator, and expect that mock generator is
+  // not used beyond this point.
+  EXPECT_CALL(*generator, Die);
+  EXPECT_CALL(*generator, Generate).Times(0);
+  SetDefaultRandomGenerator();
+  EXPECT_THAT(CreateRandomUuid(), Not(IsEmpty()));
+}
+
+}  // namespace
 }  // namespace rtc