Allow setting a custom randomness source.
This is useful in environments where OpenSSL may not be available.
Bug: webrtc:15240
Change-Id: I7ba29e44bd1d25231df13ee79dacc74f260ded67
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/308600
Reviewed-by: Harald Alvestrand <hta@webrtc.org>
Commit-Queue: Sameer Vijaykar <samvi@google.com>
Cr-Commit-Position: refs/heads/main@{#40293}
diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn
index 561cb9a..ac27ee8 100644
--- a/rtc_base/BUILD.gn
+++ b/rtc_base/BUILD.gn
@@ -1455,6 +1455,7 @@
"../api/task_queue:pending_task_safety_flag",
"../api/units:time_delta",
"../system_wrappers:field_trial",
+ "synchronization:mutex",
"system:rtc_export",
"task_utils:repeating_task",
"third_party/base64",
diff --git a/rtc_base/helpers.cc b/rtc_base/helpers.cc
index 3372398..84cbe5f 100644
--- a/rtc_base/helpers.cc
+++ b/rtc_base/helpers.cc
@@ -19,19 +19,14 @@
#include "absl/strings/string_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
+#include "rtc_base/synchronization/mutex.h"
// Protect against max macro inclusion.
#undef max
namespace rtc {
-// Base class for RNG implementations.
-class RandomGenerator {
- public:
- virtual ~RandomGenerator() {}
- virtual bool Init(const void* seed, size_t len) = 0;
- virtual bool Generate(void* buf, size_t len) = 0;
-};
+namespace {
// The OpenSSL RNG.
class SecureRandomGenerator : public RandomGenerator {
@@ -64,8 +59,6 @@
int seed_;
};
-namespace {
-
// TODO: Use Base64::Base64Table instead.
static const char kBase64[64] = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
@@ -79,6 +72,13 @@
static const char kUuidDigit17[4] = {'8', '9', 'a', 'b'};
+// Lock for the global random generator, only needed to serialize changing the
+// generator.
+webrtc::Mutex& GetRandomGeneratorLock() {
+ static webrtc::Mutex& mutex = *new webrtc::Mutex();
+ return mutex;
+}
+
// This round about way of creating a global RNG is to safe-guard against
// indeterminant static initialization order.
std::unique_ptr<RandomGenerator>& GetGlobalRng() {
@@ -94,7 +94,18 @@
} // namespace
+void SetDefaultRandomGenerator() {
+ webrtc::MutexLock lock(&GetRandomGeneratorLock());
+ GetGlobalRng().reset(new SecureRandomGenerator());
+}
+
+void SetRandomGenerator(std::unique_ptr<RandomGenerator> generator) {
+ webrtc::MutexLock lock(&GetRandomGeneratorLock());
+ GetGlobalRng() = std::move(generator);
+}
+
void SetRandomTestMode(bool test) {
+ webrtc::MutexLock lock(&GetRandomGeneratorLock());
if (!test) {
GetGlobalRng().reset(new SecureRandomGenerator());
} else {
diff --git a/rtc_base/helpers.h b/rtc_base/helpers.h
index c214f52..51ca672 100644
--- a/rtc_base/helpers.h
+++ b/rtc_base/helpers.h
@@ -14,6 +14,7 @@
#include <stddef.h>
#include <stdint.h>
+#include <memory>
#include <string>
#include "absl/strings/string_view.h"
@@ -21,6 +22,23 @@
namespace rtc {
+// Interface for RNG implementations.
+class RandomGenerator {
+ public:
+ virtual ~RandomGenerator() {}
+ virtual bool Init(const void* seed, size_t len) = 0;
+ virtual bool Generate(void* buf, size_t len) = 0;
+};
+
+// Sets the default random generator as the source of randomness. The default
+// source uses the OpenSSL RNG and provides cryptographically secure randomness.
+void SetDefaultRandomGenerator();
+
+// Set a custom random generator. Results produced by CreateRandomXyz
+// are cryptographically random iff the output of the supplied generator is
+// cryptographically random.
+void SetRandomGenerator(std::unique_ptr<RandomGenerator> generator);
+
// For testing, we can return predictable data.
void SetRandomTestMode(bool test);
diff --git a/rtc_base/helpers_unittest.cc b/rtc_base/helpers_unittest.cc
index b855872..015b4d0 100644
--- a/rtc_base/helpers_unittest.cc
+++ b/rtc_base/helpers_unittest.cc
@@ -12,20 +12,30 @@
#include <string.h>
+#include <cstring>
#include <string>
#include "rtc_base/buffer.h"
+#include "test/gmock.h"
#include "test/gtest.h"
namespace rtc {
+namespace {
-class RandomTest : public ::testing::Test {};
+using ::testing::_;
+using ::testing::DoAll;
+using ::testing::Invoke;
+using ::testing::IsEmpty;
+using ::testing::Not;
+using ::testing::Return;
+using ::testing::WithArg;
+using ::testing::WithArgs;
-TEST_F(RandomTest, TestCreateRandomId) {
+TEST(RandomTest, TestCreateRandomId) {
CreateRandomId();
}
-TEST_F(RandomTest, TestCreateRandomDouble) {
+TEST(RandomTest, TestCreateRandomDouble) {
for (int i = 0; i < 100; ++i) {
double r = CreateRandomDouble();
EXPECT_GE(r, 0.0);
@@ -33,11 +43,11 @@
}
}
-TEST_F(RandomTest, TestCreateNonZeroRandomId) {
+TEST(RandomTest, TestCreateNonZeroRandomId) {
EXPECT_NE(0U, CreateRandomNonZeroId());
}
-TEST_F(RandomTest, TestCreateRandomString) {
+TEST(RandomTest, TestCreateRandomString) {
std::string random = CreateRandomString(256);
EXPECT_EQ(256U, random.size());
std::string random2;
@@ -46,7 +56,7 @@
EXPECT_EQ(256U, random2.size());
}
-TEST_F(RandomTest, TestCreateRandomData) {
+TEST(RandomTest, TestCreateRandomData) {
static size_t kRandomDataLength = 32;
std::string random1;
std::string random2;
@@ -57,7 +67,7 @@
EXPECT_NE(0, memcmp(random1.data(), random2.data(), kRandomDataLength));
}
-TEST_F(RandomTest, TestCreateRandomStringEvenlyDivideTable) {
+TEST(RandomTest, TestCreateRandomStringEvenlyDivideTable) {
static std::string kUnbiasedTable("01234567");
std::string random;
EXPECT_TRUE(CreateRandomString(256, kUnbiasedTable, &random));
@@ -68,12 +78,12 @@
EXPECT_EQ(0U, random.size());
}
-TEST_F(RandomTest, TestCreateRandomUuid) {
+TEST(RandomTest, TestCreateRandomUuid) {
std::string random = CreateRandomUuid();
EXPECT_EQ(36U, random.size());
}
-TEST_F(RandomTest, TestCreateRandomForTest) {
+TEST(RandomTest, TestCreateRandomForTest) {
// Make sure we get the output we expect.
SetRandomTestMode(true);
EXPECT_EQ(2154761789U, CreateRandomId());
@@ -112,4 +122,50 @@
SetRandomTestMode(false);
}
+class MockRandomGenerator : public RandomGenerator {
+ public:
+ MOCK_METHOD(void, Die, ());
+ ~MockRandomGenerator() override { Die(); }
+
+ MOCK_METHOD(bool, Init, (const void* seed, size_t len), (override));
+ MOCK_METHOD(bool, Generate, (void* buf, size_t len), (override));
+};
+
+TEST(RandomTest, TestSetRandomGenerator) {
+ std::unique_ptr<MockRandomGenerator> will_move =
+ std::make_unique<MockRandomGenerator>();
+ MockRandomGenerator* generator = will_move.get();
+ SetRandomGenerator(std::move(will_move));
+
+ EXPECT_CALL(*generator, Init(_, sizeof(int))).WillOnce(Return(true));
+ EXPECT_TRUE(InitRandom(5));
+
+ std::string seed = "seed";
+ EXPECT_CALL(*generator, Init(seed.data(), seed.size()))
+ .WillOnce(Return(true));
+ EXPECT_TRUE(InitRandom(seed.data(), seed.size()));
+
+ uint32_t id = 4658;
+ EXPECT_CALL(*generator, Generate(_, sizeof(uint32_t)))
+ .WillOnce(DoAll(WithArg<0>(Invoke([&id](void* p) {
+ std::memcpy(p, &id, sizeof(uint32_t));
+ })),
+ Return(true)));
+ EXPECT_EQ(CreateRandomId(), id);
+
+ EXPECT_CALL(*generator, Generate)
+ .WillOnce(DoAll(
+ WithArgs<0, 1>([](void* p, size_t len) { std::memset(p, 0, len); }),
+ Return(true)));
+ EXPECT_THAT(CreateRandomUuid(), Not(IsEmpty()));
+
+ // Set the default random generator, and expect that mock generator is
+ // not used beyond this point.
+ EXPECT_CALL(*generator, Die);
+ EXPECT_CALL(*generator, Generate).Times(0);
+ SetDefaultRandomGenerator();
+ EXPECT_THAT(CreateRandomUuid(), Not(IsEmpty()));
+}
+
+} // namespace
} // namespace rtc