Implement a Neon optimized function to find the argmax element in an array.
Finding the array element with the largest argmax is a fairly common
operation, so it makes sense to have a Neon optimized version. The
implementation is done by first finding both the min and max value, and
then returning whichever has the largest argmax.
Bug: chromium:12355
Change-Id: I088bd4f7d469b2424a7265de10fffb42764567a1
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/201622
Commit-Queue: Ivo Creusen <ivoc@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33052}
diff --git a/common_audio/signal_processing/include/signal_processing_library.h b/common_audio/signal_processing/include/signal_processing_library.h
index 4ad92c4..0c13071 100644
--- a/common_audio/signal_processing/include/signal_processing_library.h
+++ b/common_audio/signal_processing/include/signal_processing_library.h
@@ -228,6 +228,25 @@
int32_t WebRtcSpl_MinValueW32_mips(const int32_t* vector, size_t length);
#endif
+// Returns both the minimum and maximum values of a 16-bit vector.
+//
+// Input:
+// - vector : 16-bit input vector.
+// - length : Number of samples in vector.
+// Ouput:
+// - max_val : Maximum sample value in |vector|.
+// - min_val : Minimum sample value in |vector|.
+void WebRtcSpl_MinMaxW16(const int16_t* vector,
+ size_t length,
+ int16_t* min_val,
+ int16_t* max_val);
+#if defined(WEBRTC_HAS_NEON)
+void WebRtcSpl_MinMaxW16Neon(const int16_t* vector,
+ size_t length,
+ int16_t* min_val,
+ int16_t* max_val);
+#endif
+
// Returns the vector index to the largest absolute value of a 16-bit vector.
//
// Input:
@@ -240,6 +259,17 @@
// -32768 presenting an int16 absolute value of 32767).
size_t WebRtcSpl_MaxAbsIndexW16(const int16_t* vector, size_t length);
+// Returns the element with the largest absolute value of a 16-bit vector. Note
+// that this function can return a negative value.
+//
+// Input:
+// - vector : 16-bit input vector.
+// - length : Number of samples in vector.
+//
+// Return value : The element with the largest absolute value. Note that this
+// may be a negative value.
+int16_t WebRtcSpl_MaxAbsElementW16(const int16_t* vector, size_t length);
+
// Returns the vector index to the maximum sample value of a 16-bit vector.
//
// Input:
diff --git a/common_audio/signal_processing/min_max_operations.c b/common_audio/signal_processing/min_max_operations.c
index d249a02..1b9542e 100644
--- a/common_audio/signal_processing/min_max_operations.c
+++ b/common_audio/signal_processing/min_max_operations.c
@@ -155,6 +155,15 @@
return index;
}
+int16_t WebRtcSpl_MaxAbsElementW16(const int16_t* vector, size_t length) {
+ int16_t min_val, max_val;
+ WebRtcSpl_MinMaxW16(vector, length, &min_val, &max_val);
+ if (min_val == max_val || min_val < -max_val) {
+ return min_val;
+ }
+ return max_val;
+}
+
// Index of maximum value in a word16 vector.
size_t WebRtcSpl_MaxIndexW16(const int16_t* vector, size_t length) {
size_t i = 0, index = 0;
@@ -222,3 +231,26 @@
return index;
}
+
+// Finds both the minimum and maximum elements in an array of 16-bit integers.
+void WebRtcSpl_MinMaxW16(const int16_t* vector, size_t length,
+ int16_t* min_val, int16_t* max_val) {
+#if defined(WEBRTC_HAS_NEON)
+ return WebRtcSpl_MinMaxW16Neon(vector, length, min_val, max_val);
+#else
+ int16_t minimum = WEBRTC_SPL_WORD16_MAX;
+ int16_t maximum = WEBRTC_SPL_WORD16_MIN;
+ size_t i = 0;
+
+ RTC_DCHECK_GT(length, 0);
+
+ for (i = 0; i < length; i++) {
+ if (vector[i] < minimum)
+ minimum = vector[i];
+ if (vector[i] > maximum)
+ maximum = vector[i];
+ }
+ *min_val = minimum;
+ *max_val = maximum;
+#endif
+}
diff --git a/common_audio/signal_processing/min_max_operations_neon.c b/common_audio/signal_processing/min_max_operations_neon.c
index 53217df..e5b4b7c 100644
--- a/common_audio/signal_processing/min_max_operations_neon.c
+++ b/common_audio/signal_processing/min_max_operations_neon.c
@@ -281,3 +281,53 @@
return minimum;
}
+// Finds both the minimum and maximum elements in an array of 16-bit integers.
+void WebRtcSpl_MinMaxW16Neon(const int16_t* vector, size_t length,
+ int16_t* min_val, int16_t* max_val) {
+ int16_t minimum = WEBRTC_SPL_WORD16_MAX;
+ int16_t maximum = WEBRTC_SPL_WORD16_MIN;
+ size_t i = 0;
+ size_t residual = length & 0x7;
+
+ RTC_DCHECK_GT(length, 0);
+
+ const int16_t* p_start = vector;
+ int16x8_t min16x8 = vdupq_n_s16(WEBRTC_SPL_WORD16_MAX);
+ int16x8_t max16x8 = vdupq_n_s16(WEBRTC_SPL_WORD16_MIN);
+
+ // First part, unroll the loop 8 times.
+ for (i = 0; i < length - residual; i += 8) {
+ int16x8_t in16x8 = vld1q_s16(p_start);
+ min16x8 = vminq_s16(min16x8, in16x8);
+ max16x8 = vmaxq_s16(max16x8, in16x8);
+ p_start += 8;
+ }
+
+#if defined(WEBRTC_ARCH_ARM64)
+ minimum = vminvq_s16(min16x8);
+ maximum = vmaxvq_s16(max16x8);
+#else
+ int16x4_t min16x4 = vmin_s16(vget_low_s16(min16x8), vget_high_s16(min16x8));
+ min16x4 = vpmin_s16(min16x4, min16x4);
+ min16x4 = vpmin_s16(min16x4, min16x4);
+
+ minimum = vget_lane_s16(min16x4, 0);
+
+ int16x4_t max16x4 = vmax_s16(vget_low_s16(max16x8), vget_high_s16(max16x8));
+ max16x4 = vpmax_s16(max16x4, max16x4);
+ max16x4 = vpmax_s16(max16x4, max16x4);
+
+ maximum = vget_lane_s16(max16x4, 0);
+#endif
+
+ // Second part, do the remaining iterations (if any).
+ for (i = residual; i > 0; i--) {
+ if (*p_start < minimum)
+ minimum = *p_start;
+ if (*p_start > maximum)
+ maximum = *p_start;
+ p_start++;
+ }
+ *min_val = minimum;
+ *max_val = maximum;
+}
diff --git a/common_audio/signal_processing/signal_processing_unittest.cc b/common_audio/signal_processing/signal_processing_unittest.cc
index 3106c47..9ec8590 100644
--- a/common_audio/signal_processing/signal_processing_unittest.cc
+++ b/common_audio/signal_processing/signal_processing_unittest.cc
@@ -289,6 +289,12 @@
WebRtcSpl_MinValueW32(vector32, kVectorSize));
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MinIndexW16(vector16, kVectorSize));
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MinIndexW32(vector32, kVectorSize));
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
+ WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
+ int16_t min_value, max_value;
+ WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
+ EXPECT_EQ(12334, max_value);
// Test the cases where maximum values have to be caught
// outside of the unrolled loops in ARM-Neon.
@@ -306,6 +312,11 @@
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxAbsIndexW16(vector16, kVectorSize));
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxIndexW16(vector16, kVectorSize));
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxIndexW32(vector32, kVectorSize));
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
+ WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
+ WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
+ EXPECT_EQ(-29871, min_value);
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
// Test the cases where multiple maximum and minimum values are present.
vector16[1] = WEBRTC_SPL_WORD16_MAX;
@@ -332,6 +343,43 @@
EXPECT_EQ(1u, WebRtcSpl_MaxIndexW32(vector32, kVectorSize));
EXPECT_EQ(6u, WebRtcSpl_MinIndexW16(vector16, kVectorSize));
EXPECT_EQ(6u, WebRtcSpl_MinIndexW32(vector32, kVectorSize));
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
+ WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
+ WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
+
+ // Test a one-element vector.
+ int16_t single_element_vector = 0;
+ EXPECT_EQ(0, WebRtcSpl_MaxAbsValueW16(&single_element_vector, 1));
+ EXPECT_EQ(0, WebRtcSpl_MaxValueW16(&single_element_vector, 1));
+ EXPECT_EQ(0, WebRtcSpl_MinValueW16(&single_element_vector, 1));
+ EXPECT_EQ(0u, WebRtcSpl_MaxAbsIndexW16(&single_element_vector, 1));
+ EXPECT_EQ(0u, WebRtcSpl_MaxIndexW16(&single_element_vector, 1));
+ EXPECT_EQ(0u, WebRtcSpl_MinIndexW16(&single_element_vector, 1));
+ EXPECT_EQ(0, WebRtcSpl_MaxAbsElementW16(&single_element_vector, 1));
+ WebRtcSpl_MinMaxW16(&single_element_vector, 1, &min_value, &max_value);
+ EXPECT_EQ(0, min_value);
+ EXPECT_EQ(0, max_value);
+
+ // Test a two-element vector with the values WEBRTC_SPL_WORD16_MIN and
+ // WEBRTC_SPL_WORD16_MAX.
+ int16_t two_element_vector[2] = {WEBRTC_SPL_WORD16_MIN,
+ WEBRTC_SPL_WORD16_MAX};
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
+ WebRtcSpl_MaxAbsValueW16(two_element_vector, 2));
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
+ WebRtcSpl_MaxValueW16(two_element_vector, 2));
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
+ WebRtcSpl_MinValueW16(two_element_vector, 2));
+ EXPECT_EQ(0u, WebRtcSpl_MaxAbsIndexW16(two_element_vector, 2));
+ EXPECT_EQ(1u, WebRtcSpl_MaxIndexW16(two_element_vector, 2));
+ EXPECT_EQ(0u, WebRtcSpl_MinIndexW16(two_element_vector, 2));
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
+ WebRtcSpl_MaxAbsElementW16(two_element_vector, 2));
+ WebRtcSpl_MinMaxW16(two_element_vector, 2, &min_value, &max_value);
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
+ EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
}
TEST(SplTest, VectorOperationsTest) {
diff --git a/modules/audio_coding/codecs/ilbc/enhancer_interface.c b/modules/audio_coding/codecs/ilbc/enhancer_interface.c
index 71436c2..ca23e19 100644
--- a/modules/audio_coding/codecs/ilbc/enhancer_interface.c
+++ b/modules/audio_coding/codecs/ilbc/enhancer_interface.c
@@ -205,9 +205,9 @@
/* scaling */
// Note that this is not abs-max, so we will take the absolute value below.
- max16 = regressor[WebRtcSpl_MaxAbsIndexW16(regressor, plc_blockl + 3 - 1)];
+ max16 = WebRtcSpl_MaxAbsElementW16(regressor, plc_blockl + 3 - 1);
const int16_t max_target =
- target[WebRtcSpl_MaxAbsIndexW16(target, plc_blockl + 3 - 1)];
+ WebRtcSpl_MaxAbsElementW16(target, plc_blockl + 3 - 1);
const int64_t max_val = plc_blockl * abs(max16 * max_target);
const int32_t factor = max_val >> 31;
shifts = factor == 0 ? 0 : 31 - WebRtcSpl_NormW32(factor);
diff --git a/modules/audio_coding/neteq/cross_correlation.cc b/modules/audio_coding/neteq/cross_correlation.cc
index 7ee867a..37ed937 100644
--- a/modules/audio_coding/neteq/cross_correlation.cc
+++ b/modules/audio_coding/neteq/cross_correlation.cc
@@ -25,22 +25,23 @@
size_t cross_correlation_length,
int cross_correlation_step,
int32_t* cross_correlation) {
- // Find the maximum absolute value of sequence_1 and 2.
- const int32_t max_1 =
- abs(sequence_1[WebRtcSpl_MaxAbsIndexW16(sequence_1, sequence_1_length)]);
+ // Find the element that has the maximum absolute value of sequence_1 and 2.
+ // Note that these values may be negative.
+ const int16_t max_1 =
+ WebRtcSpl_MaxAbsElementW16(sequence_1, sequence_1_length);
const int sequence_2_shift =
cross_correlation_step * (static_cast<int>(cross_correlation_length) - 1);
const int16_t* sequence_2_start =
sequence_2_shift >= 0 ? sequence_2 : sequence_2 + sequence_2_shift;
const size_t sequence_2_length =
sequence_1_length + std::abs(sequence_2_shift);
- const int32_t max_2 = abs(sequence_2_start[WebRtcSpl_MaxAbsIndexW16(
- sequence_2_start, sequence_2_length)]);
+ const int16_t max_2 =
+ WebRtcSpl_MaxAbsElementW16(sequence_2_start, sequence_2_length);
// In order to avoid overflow when computing the sum we should scale the
// samples so that (in_vector_length * max_1 * max_2) will not overflow.
const int64_t max_value =
- max_1 * max_2 * static_cast<int64_t>(sequence_1_length);
+ abs(max_1 * max_2) * static_cast<int64_t>(sequence_1_length);
const int32_t factor = max_value >> 31;
const int scaling = factor == 0 ? 0 : 31 - WebRtcSpl_NormW32(factor);