Implement a Neon optimized function to find the argmax element in an array.

Finding the array element with the largest argmax is a fairly common
operation, so it makes sense to have a Neon optimized version. The
implementation is done by first finding both the min and max value, and
then returning whichever has the largest argmax.

Bug: chromium:12355
Change-Id: I088bd4f7d469b2424a7265de10fffb42764567a1
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/201622
Commit-Queue: Ivo Creusen <ivoc@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33052}
diff --git a/common_audio/signal_processing/include/signal_processing_library.h b/common_audio/signal_processing/include/signal_processing_library.h
index 4ad92c4..0c13071 100644
--- a/common_audio/signal_processing/include/signal_processing_library.h
+++ b/common_audio/signal_processing/include/signal_processing_library.h
@@ -228,6 +228,25 @@
 int32_t WebRtcSpl_MinValueW32_mips(const int32_t* vector, size_t length);
 #endif
 
+// Returns both the minimum and maximum values of a 16-bit vector.
+//
+// Input:
+//      - vector : 16-bit input vector.
+//      - length : Number of samples in vector.
+// Ouput:
+//      - max_val : Maximum sample value in |vector|.
+//      - min_val : Minimum sample value in |vector|.
+void WebRtcSpl_MinMaxW16(const int16_t* vector,
+                         size_t length,
+                         int16_t* min_val,
+                         int16_t* max_val);
+#if defined(WEBRTC_HAS_NEON)
+void WebRtcSpl_MinMaxW16Neon(const int16_t* vector,
+                             size_t length,
+                             int16_t* min_val,
+                             int16_t* max_val);
+#endif
+
 // Returns the vector index to the largest absolute value of a 16-bit vector.
 //
 // Input:
@@ -240,6 +259,17 @@
 //                 -32768 presenting an int16 absolute value of 32767).
 size_t WebRtcSpl_MaxAbsIndexW16(const int16_t* vector, size_t length);
 
+// Returns the element with the largest absolute value of a 16-bit vector. Note
+// that this function can return a negative value.
+//
+// Input:
+//      - vector : 16-bit input vector.
+//      - length : Number of samples in vector.
+//
+// Return value  : The element with the largest absolute value. Note that this
+//                 may be a negative value.
+int16_t WebRtcSpl_MaxAbsElementW16(const int16_t* vector, size_t length);
+
 // Returns the vector index to the maximum sample value of a 16-bit vector.
 //
 // Input:
diff --git a/common_audio/signal_processing/min_max_operations.c b/common_audio/signal_processing/min_max_operations.c
index d249a02..1b9542e 100644
--- a/common_audio/signal_processing/min_max_operations.c
+++ b/common_audio/signal_processing/min_max_operations.c
@@ -155,6 +155,15 @@
   return index;
 }
 
+int16_t WebRtcSpl_MaxAbsElementW16(const int16_t* vector, size_t length) {
+  int16_t min_val, max_val;
+  WebRtcSpl_MinMaxW16(vector, length, &min_val, &max_val);
+  if (min_val == max_val || min_val < -max_val) {
+    return min_val;
+  }
+  return max_val;
+}
+
 // Index of maximum value in a word16 vector.
 size_t WebRtcSpl_MaxIndexW16(const int16_t* vector, size_t length) {
   size_t i = 0, index = 0;
@@ -222,3 +231,26 @@
 
   return index;
 }
+
+// Finds both the minimum and maximum elements in an array of 16-bit integers.
+void WebRtcSpl_MinMaxW16(const int16_t* vector, size_t length,
+                         int16_t* min_val, int16_t* max_val) {
+#if defined(WEBRTC_HAS_NEON)
+  return WebRtcSpl_MinMaxW16Neon(vector, length, min_val, max_val);
+#else
+  int16_t minimum = WEBRTC_SPL_WORD16_MAX;
+  int16_t maximum = WEBRTC_SPL_WORD16_MIN;
+  size_t i = 0;
+
+  RTC_DCHECK_GT(length, 0);
+
+  for (i = 0; i < length; i++) {
+    if (vector[i] < minimum)
+      minimum = vector[i];
+    if (vector[i] > maximum)
+      maximum = vector[i];
+  }
+  *min_val = minimum;
+  *max_val = maximum;
+#endif
+}
diff --git a/common_audio/signal_processing/min_max_operations_neon.c b/common_audio/signal_processing/min_max_operations_neon.c
index 53217df..e5b4b7c 100644
--- a/common_audio/signal_processing/min_max_operations_neon.c
+++ b/common_audio/signal_processing/min_max_operations_neon.c
@@ -281,3 +281,53 @@
   return minimum;
 }
 
+// Finds both the minimum and maximum elements in an array of 16-bit integers.
+void WebRtcSpl_MinMaxW16Neon(const int16_t* vector, size_t length,
+                             int16_t* min_val, int16_t* max_val) {
+  int16_t minimum = WEBRTC_SPL_WORD16_MAX;
+  int16_t maximum = WEBRTC_SPL_WORD16_MIN;
+  size_t i = 0;
+  size_t residual = length & 0x7;
+
+  RTC_DCHECK_GT(length, 0);
+
+  const int16_t* p_start = vector;
+  int16x8_t min16x8 = vdupq_n_s16(WEBRTC_SPL_WORD16_MAX);
+  int16x8_t max16x8 = vdupq_n_s16(WEBRTC_SPL_WORD16_MIN);
+
+  // First part, unroll the loop 8 times.
+  for (i = 0; i < length - residual; i += 8) {
+    int16x8_t in16x8 = vld1q_s16(p_start);
+    min16x8 = vminq_s16(min16x8, in16x8);
+    max16x8 = vmaxq_s16(max16x8, in16x8);
+    p_start += 8;
+  }
+
+#if defined(WEBRTC_ARCH_ARM64)
+  minimum = vminvq_s16(min16x8);
+  maximum = vmaxvq_s16(max16x8);
+#else
+  int16x4_t min16x4 = vmin_s16(vget_low_s16(min16x8), vget_high_s16(min16x8));
+  min16x4 = vpmin_s16(min16x4, min16x4);
+  min16x4 = vpmin_s16(min16x4, min16x4);
+
+  minimum = vget_lane_s16(min16x4, 0);
+
+  int16x4_t max16x4 = vmax_s16(vget_low_s16(max16x8), vget_high_s16(max16x8));
+  max16x4 = vpmax_s16(max16x4, max16x4);
+  max16x4 = vpmax_s16(max16x4, max16x4);
+
+  maximum = vget_lane_s16(max16x4, 0);
+#endif
+
+  // Second part, do the remaining iterations (if any).
+  for (i = residual; i > 0; i--) {
+    if (*p_start < minimum)
+      minimum = *p_start;
+    if (*p_start > maximum)
+      maximum = *p_start;
+    p_start++;
+  }
+  *min_val = minimum;
+  *max_val = maximum;
+}
diff --git a/common_audio/signal_processing/signal_processing_unittest.cc b/common_audio/signal_processing/signal_processing_unittest.cc
index 3106c47..9ec8590 100644
--- a/common_audio/signal_processing/signal_processing_unittest.cc
+++ b/common_audio/signal_processing/signal_processing_unittest.cc
@@ -289,6 +289,12 @@
             WebRtcSpl_MinValueW32(vector32, kVectorSize));
   EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MinIndexW16(vector16, kVectorSize));
   EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MinIndexW32(vector32, kVectorSize));
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
+            WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
+  int16_t min_value, max_value;
+  WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
+  EXPECT_EQ(12334, max_value);
 
   // Test the cases where maximum values have to be caught
   // outside of the unrolled loops in ARM-Neon.
@@ -306,6 +312,11 @@
   EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxAbsIndexW16(vector16, kVectorSize));
   EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxIndexW16(vector16, kVectorSize));
   EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxIndexW32(vector32, kVectorSize));
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
+            WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
+  WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
+  EXPECT_EQ(-29871, min_value);
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
 
   // Test the cases where multiple maximum and minimum values are present.
   vector16[1] = WEBRTC_SPL_WORD16_MAX;
@@ -332,6 +343,43 @@
   EXPECT_EQ(1u, WebRtcSpl_MaxIndexW32(vector32, kVectorSize));
   EXPECT_EQ(6u, WebRtcSpl_MinIndexW16(vector16, kVectorSize));
   EXPECT_EQ(6u, WebRtcSpl_MinIndexW32(vector32, kVectorSize));
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
+            WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
+  WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
+
+  // Test a one-element vector.
+  int16_t single_element_vector = 0;
+  EXPECT_EQ(0, WebRtcSpl_MaxAbsValueW16(&single_element_vector, 1));
+  EXPECT_EQ(0, WebRtcSpl_MaxValueW16(&single_element_vector, 1));
+  EXPECT_EQ(0, WebRtcSpl_MinValueW16(&single_element_vector, 1));
+  EXPECT_EQ(0u, WebRtcSpl_MaxAbsIndexW16(&single_element_vector, 1));
+  EXPECT_EQ(0u, WebRtcSpl_MaxIndexW16(&single_element_vector, 1));
+  EXPECT_EQ(0u, WebRtcSpl_MinIndexW16(&single_element_vector, 1));
+  EXPECT_EQ(0, WebRtcSpl_MaxAbsElementW16(&single_element_vector, 1));
+  WebRtcSpl_MinMaxW16(&single_element_vector, 1, &min_value, &max_value);
+  EXPECT_EQ(0, min_value);
+  EXPECT_EQ(0, max_value);
+
+  // Test a two-element vector with the values WEBRTC_SPL_WORD16_MIN and
+  // WEBRTC_SPL_WORD16_MAX.
+  int16_t two_element_vector[2] = {WEBRTC_SPL_WORD16_MIN,
+                                   WEBRTC_SPL_WORD16_MAX};
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
+            WebRtcSpl_MaxAbsValueW16(two_element_vector, 2));
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
+            WebRtcSpl_MaxValueW16(two_element_vector, 2));
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
+            WebRtcSpl_MinValueW16(two_element_vector, 2));
+  EXPECT_EQ(0u, WebRtcSpl_MaxAbsIndexW16(two_element_vector, 2));
+  EXPECT_EQ(1u, WebRtcSpl_MaxIndexW16(two_element_vector, 2));
+  EXPECT_EQ(0u, WebRtcSpl_MinIndexW16(two_element_vector, 2));
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
+            WebRtcSpl_MaxAbsElementW16(two_element_vector, 2));
+  WebRtcSpl_MinMaxW16(two_element_vector, 2, &min_value, &max_value);
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
+  EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
 }
 
 TEST(SplTest, VectorOperationsTest) {
diff --git a/modules/audio_coding/codecs/ilbc/enhancer_interface.c b/modules/audio_coding/codecs/ilbc/enhancer_interface.c
index 71436c2..ca23e19 100644
--- a/modules/audio_coding/codecs/ilbc/enhancer_interface.c
+++ b/modules/audio_coding/codecs/ilbc/enhancer_interface.c
@@ -205,9 +205,9 @@
 
     /* scaling */
     // Note that this is not abs-max, so we will take the absolute value below.
-    max16 = regressor[WebRtcSpl_MaxAbsIndexW16(regressor, plc_blockl + 3 - 1)];
+    max16 = WebRtcSpl_MaxAbsElementW16(regressor, plc_blockl + 3 - 1);
     const int16_t max_target =
-        target[WebRtcSpl_MaxAbsIndexW16(target, plc_blockl + 3 - 1)];
+        WebRtcSpl_MaxAbsElementW16(target, plc_blockl + 3 - 1);
     const int64_t max_val = plc_blockl * abs(max16 * max_target);
     const int32_t factor = max_val >> 31;
     shifts = factor == 0 ? 0 : 31 - WebRtcSpl_NormW32(factor);
diff --git a/modules/audio_coding/neteq/cross_correlation.cc b/modules/audio_coding/neteq/cross_correlation.cc
index 7ee867a..37ed937 100644
--- a/modules/audio_coding/neteq/cross_correlation.cc
+++ b/modules/audio_coding/neteq/cross_correlation.cc
@@ -25,22 +25,23 @@
                                   size_t cross_correlation_length,
                                   int cross_correlation_step,
                                   int32_t* cross_correlation) {
-  // Find the maximum absolute value of sequence_1 and 2.
-  const int32_t max_1 =
-      abs(sequence_1[WebRtcSpl_MaxAbsIndexW16(sequence_1, sequence_1_length)]);
+  // Find the element that has the maximum absolute value of sequence_1 and 2.
+  // Note that these values may be negative.
+  const int16_t max_1 =
+      WebRtcSpl_MaxAbsElementW16(sequence_1, sequence_1_length);
   const int sequence_2_shift =
       cross_correlation_step * (static_cast<int>(cross_correlation_length) - 1);
   const int16_t* sequence_2_start =
       sequence_2_shift >= 0 ? sequence_2 : sequence_2 + sequence_2_shift;
   const size_t sequence_2_length =
       sequence_1_length + std::abs(sequence_2_shift);
-  const int32_t max_2 = abs(sequence_2_start[WebRtcSpl_MaxAbsIndexW16(
-      sequence_2_start, sequence_2_length)]);
+  const int16_t max_2 =
+      WebRtcSpl_MaxAbsElementW16(sequence_2_start, sequence_2_length);
 
   // In order to avoid overflow when computing the sum we should scale the
   // samples so that (in_vector_length * max_1 * max_2) will not overflow.
   const int64_t max_value =
-      max_1 * max_2 * static_cast<int64_t>(sequence_1_length);
+      abs(max_1 * max_2) * static_cast<int64_t>(sequence_1_length);
   const int32_t factor = max_value >> 31;
   const int scaling = factor == 0 ? 0 : 31 - WebRtcSpl_NormW32(factor);