Total Harmonic Distorsion plus noise (THD+n) score in APM-QA.

In order to compute a THD score, a pure tone must be used as input signal.
Also, its frequency must be known. For this reason, this CL adds a number of
changes in the APM-QA pipeline. More in detail, input signal metadata is loaded
and passed to the THD evaluation score instance. This makes the eval_scores
module less reusable, but it is fine since the module has been specifically
designed for the APM-QA module.

BUG=webrtc:7494

Review-Url: https://codereview.webrtc.org/3010413002
Cr-Commit-Position: refs/heads/master@{#19970}
diff --git a/modules/audio_processing/test/py_quality_assessment/README.md b/modules/audio_processing/test/py_quality_assessment/README.md
index e19a780..79e1650 100644
--- a/modules/audio_processing/test/py_quality_assessment/README.md
+++ b/modules/audio_processing/test/py_quality_assessment/README.md
@@ -33,6 +33,12 @@
  - Go to `out/Default/py_quality_assessment` and check that
    `apm_quality_assessment.py` exists
 
+## Unit tests
+
+ - Compile WebRTC
+ - Go to `out/Default/py_quality_assessment`
+ - Run `python -m unittest -p "*_unittest.py" discover`
+
 ## First time setup
 
  - Deploy PolqaOem64 and set the `POLQA_PATH` environment variable
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py
index 826a089..c488859 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py
@@ -31,9 +31,34 @@
   def __init__(self):
     pass
 
+  _GENERIC_METADATA_SUFFIX = '.mdata'
   _AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
 
   @classmethod
+  def LoadFileMetadata(cls, filepath):
+    """Loads generic metadata linked to a file.
+
+    Args:
+      filepath: path to the metadata file to read.
+
+    Returns:
+      A dict.
+    """
+    with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
+      return json.load(f)
+
+  @classmethod
+  def SaveFileMetadata(cls, filepath, metadata):
+    """Saves generic metadata linked to a file.
+
+    Args:
+      filepath: path to the metadata file to write.
+      metadata: a dict.
+    """
+    with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
+      json.dump(metadata, f)
+
+  @classmethod
   def LoadAudioTestDataPaths(cls, metadata_path):
     """Loads the input and the reference audio track paths.
 
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py
index 78d0c18..420afd2 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py
@@ -14,6 +14,13 @@
 import os
 import re
 import subprocess
+import sys
+
+try:
+  import numpy as np
+except ImportError:
+  logging.critical('Cannot import the third-party Python package numpy')
+  sys.exit(1)
 
 from . import data_access
 from . import exceptions
@@ -27,6 +34,7 @@
 
   def __init__(self, score_filename_prefix):
     self._score_filename_prefix = score_filename_prefix
+    self._input_signal_metadata = None
     self._reference_signal = None
     self._reference_signal_filepath = None
     self._tested_signal = None
@@ -56,8 +64,16 @@
   def score(self):
     return self._score
 
+  def SetInputSignalMetadata(self, metadata):
+    """Sets input signal metadata.
+
+    Args:
+      metadata: dict instance.
+    """
+    self._input_signal_metadata = metadata
+
   def SetReferenceSignalFilepath(self, filepath):
-    """ Sets the path to the audio track used as reference signal.
+    """Sets the path to the audio track used as reference signal.
 
     Args:
       filepath: path to the reference audio track.
@@ -65,7 +81,7 @@
     self._reference_signal_filepath = filepath
 
   def SetTestedSignalFilepath(self, filepath):
-    """ Sets the path to the audio track used as test signal.
+    """Sets the path to the audio track used as test signal.
 
     Args:
       filepath: path to the test audio track.
@@ -242,3 +258,84 @@
     # Build and return a dictionary with field names (header) as keys and the
     # corresponding field values as values.
     return {data[0][index]: data[1][index] for index in range(number_of_fields)}
+
+
+@EvaluationScore.RegisterClass
+class TotalHarmonicDistorsionScore(EvaluationScore):
+  """Total harmonic distorsion plus noise score.
+
+  Total harmonic distorsion plus noise score.
+  See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
+
+  Unit: -.
+  Ideal: 0.
+  Worst case: +inf
+  """
+
+  NAME = 'thd'
+
+  def __init__(self, score_filename_prefix):
+    EvaluationScore.__init__(self, score_filename_prefix)
+    self._input_frequency = None
+
+  def _Run(self, output_path):
+    # TODO(aleloi): Integrate changes made locally.
+    self._CheckInputSignal()
+
+    self._LoadTestedSignal()
+    if self._tested_signal.channels != 1:
+      raise exceptions.EvaluationScoreException(
+          'unsupported number of channels')
+    samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
+        self._tested_signal)
+
+    # Init.
+    num_samples = len(samples)
+    duration = len(self._tested_signal) / 1000.0
+    scaling = 2.0 / num_samples
+    max_freq = self._tested_signal.frame_rate / 2
+    f0_freq = float(self._input_frequency)
+    t = np.linspace(0, duration, num_samples)
+
+    # Analyze harmonics.
+    b_terms = []
+    n = 1
+    while f0_freq * n < max_freq:
+      x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
+      y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
+      b_terms.append(np.sqrt(x_n**2 + y_n**2))
+      n += 1
+
+    output_without_fundamental = samples - b_terms[0] * np.sin(
+        2.0 * np.pi * f0_freq * t)
+    distortion_and_noise = np.sqrt(np.sum(
+        output_without_fundamental**2) * np.pi * scaling)
+
+    # TODO(alessiob): Fix or remove if not needed.
+    # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
+
+    # TODO(alessiob): Check the range of |thd_plus_noise| and update the class
+    # docstring above if accordingly.
+    thd_plus_noise = distortion_and_noise / b_terms[0]
+
+    self._score = thd_plus_noise
+    self._SaveScore()
+
+  def _CheckInputSignal(self):
+    # Check input signal and get properties.
+    try:
+      if self._input_signal_metadata['signal'] != 'pure_tone':
+        raise exceptions.EvaluationScoreException(
+            'The THD score requires a pure tone as input signal')
+      self._input_frequency = self._input_signal_metadata['frequency']
+      if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
+          self._input_signal_metadata['test_data_gen_config'] != 'default'):
+        raise exceptions.EvaluationScoreException(
+            'The THD score cannot be used with any test data generator other '
+            'than "identity"')
+    except TypeError:
+      raise exceptions.EvaluationScoreException(
+          'The THD score requires an input signal with associated metadata')
+    except KeyError:
+      raise exceptions.EvaluationScoreException(
+          'Invalid input signal metadata to compute the THD score')
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
index b3bd4f9..ce51051 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
@@ -52,6 +52,9 @@
     shutil.rmtree(self._output_path)
 
   def testRegisteredClasses(self):
+    # Evaluation score names to exclude (tested separately).
+    exceptions = ['thd']
+
     # Preliminary check.
     self.assertTrue(os.path.exists(self._output_path))
 
@@ -69,11 +72,14 @@
 
     # Try each registered evaluation score worker.
     for eval_score_name in registered_classes:
+      if eval_score_name in exceptions:
+        continue
+
       # Instance evaluation score worker.
       eval_score_worker = eval_score_workers_factory.GetInstance(
           registered_classes[eval_score_name])
 
-      # Set reference and test, then run.
+      # Set fake input metadata and reference and test file paths, then run.
       eval_score_worker.SetReferenceSignalFilepath(
           self._fake_reference_signal_filepath)
       eval_score_worker.SetTestedSignalFilepath(
@@ -83,3 +89,43 @@
       # Check output.
       score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
       self.assertTrue(isinstance(score, float))
+
+  def testTotalHarmonicDistorsionScore(self):
+    # Init.
+    pure_tone_freq = 5000.0
+    eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
+    eval_score_worker.SetInputSignalMetadata({
+        'signal': 'pure_tone',
+        'frequency': pure_tone_freq,
+        'test_data_gen_name': 'identity',
+        'test_data_gen_config': 'default',
+    })
+    template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
+
+    # Create 3 test signals: pure tone, pure tone + white noise, white noise
+    # only.
+    pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
+        template, pure_tone_freq)
+    white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
+        template)
+    noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
+        pure_tone, white_noise)
+
+    # Compute scores for increasingly distorted pure tone signals.
+    scores = [None, None, None]
+    for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
+      # Save signal.
+      tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
+      signal_processing.SignalProcessingUtils.SaveWav(
+          tmp_filepath, tested_signal)
+
+      # Compute score.
+      eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
+      eval_score_worker.Run(self._output_path)
+      scores[index] = eval_score_worker.score
+
+      # Remove output file to avoid caching.
+      os.remove(eval_score_worker.output_filepath)
+
+    # Validate scores (lowest score with a pure tone).
+    self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py
index e18f193..09ded4c 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py
@@ -20,14 +20,15 @@
     pass
 
   @classmethod
-  def Run(cls, evaluation_score_workers, apm_output_filepath,
-          reference_input_filepath, output_path):
+  def Run(cls, evaluation_score_workers, apm_input_metadata,
+          apm_output_filepath, reference_input_filepath, output_path):
     """Runs the evaluation.
 
     Iterates over the given evaluation score workers.
 
     Args:
       evaluation_score_workers: list of EvaluationScore instances.
+      apm_input_metadata: dictionary with metadata of the APM input.
       apm_output_filepath: path to the audio track file with the APM output.
       reference_input_filepath: path to the reference audio track file.
       output_path: output path.
@@ -40,6 +41,7 @@
 
     for evaluation_score_worker in evaluation_score_workers:
       logging.info('   computing <%s> score', evaluation_score_worker.NAME)
+      evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
       evaluation_score_worker.SetReferenceSignalFilepath(
           reference_input_filepath)
       evaluation_score_worker.SetTestedSignalFilepath(
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
index 0f7716a..b13b35b 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
@@ -32,3 +32,9 @@
   """Input signal creator exeception.
   """
   pass
+
+
+class EvaluationScoreException(Exception):
+  """Evaluation score exeception.
+  """
+  pass
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py
index e2a720c..5d97c3b 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py
@@ -18,26 +18,36 @@
   """
 
   @classmethod
-  def Create(cls, name, params):
-    """Creates a input signal.
+  def Create(cls, name, raw_params):
+    """Creates a input signal and its metadata.
 
     Args:
       name: Input signal creator name.
-      params: Tuple of parameters to pass to the specific signal creator.
+      raw_params: Tuple of parameters to pass to the specific signal creator.
 
     Returns:
-      AudioSegment instance.
+      (AudioSegment, dict) tuple.
     """
     try:
+      signal = {}
+      params = {}
+
       if name == 'pure_tone':
-        return cls._CreatePureTone(float(params[0]), int(params[1]))
+        params['frequency'] = float(raw_params[0])
+        params['duration'] = int(raw_params[1])
+        signal = cls._CreatePureTone(params['frequency'], params['duration'])
+      else:
+        raise exceptions.InputSignalCreatorException(
+            'Invalid input signal creator name')
+
+      # Complete metadata.
+      params['signal'] = name
+
+      return signal, params
     except (TypeError, AssertionError) as e:
       raise exceptions.InputSignalCreatorException(
           'Invalid signal creator parameters: {}'.format(e))
 
-    raise exceptions.InputSignalCreatorException(
-        'Invalid input signal creator name')
-
   @classmethod
   def _CreatePureTone(cls, frequency, duration):
     """
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py
index 9a1f279..5beb3fb 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py
@@ -149,6 +149,13 @@
         volume=0.0)
 
   @classmethod
+  def AudioSegmentToRawData(cls, signal):
+    samples = signal.get_array_of_samples()
+    if samples.typecode != 'h':
+      raise exceptions.SignalProcessingException('Unsupported samples type')
+    return np.array(signal.get_array_of_samples(), np.int16)
+
+  @classmethod
   def DetectHardClipping(cls, signal, threshold=2):
     """Detects hard clipping.
 
@@ -169,13 +176,7 @@
     if signal.sample_width != 2:  # Note that signal.sample_width is in bytes.
       raise exceptions.SignalProcessingException(
           'hard-clipping detection only supported for 16 bit samples')
-
-    # Get raw samples, check type, cast.
-    samples = signal.get_array_of_samples()
-    if samples.typecode != 'h':
-      raise exceptions.SignalProcessingException(
-          'hard-clipping detection only supported for 16 bit samples')
-    samples = np.array(signal.get_array_of_samples(), np.int16)
+    samples = cls.AudioSegmentToRawData(signal)
 
     # Detect adjacent clipped samples.
     samples_type_info = np.iinfo(samples.dtype)
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
index 7023b6a..b256940 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
@@ -17,6 +17,7 @@
 from . import echo_path_simulation_factory
 from . import eval_scores
 from . import eval_scores_factory
+from . import exceptions
 from . import input_mixer
 from . import test_data_generation
 from . import test_data_generation_factory
@@ -248,9 +249,20 @@
         test_data_cache_path=test_data_cache_path,
         base_output_path=output_path)
 
+    # Extract metadata linked to the clean input file (if any).
+    apm_input_metadata = None
+    try:
+      apm_input_metadata = data_access.Metadata.LoadFileMetadata(
+          clean_capture_input_filepath)
+    except IOError as e:
+      apm_input_metadata = {}
+    apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
+    apm_input_metadata['test_data_gen_config'] = None
+
     # For each test data pair, simulate a call and evaluate.
     for config_name in test_data_generators.config_names:
       logging.info(' - test data generator config: <%s>', config_name)
+      apm_input_metadata['test_data_gen_config'] = config_name
 
       # Paths to the test data generator output.
       # Note that the reference signal does not depend on the render input
@@ -278,23 +290,28 @@
           render_input_filepath=render_input_filepath,
           output_path=evaluation_output_path)
 
-      # Evaluate.
-      self._evaluator.Run(
-          evaluation_score_workers=self._evaluation_score_workers,
-          apm_output_filepath=self._audioproc_wrapper.output_filepath,
-          reference_input_filepath=reference_signal_filepath,
-          output_path=evaluation_output_path)
+      try:
+        # Evaluate.
+        self._evaluator.Run(
+            evaluation_score_workers=self._evaluation_score_workers,
+            apm_input_metadata=apm_input_metadata,
+            apm_output_filepath=self._audioproc_wrapper.output_filepath,
+            reference_input_filepath=reference_signal_filepath,
+            output_path=evaluation_output_path)
 
-      # Save simulation metadata.
-      data_access.Metadata.SaveAudioTestDataPaths(
-          output_path=evaluation_output_path,
-          clean_capture_input_filepath=clean_capture_input_filepath,
-          echo_free_capture_filepath=noisy_capture_input_filepath,
-          echo_filepath=echo_path_filepath,
-          render_filepath=render_input_filepath,
-          capture_filepath=apm_input_filepath,
-          apm_output_filepath=self._audioproc_wrapper.output_filepath,
-          apm_reference_filepath=reference_signal_filepath)
+        # Save simulation metadata.
+        data_access.Metadata.SaveAudioTestDataPaths(
+            output_path=evaluation_output_path,
+            clean_capture_input_filepath=clean_capture_input_filepath,
+            echo_free_capture_filepath=noisy_capture_input_filepath,
+            echo_filepath=echo_path_filepath,
+            render_filepath=render_input_filepath,
+            capture_filepath=apm_input_filepath,
+            apm_output_filepath=self._audioproc_wrapper.output_filepath,
+            apm_reference_filepath=reference_signal_filepath)
+      except exceptions.EvaluationScoreException as e:
+        logging.warning('the evaluation failed: %s', e.message)
+        continue
 
   def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
                                    render_input_filepaths):
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
index 544ad97..33ee921 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
@@ -9,6 +9,7 @@
 """Unit tests for the simulation module.
 """
 
+import logging
 import os
 import shutil
 import sys
@@ -33,8 +34,9 @@
   """
 
   def setUp(self):
-    """Create temporary folder and fake audio track."""
+    """Create temporary folders and fake audio track."""
     self._output_path = tempfile.mkdtemp()
+    self._tmp_path = tempfile.mkdtemp()
 
     silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
     fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
@@ -46,6 +48,7 @@
   def tearDown(self):
     """Recursively delete temporary folders."""
     shutil.rmtree(self._output_path)
+    shutil.rmtree(self._tmp_path)
 
   def testSimulation(self):
     # Instance dependencies to inject and mock.
@@ -87,3 +90,39 @@
                             min_number_of_simulations)
     self.assertGreaterEqual(len(evaluator.Run.call_args_list),
                             min_number_of_simulations)
+
+  def testPureToneGenerationWithTotalHarmonicDistorsion(self):
+    logging.warning = mock.MagicMock(name='warning')
+
+    # Instance simulator.
+    simulator = simulation.ApmModuleSimulator(
+        aechen_ir_database_path='',
+        polqa_tool_bin_path=os.path.join(
+            os.path.dirname(__file__), 'fake_polqa'),
+        ap_wrapper=audioproc_wrapper.AudioProcWrapper(),
+        evaluator=evaluation.ApmModuleEvaluator())
+
+    # What to simulate.
+    config_files = ['apm_configs/default.json']
+    input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
+    eval_scores = ['thd']
+
+    # Should work.
+    simulator.Run(
+        config_filepaths=config_files,
+        capture_input_filepaths=input_files,
+        test_data_generator_names=['identity'],
+        eval_score_names=eval_scores,
+        output_dir=self._output_path)
+    self.assertFalse(logging.warning.called)
+
+    # Warning expected.
+    simulator.Run(
+        config_filepaths=config_files,
+        capture_input_filepaths=input_files,
+        test_data_generator_names=['white_noise'],  # Not allowed with THD.
+        eval_score_names=eval_scores,
+        output_dir=self._output_path)
+    logging.warning.assert_called_with('the evaluation failed: %s', (
+        'The THD score cannot be used with any test data generator other than '
+        '"identity"'))
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
index 3d54da5..4153f73 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
@@ -147,11 +147,12 @@
       raise exceptions.InputSignalCreatorException(
           'Cannot parse input signal file name')
 
-    signal = input_signal_creator.InputSignalCreator.Create(
+    signal, metadata = input_signal_creator.InputSignalCreator.Create(
         filename_parts[0], filename_parts[1].split('_'))
 
     signal_processing.SignalProcessingUtils.SaveWav(
         input_signal_filepath, signal)
+    data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
 
   def _Generate(
       self, input_signal_filepath, test_data_cache_path, base_output_path):