blob: 12e043320e66d337221a1026a75040c03ddd1202 [file] [log] [blame]
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
#
# Use of this source code is governed by a BSD-style license
# that can be found in the LICENSE file in the root of the source
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the eval_scores module.
"""
import os
import shutil
import tempfile
import unittest
import pydub
from . import data_access
from . import eval_scores
from . import eval_scores_factory
from . import signal_processing
class TestEvalScores(unittest.TestCase):
"""Unit tests for the eval_scores module.
"""
def setUp(self):
"""Create temporary output folder and two audio track files."""
self._output_path = tempfile.mkdtemp()
# Create fake reference and tested (i.e., APM output) audio track files.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_reference_signal = (signal_processing.SignalProcessingUtils.
GenerateWhiteNoise(silence))
fake_tested_signal = (signal_processing.SignalProcessingUtils.
GenerateWhiteNoise(silence))
# Save fake audio tracks.
self._fake_reference_signal_filepath = os.path.join(
self._output_path, 'fake_ref.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_reference_signal_filepath, fake_reference_signal)
self._fake_tested_signal_filepath = os.path.join(
self._output_path, 'fake_test.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_tested_signal_filepath, fake_tested_signal)
def tearDown(self):
"""Recursively delete temporary folder."""
shutil.rmtree(self._output_path)
def testRegisteredClasses(self):
# Evaluation score names to exclude (tested separately).
exceptions = ['thd', 'echo_metric']
# Preliminary check.
self.assertTrue(os.path.exists(self._output_path))
# Check that there is at least one registered evaluation score worker.
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance evaluation score workers factory with fake dependencies.
eval_score_workers_factory = (
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
echo_metric_tool_bin_path=None))
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
# Try each registered evaluation score worker.
for eval_score_name in registered_classes:
if eval_score_name in exceptions:
continue
# Instance evaluation score worker.
eval_score_worker = eval_score_workers_factory.GetInstance(
registered_classes[eval_score_name])
# Set fake input metadata and reference and test file paths, then run.
eval_score_worker.SetReferenceSignalFilepath(
self._fake_reference_signal_filepath)
eval_score_worker.SetTestedSignalFilepath(
self._fake_tested_signal_filepath)
eval_score_worker.Run(self._output_path)
# Check output.
score = data_access.ScoreFile.Load(
eval_score_worker.output_filepath)
self.assertTrue(isinstance(score, float))
def testTotalHarmonicDistorsionScore(self):
# Init.
pure_tone_freq = 5000.0
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
eval_score_worker.SetInputSignalMetadata({
'signal':
'pure_tone',
'frequency':
pure_tone_freq,
'test_data_gen_name':
'identity',
'test_data_gen_config':
'default',
})
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
# Create 3 test signals: pure tone, pure tone + white noise, white noise
# only.
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
template, pure_tone_freq)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
template)
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
pure_tone, white_noise)
# Compute scores for increasingly distorted pure tone signals.
scores = [None, None, None]
for index, tested_signal in enumerate(
[pure_tone, noisy_tone, white_noise]):
# Save signal.
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
signal_processing.SignalProcessingUtils.SaveWav(
tmp_filepath, tested_signal)
# Compute score.
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
eval_score_worker.Run(self._output_path)
scores[index] = eval_score_worker.score
# Remove output file to avoid caching.
os.remove(eval_score_worker.output_filepath)
# Validate scores (lowest score with a pure tone).
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))