blob: 293881a06f1d55bc3292f5d94001e2de83f312c2 [file] [log] [blame]
/*
* Copyright 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "pc/sctp_transport.h"
#include <cstddef>
#include <functional>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "api/dtls_transport_interface.h"
#include "api/make_ref_counted.h"
#include "api/priority.h"
#include "api/rtc_error.h"
#include "api/scoped_refptr.h"
#include "api/sctp_transport_interface.h"
#include "api/transport/data_channel_transport_interface.h"
#include "media/sctp/sctp_transport_internal.h"
#include "p2p/base/p2p_constants.h"
#include "p2p/base/packet_transport_internal.h"
#include "p2p/dtls/dtls_transport_internal.h"
#include "p2p/dtls/fake_dtls_transport.h"
#include "pc/dtls_transport.h"
#include "rtc_base/copy_on_write_buffer.h"
#include "rtc_base/gunit.h"
#include "rtc_base/thread.h"
#include "test/gmock.h"
#include "test/gtest.h"
constexpr int kDefaultTimeout = 1000; // milliseconds
constexpr int kTestMaxSctpStreams = 1234;
using cricket::FakeDtlsTransport;
using ::testing::ElementsAre;
namespace webrtc {
namespace {
class FakeCricketSctpTransport : public cricket::SctpTransportInternal {
public:
void SetOnConnectedCallback(std::function<void()> callback) override {
on_connected_callback_ = std::move(callback);
}
void SetDataChannelSink(DataChannelSink* sink) override {}
void SetDtlsTransport(cricket::DtlsTransportInternal* transport) override {}
bool Start(int local_port, int remote_port, int max_message_size) override {
return true;
}
bool OpenStream(int sid, PriorityValue priority) override { return true; }
bool ResetStream(int sid) override { return true; }
RTCError SendData(int sid,
const SendDataParams& params,
const rtc::CopyOnWriteBuffer& payload) override {
return RTCError::OK();
}
bool ReadyToSendData() override { return true; }
void set_debug_name_for_testing(const char* debug_name) override {}
int max_message_size() const override { return 0; }
std::optional<int> max_outbound_streams() const override {
return max_outbound_streams_;
}
std::optional<int> max_inbound_streams() const override {
return max_inbound_streams_;
}
size_t buffered_amount(int sid) const override { return 0; }
size_t buffered_amount_low_threshold(int sid) const override { return 0; }
void SetBufferedAmountLowThreshold(int sid, size_t bytes) override {}
void SendSignalAssociationChangeCommunicationUp() {
ASSERT_TRUE(on_connected_callback_);
on_connected_callback_();
}
void set_max_outbound_streams(int streams) {
max_outbound_streams_ = streams;
}
void set_max_inbound_streams(int streams) { max_inbound_streams_ = streams; }
private:
std::optional<int> max_outbound_streams_;
std::optional<int> max_inbound_streams_;
std::function<void()> on_connected_callback_;
};
} // namespace
class TestSctpTransportObserver : public SctpTransportObserverInterface {
public:
TestSctpTransportObserver() : info_(SctpTransportState::kNew) {}
void OnStateChange(SctpTransportInformation info) override {
info_ = info;
states_.push_back(info.state());
}
SctpTransportState State() {
if (states_.size() > 0) {
return states_[states_.size() - 1];
} else {
return SctpTransportState::kNew;
}
}
const std::vector<SctpTransportState>& States() { return states_; }
SctpTransportInformation LastReceivedInformation() { return info_; }
private:
std::vector<SctpTransportState> states_;
SctpTransportInformation info_;
};
class SctpTransportTest : public ::testing::Test {
public:
SctpTransport* transport() { return transport_.get(); }
SctpTransportObserverInterface* observer() { return &observer_; }
void CreateTransport() {
std::unique_ptr<cricket::DtlsTransportInternal> cricket_transport =
std::make_unique<FakeDtlsTransport>(
"audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
dtls_transport_ =
rtc::make_ref_counted<DtlsTransport>(std::move(cricket_transport));
auto cricket_sctp_transport =
absl::WrapUnique(new FakeCricketSctpTransport());
transport_ = rtc::make_ref_counted<SctpTransport>(
std::move(cricket_sctp_transport), dtls_transport_);
}
void CompleteSctpHandshake() {
// The computed MaxChannels shall be the minimum of the outgoing
// and incoming # of streams.
CricketSctpTransport()->set_max_outbound_streams(kTestMaxSctpStreams);
CricketSctpTransport()->set_max_inbound_streams(kTestMaxSctpStreams + 1);
CricketSctpTransport()->SendSignalAssociationChangeCommunicationUp();
}
FakeCricketSctpTransport* CricketSctpTransport() {
return static_cast<FakeCricketSctpTransport*>(transport_->internal());
}
rtc::AutoThread main_thread_;
rtc::scoped_refptr<SctpTransport> transport_;
rtc::scoped_refptr<DtlsTransport> dtls_transport_;
TestSctpTransportObserver observer_;
};
TEST(SctpTransportSimpleTest, CreateClearDelete) {
rtc::AutoThread main_thread;
std::unique_ptr<cricket::DtlsTransportInternal> cricket_transport =
std::make_unique<FakeDtlsTransport>("audio",
cricket::ICE_CANDIDATE_COMPONENT_RTP);
rtc::scoped_refptr<DtlsTransport> dtls_transport =
rtc::make_ref_counted<DtlsTransport>(std::move(cricket_transport));
std::unique_ptr<cricket::SctpTransportInternal> fake_cricket_sctp_transport =
absl::WrapUnique(new FakeCricketSctpTransport());
rtc::scoped_refptr<SctpTransport> sctp_transport =
rtc::make_ref_counted<SctpTransport>(
std::move(fake_cricket_sctp_transport), dtls_transport);
ASSERT_TRUE(sctp_transport->internal());
ASSERT_EQ(SctpTransportState::kConnecting,
sctp_transport->Information().state());
sctp_transport->Clear();
ASSERT_FALSE(sctp_transport->internal());
ASSERT_EQ(SctpTransportState::kClosed, sctp_transport->Information().state());
}
TEST_F(SctpTransportTest, EventsObservedWhenConnecting) {
CreateTransport();
transport()->RegisterObserver(observer());
CompleteSctpHandshake();
ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
kDefaultTimeout);
EXPECT_THAT(observer_.States(), ElementsAre(SctpTransportState::kConnected));
}
TEST_F(SctpTransportTest, CloseWhenClearing) {
CreateTransport();
transport()->RegisterObserver(observer());
CompleteSctpHandshake();
ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
kDefaultTimeout);
transport()->Clear();
ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(),
kDefaultTimeout);
}
TEST_F(SctpTransportTest, MaxChannelsSignalled) {
CreateTransport();
transport()->RegisterObserver(observer());
EXPECT_FALSE(transport()->Information().MaxChannels());
EXPECT_FALSE(observer_.LastReceivedInformation().MaxChannels());
CompleteSctpHandshake();
ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
kDefaultTimeout);
EXPECT_TRUE(transport()->Information().MaxChannels());
EXPECT_EQ(kTestMaxSctpStreams, *(transport()->Information().MaxChannels()));
EXPECT_TRUE(observer_.LastReceivedInformation().MaxChannels());
EXPECT_EQ(kTestMaxSctpStreams,
*(observer_.LastReceivedInformation().MaxChannels()));
}
TEST_F(SctpTransportTest, CloseWhenTransportCloses) {
CreateTransport();
transport()->RegisterObserver(observer());
CompleteSctpHandshake();
ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
kDefaultTimeout);
static_cast<cricket::FakeDtlsTransport*>(dtls_transport_->internal())
->SetDtlsState(DtlsTransportState::kClosed);
ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(),
kDefaultTimeout);
}
} // namespace webrtc