Support for external VAD program in APM-QA

There is now an 'ExternalVad' class in the AnnotationsExtractor. The
Extractor takes an extra list of these in addition to the other
VADs. The external VAD runs an external program to generate the
annotations. Annotations are loaded and saved to a compressed Numpy format.

Also made a small fix to name a mixed file in a way so that files will not
be overwritten.

Also did some minor changes to the unittests.
TBR=alessiob@webrtc.org

Bug: webrtc:7494
Change-Id: I7816b04466be16cd635ac6ceab18cd7aad5325a4
Reviewed-on: https://webrtc-review.googlesource.com/23623
Commit-Queue: Alex Loiko <aleloi@webrtc.org>
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#20819}
diff --git a/modules/audio_processing/test/py_quality_assessment/BUILD.gn b/modules/audio_processing/test/py_quality_assessment/BUILD.gn
index eee58da..64e3a30 100644
--- a/modules/audio_processing/test/py_quality_assessment/BUILD.gn
+++ b/modules/audio_processing/test/py_quality_assessment/BUILD.gn
@@ -66,6 +66,7 @@
     "quality_assessment/exceptions.py",
     "quality_assessment/export.py",
     "quality_assessment/export_unittest.py",
+    "quality_assessment/external_vad.py",
     "quality_assessment/input_mixer.py",
     "quality_assessment/input_signal_creator.py",
     "quality_assessment/results.css",
@@ -149,6 +150,7 @@
     "quality_assessment/annotations_unittest.py",
     "quality_assessment/echo_path_simulation_unittest.py",
     "quality_assessment/eval_scores_unittest.py",
+    "quality_assessment/fake_external_vad.py",
     "quality_assessment/input_mixer_unittest.py",
     "quality_assessment/signal_processing_unittest.py",
     "quality_assessment/simulation_unittest.py",
diff --git a/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py b/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py
index 78ff5e9..a4cc5f0 100755
--- a/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py
+++ b/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py
@@ -27,6 +27,7 @@
 import quality_assessment.eval_scores as eval_scores
 import quality_assessment.evaluation as evaluation
 import quality_assessment.eval_scores_factory as eval_scores_factory
+import quality_assessment.external_vad as external_vad
 import quality_assessment.test_data_generation as test_data_generation
 import quality_assessment.test_data_generation_factory as  \
     test_data_generation_factory
@@ -113,6 +114,14 @@
                             'copy of the clean speech input file.'),
                       default=False)
 
+  parser.add_argument('--external_vad_paths', nargs='+', required=False,
+                      help=('Paths to external VAD programs. Each must take'
+                            '\'-i <wav file> -o <output>\' inputs'), default=[])
+
+  parser.add_argument('--external_vad_names', nargs='+', required=False,
+                      help=('Keys to the vad paths. Must be different and '
+                            'as many as the paths.'), default=[])
+
   return parser
 
 
@@ -128,6 +137,12 @@
                  'also required')
     sys.exit(1)
 
+  if len(args.external_vad_names) != len(args.external_vad_paths):
+    parser.error('If provided, --external_vad_paths and '
+                 '--external_vad_names must '
+                 'have the same number of arguments.')
+    sys.exit(1)
+
 
 def main():
   # TODO(alessiob): level = logging.INFO once debugged.
@@ -145,7 +160,9 @@
       evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory(
           polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME)),
       ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
-      evaluator=evaluation.ApmModuleEvaluator())
+      evaluator=evaluation.ApmModuleEvaluator(),
+      external_vads=external_vad.ExternalVad.ConstructVadDict(
+          args.external_vad_paths, args.external_vad_names))
   simulator.Run(
       config_filepaths=args.config_files,
       capture_input_filepaths=args.capture_input_files,
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py
index 2f5daf1..a4e9097 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py
@@ -24,6 +24,7 @@
   logging.critical('Cannot import the third-party Python package numpy')
   sys.exit(1)
 
+from . import external_vad
 from . import exceptions
 from . import signal_processing
 
@@ -76,7 +77,7 @@
   _VAD_WEBRTC_APM_PATH = os.path.join(
       _VAD_WEBRTC_PATH, 'apm_vad')
 
-  def __init__(self, vad_type):
+  def __init__(self, vad_type, external_vads=None):
     self._signal = None
     self._level = None
     self._level_frame_size = None
@@ -92,6 +93,19 @@
     self._vad_type = self.VadType(vad_type)
     logging.info('VADs used for annotations: ' + str(self._vad_type))
 
+    if external_vads is None:
+      external_vads = {}
+    self._external_vads = external_vads
+
+    assert len(self._external_vads) == len(external_vads), (
+        'The external VAD names must be unique.')
+    for vad in external_vads.values():
+      if not isinstance(vad, external_vad.ExternalVad):
+        raise exceptions.InitializationException(
+            'Invalid vad type: ' + str(type(vad)))
+      logging.info('External VAD used for annotation: ' +
+                   str(vad.name))
+
     assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
       self._VAD_WEBRTC_COMMON_AUDIO_PATH
     assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
@@ -113,9 +127,9 @@
 
   def GetVadOutput(self, vad_type):
     if vad_type == self.VadType.ENERGY_THRESHOLD:
-      return (self._energy_vad, )
+      return self._energy_vad
     elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
-      return (self._common_audio_vad, )
+      return self._common_audio_vad
     elif vad_type == self.VadType.WEBRTC_APM:
       return (self._apm_vad_probs, self._apm_vad_rms)
     else:
@@ -132,7 +146,7 @@
     # Load signal.
     self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath)
     if self._signal.channels != 1:
-      raise NotImplementedError('multiple-channel annotations not implemented')
+      raise NotImplementedError('Multiple-channel annotations not implemented')
 
     # Level estimation params.
     self._level_frame_size = int(self._signal.frame_rate / 1000 * (
@@ -160,8 +174,14 @@
     if self._vad_type.Contains(self.VadType.WEBRTC_APM):
       # WebRTC modules/audio_processing/ VAD.
       self._RunWebRtcApmVad(filepath)
+    for extvad_name in self._external_vads:
+      self._external_vads[extvad_name].Run(filepath)
 
   def Save(self, output_path):
+    ext_kwargs = {'extvad_conf-' + ext_vad:
+                  self._external_vads[ext_vad].GetVadOutput()
+                  for ext_vad in self._external_vads}
+    # pylint: disable=star-args
     np.savez_compressed(
         file=os.path.join(output_path, self._OUTPUT_FILENAME),
         level=self._level,
@@ -172,7 +192,8 @@
         vad_frame_size=self._vad_frame_size,
         vad_frame_size_ms=self._vad_frame_size_ms,
         vad_probs=self._apm_vad_probs,
-        vad_rms=self._apm_vad_rms
+        vad_rms=self._apm_vad_rms,
+        **ext_kwargs
     )
 
   def _LevelEstimation(self):
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py
index 3f44edf..5fe5f5d 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py
@@ -19,6 +19,7 @@
 import numpy as np
 
 from . import annotations
+from . import external_vad
 from . import input_signal_creator
 from . import signal_processing
 
@@ -29,6 +30,11 @@
 
   _CLEAN_TMP_OUTPUT = True
   _DEBUG_PLOT_VAD = False
+  _VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
+  _ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD |
+                   _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO |
+                   _VAD_TYPE_CLASS.WEBRTC_APM)
+
 
   def setUp(self):
     """Create temporary folder."""
@@ -49,11 +55,7 @@
           self._tmp_path))
 
   def testFrameSizes(self):
-    vad_type_class = annotations.AudioAnnotationsExtractor.VadType
-    vad_type = (vad_type_class.ENERGY_THRESHOLD |
-                vad_type_class.WEBRTC_COMMON_AUDIO |
-                vad_type_class.WEBRTC_APM)
-    e = annotations.AudioAnnotationsExtractor(vad_type=vad_type)
+    e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
     e.Extract(self._wav_file_path)
     samples_to_ms = lambda n, sr: 1000 * n // sr
     self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
@@ -62,35 +64,31 @@
                      e.GetVadFrameSizeMs())
 
   def testVoiceActivityDetectors(self):
-    vad_type_class = annotations.AudioAnnotationsExtractor.VadType
-    max_vad_type = (vad_type_class.ENERGY_THRESHOLD |
-                vad_type_class.WEBRTC_COMMON_AUDIO |
-                vad_type_class.WEBRTC_APM)
-    for vad_type_value in range(0, max_vad_type+1):
-      vad_type = vad_type_class(vad_type_value)
+    for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
+      vad_type = self._VAD_TYPE_CLASS(vad_type_value)
       e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
       e.Extract(self._wav_file_path)
-      if vad_type.Contains(vad_type_class.ENERGY_THRESHOLD):
-        # pylint: disable=unbalanced-tuple-unpacking
-        (vad_output, ) = e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD)
+      if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
+        # pylint: disable=unpacking-non-sequence
+        vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
         self.assertGreater(len(vad_output), 0)
         self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
                                 0.95)
 
-      if vad_type.Contains(vad_type_class.WEBRTC_COMMON_AUDIO):
-        # pylint: disable=unbalanced-tuple-unpacking
-        (vad_output,) = e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO)
+      if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
+        # pylint: disable=unpacking-non-sequence
+        vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
         self.assertGreater(len(vad_output), 0)
         self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
                                 0.95)
 
-      if vad_type.Contains(vad_type_class.WEBRTC_APM):
-        # pylint: disable=unbalanced-tuple-unpacking
-        (vad_probs, vad_rms) = e.GetVadOutput(vad_type_class.WEBRTC_APM)
+      if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
+        # pylint: disable=unpacking-non-sequence
+        (vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
         self.assertGreater(len(vad_probs), 0)
         self.assertGreater(len(vad_rms), 0)
         self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs),
-                                0.95)
+                                0.5)
         self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000)
 
       if self._DEBUG_PLOT_VAD:
@@ -111,11 +109,7 @@
         plt.show()
 
   def testSaveLoad(self):
-    vad_type_class = annotations.AudioAnnotationsExtractor.VadType
-    vad_type = (vad_type_class.ENERGY_THRESHOLD |
-                vad_type_class.WEBRTC_COMMON_AUDIO |
-                vad_type_class.WEBRTC_APM)
-    e = annotations.AudioAnnotationsExtractor(vad_type)
+    e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
     e.Extract(self._wav_file_path)
     e.Save(self._tmp_path)
 
@@ -123,14 +117,37 @@
     np.testing.assert_array_equal(e.GetLevel(), data['level'])
     self.assertEqual(np.float32, data['level'].dtype)
     np.testing.assert_array_equal(
-        e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD),
+        e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
         data['vad_energy_output'])
     np.testing.assert_array_equal(
-        e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO), data['vad_output'])
+        e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
+        data['vad_output'])
     np.testing.assert_array_equal(
-        e.GetVadOutput(vad_type_class.WEBRTC_APM)[0], data['vad_probs'])
+        e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs'])
     np.testing.assert_array_equal(
-        e.GetVadOutput(vad_type_class.WEBRTC_APM)[1], data['vad_rms'])
+        e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms'])
     self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
     self.assertEqual(np.float64, data['vad_probs'].dtype)
     self.assertEqual(np.float64, data['vad_rms'].dtype)
+
+  def testEmptyExternalShouldNotCrash(self):
+    for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
+      annotations.AudioAnnotationsExtractor(vad_type_value, {})
+
+  def testFakeExternalSaveLoad(self):
+    def FakeExternalFactory():
+      return external_vad.ExternalVad(
+        os.path.join(
+            os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'),
+        'fake'
+      )
+    for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
+      e = annotations.AudioAnnotationsExtractor(
+          vad_type_value,
+          {'fake': FakeExternalFactory()})
+      e.Extract(self._wav_file_path)
+      e.Save(self._tmp_path)
+      data = np.load(os.path.join(self._tmp_path, e.GetOutputFileName()))
+      self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
+      np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
+                                     data['extvad_conf-fake'])
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py
new file mode 100644
index 0000000..01418d8
--- /dev/null
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
+#
+# Use of this source code is governed by a BSD-style license
+# that can be found in the LICENSE file in the root of the source
+# tree. An additional intellectual property rights grant can be found
+# in the file PATENTS.  All contributing project authors may
+# be found in the AUTHORS file in the root of the source tree.
+
+from __future__ import division
+
+import logging
+import os
+import subprocess
+import shutil
+import sys
+import tempfile
+
+try:
+  import numpy as np
+except ImportError:
+  logging.critical('Cannot import the third-party Python package numpy')
+  sys.exit(1)
+
+from . import signal_processing
+
+class ExternalVad(object):
+
+  def __init__(self, path_to_binary, name):
+    """Args:
+       path_to_binary: path to binary that accepts '-i <wav>', '-o
+          <float probabilities>'. There must be one float value per
+          10ms audio
+       name: a name to identify the external VAD. Used for saving
+          the output as extvad_output-<name>.
+    """
+    self._path_to_binary = path_to_binary
+    self.name = name
+    assert os.path.exists(self._path_to_binary), (
+        self._path_to_binary)
+    self._vad_output = None
+
+  def Run(self, wav_file_path):
+    _signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path)
+    if _signal.channels != 1:
+      raise NotImplementedError('Multiple-channel'
+                                ' annotations not implemented')
+    if _signal.frame_rate != 48000:
+      raise NotImplementedError('Frame rates '
+                                'other than 48000 not implemented')
+
+    tmp_path = tempfile.mkdtemp()
+    try:
+      output_file_path = os.path.join(
+          tmp_path, self.name + '_vad.tmp')
+      subprocess.call([
+          self._path_to_binary,
+          '-i', wav_file_path,
+          '-o', output_file_path
+      ])
+      self._vad_output = np.fromfile(output_file_path, np.float32)
+    except Exception as e:
+      logging.error('Error while running the ' + self.name +
+                    ' VAD (' + e.message + ')')
+    finally:
+      if os.path.exists(tmp_path):
+        shutil.rmtree(tmp_path)
+
+  def GetVadOutput(self):
+    assert self._vad_output is not None
+    return self._vad_output
+
+  @classmethod
+  def ConstructVadDict(cls, vad_paths, vad_names):
+    external_vads = {}
+    for path, name in zip(vad_paths, vad_names):
+      external_vads[name] = ExternalVad(path, name)
+    return external_vads
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py
new file mode 100755
index 0000000..7c75e8f
--- /dev/null
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py
@@ -0,0 +1,24 @@
+#!/usr/bin/python
+# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
+#
+# Use of this source code is governed by a BSD-style license
+# that can be found in the LICENSE file in the root of the source
+# tree. An additional intellectual property rights grant can be found
+# in the file PATENTS.  All contributing project authors may
+# be found in the AUTHORS file in the root of the source tree.
+import argparse
+import numpy as np
+
+def main():
+  parser = argparse.ArgumentParser()
+  parser.add_argument('-i', required=True)
+  parser.add_argument('-o', required=True)
+
+  args = parser.parse_args()
+
+  array = np.arange(100, dtype=np.float32)
+  array.tofile(open(args.o, 'w'))
+
+
+if __name__ == '__main__':
+  main()
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py
index 8f9e542..b1afe14 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py
@@ -65,8 +65,10 @@
     # This ensures that if the internal parameters of the echo path simulator
     # change, no erroneous cache hit occurs.
     echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
-    mix_filepath = os.path.join(output_path, 'mix_capture_{}.wav'.format(
-        echo_file_name))
+    capture_input_file_name, _ = os.path.splitext(
+        os.path.split(capture_input_filepath)[1])
+    mix_filepath = os.path.join(output_path, 'mix_capture_{}_{}.wav'.format(
+        capture_input_file_name, echo_file_name))
 
     # Create the mix if not done yet.
     mix = None
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
index 8e67291..f791ddd 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
@@ -41,7 +41,9 @@
   _PREFIX_SCORE = 'score-'
 
   def __init__(self, test_data_generator_factory, evaluation_score_factory,
-               ap_wrapper, evaluator):
+               ap_wrapper, evaluator, external_vads=None):
+    if external_vads is None:
+      external_vads = {}
     self._test_data_generator_factory = test_data_generator_factory
     self._evaluation_score_factory = evaluation_score_factory
     self._audioproc_wrapper = ap_wrapper
@@ -49,7 +51,9 @@
     self._annotator = annotations.AudioAnnotationsExtractor(
         annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD |
         annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO |
-        annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM)
+        annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
+        external_vads
+    )
 
     # Init.
     self._test_data_generator_factory.SetOutputDirectoryPrefix(
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
index cf9aac8..c7ebcbc 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
@@ -26,6 +26,7 @@
 from . import audioproc_wrapper
 from . import eval_scores_factory
 from . import evaluation
+from . import external_vad
 from . import signal_processing
 from . import simulation
 from . import test_data_generation_factory
@@ -75,7 +76,10 @@
         test_data_generator_factory=test_data_generator_factory,
         evaluation_score_factory=evaluation_score_factory,
         ap_wrapper=ap_wrapper,
-        evaluator=evaluator)
+        evaluator=evaluator,
+        external_vads={'fake': external_vad.ExternalVad(os.path.join(
+            os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
+    )
 
     # What to simulate.
     config_files = ['apm_configs/default.json']