Total Harmonic Distorsion plus noise (THD+n) score in APM-QA.
In order to compute a THD score, a pure tone must be used as input signal.
Also, its frequency must be known. For this reason, this CL adds a number of
changes in the APM-QA pipeline. More in detail, input signal metadata is loaded
and passed to the THD evaluation score instance. This makes the eval_scores
module less reusable, but it is fine since the module has been specifically
designed for the APM-QA module.
BUG=webrtc:7494
Review-Url: https://codereview.webrtc.org/3010413002
Cr-Commit-Position: refs/heads/master@{#19970}
diff --git a/modules/audio_processing/test/py_quality_assessment/README.md b/modules/audio_processing/test/py_quality_assessment/README.md
index e19a780..79e1650 100644
--- a/modules/audio_processing/test/py_quality_assessment/README.md
+++ b/modules/audio_processing/test/py_quality_assessment/README.md
@@ -33,6 +33,12 @@
- Go to `out/Default/py_quality_assessment` and check that
`apm_quality_assessment.py` exists
+## Unit tests
+
+ - Compile WebRTC
+ - Go to `out/Default/py_quality_assessment`
+ - Run `python -m unittest -p "*_unittest.py" discover`
+
## First time setup
- Deploy PolqaOem64 and set the `POLQA_PATH` environment variable
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py
index 826a089..c488859 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py
@@ -31,9 +31,34 @@
def __init__(self):
pass
+ _GENERIC_METADATA_SUFFIX = '.mdata'
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
@classmethod
+ def LoadFileMetadata(cls, filepath):
+ """Loads generic metadata linked to a file.
+
+ Args:
+ filepath: path to the metadata file to read.
+
+ Returns:
+ A dict.
+ """
+ with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
+ return json.load(f)
+
+ @classmethod
+ def SaveFileMetadata(cls, filepath, metadata):
+ """Saves generic metadata linked to a file.
+
+ Args:
+ filepath: path to the metadata file to write.
+ metadata: a dict.
+ """
+ with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
+ json.dump(metadata, f)
+
+ @classmethod
def LoadAudioTestDataPaths(cls, metadata_path):
"""Loads the input and the reference audio track paths.
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py
index 78d0c18..420afd2 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py
@@ -14,6 +14,13 @@
import os
import re
import subprocess
+import sys
+
+try:
+ import numpy as np
+except ImportError:
+ logging.critical('Cannot import the third-party Python package numpy')
+ sys.exit(1)
from . import data_access
from . import exceptions
@@ -27,6 +34,7 @@
def __init__(self, score_filename_prefix):
self._score_filename_prefix = score_filename_prefix
+ self._input_signal_metadata = None
self._reference_signal = None
self._reference_signal_filepath = None
self._tested_signal = None
@@ -56,8 +64,16 @@
def score(self):
return self._score
+ def SetInputSignalMetadata(self, metadata):
+ """Sets input signal metadata.
+
+ Args:
+ metadata: dict instance.
+ """
+ self._input_signal_metadata = metadata
+
def SetReferenceSignalFilepath(self, filepath):
- """ Sets the path to the audio track used as reference signal.
+ """Sets the path to the audio track used as reference signal.
Args:
filepath: path to the reference audio track.
@@ -65,7 +81,7 @@
self._reference_signal_filepath = filepath
def SetTestedSignalFilepath(self, filepath):
- """ Sets the path to the audio track used as test signal.
+ """Sets the path to the audio track used as test signal.
Args:
filepath: path to the test audio track.
@@ -242,3 +258,84 @@
# Build and return a dictionary with field names (header) as keys and the
# corresponding field values as values.
return {data[0][index]: data[1][index] for index in range(number_of_fields)}
+
+
+@EvaluationScore.RegisterClass
+class TotalHarmonicDistorsionScore(EvaluationScore):
+ """Total harmonic distorsion plus noise score.
+
+ Total harmonic distorsion plus noise score.
+ See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
+
+ Unit: -.
+ Ideal: 0.
+ Worst case: +inf
+ """
+
+ NAME = 'thd'
+
+ def __init__(self, score_filename_prefix):
+ EvaluationScore.__init__(self, score_filename_prefix)
+ self._input_frequency = None
+
+ def _Run(self, output_path):
+ # TODO(aleloi): Integrate changes made locally.
+ self._CheckInputSignal()
+
+ self._LoadTestedSignal()
+ if self._tested_signal.channels != 1:
+ raise exceptions.EvaluationScoreException(
+ 'unsupported number of channels')
+ samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
+ self._tested_signal)
+
+ # Init.
+ num_samples = len(samples)
+ duration = len(self._tested_signal) / 1000.0
+ scaling = 2.0 / num_samples
+ max_freq = self._tested_signal.frame_rate / 2
+ f0_freq = float(self._input_frequency)
+ t = np.linspace(0, duration, num_samples)
+
+ # Analyze harmonics.
+ b_terms = []
+ n = 1
+ while f0_freq * n < max_freq:
+ x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
+ y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
+ b_terms.append(np.sqrt(x_n**2 + y_n**2))
+ n += 1
+
+ output_without_fundamental = samples - b_terms[0] * np.sin(
+ 2.0 * np.pi * f0_freq * t)
+ distortion_and_noise = np.sqrt(np.sum(
+ output_without_fundamental**2) * np.pi * scaling)
+
+ # TODO(alessiob): Fix or remove if not needed.
+ # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
+
+ # TODO(alessiob): Check the range of |thd_plus_noise| and update the class
+ # docstring above if accordingly.
+ thd_plus_noise = distortion_and_noise / b_terms[0]
+
+ self._score = thd_plus_noise
+ self._SaveScore()
+
+ def _CheckInputSignal(self):
+ # Check input signal and get properties.
+ try:
+ if self._input_signal_metadata['signal'] != 'pure_tone':
+ raise exceptions.EvaluationScoreException(
+ 'The THD score requires a pure tone as input signal')
+ self._input_frequency = self._input_signal_metadata['frequency']
+ if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
+ self._input_signal_metadata['test_data_gen_config'] != 'default'):
+ raise exceptions.EvaluationScoreException(
+ 'The THD score cannot be used with any test data generator other '
+ 'than "identity"')
+ except TypeError:
+ raise exceptions.EvaluationScoreException(
+ 'The THD score requires an input signal with associated metadata')
+ except KeyError:
+ raise exceptions.EvaluationScoreException(
+ 'Invalid input signal metadata to compute the THD score')
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
index b3bd4f9..ce51051 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py
@@ -52,6 +52,9 @@
shutil.rmtree(self._output_path)
def testRegisteredClasses(self):
+ # Evaluation score names to exclude (tested separately).
+ exceptions = ['thd']
+
# Preliminary check.
self.assertTrue(os.path.exists(self._output_path))
@@ -69,11 +72,14 @@
# Try each registered evaluation score worker.
for eval_score_name in registered_classes:
+ if eval_score_name in exceptions:
+ continue
+
# Instance evaluation score worker.
eval_score_worker = eval_score_workers_factory.GetInstance(
registered_classes[eval_score_name])
- # Set reference and test, then run.
+ # Set fake input metadata and reference and test file paths, then run.
eval_score_worker.SetReferenceSignalFilepath(
self._fake_reference_signal_filepath)
eval_score_worker.SetTestedSignalFilepath(
@@ -83,3 +89,43 @@
# Check output.
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
self.assertTrue(isinstance(score, float))
+
+ def testTotalHarmonicDistorsionScore(self):
+ # Init.
+ pure_tone_freq = 5000.0
+ eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
+ eval_score_worker.SetInputSignalMetadata({
+ 'signal': 'pure_tone',
+ 'frequency': pure_tone_freq,
+ 'test_data_gen_name': 'identity',
+ 'test_data_gen_config': 'default',
+ })
+ template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
+
+ # Create 3 test signals: pure tone, pure tone + white noise, white noise
+ # only.
+ pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
+ template, pure_tone_freq)
+ white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
+ template)
+ noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
+ pure_tone, white_noise)
+
+ # Compute scores for increasingly distorted pure tone signals.
+ scores = [None, None, None]
+ for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
+ # Save signal.
+ tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
+ signal_processing.SignalProcessingUtils.SaveWav(
+ tmp_filepath, tested_signal)
+
+ # Compute score.
+ eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
+ eval_score_worker.Run(self._output_path)
+ scores[index] = eval_score_worker.score
+
+ # Remove output file to avoid caching.
+ os.remove(eval_score_worker.output_filepath)
+
+ # Validate scores (lowest score with a pure tone).
+ self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py
index e18f193..09ded4c 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py
@@ -20,14 +20,15 @@
pass
@classmethod
- def Run(cls, evaluation_score_workers, apm_output_filepath,
- reference_input_filepath, output_path):
+ def Run(cls, evaluation_score_workers, apm_input_metadata,
+ apm_output_filepath, reference_input_filepath, output_path):
"""Runs the evaluation.
Iterates over the given evaluation score workers.
Args:
evaluation_score_workers: list of EvaluationScore instances.
+ apm_input_metadata: dictionary with metadata of the APM input.
apm_output_filepath: path to the audio track file with the APM output.
reference_input_filepath: path to the reference audio track file.
output_path: output path.
@@ -40,6 +41,7 @@
for evaluation_score_worker in evaluation_score_workers:
logging.info(' computing <%s> score', evaluation_score_worker.NAME)
+ evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
evaluation_score_worker.SetReferenceSignalFilepath(
reference_input_filepath)
evaluation_score_worker.SetTestedSignalFilepath(
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
index 0f7716a..b13b35b 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py
@@ -32,3 +32,9 @@
"""Input signal creator exeception.
"""
pass
+
+
+class EvaluationScoreException(Exception):
+ """Evaluation score exeception.
+ """
+ pass
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py
index e2a720c..5d97c3b 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py
@@ -18,26 +18,36 @@
"""
@classmethod
- def Create(cls, name, params):
- """Creates a input signal.
+ def Create(cls, name, raw_params):
+ """Creates a input signal and its metadata.
Args:
name: Input signal creator name.
- params: Tuple of parameters to pass to the specific signal creator.
+ raw_params: Tuple of parameters to pass to the specific signal creator.
Returns:
- AudioSegment instance.
+ (AudioSegment, dict) tuple.
"""
try:
+ signal = {}
+ params = {}
+
if name == 'pure_tone':
- return cls._CreatePureTone(float(params[0]), int(params[1]))
+ params['frequency'] = float(raw_params[0])
+ params['duration'] = int(raw_params[1])
+ signal = cls._CreatePureTone(params['frequency'], params['duration'])
+ else:
+ raise exceptions.InputSignalCreatorException(
+ 'Invalid input signal creator name')
+
+ # Complete metadata.
+ params['signal'] = name
+
+ return signal, params
except (TypeError, AssertionError) as e:
raise exceptions.InputSignalCreatorException(
'Invalid signal creator parameters: {}'.format(e))
- raise exceptions.InputSignalCreatorException(
- 'Invalid input signal creator name')
-
@classmethod
def _CreatePureTone(cls, frequency, duration):
"""
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py
index 9a1f279..5beb3fb 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py
@@ -149,6 +149,13 @@
volume=0.0)
@classmethod
+ def AudioSegmentToRawData(cls, signal):
+ samples = signal.get_array_of_samples()
+ if samples.typecode != 'h':
+ raise exceptions.SignalProcessingException('Unsupported samples type')
+ return np.array(signal.get_array_of_samples(), np.int16)
+
+ @classmethod
def DetectHardClipping(cls, signal, threshold=2):
"""Detects hard clipping.
@@ -169,13 +176,7 @@
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
raise exceptions.SignalProcessingException(
'hard-clipping detection only supported for 16 bit samples')
-
- # Get raw samples, check type, cast.
- samples = signal.get_array_of_samples()
- if samples.typecode != 'h':
- raise exceptions.SignalProcessingException(
- 'hard-clipping detection only supported for 16 bit samples')
- samples = np.array(signal.get_array_of_samples(), np.int16)
+ samples = cls.AudioSegmentToRawData(signal)
# Detect adjacent clipped samples.
samples_type_info = np.iinfo(samples.dtype)
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
index 7023b6a..b256940 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py
@@ -17,6 +17,7 @@
from . import echo_path_simulation_factory
from . import eval_scores
from . import eval_scores_factory
+from . import exceptions
from . import input_mixer
from . import test_data_generation
from . import test_data_generation_factory
@@ -248,9 +249,20 @@
test_data_cache_path=test_data_cache_path,
base_output_path=output_path)
+ # Extract metadata linked to the clean input file (if any).
+ apm_input_metadata = None
+ try:
+ apm_input_metadata = data_access.Metadata.LoadFileMetadata(
+ clean_capture_input_filepath)
+ except IOError as e:
+ apm_input_metadata = {}
+ apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
+ apm_input_metadata['test_data_gen_config'] = None
+
# For each test data pair, simulate a call and evaluate.
for config_name in test_data_generators.config_names:
logging.info(' - test data generator config: <%s>', config_name)
+ apm_input_metadata['test_data_gen_config'] = config_name
# Paths to the test data generator output.
# Note that the reference signal does not depend on the render input
@@ -278,23 +290,28 @@
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path)
- # Evaluate.
- self._evaluator.Run(
- evaluation_score_workers=self._evaluation_score_workers,
- apm_output_filepath=self._audioproc_wrapper.output_filepath,
- reference_input_filepath=reference_signal_filepath,
- output_path=evaluation_output_path)
+ try:
+ # Evaluate.
+ self._evaluator.Run(
+ evaluation_score_workers=self._evaluation_score_workers,
+ apm_input_metadata=apm_input_metadata,
+ apm_output_filepath=self._audioproc_wrapper.output_filepath,
+ reference_input_filepath=reference_signal_filepath,
+ output_path=evaluation_output_path)
- # Save simulation metadata.
- data_access.Metadata.SaveAudioTestDataPaths(
- output_path=evaluation_output_path,
- clean_capture_input_filepath=clean_capture_input_filepath,
- echo_free_capture_filepath=noisy_capture_input_filepath,
- echo_filepath=echo_path_filepath,
- render_filepath=render_input_filepath,
- capture_filepath=apm_input_filepath,
- apm_output_filepath=self._audioproc_wrapper.output_filepath,
- apm_reference_filepath=reference_signal_filepath)
+ # Save simulation metadata.
+ data_access.Metadata.SaveAudioTestDataPaths(
+ output_path=evaluation_output_path,
+ clean_capture_input_filepath=clean_capture_input_filepath,
+ echo_free_capture_filepath=noisy_capture_input_filepath,
+ echo_filepath=echo_path_filepath,
+ render_filepath=render_input_filepath,
+ capture_filepath=apm_input_filepath,
+ apm_output_filepath=self._audioproc_wrapper.output_filepath,
+ apm_reference_filepath=reference_signal_filepath)
+ except exceptions.EvaluationScoreException as e:
+ logging.warning('the evaluation failed: %s', e.message)
+ continue
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
render_input_filepaths):
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
index 544ad97..33ee921 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py
@@ -9,6 +9,7 @@
"""Unit tests for the simulation module.
"""
+import logging
import os
import shutil
import sys
@@ -33,8 +34,9 @@
"""
def setUp(self):
- """Create temporary folder and fake audio track."""
+ """Create temporary folders and fake audio track."""
self._output_path = tempfile.mkdtemp()
+ self._tmp_path = tempfile.mkdtemp()
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
@@ -46,6 +48,7 @@
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._output_path)
+ shutil.rmtree(self._tmp_path)
def testSimulation(self):
# Instance dependencies to inject and mock.
@@ -87,3 +90,39 @@
min_number_of_simulations)
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)
+
+ def testPureToneGenerationWithTotalHarmonicDistorsion(self):
+ logging.warning = mock.MagicMock(name='warning')
+
+ # Instance simulator.
+ simulator = simulation.ApmModuleSimulator(
+ aechen_ir_database_path='',
+ polqa_tool_bin_path=os.path.join(
+ os.path.dirname(__file__), 'fake_polqa'),
+ ap_wrapper=audioproc_wrapper.AudioProcWrapper(),
+ evaluator=evaluation.ApmModuleEvaluator())
+
+ # What to simulate.
+ config_files = ['apm_configs/default.json']
+ input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
+ eval_scores = ['thd']
+
+ # Should work.
+ simulator.Run(
+ config_filepaths=config_files,
+ capture_input_filepaths=input_files,
+ test_data_generator_names=['identity'],
+ eval_score_names=eval_scores,
+ output_dir=self._output_path)
+ self.assertFalse(logging.warning.called)
+
+ # Warning expected.
+ simulator.Run(
+ config_filepaths=config_files,
+ capture_input_filepaths=input_files,
+ test_data_generator_names=['white_noise'], # Not allowed with THD.
+ eval_score_names=eval_scores,
+ output_dir=self._output_path)
+ logging.warning.assert_called_with('the evaluation failed: %s', (
+ 'The THD score cannot be used with any test data generator other than '
+ '"identity"'))
diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
index 3d54da5..4153f73 100644
--- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
+++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py
@@ -147,11 +147,12 @@
raise exceptions.InputSignalCreatorException(
'Cannot parse input signal file name')
- signal = input_signal_creator.InputSignalCreator.Create(
+ signal, metadata = input_signal_creator.InputSignalCreator.Create(
filename_parts[0], filename_parts[1].split('_'))
signal_processing.SignalProcessingUtils.SaveWav(
input_signal_filepath, signal)
+ data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):