| # Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. |
| # |
| # Use of this source code is governed by a BSD-style license |
| # that can be found in the LICENSE file in the root of the source |
| # tree. An additional intellectual property rights grant can be found |
| # in the file PATENTS. All contributing project authors may |
| # be found in the AUTHORS file in the root of the source tree. |
| |
| """Unit tests for the test_data_generation module. |
| """ |
| |
| import os |
| import shutil |
| import tempfile |
| import unittest |
| |
| import numpy as np |
| import scipy.io |
| |
| from . import test_data_generation |
| from . import test_data_generation_factory |
| from . import signal_processing |
| |
| |
| class TestTestDataGenerators(unittest.TestCase): |
| """Unit tests for the test_data_generation module. |
| """ |
| |
| def setUp(self): |
| """Create temporary folders.""" |
| self._base_output_path = tempfile.mkdtemp() |
| self._test_data_cache_path = tempfile.mkdtemp() |
| self._fake_air_db_path = tempfile.mkdtemp() |
| |
| # Fake AIR DB impulse responses. |
| # TODO(alessiob): ReverberationTestDataGenerator will change to allow custom |
| # impulse responses. When changed, the coupling below between |
| # impulse_response_mat_file_names and |
| # ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed. |
| impulse_response_mat_file_names = [ |
| 'air_binaural_lecture_0_0_1.mat', |
| 'air_binaural_booth_0_0_1.mat', |
| ] |
| for impulse_response_mat_file_name in impulse_response_mat_file_names: |
| data = {'h_air': np.random.rand(1, 1000).astype('<f8')} |
| scipy.io.savemat(os.path.join( |
| self._fake_air_db_path, impulse_response_mat_file_name), data) |
| |
| def tearDown(self): |
| """Recursively delete temporary folders.""" |
| shutil.rmtree(self._base_output_path) |
| shutil.rmtree(self._test_data_cache_path) |
| shutil.rmtree(self._fake_air_db_path) |
| |
| def testInputSignalCreation(self): |
| # Init. |
| generator = test_data_generation.IdentityTestDataGenerator('tmp') |
| input_signal_filepath = os.path.join( |
| self._test_data_cache_path, 'pure_tone-440_1000.wav') |
| |
| # Check that the input signal is generated. |
| self.assertFalse(os.path.exists(input_signal_filepath)) |
| generator.Generate( |
| input_signal_filepath=input_signal_filepath, |
| test_data_cache_path=self._test_data_cache_path, |
| base_output_path=self._base_output_path) |
| self.assertTrue(os.path.exists(input_signal_filepath)) |
| |
| # Check input signal properties. |
| input_signal = signal_processing.SignalProcessingUtils.LoadWav( |
| input_signal_filepath) |
| self.assertEqual(1000, len(input_signal)) |
| |
| def testTestDataGenerators(self): |
| # Preliminary check. |
| self.assertTrue(os.path.exists(self._base_output_path)) |
| self.assertTrue(os.path.exists(self._test_data_cache_path)) |
| |
| # Check that there is at least one registered test data generator. |
| registered_classes = ( |
| test_data_generation.TestDataGenerator.REGISTERED_CLASSES) |
| self.assertIsInstance(registered_classes, dict) |
| self.assertGreater(len(registered_classes), 0) |
| |
| # Instance generators factory. |
| generators_factory = ( |
| test_data_generation_factory.TestDataGeneratorFactory( |
| output_directory_prefix='datagen-', |
| aechen_ir_database_path=self._fake_air_db_path)) |
| |
| # Use a sample input file as clean input signal. |
| input_signal_filepath = os.path.join( |
| os.getcwd(), 'probing_signals', 'tone-880.wav') |
| self.assertTrue(os.path.exists(input_signal_filepath)) |
| |
| # Load input signal. |
| input_signal = signal_processing.SignalProcessingUtils.LoadWav( |
| input_signal_filepath) |
| |
| # Try each registered test data generator. |
| for generator_name in registered_classes: |
| # Instance test data generator. |
| generator = generators_factory.GetInstance( |
| registered_classes[generator_name]) |
| |
| # Generate the noisy input - reference pairs. |
| generator.Generate( |
| input_signal_filepath=input_signal_filepath, |
| test_data_cache_path=self._test_data_cache_path, |
| base_output_path=self._base_output_path) |
| |
| # Perform checks. |
| self._CheckGeneratedPairsListSizes(generator) |
| self._CheckGeneratedPairsSignalDurations(generator, input_signal) |
| self._CheckGeneratedPairsOutputPaths(generator) |
| |
| def _CheckGeneratedPairsListSizes(self, generator): |
| config_names = generator.config_names |
| number_of_pairs = len(config_names) |
| self.assertEqual(number_of_pairs, |
| len(generator.noisy_signal_filepaths)) |
| self.assertEqual(number_of_pairs, |
| len(generator.apm_output_paths)) |
| self.assertEqual(number_of_pairs, |
| len(generator.reference_signal_filepaths)) |
| |
| def _CheckGeneratedPairsSignalDurations( |
| self, generator, input_signal): |
| """Checks duration of the generated signals. |
| |
| Checks that the noisy input and the reference tracks are audio files |
| with duration equal to or greater than that of the input signal. |
| |
| Args: |
| generator: TestDataGenerator instance. |
| input_signal: AudioSegment instance. |
| """ |
| input_signal_length = ( |
| signal_processing.SignalProcessingUtils.CountSamples(input_signal)) |
| |
| # Iterate over the noisy signal - reference pairs. |
| for config_name in generator.config_names: |
| # Load the noisy input file. |
| noisy_signal_filepath = generator.noisy_signal_filepaths[ |
| config_name] |
| noisy_signal = signal_processing.SignalProcessingUtils.LoadWav( |
| noisy_signal_filepath) |
| |
| # Check noisy input signal length. |
| noisy_signal_length = ( |
| signal_processing.SignalProcessingUtils.CountSamples(noisy_signal)) |
| self.assertGreaterEqual(noisy_signal_length, input_signal_length) |
| |
| # Load the reference file. |
| reference_signal_filepath = generator.reference_signal_filepaths[ |
| config_name] |
| reference_signal = signal_processing.SignalProcessingUtils.LoadWav( |
| reference_signal_filepath) |
| |
| # Check noisy input signal length. |
| reference_signal_length = ( |
| signal_processing.SignalProcessingUtils.CountSamples( |
| reference_signal)) |
| self.assertGreaterEqual(reference_signal_length, input_signal_length) |
| |
| def _CheckGeneratedPairsOutputPaths(self, generator): |
| """Checks that the output path created by the generator exists. |
| |
| Args: |
| generator: TestDataGenerator instance. |
| """ |
| # Iterate over the noisy signal - reference pairs. |
| for config_name in generator.config_names: |
| output_path = generator.apm_output_paths[config_name] |
| self.assertTrue(os.path.exists(output_path)) |