APM Transient Suppressor (TS): wire-up RNN VAD, TS and AGC2

When the `WebRTC-Audio-TransientSuppressorVadMode-RnnVad` field trial
is set, APM now uses (i) its RNN VAD sub-module to compute the voice
probability, (ii) that probability for TS and (iii) a temporally
delayed version of it for AGC2 (the delay introduced by TS is taken
into account).

Bug: webrtc:13663
Change-Id: Ic0f245c3f00d318c19bb01d3dbc2d5176c90f851
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/266362
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Hanna Silen <silen@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37291}
diff --git a/modules/audio_processing/audio_processing_impl.cc b/modules/audio_processing/audio_processing_impl.cc
index 5714d6b..31a6a14 100644
--- a/modules/audio_processing/audio_processing_impl.cc
+++ b/modules/audio_processing/audio_processing_impl.cc
@@ -1274,28 +1274,42 @@
                                       capture_buffer->num_frames()));
     }
 
+    absl::optional<float> voice_probability;
+    if (!!submodules_.voice_activity_detector) {
+      voice_probability = submodules_.voice_activity_detector->Analyze(
+          AudioFrameView<const float>(capture_buffer->channels(),
+                                      capture_buffer->num_channels(),
+                                      capture_buffer->num_frames()));
+    }
+
     if (submodules_.transient_suppressor) {
-      float voice_probability = 1.0f;
+      float transient_suppressor_voice_probability = 1.0f;
       switch (transient_suppressor_vad_mode_) {
         case TransientSuppressor::VadMode::kDefault:
           if (submodules_.agc_manager) {
-            voice_probability = submodules_.agc_manager->voice_probability();
+            transient_suppressor_voice_probability =
+                submodules_.agc_manager->voice_probability();
           }
           break;
         case TransientSuppressor::VadMode::kRnnVad:
-          // TODO(bugs.webrtc.org/13663): Use RNN VAD.
+          RTC_DCHECK(voice_probability.has_value());
+          transient_suppressor_voice_probability = *voice_probability;
           break;
         case TransientSuppressor::VadMode::kNoVad:
           // The transient suppressor will ignore `voice_probability`.
           break;
       }
-      submodules_.transient_suppressor->Suppress(
-          capture_buffer->channels()[0], capture_buffer->num_frames(),
-          capture_buffer->num_channels(),
-          capture_buffer->split_bands_const(0)[kBand0To8kHz],
-          capture_buffer->num_frames_per_band(),
-          /*reference_data=*/nullptr, /*reference_length=*/0, voice_probability,
-          capture_.key_pressed);
+      float delayed_voice_probability =
+          submodules_.transient_suppressor->Suppress(
+              capture_buffer->channels()[0], capture_buffer->num_frames(),
+              capture_buffer->num_channels(),
+              capture_buffer->split_bands_const(0)[kBand0To8kHz],
+              capture_buffer->num_frames_per_band(),
+              /*reference_data=*/nullptr, /*reference_length=*/0,
+              transient_suppressor_voice_probability, capture_.key_pressed);
+      if (voice_probability.has_value()) {
+        *voice_probability = delayed_voice_probability;
+      }
     }
 
     // Experimental APM sub-module that analyzes `capture_buffer`.
@@ -1303,19 +1317,10 @@
       submodules_.capture_analyzer->Analyze(capture_buffer);
     }
 
-    absl::optional<float> voice_activity_probability = absl::nullopt;
     if (submodules_.gain_controller2) {
       submodules_.gain_controller2->NotifyAnalogLevel(
           recommended_stream_analog_level_locked());
-      if (submodules_.voice_activity_detector) {
-        voice_activity_probability =
-            submodules_.voice_activity_detector->Analyze(
-                AudioFrameView<const float>(capture_buffer->channels(),
-                                            capture_buffer->num_channels(),
-                                            capture_buffer->num_frames()));
-      }
-      submodules_.gain_controller2->Process(voice_activity_probability,
-                                            capture_buffer);
+      submodules_.gain_controller2->Process(voice_probability, capture_buffer);
     }
 
     if (submodules_.capture_post_processor) {