blob: 78ca17f589a3b38aa3139ed926f7bc2b8fc25514 [file] [log] [blame]
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the simulation module.
"""
import logging
import os
import shutil
import tempfile
import unittest
import mock
import pydub
from . import audioproc_wrapper
from . import eval_scores_factory
from . import evaluation
from . import external_vad
from . import signal_processing
from . import simulation
from . import test_data_generation_factory
class TestApmModuleSimulator(unittest.TestCase):
"""Unit tests for the ApmModuleSimulator class.
"""
def setUp(self):
"""Create temporary folders and fake audio track."""
self._output_path = tempfile.mkdtemp()
self._tmp_path = tempfile.mkdtemp()
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._fake_audio_track_path = os.path.join(self._output_path,
'fake.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_audio_track_path, fake_signal)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._output_path)
shutil.rmtree(self._tmp_path)
def testSimulation(self):
# Instance dependencies to mock and inject.
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
evaluator = evaluation.ApmModuleEvaluator()
ap_wrapper.Run = mock.MagicMock(name='Run')
evaluator.Run = mock.MagicMock(name='Run')
# Instance non-mocked dependencies.
test_data_generator_factory = (
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False))
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=test_data_generator_factory,
evaluation_score_factory=evaluation_score_factory,
ap_wrapper=ap_wrapper,
evaluator=evaluator,
external_vads={
'fake':
external_vad.ExternalVad(
os.path.join(os.path.dirname(__file__),
'fake_external_vad.py'), 'fake')
})
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [self._fake_audio_track_path]
test_data_generators = ['identity', 'white_noise']
eval_scores = ['audio_level_mean', 'polqa']
# Run all simulations.
simulator.Run(config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=test_data_generators,
eval_score_names=eval_scores,
output_dir=self._output_path)
# Check.
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
# the client code (e.g., number of SNR pairs for the white noise test data
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
# is known; use that with assertEqual.
min_number_of_simulations = len(config_files) * len(input_files) * len(
test_data_generators)
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
min_number_of_simulations)
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)
def testInputSignalCreation(self):
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# Inexistent input files to be silently created.
input_files = [
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
]
self.assertFalse(
any([os.path.exists(input_file) for input_file in (input_files)]))
# The input files are created during the simulation.
simulator.Run(config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=['audio_level_peak'],
output_dir=self._output_path)
self.assertTrue(
all([os.path.exists(input_file) for input_file in (input_files)]))
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
logging.warning = mock.MagicMock(name='warning')
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
eval_scores = ['thd']
# Should work.
simulator.Run(config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=eval_scores,
output_dir=self._output_path)
self.assertFalse(logging.warning.called)
# Warning expected.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['white_noise'], # Not allowed with THD.
eval_score_names=eval_scores,
output_dir=self._output_path)
logging.warning.assert_called_with('the evaluation failed: %s', (
'The THD score cannot be used with any test data generator other than '
'"identity"'))
# # Init.
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
# input_signal_filepath = os.path.join(
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
# # Check that the input signal is generated.
# self.assertFalse(os.path.exists(input_signal_filepath))
# generator.Generate(
# input_signal_filepath=input_signal_filepath,
# test_data_cache_path=self._test_data_cache_path,
# base_output_path=self._base_output_path)
# self.assertTrue(os.path.exists(input_signal_filepath))
# # Check input signal properties.
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
# input_signal_filepath)
# self.assertEqual(1000, len(input_signal))