blob: 5261dd25f4cae1ec3393379ed47ca242ab1af37e [file] [log] [blame]
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the test_data_generation module.
"""
import os
import shutil
import tempfile
import unittest
from . import test_data_generation
from . import test_data_generation_factory
from . import signal_processing
class TestTestDataGenerators(unittest.TestCase):
"""Unit tests for the test_data_generation module.
"""
def setUp(self):
"""Create temporary folders."""
self._base_output_path = tempfile.mkdtemp()
self._input_noise_cache_path = tempfile.mkdtemp()
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._base_output_path)
shutil.rmtree(self._input_noise_cache_path)
def testTestDataGenerators(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._input_noise_cache_path))
# Check that there is at least one registered test data generator.
registered_classes = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance generators factory.
generators_factory = (
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=''))
# TODO(alessiob): Replace with a mock of TestDataGeneratorFactory that
# takes no arguments in the ctor. For those generators that need parameters,
# it will return a mock generator (see the first comment in the next for
# loop).
# Use a sample input file as clean input signal.
input_signal_filepath = os.path.join(
os.getcwd(), 'probing_signals', 'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
# Load input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Try each registered test data generator.
for generator_name in registered_classes:
# Exclude ReverberationTestDataGenerator.
# TODO(alessiob): Mock ReverberationTestDataGenerator, the mock
# should rely on hard-coded impulse responses. This requires a mock for
# TestDataGeneratorFactory. The latter knows whether returning the
# actual generator or a mock object (as in the case of
# ReverberationTestDataGenerator).
if generator_name == (
test_data_generation.ReverberationTestDataGenerator.NAME):
continue
# Instance test data generator.
generator = generators_factory.GetInstance(
registered_classes[generator_name])
# Generate the noisy input - reference pairs.
generator.Generate(
input_signal_filepath=input_signal_filepath,
input_noise_cache_path=self._input_noise_cache_path,
base_output_path=self._base_output_path)
# Perform checks.
self._CheckGeneratedPairsListSizes(generator)
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
self._CheckGeneratedPairsOutputPaths(generator)
def _CheckGeneratedPairsListSizes(self, generator):
config_names = generator.config_names
number_of_pairs = len(config_names)
self.assertEqual(number_of_pairs,
len(generator.noisy_signal_filepaths))
self.assertEqual(number_of_pairs,
len(generator.apm_output_paths))
self.assertEqual(number_of_pairs,
len(generator.reference_signal_filepaths))
def _CheckGeneratedPairsSignalDurations(
self, generator, input_signal):
"""Checks duration of the generated signals.
Checks that the noisy input and the reference tracks are audio files
with duration equal to or greater than that of the input signal.
Args:
generator: TestDataGenerator instance.
input_signal: AudioSegment instance.
"""
input_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
# Load the noisy input file.
noisy_signal_filepath = generator.noisy_signal_filepaths[
config_name]
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
noisy_signal_filepath)
# Check noisy input signal length.
noisy_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(noisy_signal))
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
# Load the reference file.
reference_signal_filepath = generator.reference_signal_filepaths[
config_name]
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
reference_signal_filepath)
# Check noisy input signal length.
reference_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(
reference_signal))
self.assertGreaterEqual(reference_signal_length, input_signal_length)
def _CheckGeneratedPairsOutputPaths(self, generator):
"""Checks that the output path created by the generator exists.
Args:
generator: TestDataGenerator instance.
"""
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
output_path = generator.apm_output_paths[config_name]
self.assertTrue(os.path.exists(output_path))