Compensate for the IntelligibilityEnhancer processing delay in high bands

Before this CL, the IntelligibilityEnhancer introduced a processing delay to the lower band, without compensating for it in the higher bands. This CL corrects this.

BUG=b/30780909
R=henrik.lundin@webrtc.org, peah@webrtc.org

Review URL: https://codereview.webrtc.org/2320833002 .

Cr-Commit-Position: refs/heads/master@{#14311}
diff --git a/webrtc/common_audio/blocker.h b/webrtc/common_audio/blocker.h
index 07f9f1a..d941c2a 100644
--- a/webrtc/common_audio/blocker.h
+++ b/webrtc/common_audio/blocker.h
@@ -79,6 +79,8 @@
                     size_t num_output_channels,
                     float* const* output);
 
+  size_t initial_delay() const { return initial_delay_; }
+
  private:
   const size_t chunk_size_;
   const size_t block_size_;
diff --git a/webrtc/common_audio/lapped_transform.h b/webrtc/common_audio/lapped_transform.h
index 0d668d0..42a103a 100644
--- a/webrtc/common_audio/lapped_transform.h
+++ b/webrtc/common_audio/lapped_transform.h
@@ -86,6 +86,12 @@
   // constructor.
   size_t num_out_channels() const { return num_out_channels_; }
 
+  // Returns the initial delay.
+  //
+  // This is the delay introduced by the |blocker_| to be able to get and return
+  // chunks of |chunk_length|, but process blocks of |block_length|.
+  size_t initial_delay() const { return blocker_.initial_delay(); }
+
  private:
   // Internal middleware callback, given to the blocker. Transforms each block
   // and hands it over to the processing method given at construction time.
diff --git a/webrtc/modules/audio_processing/audio_processing_impl.cc b/webrtc/modules/audio_processing/audio_processing_impl.cc
index 300fc02..5e08b85 100644
--- a/webrtc/modules/audio_processing/audio_processing_impl.cc
+++ b/webrtc/modules/audio_processing/audio_processing_impl.cc
@@ -1118,8 +1118,7 @@
 #if WEBRTC_INTELLIGIBILITY_ENHANCER
   if (capture_nonlocked_.intelligibility_enabled) {
     public_submodules_->intelligibility_enhancer->ProcessRenderAudio(
-        render_buffer->split_channels_f(kBand0To8kHz),
-        capture_nonlocked_.split_rate, render_buffer->num_channels());
+        render_buffer);
   }
 #endif
 
@@ -1342,6 +1341,7 @@
     public_submodules_->intelligibility_enhancer.reset(
         new IntelligibilityEnhancer(capture_nonlocked_.split_rate,
                                     render_.render_audio->num_channels(),
+                                    render_.render_audio->num_bands(),
                                     NoiseSuppressionImpl::num_noise_bins()));
   }
 #endif
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
index f3d023e..f9d1c3c 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
@@ -68,6 +68,7 @@
 
 IntelligibilityEnhancer::IntelligibilityEnhancer(int sample_rate_hz,
                                                  size_t num_render_channels,
+                                                 size_t num_bands,
                                                  size_t num_noise_bins)
     : freqs_(RealFourier::ComplexLength(
           RealFourier::FftOrder(sample_rate_hz * kWindowSizeMs / 1000))),
@@ -110,14 +111,24 @@
   render_mangler_.reset(new LappedTransform(
       num_render_channels_, num_render_channels_, chunk_length_,
       kbd_window.data(), window_size, window_size / 2, this));
+
+  const size_t initial_delay = render_mangler_->initial_delay();
+  for (size_t i = 0u; i < num_bands - 1; ++i) {
+    high_bands_buffers_.push_back(std::unique_ptr<intelligibility::DelayBuffer>(
+        new intelligibility::DelayBuffer(initial_delay, num_render_channels_)));
+  }
 }
 
 IntelligibilityEnhancer::~IntelligibilityEnhancer() {
-  // Don't rely on this log, since the destructor isn't called when the app/tab
-  // is killed.
-  LOG(LS_INFO) << "Intelligibility Enhancer was active for "
-               << static_cast<float>(num_active_chunks_) / num_chunks_
-               << "% of the call.";
+  // Don't rely on this log, since the destructor isn't called when the
+  // app/tab is killed.
+  if (num_chunks_ > 0) {
+    LOG(LS_INFO) << "Intelligibility Enhancer was active for "
+                 << 100.f * static_cast<float>(num_active_chunks_) / num_chunks_
+                 << "% of the call.";
+  } else {
+    LOG(LS_INFO) << "Intelligibility Enhancer processed no chunk.";
+  }
 }
 
 void IntelligibilityEnhancer::SetCaptureNoiseEstimate(
@@ -132,16 +143,15 @@
   };
 }
 
-void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio,
-                                                 int sample_rate_hz,
-                                                 size_t num_channels) {
-  RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz);
-  RTC_CHECK_EQ(num_render_channels_, num_channels);
+void IntelligibilityEnhancer::ProcessRenderAudio(AudioBuffer* audio) {
+  RTC_DCHECK_EQ(num_render_channels_, audio->num_channels());
   while (noise_estimation_queue_.Remove(&noise_estimation_buffer_)) {
     noise_power_estimator_.Step(noise_estimation_buffer_.data());
   }
-  is_speech_ = IsSpeech(audio[0]);
-  render_mangler_->ProcessChunk(audio, audio);
+  float* const* low_band = audio->split_channels_f(kBand0To8kHz);
+  is_speech_ = IsSpeech(low_band[0]);
+  render_mangler_->ProcessChunk(low_band, low_band);
+  DelayHighBands(audio);
 }
 
 void IntelligibilityEnhancer::ProcessAudioBlock(
@@ -369,4 +379,12 @@
   return chunks_since_voice_ < kSpeechOffsetDelay;
 }
 
+void IntelligibilityEnhancer::DelayHighBands(AudioBuffer* audio) {
+  RTC_DCHECK_EQ(audio->num_bands(), high_bands_buffers_.size() + 1u);
+  for (size_t i = 0u; i < high_bands_buffers_.size(); ++i) {
+    Band band = static_cast<Band>(i + 1);
+    high_bands_buffers_[i]->Delay(audio->split_channels_f(band), chunk_length_);
+  }
+}
+
 }  // namespace webrtc
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
index 3af1190..f7306cb 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
@@ -16,8 +16,9 @@
 #include <vector>
 
 #include "webrtc/base/swap_queue.h"
-#include "webrtc/common_audio/lapped_transform.h"
 #include "webrtc/common_audio/channel_buffer.h"
+#include "webrtc/common_audio/lapped_transform.h"
+#include "webrtc/modules/audio_processing/audio_buffer.h"
 #include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h"
 #include "webrtc/modules/audio_processing/render_queue_item_verifier.h"
 #include "webrtc/modules/audio_processing/vad/voice_activity_detector.h"
@@ -33,6 +34,7 @@
  public:
   IntelligibilityEnhancer(int sample_rate_hz,
                           size_t num_render_channels,
+                          size_t num_bands,
                           size_t num_noise_bins);
 
   ~IntelligibilityEnhancer() override;
@@ -41,9 +43,7 @@
   void SetCaptureNoiseEstimate(std::vector<float> noise, float gain);
 
   // Reads chunk of speech in time domain and updates with modified signal.
-  void ProcessRenderAudio(float* const* audio,
-                          int sample_rate_hz,
-                          size_t num_channels);
+  void ProcessRenderAudio(AudioBuffer* audio);
   bool active() const;
 
  protected:
@@ -56,10 +56,13 @@
                          std::complex<float>* const* out_block) override;
 
  private:
+  FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestRenderUpdate);
   FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestErbCreation);
   FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestSolveForGains);
   FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest,
                            TestNoiseGainHasExpectedResult);
+  FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest,
+                           TestAllBandsHaveSameDelay);
 
   // Updates the SNR estimation and enables or disables this component using a
   // hysteresis.
@@ -84,6 +87,10 @@
   // Returns true if the audio is speech.
   bool IsSpeech(const float* audio);
 
+  // Delays the high bands to compensate for the processing delay in the low
+  // band.
+  void DelayHighBands(AudioBuffer* audio);
+
   static const size_t kMaxNumNoiseEstimatesToBuffer = 5;
 
   const size_t freqs_;         // Num frequencies in frequency domain.
@@ -120,6 +127,9 @@
   std::vector<float> noise_estimation_buffer_;
   SwapQueue<std::vector<float>, RenderQueueItemVerifier<float>>
       noise_estimation_queue_;
+
+  std::vector<std::unique_ptr<intelligibility::DelayBuffer>>
+      high_bands_buffers_;
 };
 
 }  // namespace webrtc
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
index 45f338c..0c56870 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
@@ -202,11 +202,12 @@
 const float kMaxTestError = 0.005f;
 
 // Enhancer initialization parameters.
-const int kSamples = 1000;
+const int kSamples = 10000;
 const int kSampleRate = 4000;
 const int kNumChannels = 1;
 const int kFragmentSize = kSampleRate / 100;
 const size_t kNumNoiseBins = 129;
+const size_t kNumBands = 1;
 
 // Number of frames to process in the bitexactness tests.
 const size_t kNumFramesToProcess = 1000;
@@ -228,10 +229,7 @@
     capture_audio_buffer->SplitIntoFrequencyBands();
   }
 
-  intelligibility_enhancer->ProcessRenderAudio(
-      render_audio_buffer->split_channels_f(kBand0To8kHz),
-      IntelligibilityEnhancerSampleRate(sample_rate_hz),
-      render_audio_buffer->num_channels());
+  intelligibility_enhancer->ProcessRenderAudio(render_audio_buffer);
 
   noise_suppressor->AnalyzeCaptureAudio(capture_audio_buffer);
   noise_suppressor->ProcessCaptureAudio(capture_audio_buffer);
@@ -276,7 +274,8 @@
 
   IntelligibilityEnhancer intelligibility_enhancer(
       IntelligibilityEnhancerSampleRate(sample_rate_hz),
-      render_config.num_channels(), NoiseSuppressionImpl::num_noise_bins());
+      render_config.num_channels(), kNumBands,
+      NoiseSuppressionImpl::num_noise_bins());
 
   for (size_t frame_no = 0u; frame_no < kNumFramesToProcess; ++frame_no) {
     ReadFloatSamplesFromStereoFile(render_buffer.num_frames(),
@@ -320,24 +319,34 @@
 class IntelligibilityEnhancerTest : public ::testing::Test {
  protected:
   IntelligibilityEnhancerTest()
-      : clear_data_(kSamples), noise_data_(kSamples), orig_data_(kSamples) {
+      : clear_buffer_(kFragmentSize,
+                      kNumChannels,
+                      kFragmentSize,
+                      kNumChannels,
+                      kFragmentSize),
+        stream_config_(kSampleRate, kNumChannels),
+        clear_data_(kSamples),
+        noise_data_(kNumNoiseBins),
+        orig_data_(kSamples) {
     std::srand(1);
-    enh_.reset(
-        new IntelligibilityEnhancer(kSampleRate, kNumChannels, kNumNoiseBins));
+    enh_.reset(new IntelligibilityEnhancer(kSampleRate, kNumChannels, kNumBands,
+                                           kNumNoiseBins));
   }
 
   bool CheckUpdate() {
-    enh_.reset(
-        new IntelligibilityEnhancer(kSampleRate, kNumChannels, kNumNoiseBins));
+    enh_.reset(new IntelligibilityEnhancer(kSampleRate, kNumChannels, kNumBands,
+                                           kNumNoiseBins));
     float* clear_cursor = clear_data_.data();
-    float* noise_cursor = noise_data_.data();
     for (int i = 0; i < kSamples; i += kFragmentSize) {
-      enh_->ProcessRenderAudio(&clear_cursor, kSampleRate, kNumChannels);
+      enh_->SetCaptureNoiseEstimate(noise_data_, 1);
+      clear_buffer_.CopyFrom(&clear_cursor, stream_config_);
+      enh_->ProcessRenderAudio(&clear_buffer_);
+      clear_buffer_.CopyTo(stream_config_, &clear_cursor);
       clear_cursor += kFragmentSize;
-      noise_cursor += kFragmentSize;
     }
-    for (int i = 0; i < kSamples; i++) {
-      if (std::fabs(clear_data_[i] - orig_data_[i]) > kMaxTestError) {
+    for (int i = initial_delay_; i < kSamples; i++) {
+      if (std::fabs(clear_data_[i] - orig_data_[i - initial_delay_]) >
+          kMaxTestError) {
         return true;
       }
     }
@@ -345,22 +354,30 @@
   }
 
   std::unique_ptr<IntelligibilityEnhancer> enh_;
+  // Render clean speech buffer.
+  AudioBuffer clear_buffer_;
+  StreamConfig stream_config_;
   std::vector<float> clear_data_;
   std::vector<float> noise_data_;
   std::vector<float> orig_data_;
+  size_t initial_delay_;
 };
 
 // For each class of generated data, tests that render stream is updated when
 // it should be.
 TEST_F(IntelligibilityEnhancerTest, TestRenderUpdate) {
+  initial_delay_ = enh_->render_mangler_->initial_delay();
   std::fill(noise_data_.begin(), noise_data_.end(), 0.f);
   std::fill(orig_data_.begin(), orig_data_.end(), 0.f);
   std::fill(clear_data_.begin(), clear_data_.end(), 0.f);
   EXPECT_FALSE(CheckUpdate());
-  std::generate(noise_data_.begin(), noise_data_.end(), float_rand);
+  std::generate(clear_data_.begin(), clear_data_.end(), float_rand);
+  orig_data_ = clear_data_;
   EXPECT_FALSE(CheckUpdate());
   std::generate(clear_data_.begin(), clear_data_.end(), float_rand);
   orig_data_ = clear_data_;
+  std::generate(noise_data_.begin(), noise_data_.end(), float_rand);
+  FloatToFloatS16(noise_data_.data(), noise_data_.size(), noise_data_.data());
   EXPECT_TRUE(CheckUpdate());
 }
 
@@ -418,7 +435,8 @@
   float* clear_cursor = clear_data_.data();
   for (size_t i = 0; i < kNumFramesToProcess; ++i) {
     enh_->SetCaptureNoiseEstimate(noise, kGain);
-    enh_->ProcessRenderAudio(&clear_cursor, kSampleRate, kNumChannels);
+    clear_buffer_.CopyFrom(&clear_cursor, stream_config_);
+    enh_->ProcessRenderAudio(&clear_buffer_);
   }
   const std::vector<float>& estimated_psd =
       enh_->noise_power_estimator_.power();
@@ -428,6 +446,41 @@
   }
 }
 
+TEST_F(IntelligibilityEnhancerTest, TestAllBandsHaveSameDelay) {
+  const int kTestSampleRate = AudioProcessing::kSampleRate32kHz;
+  const int kTestSplitRate = AudioProcessing::kSampleRate16kHz;
+  const size_t kTestNumBands =
+      rtc::CheckedDivExact(kTestSampleRate, kTestSplitRate);
+  const size_t kTestFragmentSize = rtc::CheckedDivExact(kTestSampleRate, 100);
+  const size_t kTestSplitFragmentSize =
+      rtc::CheckedDivExact(kTestSplitRate, 100);
+  enh_.reset(new IntelligibilityEnhancer(kTestSplitRate, kNumChannels,
+                                         kTestNumBands, kNumNoiseBins));
+  size_t initial_delay = enh_->render_mangler_->initial_delay();
+  std::vector<float> rand_gen_buf(kTestFragmentSize);
+  AudioBuffer original_buffer(kTestFragmentSize, kNumChannels,
+                              kTestFragmentSize, kNumChannels,
+                              kTestFragmentSize);
+  AudioBuffer audio_buffer(kTestFragmentSize, kNumChannels, kTestFragmentSize,
+                           kNumChannels, kTestFragmentSize);
+  for (size_t i = 0u; i < kTestNumBands; ++i) {
+    std::generate(rand_gen_buf.begin(), rand_gen_buf.end(), float_rand);
+    original_buffer.split_data_f()->SetDataForTesting(rand_gen_buf.data(),
+                                                      rand_gen_buf.size());
+    audio_buffer.split_data_f()->SetDataForTesting(rand_gen_buf.data(),
+                                                   rand_gen_buf.size());
+  }
+  enh_->ProcessRenderAudio(&audio_buffer);
+  for (size_t i = 0u; i < kTestNumBands; ++i) {
+    const float* original_ptr = original_buffer.split_bands_const_f(0)[i];
+    const float* audio_ptr = audio_buffer.split_bands_const_f(0)[i];
+    for (size_t j = initial_delay; j < kTestSplitFragmentSize; ++j) {
+      EXPECT_LT(std::fabs(original_ptr[j - initial_delay] - audio_ptr[j]),
+                kMaxTestError);
+    }
+  }
+}
+
 TEST(IntelligibilityEnhancerBitExactnessTest, DISABLED_Mono8kHz) {
   const float kOutputReference[] = {-0.001892f, -0.003296f, -0.001953f};
 
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc
index 6e641a2..fa8d170 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.cc
@@ -66,6 +66,27 @@
   }
 }
 
+DelayBuffer::DelayBuffer(size_t delay, size_t num_channels)
+    : buffer_(num_channels, std::vector<float>(delay, 0.f)), read_index_(0u) {}
+
+DelayBuffer::~DelayBuffer() {}
+
+void DelayBuffer::Delay(float* const* data, size_t length) {
+  size_t sample_index = read_index_;
+  for (size_t i = 0u; i < buffer_.size(); ++i) {
+    sample_index = read_index_;
+    for (size_t j = 0u; j < length; ++j) {
+      float swap = data[i][j];
+      data[i][j] = buffer_[i][sample_index];
+      buffer_[i][sample_index] = swap;
+      if (++sample_index == buffer_.size()) {
+        sample_index = 0u;
+      }
+    }
+  }
+  read_index_ = sample_index;
+}
+
 }  // namespace intelligibility
 
 }  // namespace webrtc
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h
index b5cc075..2566616 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h
@@ -65,6 +65,20 @@
   std::vector<float> current_;
 };
 
+// Helper class to delay a signal by an integer number of samples.
+class DelayBuffer {
+ public:
+  DelayBuffer(size_t delay, size_t num_channels);
+
+  ~DelayBuffer();
+
+  void Delay(float* const* data, size_t length);
+
+ private:
+  std::vector<std::vector<float>> buffer_;
+  size_t read_index_;
+};
+
 }  // namespace intelligibility
 
 }  // namespace webrtc
diff --git a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
index abd10d8..c928124 100644
--- a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
+++ b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
@@ -39,7 +39,7 @@
                      in_file.num_channels());
   rtc::CriticalSection crit;
   NoiseSuppressionImpl ns(&crit);
-  IntelligibilityEnhancer enh(in_file.sample_rate(), in_file.num_channels(),
+  IntelligibilityEnhancer enh(in_file.sample_rate(), in_file.num_channels(), 1u,
                               NoiseSuppressionImpl::num_noise_bins());
   ns.Initialize(noise_file.num_channels(), noise_file.sample_rate());
   ns.Enable(true);
@@ -52,23 +52,29 @@
   AudioBuffer capture_audio(noise_samples, noise_file.num_channels(),
                             noise_samples, noise_file.num_channels(),
                             noise_samples);
-  StreamConfig stream_config(noise_file.sample_rate(),
-                             noise_file.num_channels());
+  AudioBuffer render_audio(in_samples, in_file.num_channels(), in_samples,
+                           in_file.num_channels(), in_samples);
+  StreamConfig noise_config(noise_file.sample_rate(),
+                            noise_file.num_channels());
+  StreamConfig in_config(in_file.sample_rate(), in_file.num_channels());
   while (in_file.ReadSamples(in.size(), in.data()) == in.size() &&
          noise_file.ReadSamples(noise.size(), noise.data()) == noise.size()) {
     FloatS16ToFloat(noise.data(), noise.size(), noise.data());
+    FloatS16ToFloat(in.data(), in.size(), in.data());
     Deinterleave(in.data(), in_buf.num_frames(), in_buf.num_channels(),
                  in_buf.channels());
     Deinterleave(noise.data(), noise_buf.num_frames(), noise_buf.num_channels(),
                  noise_buf.channels());
-    capture_audio.CopyFrom(noise_buf.channels(), stream_config);
+    capture_audio.CopyFrom(noise_buf.channels(), noise_config);
+    render_audio.CopyFrom(in_buf.channels(), in_config);
     ns.AnalyzeCaptureAudio(&capture_audio);
     ns.ProcessCaptureAudio(&capture_audio);
-    enh.SetCaptureNoiseEstimate(ns.NoiseEstimate(), 0);
-    enh.ProcessRenderAudio(in_buf.channels(), in_file.sample_rate(),
-                           in_file.num_channels());
+    enh.SetCaptureNoiseEstimate(ns.NoiseEstimate(), 1);
+    enh.ProcessRenderAudio(&render_audio);
+    render_audio.CopyTo(in_config, in_buf.channels());
     Interleave(in_buf.channels(), in_buf.num_frames(), in_buf.num_channels(),
                in.data());
+    FloatToFloatS16(in.data(), in.size(), in.data());
     out_file.WriteSamples(in.data(), in.size());
   }
 }