blob: 7501b2db896b71fc2dde43cdff597b44162837f9 [file] [log] [blame]
/*
* Copyright (c) 2025 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/neural_residual_echo_estimator_impl.h"
#include <algorithm>
#include <array>
#include <cstdarg>
#include <cstdio>
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "api/audio/echo_canceller3_config.h"
#include "modules/audio_processing/aec3/aec3_common.h"
#include "modules/audio_processing/aec3/neural_feature_extractor.h"
#ifdef WEBRTC_ANDROID_PLATFORM_BUILD
#include "external/webrtc/webrtc/modules/audio_processing/aec3/neural_residual_echo_estimator.pb.h"
#else
#include "modules/audio_processing/aec3/neural_residual_echo_estimator.pb.h"
#endif
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "third_party/tflite/src/tensorflow/lite/error_reporter.h"
#include "third_party/tflite/src/tensorflow/lite/interpreter.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tflite/src/tensorflow/lite/kernels/register.h"
#include "third_party/tflite/src/tensorflow/lite/model_builder.h"
namespace webrtc {
namespace {
using ModelInputEnum = NeuralResidualEchoEstimatorImpl::ModelInputEnum;
using ModelOutputEnum = NeuralResidualEchoEstimatorImpl::ModelOutputEnum;
// A TFLite ErrorReporter that writes its messages to RTC_LOG.
class LoggingErrorReporter : public tflite::ErrorReporter {
int Report(const char* format, va_list args) override {
char buffer[2048];
const int result = vsnprintf(buffer, sizeof(buffer), format, args);
RTC_LOG(LS_ERROR) << buffer;
return result;
}
};
tflite::ErrorReporter* DefaultLoggingErrorReporter() {
static LoggingErrorReporter* instance = new LoggingErrorReporter();
return instance;
}
// Field under which the ML-REE metadata is stored in a TFLite model.
constexpr char kTfLiteMetadataKey[] = "REE_METADATA";
// Reads the model metadata from the TFLite model. If the metadata is not
// present, it returns a default metadata with version 1. If the metadata is
// present but cannot be parsed, it returns nullopt.
std::optional<audioproc::ReeModelMetadata> ReadModelMetadata(
const tflite::FlatBufferModel* model) {
audioproc::ReeModelMetadata default_metadata;
default_metadata.set_version(1);
const auto metadata_records = model->ReadAllMetadata();
const auto metadata_field = metadata_records.find(kTfLiteMetadataKey);
if (metadata_field == metadata_records.end()) {
return default_metadata;
}
audioproc::ReeModelMetadata metadata;
if (metadata.ParseFromString(metadata_field->second)) {
return metadata;
}
return std::nullopt;
}
// Encapsulates all the NeuralResidualEchoEstimatorImpl's interaction with
// TFLite. This allows the separation of rebuffering and similar AEC3-related
// bookkeeping from the TFLite-specific code, and makes it easier to test the
// former code by mocking.
class TfLiteModelRunner : public NeuralResidualEchoEstimatorImpl::ModelRunner {
public:
TfLiteModelRunner(std::string model_data,
std::unique_ptr<tflite::FlatBufferModel> tflite_model,
std::unique_ptr<tflite::Interpreter> tflite_interpreter,
audioproc::ReeModelMetadata metadata)
: model_data_(std::move(model_data)),
input_tensor_size_(static_cast<int>(
tflite::NumElements(tflite_interpreter->input_tensor(
static_cast<int>(ModelInputEnum::kMic))))),
frame_size_(metadata.version() == 1 ? input_tensor_size_
: (input_tensor_size_ - 1) * 2),
step_size_(frame_size_ / 2),
frame_size_by_2_plus_1_(frame_size_ / 2 + 1),
metadata_(metadata),
model_state_(tflite::NumElements(tflite_interpreter->input_tensor(
static_cast<int>(ModelInputEnum::kModelState))),
0.0f),
tflite_model_(std::move(tflite_model)),
tflite_interpreter_(std::move(tflite_interpreter)) {
for (const auto input_enum :
{ModelInputEnum::kMic, ModelInputEnum::kLinearAecOutput,
ModelInputEnum::kAecRef}) {
webrtc::ArrayView<float> input_tensor(
tflite_interpreter_->typed_input_tensor<float>(
static_cast<int>(input_enum)),
input_tensor_size_);
std::fill(input_tensor.begin(), input_tensor.end(), 0.0f);
}
RTC_CHECK_EQ(frame_size_ % kBlockSize, 0);
RTC_CHECK_EQ(tflite::NumElements(tflite_interpreter_->input_tensor(
static_cast<int>(ModelInputEnum::kLinearAecOutput))),
input_tensor_size_);
RTC_CHECK_EQ(tflite::NumElements(tflite_interpreter_->input_tensor(
static_cast<int>(ModelInputEnum::kAecRef))),
input_tensor_size_);
RTC_CHECK_EQ(tflite::NumElements(tflite_interpreter_->input_tensor(
static_cast<int>(ModelInputEnum::kModelState))),
tflite::NumElements(tflite_interpreter_->output_tensor(
static_cast<int>(ModelOutputEnum::kModelState))));
RTC_CHECK_EQ(tflite::NumElements(tflite_interpreter_->output_tensor(
static_cast<int>(ModelOutputEnum::kEchoMask))),
frame_size_by_2_plus_1_);
}
~TfLiteModelRunner() override {}
int StepSize() const override { return step_size_; }
webrtc::ArrayView<float> GetInput(ModelInputEnum input_enum) override {
int tensor_size = 0;
switch (input_enum) {
case ModelInputEnum::kMic: // fall-through
case ModelInputEnum::kLinearAecOutput: // fall-through
case ModelInputEnum::kAecRef:
tensor_size = input_tensor_size_;
break;
case ModelInputEnum::kModelState:
tensor_size = static_cast<int>(model_state_.size());
break;
case ModelInputEnum::kNumInputs:
RTC_CHECK(false);
}
return webrtc::ArrayView<float>(
tflite_interpreter_->typed_input_tensor<float>(
static_cast<int>(input_enum)),
tensor_size);
}
webrtc::ArrayView<const float> GetOutputEchoMask() override {
return webrtc::ArrayView<const float>(
tflite_interpreter_->typed_output_tensor<const float>(
static_cast<int>(ModelOutputEnum::kEchoMask)),
frame_size_by_2_plus_1_);
}
const audioproc::ReeModelMetadata& GetMetadata() const override {
return metadata_;
}
bool Invoke() override {
auto input_state = GetInput(ModelInputEnum::kModelState);
std::copy(model_state_.begin(), model_state_.end(), input_state.begin());
const TfLiteStatus status = tflite_interpreter_->Invoke();
if (status != kTfLiteOk && processing_error_log_counter_ <= 0) {
RTC_LOG(LS_ERROR) << "TfLiteModelRunner::Estimate() "
"invocation error, status="
<< status;
// Wait ~1 second before logging this error again.
processing_error_log_counter_ = 16000 / step_size_;
return false;
} else if (processing_error_log_counter_ > 0) {
--processing_error_log_counter_;
}
auto output_state = webrtc::ArrayView<const float>(
tflite_interpreter_->typed_output_tensor<const float>(
static_cast<int>(ModelOutputEnum::kModelState)),
model_state_.size());
std::copy(output_state.begin(), output_state.end(), model_state_.begin());
constexpr float kStateDecay = 0.999f;
for (float& state : model_state_) {
state *= kStateDecay;
}
return true;
}
private:
// Model data needs to be declared before `tflite_model_` to ensure that the
// data is destroyed after the tflite model.
const std::string model_data_;
// Size of the input tensors.
const int input_tensor_size_;
// Frame size of the model.
const int frame_size_;
// Step size.
const int step_size_;
// Size of the spectrum mask that is returned by the model.
const int frame_size_by_2_plus_1_;
// Metadata of the model.
const audioproc::ReeModelMetadata metadata_;
// LSTM states that carry over to the next inference invocation.
std::vector<float> model_state_;
// TFLite model for residual echo estimation.
// Must outlive `tflite_interpreter_`
std::unique_ptr<tflite::FlatBufferModel> tflite_model_;
// Used to run inference with `tflite_model_`.
std::unique_ptr<tflite::Interpreter> tflite_interpreter_;
// Counter to avoid logging processing errors too often.
int processing_error_log_counter_ = 0;
};
} // namespace
std::unique_ptr<NeuralResidualEchoEstimatorImpl::ModelRunner>
NeuralResidualEchoEstimatorImpl::LoadTfLiteModel(
absl::string_view ml_ree_model_path) {
std::string model_data;
auto model = tflite::FlatBufferModel::BuildFromFile(
std::string(ml_ree_model_path).c_str(), DefaultLoggingErrorReporter());
if (!model) {
RTC_LOG(LS_ERROR) << "Error loading model from " << ml_ree_model_path;
return nullptr;
}
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver resolver;
if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
RTC_LOG(LS_ERROR) << "Error creating interpreter";
return nullptr;
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
RTC_LOG(LS_ERROR) << "Error allocating tensors";
return nullptr;
}
if (interpreter->inputs().size() !=
static_cast<int>(ModelInputEnum::kNumInputs)) {
RTC_LOG(LS_ERROR) << "Model input number mismatch, got "
<< interpreter->inputs().size() << " expected "
<< static_cast<int>(ModelInputEnum::kNumInputs);
return nullptr;
}
if (interpreter->outputs().size() !=
static_cast<int>(ModelOutputEnum::kNumOutputs)) {
RTC_LOG(LS_ERROR) << "Model output number mismatch, got "
<< interpreter->outputs().size() << " expected "
<< static_cast<int>(ModelOutputEnum::kNumOutputs);
return nullptr;
}
auto metadata = ReadModelMetadata(model.get());
if (!metadata.has_value()) {
RTC_LOG(LS_ERROR) << "Error reading model metadata";
return nullptr;
}
if (metadata->version() < 1 || metadata->version() > 2) {
RTC_LOG(LS_ERROR) << "Model version mismatch, got " << metadata->version()
<< " expected 1 or 2.";
return nullptr;
}
return std::make_unique<TfLiteModelRunner>(std::move(model_data),
std::move(model),
std::move(interpreter), *metadata);
}
int NeuralResidualEchoEstimatorImpl::instance_count_ = 0;
NeuralResidualEchoEstimatorImpl::NeuralResidualEchoEstimatorImpl(
std::unique_ptr<ModelRunner> model_runner)
: model_runner_(std::move(model_runner)),
data_dumper_(new ApmDataDumper(++instance_count_)) {
input_mic_buffer_.reserve(model_runner_->StepSize());
input_linear_aec_output_buffer_.reserve(model_runner_->StepSize());
input_aec_ref_buffer_.reserve(model_runner_->StepSize());
output_mask_.fill(0.0f);
if (model_runner_->GetMetadata().version() == 1) {
feature_extractor_ = std::make_unique<TimeDomainFeatureExtractor>();
} else {
feature_extractor_ = std::make_unique<FrequencyDomainFeatureExtractor>(
/*step_size=*/model_runner_->StepSize());
}
}
void NeuralResidualEchoEstimatorImpl::Estimate(
webrtc::ArrayView<const float> x,
webrtc::ArrayView<const std::array<float, kBlockSize>> y,
webrtc::ArrayView<const std::array<float, kBlockSize>> e,
webrtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> S2,
webrtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
webrtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
webrtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2,
webrtc::ArrayView<std::array<float, kFftLengthBy2Plus1>> R2_unbounded) {
// The input is buffered for model inference; multi-channel data is handled by
// summing the content of all channels.
input_mic_buffer_.insert(input_mic_buffer_.end(), y[0].begin(), y[0].end());
input_linear_aec_output_buffer_.insert(input_linear_aec_output_buffer_.end(),
e[0].begin(), e[0].end());
for (size_t ch = 1; ch < y.size(); ++ch) {
std::transform(y[ch].begin(), y[ch].end(),
input_mic_buffer_.end() - kBlockSize,
input_mic_buffer_.end() - kBlockSize, std::plus<float>());
std::transform(e[ch].begin(), e[ch].end(),
input_linear_aec_output_buffer_.end() - kBlockSize,
input_linear_aec_output_buffer_.end() - kBlockSize,
std::plus<float>());
}
input_aec_ref_buffer_.insert(input_aec_ref_buffer_.end(), x.begin(), x.end());
if (static_cast<int>(input_mic_buffer_.size()) == model_runner_->StepSize()) {
DumpInputs();
feature_extractor_->PushFeaturesToModelInput(
input_mic_buffer_, model_runner_->GetInput(ModelInputEnum::kMic));
feature_extractor_->PushFeaturesToModelInput(
input_linear_aec_output_buffer_,
model_runner_->GetInput(ModelInputEnum::kLinearAecOutput));
feature_extractor_->PushFeaturesToModelInput(
input_aec_ref_buffer_,
model_runner_->GetInput(ModelInputEnum::kAecRef));
if (model_runner_->Invoke()) {
// Downsample output mask to match the AEC3 frequency resolution.
webrtc::ArrayView<const float> output_mask =
model_runner_->GetOutputEchoMask();
const int kDownsampleFactor = (output_mask.size() - 1) / kFftLengthBy2;
output_mask_[0] = output_mask[0];
for (size_t i = 1; i < kFftLengthBy2Plus1; ++i) {
const auto* output_mask_ptr =
&output_mask[kDownsampleFactor * (i - 1) + 1];
output_mask_[i] = *std::max_element(
output_mask_ptr, output_mask_ptr + kDownsampleFactor);
}
// The model is trained to predict the nearend magnitude spectrum but
// exposes 1 minus that mask. The next transformation computes the mask
// that estimates the echo power spectrum assuming that the sum of the
// power spectra of the nearend and the echo produces the power spectrum
// of the input microphone signal.
for (float& m : output_mask_) {
m = 1.0f - (1.0f - m) * (1.0f - m);
}
data_dumper_->DumpRaw("ml_ree_model_mask", output_mask);
data_dumper_->DumpRaw("ml_ree_output_mask", output_mask_);
}
}
// Use the latest output mask to produce output echo power estimates.
for (size_t ch = 0; ch < E2.size(); ++ch) {
std::transform(E2[ch].begin(), E2[ch].end(), output_mask_.begin(),
R2[ch].begin(),
[](float power, float mask) { return power * mask; });
std::copy(R2[ch].begin(), R2[ch].end(), R2_unbounded[ch].begin());
}
}
EchoCanceller3Config NeuralResidualEchoEstimatorImpl::GetConfiguration(
bool multi_channel) const {
EchoCanceller3Config config;
EchoCanceller3Config::Suppressor::MaskingThresholds tuning_masking_thresholds(
/*enr_transparent=*/0.0f, /*enr_suppress=*/1.0f,
/*emr_transparent=*/0.3f);
EchoCanceller3Config::Suppressor::Tuning tuning(
/*mask_lf=*/tuning_masking_thresholds,
/*mask_hf=*/tuning_masking_thresholds, /*max_inc_factor=*/100.0f,
/*max_dec_factor_lf=*/0.0f);
config.filter.enable_coarse_filter_output_usage = false;
config.suppressor.nearend_average_blocks = 1;
config.suppressor.normal_tuning = tuning;
config.suppressor.nearend_tuning = tuning;
config.suppressor.dominant_nearend_detection.enr_threshold = 0.5f;
config.suppressor.dominant_nearend_detection.trigger_threshold = 2;
config.suppressor.high_frequency_suppression.limiting_gain_band = 24;
config.suppressor.high_frequency_suppression.bands_in_limiting_gain = 3;
return config;
}
void NeuralResidualEchoEstimatorImpl::DumpInputs() {
data_dumper_->DumpWav("ml_ree_mic_input", input_mic_buffer_, 16000, 1);
data_dumper_->DumpWav("ml_ree_linear_aec_output",
input_linear_aec_output_buffer_, 16000, 1);
data_dumper_->DumpWav("ml_ree_aec_ref", input_aec_ref_buffer_, 16000, 1);
}
} // namespace webrtc