Customizable noise tracks path in APM-QA

This CL adds the possibility to specify a custom path for the noise tracks to use with
the addivitve noise test data generator (formerly called environmental noise generator).
It also includes a minor refactoring of ApmModuleSimulator to allow injection and remove
all the parameters that were forwarded to its dependencies.

Bug: webrtc:7494
Change-Id: I07bc359913c375a51bd3692822814d3ce8437268
Reviewed-on: https://webrtc-review.googlesource.com/5982
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#20163}
diff --git a/modules/audio_processing/test/py_quality_assessment/README.md b/modules/audio_processing/test/py_quality_assessment/README.md
index d893660..73d1d0d 100644
--- a/modules/audio_processing/test/py_quality_assessment/README.md
+++ b/modules/audio_processing/test/py_quality_assessment/README.md
@@ -51,9 +51,6 @@
 encoded in the 16 bit signed format (it is recommended that the tracks are
 converted and exported with Audacity).
 
-(*2) Adapt `EnvironmentalNoiseTestDataGenerator._NOISE_TRACKS` accordingly in
-`out/Default/py_quality_assessment/quality_assessment/test_data_generation.py`.
-
 ## Usage (scores computation)
  - Go to `out/Default/py_quality_assessment`
  - Check the `apm_quality_assessment.sh` as an example script to parallelize the
diff --git a/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py b/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py
index 1e4ecc0..cdb43d4 100755
--- a/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py
+++ b/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py
@@ -26,7 +26,10 @@
 import quality_assessment.echo_path_simulation as echo_path_simulation
 import quality_assessment.eval_scores as eval_scores
 import quality_assessment.evaluation as evaluation
+import quality_assessment.eval_scores_factory as eval_scores_factory
 import quality_assessment.test_data_generation as test_data_generation
+import quality_assessment.test_data_generation_factory as  \
+    test_data_generation_factory
 import quality_assessment.simulation as simulation
 
 _ECHO_PATH_SIMULATOR_NAMES = (
@@ -76,6 +79,12 @@
                       choices=_TEST_DATA_GENERATORS_NAMES,
                       default=_TEST_DATA_GENERATORS_NAMES)
 
+  parser.add_argument('--additive_noise_tracks_path', required=False,
+                      help='path to the wav files for the additive',
+                      default=test_data_generation.  \
+                              AdditiveNoiseTestDataGenerator.  \
+                              DEFAULT_NOISE_TRACKS_PATH)
+
   parser.add_argument('-e', '--eval_scores', nargs='+', required=False,
                       help='custom list of evaluation scores to use',
                       choices=_EVAL_SCORE_WORKER_NAMES,
@@ -93,32 +102,42 @@
   parser.add_argument('--air_db_path', required=True,
                       help='path to the Aechen IR database')
 
-  parser.add_argument(
-      '--apm_sim_path', required=False, help='path to the APM simulator tool',
-      default=audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
+  parser.add_argument('--apm_sim_path', required=False,
+                      help='path to the APM simulator tool',
+                      default=audioproc_wrapper.  \
+                              AudioProcWrapper.  \
+                              DEFAULT_APM_SIMULATOR_BIN_PATH)
 
   return parser
 
 
-def main():
-  # TODO(alessiob): level = logging.INFO once debugged.
-  logging.basicConfig(level=logging.DEBUG)
-
-  parser = _InstanceArgumentsParser()
-  args = parser.parse_args()
+def _ValidateArguments(args, parser):
   if args.capture_input_files and args.render_input_files and (
       len(args.capture_input_files) != len(args.render_input_files)):
     parser.error('--render_input_files and --capture_input_files must be lists '
                  'having the same length')
     sys.exit(1)
+
   if args.render_input_files and not args.echo_path_simulator:
     parser.error('when --render_input_files is set, --echo_path_simulator is '
                  'also required')
     sys.exit(1)
 
+
+def main():
+  # TODO(alessiob): level = logging.INFO once debugged.
+  logging.basicConfig(level=logging.DEBUG)
+  parser = _InstanceArgumentsParser()
+  args = parser.parse_args()
+  _ValidateArguments(args, parser)
+
   simulator = simulation.ApmModuleSimulator(
-      aechen_ir_database_path=args.air_db_path,
-      polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
+      test_data_generator_factory=(
+          test_data_generation_factory.TestDataGeneratorFactory(
+              aechen_ir_database_path=args.air_db_path,
+              noise_tracks_path=args.additive_noise_tracks_path)),
+      evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory(
+          polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME)),
       ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
       evaluator=evaluation.ApmModuleEvaluator())
   simulator.Run(
@@ -129,7 +148,6 @@
       test_data_generator_names=args.test_data_generators,
       eval_score_names=args.eval_scores,
       output_dir=args.output_dir)
-
   sys.exit(0)
 
 
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py
index c19e1f9..c2ef317 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_factory.py
@@ -11,6 +11,7 @@
 
 import logging
 
+from . import exceptions
 from . import eval_scores
 
 
@@ -21,10 +22,13 @@
   workers.
   """
 
-  def __init__(self, score_filename_prefix, polqa_tool_bin_path):
-    self._score_filename_prefix = score_filename_prefix
+  def __init__(self, polqa_tool_bin_path):
+    self._score_filename_prefix = None
     self._polqa_tool_bin_path = polqa_tool_bin_path
 
+  def SetScoreFilenamePrefix(self, prefix):
+    self._score_filename_prefix = prefix
+
   def GetInstance(self, evaluation_score_class):
     """Creates an EvaluationScore instance given a class object.
 
@@ -34,8 +38,12 @@
     Returns:
       An EvaluationScore instance.
     """
+    if self._score_filename_prefix is None:
+      raise exceptions.InitializationException(
+          'The score file name prefix for evaluation score workers is not set')
     logging.debug(
         'factory producing a %s evaluation score', evaluation_score_class)
+
     if evaluation_score_class == eval_scores.PolqaScore:
       return eval_scores.PolqaScore(
           self._score_filename_prefix, self._polqa_tool_bin_path)
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
index ce51051..ddb5d0b 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
@@ -66,9 +66,9 @@
     # Instance evaluation score workers factory with fake dependencies.
     eval_score_workers_factory = (
         eval_scores_factory.EvaluationScoreWorkerFactory(
-            score_filename_prefix='scores-',
             polqa_tool_bin_path=os.path.join(
                 os.path.dirname(os.path.abspath(__file__)), 'fake_polqa')))
+    eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
 
     # Try each registered evaluation score worker.
     for eval_score_name in registered_classes:
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
index b13b35b..852e9e8 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
@@ -11,30 +11,36 @@
 
 
 class FileNotFoundError(Exception):
-  """File not found exeception.
+  """File not found exception.
   """
   pass
 
 
 class SignalProcessingException(Exception):
-  """Signal processing exeception.
+  """Signal processing exception.
   """
   pass
 
 
 class InputMixerException(Exception):
-  """Input mixer exeception.
+  """Input mixer exception.
   """
   pass
 
 
 class InputSignalCreatorException(Exception):
-  """Input signal creator exeception.
+  """Input signal creator exception.
   """
   pass
 
 
 class EvaluationScoreException(Exception):
-  """Evaluation score exeception.
+  """Evaluation score exception.
+  """
+  pass
+
+
+class InitializationException(Exception):
+  """Initialization exception.
   """
   pass
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
index b256940..6545e0e 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
@@ -16,11 +16,9 @@
 from . import echo_path_simulation
 from . import echo_path_simulation_factory
 from . import eval_scores
-from . import eval_scores_factory
 from . import exceptions
 from . import input_mixer
 from . import test_data_generation
-from . import test_data_generation_factory
 
 
 class ApmModuleSimulator(object):
@@ -39,21 +37,18 @@
   _PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
   _PREFIX_SCORE = 'score-'
 
-  def __init__(self, aechen_ir_database_path, polqa_tool_bin_path,
+  def __init__(self, test_data_generator_factory, evaluation_score_factory,
                ap_wrapper, evaluator):
-    # Init.
+    self._test_data_generator_factory = test_data_generator_factory
+    self._evaluation_score_factory = evaluation_score_factory
     self._audioproc_wrapper = ap_wrapper
     self._evaluator = evaluator
 
-    # Instance factory objects.
-    self._test_data_generator_factory = (
-        test_data_generation_factory.TestDataGeneratorFactory(
-            output_directory_prefix=self._PREFIX_TEST_DATA_GEN_PARAMS,
-            aechen_ir_database_path=aechen_ir_database_path))
-    self._evaluation_score_factory = (
-        eval_scores_factory.EvaluationScoreWorkerFactory(
-            score_filename_prefix=self._PREFIX_SCORE,
-            polqa_tool_bin_path=polqa_tool_bin_path))
+    # Init.
+    self._test_data_generator_factory.SetOutputDirectoryPrefix(
+        self._PREFIX_TEST_DATA_GEN_PARAMS)
+    self._evaluation_score_factory.SetScoreFilenamePrefix(
+        self._PREFIX_SCORE)
 
     # Properties for each run.
     self._base_output_path = None
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
index 017a316..521f006 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
@@ -24,9 +24,11 @@
 import pydub
 
 from . import audioproc_wrapper
+from . import eval_scores_factory
 from . import evaluation
 from . import signal_processing
 from . import simulation
+from . import test_data_generation_factory
 
 
 class TestApmModuleSimulator(unittest.TestCase):
@@ -51,18 +53,26 @@
     shutil.rmtree(self._tmp_path)
 
   def testSimulation(self):
-    # Instance dependencies to inject and mock.
+    # Instance dependencies to mock and inject.
     ap_wrapper = audioproc_wrapper.AudioProcWrapper(
         audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
     evaluator = evaluation.ApmModuleEvaluator()
     ap_wrapper.Run = mock.MagicMock(name='Run')
     evaluator.Run = mock.MagicMock(name='Run')
 
+    # Instance non-mocked dependencies.
+    test_data_generator_factory = (
+        test_data_generation_factory.TestDataGeneratorFactory(
+            aechen_ir_database_path='',
+            noise_tracks_path=''))
+    evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
+        polqa_tool_bin_path=os.path.join(
+            os.path.dirname(__file__), 'fake_polqa'))
+
     # Instance simulator.
     simulator = simulation.ApmModuleSimulator(
-        aechen_ir_database_path='',
-        polqa_tool_bin_path=os.path.join(
-            os.path.dirname(__file__), 'fake_polqa'),
+        test_data_generator_factory=test_data_generator_factory,
+        evaluation_score_factory=evaluation_score_factory,
         ap_wrapper=ap_wrapper,
         evaluator=evaluator)
 
@@ -97,9 +107,14 @@
 
     # Instance simulator.
     simulator = simulation.ApmModuleSimulator(
-        aechen_ir_database_path='',
-        polqa_tool_bin_path=os.path.join(
-            os.path.dirname(__file__), 'fake_polqa'),
+        test_data_generator_factory=(
+            test_data_generation_factory.TestDataGeneratorFactory(
+                aechen_ir_database_path='',
+                noise_tracks_path='')),
+        evaluation_score_factory=(
+            eval_scores_factory.EvaluationScoreWorkerFactory(
+                polqa_tool_bin_path=os.path.join(
+                    os.path.dirname(__file__), 'fake_polqa'))),
         ap_wrapper=audioproc_wrapper.AudioProcWrapper(
             audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
         evaluator=evaluation.ApmModuleEvaluator())
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
index 7c6885d..bd0efb7 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
@@ -313,26 +313,20 @@
 
 
 @TestDataGenerator.RegisterClass
-class EnvironmentalNoiseTestDataGenerator(TestDataGenerator):
-  """Generator that adds environmental noise.
+class AdditiveNoiseTestDataGenerator(TestDataGenerator):
+  """Generator that adds noise loops.
 
-  TODO(alessiob): Make the class more generic e.g.,
-  MixNoiseTrackTestDataGenerator.
+  This generator uses all the wav files in a given path (default: noise_tracks/)
+  and mixes them to the clean speech with different target SNRs (hard-coded).
   """
 
-  NAME = 'environmental_noise'
+  NAME = 'additive_noise'
   _NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
 
-  # TODO(alessiob): allow the user to store the noise tracks in a custom path.
-  _NOISE_TRACKS_PATH = os.path.join(
+  DEFAULT_NOISE_TRACKS_PATH = os.path.join(
       os.path.dirname(__file__), os.pardir, 'noise_tracks')
 
-  # TODO(alessiob): Allow the user to have custom noise tracks.
-  # TODO(alessiob): Exploit TestDataGeneratorFactory.GetInstance().
-  _NOISE_TRACKS = [
-      'city.wav'
-  ]
-
+  # TODO(alessiob): Make the list of SNR pairs customizable.
   # Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
   # The reference (second value of each pair) always has a lower amount of noise
   # - i.e., the SNR is 10 dB higher.
@@ -343,8 +337,15 @@
       [0, 10],  # Largest noise.
   ]
 
-  def __init__(self, output_directory_prefix):
+  def __init__(self, output_directory_prefix, noise_tracks_path):
     TestDataGenerator.__init__(self, output_directory_prefix)
+    self._noise_tracks_path = noise_tracks_path
+    self._noise_tracks_file_names = [n for n in os.listdir(
+        self._noise_tracks_path) if n.lower().endswith('.wav')]
+    if len(self._noise_tracks_file_names) == 0:
+      raise exceptions.InitializationException(
+          'No wav files found in the noise tracks path %s' % (
+              self._noise_tracks_path))
 
   def _Generate(
       self, input_signal_filepath, test_data_cache_path, base_output_path):
@@ -363,11 +364,11 @@
         input_signal_filepath)
 
     noisy_mix_filepaths = {}
-    for noise_track_filename in self._NOISE_TRACKS:
+    for noise_track_filename in self._noise_tracks_file_names:
       # Load the noise track.
       noise_track_name, _ = os.path.splitext(noise_track_filename)
       noise_track_filepath = os.path.join(
-          self._NOISE_TRACKS_PATH, noise_track_filename)
+          self._noise_tracks_path, noise_track_filename)
       if not os.path.exists(noise_track_filepath):
         logging.error('cannot find the <%s> noise track', noise_track_filename)
         raise exceptions.FileNotFoundError()
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_factory.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_factory.py
index b42d3af..fd7f3f7 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_factory.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_factory.py
@@ -11,6 +11,7 @@
 
 import logging
 
+from . import exceptions
 from . import test_data_generation
 
 
@@ -21,9 +22,13 @@
   generators will be produced.
   """
 
-  def __init__(self, output_directory_prefix, aechen_ir_database_path):
-    self._output_directory_prefix = output_directory_prefix
+  def __init__(self, aechen_ir_database_path, noise_tracks_path):
+    self._output_directory_prefix = None
     self._aechen_ir_database_path = aechen_ir_database_path
+    self._noise_tracks_path = noise_tracks_path
+
+  def SetOutputDirectoryPrefix(self, prefix):
+    self._output_directory_prefix = prefix
 
   def GetInstance(self, test_data_generators_class):
     """Creates an TestDataGenerator instance given a class object.
@@ -35,10 +40,18 @@
     Returns:
       TestDataGenerator instance.
     """
+    if self._output_directory_prefix is None:
+      raise exceptions.InitializationException(
+          'The output directory prefix for test data generators is not set')
     logging.debug('factory producing %s', test_data_generators_class)
+
     if test_data_generators_class == (
         test_data_generation.ReverberationTestDataGenerator):
       return test_data_generation.ReverberationTestDataGenerator(
           self._output_directory_prefix, self._aechen_ir_database_path)
+    elif test_data_generators_class == (
+        test_data_generation.AdditiveNoiseTestDataGenerator):
+      return test_data_generation.AdditiveNoiseTestDataGenerator(
+          self._output_directory_prefix, self._noise_tracks_path)
     else:
       return test_data_generators_class(self._output_directory_prefix)
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py
index 6239d51..73ea45d 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py
@@ -83,10 +83,12 @@
     self.assertGreater(len(registered_classes), 0)
 
     # Instance generators factory.
-    generators_factory = (
-        test_data_generation_factory.TestDataGeneratorFactory(
-            output_directory_prefix='datagen-',
-            aechen_ir_database_path=self._fake_air_db_path))
+    generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
+        aechen_ir_database_path=self._fake_air_db_path,
+        noise_tracks_path=test_data_generation.  \
+                          AdditiveNoiseTestDataGenerator.  \
+                          DEFAULT_NOISE_TRACKS_PATH)
+    generators_factory.SetOutputDirectoryPrefix('datagen-')
 
     # Use a sample input file as clean input signal.
     input_signal_filepath = os.path.join(