NV12 support for VP8 simulcast

Tested using video_loopback with generated NV12 frames.

Bug: webrtc:11635, webrtc:11975
Change-Id: I14b2d663c55a83d80e48e226fcf706cb18903193
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/186722
Commit-Queue: Evan Shrubsole <eshr@google.com>
Reviewed-by: Ilya Nikolaevskiy <ilnik@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32325}
diff --git a/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.cc b/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.cc
index ceb0a2e..926a993 100644
--- a/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.cc
+++ b/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.cc
@@ -497,6 +497,10 @@
     return WEBRTC_VIDEO_CODEC_ERR_PARAMETER;
   }
 
+  // Use the previous pixel format to avoid extra image allocations.
+  vpx_img_fmt_t pixel_format =
+      raw_images_.empty() ? VPX_IMG_FMT_I420 : raw_images_[0].fmt;
+
   int retVal = Release();
   if (retVal < 0) {
     return retVal;
@@ -650,8 +654,8 @@
   // Creating a wrapper to the image - setting image data to NULL.
   // Actual pointer will be set in encode. Setting align to 1, as it
   // is meaningless (no memory allocation is done here).
-  libvpx_->img_wrap(&raw_images_[0], VPX_IMG_FMT_I420, inst->width,
-                    inst->height, 1, NULL);
+  libvpx_->img_wrap(&raw_images_[0], pixel_format, inst->width, inst->height, 1,
+                    NULL);
 
   // Note the order we use is different from webm, we have lowest resolution
   // at position 0 and they have highest resolution at position 0.
@@ -699,10 +703,9 @@
     // Setting alignment to 32 - as that ensures at least 16 for all
     // planes (32 for Y, 16 for U,V). Libvpx sets the requested stride for
     // the y plane, but only half of it to the u and v planes.
-    libvpx_->img_alloc(&raw_images_[i], VPX_IMG_FMT_I420,
-                       inst->simulcastStream[stream_idx].width,
-                       inst->simulcastStream[stream_idx].height,
-                       kVp832ByteAlign);
+    libvpx_->img_alloc(
+        &raw_images_[i], pixel_format, inst->simulcastStream[stream_idx].width,
+        inst->simulcastStream[stream_idx].height, kVp832ByteAlign);
     SetStreamState(stream_bitrates[stream_idx] > 0, stream_idx);
     vpx_configs_[i].rc_target_bitrate = stream_bitrates[stream_idx];
     if (stream_bitrates[stream_idx] > 0) {
@@ -1014,26 +1017,12 @@
     flags[i] = send_key_frame ? VPX_EFLAG_FORCE_KF : EncodeFlags(tl_configs[i]);
   }
 
-  rtc::scoped_refptr<I420BufferInterface> input_image =
-      frame.video_frame_buffer()->ToI420();
-  // Since we are extracting raw pointers from |input_image| to
-  // |raw_images_[0]|, the resolution of these frames must match.
-  RTC_DCHECK_EQ(input_image->width(), raw_images_[0].d_w);
-  RTC_DCHECK_EQ(input_image->height(), raw_images_[0].d_h);
-
-  // Image in vpx_image_t format.
-  // Input image is const. VP8's raw image is not defined as const.
-  raw_images_[0].planes[VPX_PLANE_Y] =
-      const_cast<uint8_t*>(input_image->DataY());
-  raw_images_[0].planes[VPX_PLANE_U] =
-      const_cast<uint8_t*>(input_image->DataU());
-  raw_images_[0].planes[VPX_PLANE_V] =
-      const_cast<uint8_t*>(input_image->DataV());
-
-  raw_images_[0].stride[VPX_PLANE_Y] = input_image->StrideY();
-  raw_images_[0].stride[VPX_PLANE_U] = input_image->StrideU();
-  raw_images_[0].stride[VPX_PLANE_V] = input_image->StrideV();
-
+  rtc::scoped_refptr<VideoFrameBuffer> input_image = frame.video_frame_buffer();
+  if (input_image->type() != VideoFrameBuffer::Type::kI420 &&
+      input_image->type() != VideoFrameBuffer::Type::kNV12) {
+    input_image = input_image->ToI420();
+  }
+  PrepareRawImagesForEncoding(input_image);
   struct CleanUpOnExit {
     explicit CleanUpOnExit(vpx_image_t& raw_image) : raw_image_(raw_image) {}
     ~CleanUpOnExit() {
@@ -1044,22 +1033,6 @@
     vpx_image_t& raw_image_;
   } clean_up_on_exit(raw_images_[0]);
 
-  for (size_t i = 1; i < encoders_.size(); ++i) {
-    // Scale the image down a number of times by downsampling factor
-    libyuv::I420Scale(
-        raw_images_[i - 1].planes[VPX_PLANE_Y],
-        raw_images_[i - 1].stride[VPX_PLANE_Y],
-        raw_images_[i - 1].planes[VPX_PLANE_U],
-        raw_images_[i - 1].stride[VPX_PLANE_U],
-        raw_images_[i - 1].planes[VPX_PLANE_V],
-        raw_images_[i - 1].stride[VPX_PLANE_V], raw_images_[i - 1].d_w,
-        raw_images_[i - 1].d_h, raw_images_[i].planes[VPX_PLANE_Y],
-        raw_images_[i].stride[VPX_PLANE_Y], raw_images_[i].planes[VPX_PLANE_U],
-        raw_images_[i].stride[VPX_PLANE_U], raw_images_[i].planes[VPX_PLANE_V],
-        raw_images_[i].stride[VPX_PLANE_V], raw_images_[i].d_w,
-        raw_images_[i].d_h, libyuv::kFilterBilinear);
-  }
-
   if (send_key_frame) {
     // Adapt the size of the key frame when in screenshare with 1 temporal
     // layer.
@@ -1309,6 +1282,105 @@
   return WEBRTC_VIDEO_CODEC_OK;
 }
 
+void LibvpxVp8Encoder::PrepareRawImagesForEncoding(
+    const rtc::scoped_refptr<VideoFrameBuffer>& frame) {
+  // Since we are extracting raw pointers from |input_image| to
+  // |raw_images_[0]|, the resolution of these frames must match.
+  RTC_DCHECK_EQ(frame->width(), raw_images_[0].d_w);
+  RTC_DCHECK_EQ(frame->height(), raw_images_[0].d_h);
+  switch (frame->type()) {
+    case VideoFrameBuffer::Type::kI420:
+      return PrepareI420Image(frame->GetI420());
+    case VideoFrameBuffer::Type::kNV12:
+      return PrepareNV12Image(frame->GetNV12());
+    default:
+      RTC_NOTREACHED();
+  }
+}
+
+void LibvpxVp8Encoder::MaybeUpdatePixelFormat(vpx_img_fmt fmt) {
+  RTC_DCHECK(!raw_images_.empty());
+  if (raw_images_[0].fmt == fmt) {
+    RTC_DCHECK(std::all_of(
+        std::next(raw_images_.begin()), raw_images_.end(),
+        [fmt](const vpx_image_t& raw_img) { return raw_img.fmt == fmt; }))
+        << "Not all raw images had the right format!";
+    return;
+  }
+  RTC_LOG(INFO) << "Updating vp8 encoder pixel format to "
+                << (fmt == VPX_IMG_FMT_NV12 ? "NV12" : "I420");
+  for (size_t i = 0; i < raw_images_.size(); ++i) {
+    vpx_image_t& img = raw_images_[i];
+    auto d_w = img.d_w;
+    auto d_h = img.d_h;
+    libvpx_->img_free(&img);
+    // First image is wrapping the input frame, the rest are allocated.
+    if (i == 0) {
+      libvpx_->img_wrap(&img, fmt, d_w, d_h, 1, NULL);
+    } else {
+      libvpx_->img_alloc(&img, fmt, d_w, d_h, kVp832ByteAlign);
+    }
+  }
+}
+
+void LibvpxVp8Encoder::PrepareI420Image(const I420BufferInterface* frame) {
+  RTC_DCHECK(!raw_images_.empty());
+  MaybeUpdatePixelFormat(VPX_IMG_FMT_I420);
+  // Image in vpx_image_t format.
+  // Input image is const. VP8's raw image is not defined as const.
+  raw_images_[0].planes[VPX_PLANE_Y] = const_cast<uint8_t*>(frame->DataY());
+  raw_images_[0].planes[VPX_PLANE_U] = const_cast<uint8_t*>(frame->DataU());
+  raw_images_[0].planes[VPX_PLANE_V] = const_cast<uint8_t*>(frame->DataV());
+
+  raw_images_[0].stride[VPX_PLANE_Y] = frame->StrideY();
+  raw_images_[0].stride[VPX_PLANE_U] = frame->StrideU();
+  raw_images_[0].stride[VPX_PLANE_V] = frame->StrideV();
+
+  for (size_t i = 1; i < encoders_.size(); ++i) {
+    // Scale the image down a number of times by downsampling factor
+    libyuv::I420Scale(
+        raw_images_[i - 1].planes[VPX_PLANE_Y],
+        raw_images_[i - 1].stride[VPX_PLANE_Y],
+        raw_images_[i - 1].planes[VPX_PLANE_U],
+        raw_images_[i - 1].stride[VPX_PLANE_U],
+        raw_images_[i - 1].planes[VPX_PLANE_V],
+        raw_images_[i - 1].stride[VPX_PLANE_V], raw_images_[i - 1].d_w,
+        raw_images_[i - 1].d_h, raw_images_[i].planes[VPX_PLANE_Y],
+        raw_images_[i].stride[VPX_PLANE_Y], raw_images_[i].planes[VPX_PLANE_U],
+        raw_images_[i].stride[VPX_PLANE_U], raw_images_[i].planes[VPX_PLANE_V],
+        raw_images_[i].stride[VPX_PLANE_V], raw_images_[i].d_w,
+        raw_images_[i].d_h, libyuv::kFilterBilinear);
+  }
+}
+
+void LibvpxVp8Encoder::PrepareNV12Image(const NV12BufferInterface* frame) {
+  RTC_DCHECK(!raw_images_.empty());
+  MaybeUpdatePixelFormat(VPX_IMG_FMT_NV12);
+  // Image in vpx_image_t format.
+  // Input image is const. VP8's raw image is not defined as const.
+  raw_images_[0].planes[VPX_PLANE_Y] = const_cast<uint8_t*>(frame->DataY());
+  raw_images_[0].planes[VPX_PLANE_U] = const_cast<uint8_t*>(frame->DataUV());
+  raw_images_[0].planes[VPX_PLANE_V] = raw_images_[0].planes[VPX_PLANE_U] + 1;
+  raw_images_[0].stride[VPX_PLANE_Y] = frame->StrideY();
+  raw_images_[0].stride[VPX_PLANE_U] = frame->StrideUV();
+  raw_images_[0].stride[VPX_PLANE_V] = frame->StrideUV();
+
+  for (size_t i = 1; i < encoders_.size(); ++i) {
+    // Scale the image down a number of times by downsampling factor
+    libyuv::NV12Scale(
+        raw_images_[i - 1].planes[VPX_PLANE_Y],
+        raw_images_[i - 1].stride[VPX_PLANE_Y],
+        raw_images_[i - 1].planes[VPX_PLANE_U],
+        raw_images_[i - 1].stride[VPX_PLANE_U], raw_images_[i - 1].d_w,
+        raw_images_[i - 1].d_h, raw_images_[i].planes[VPX_PLANE_Y],
+        raw_images_[i].stride[VPX_PLANE_Y], raw_images_[i].planes[VPX_PLANE_U],
+        raw_images_[i].stride[VPX_PLANE_U], raw_images_[i].d_w,
+        raw_images_[i].d_h, libyuv::kFilterBilinear);
+    raw_images_[i].planes[VPX_PLANE_V] = raw_images_[i].planes[VPX_PLANE_U] + 1;
+    raw_images_[i].stride[VPX_PLANE_V] = raw_images_[i].stride[VPX_PLANE_U] + 1;
+  }
+}
+
 // static
 LibvpxVp8Encoder::VariableFramerateExperiment
 LibvpxVp8Encoder::ParseVariableFramerateConfig(std::string group_name) {
diff --git a/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.h b/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.h
index 731a9a0..338dd40 100644
--- a/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.h
+++ b/modules/video_coding/codecs/vp8/libvpx_vp8_encoder.h
@@ -93,6 +93,12 @@
 
   bool UpdateVpxConfiguration(size_t stream_index);
 
+  void PrepareRawImagesForEncoding(
+      const rtc::scoped_refptr<VideoFrameBuffer>& frame);
+  void MaybeUpdatePixelFormat(vpx_img_fmt fmt);
+  void PrepareI420Image(const I420BufferInterface* frame);
+  void PrepareNV12Image(const NV12BufferInterface* frame);
+
   const std::unique_ptr<LibvpxInterface> libvpx_;
 
   const CpuSpeedExperiment experimental_cpu_speed_config_arm_;
diff --git a/modules/video_coding/codecs/vp8/test/vp8_impl_unittest.cc b/modules/video_coding/codecs/vp8/test/vp8_impl_unittest.cc
index 342187b..4779572 100644
--- a/modules/video_coding/codecs/vp8/test/vp8_impl_unittest.cc
+++ b/modules/video_coding/codecs/vp8/test/vp8_impl_unittest.cc
@@ -266,6 +266,44 @@
             encoder_->Encode(NextInputFrame(), nullptr));
 }
 
+TEST_F(TestVp8Impl, EncodeNv12FrameSimulcast) {
+  EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, encoder_->Release());
+  EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK,
+            encoder_->InitEncode(&codec_settings_, kSettings));
+
+  EncodedImage encoded_frame;
+  CodecSpecificInfo codec_specific_info;
+  input_frame_generator_ = test::CreateSquareFrameGenerator(
+      kWidth, kHeight, test::FrameGeneratorInterface::OutputType::kNV12,
+      absl::nullopt);
+  EncodeAndWaitForFrame(NextInputFrame(), &encoded_frame, &codec_specific_info);
+
+  EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, encoder_->Release());
+  EXPECT_EQ(WEBRTC_VIDEO_CODEC_UNINITIALIZED,
+            encoder_->Encode(NextInputFrame(), nullptr));
+}
+
+TEST_F(TestVp8Impl, EncodeI420FrameAfterNv12Frame) {
+  EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, encoder_->Release());
+  EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK,
+            encoder_->InitEncode(&codec_settings_, kSettings));
+
+  EncodedImage encoded_frame;
+  CodecSpecificInfo codec_specific_info;
+  input_frame_generator_ = test::CreateSquareFrameGenerator(
+      kWidth, kHeight, test::FrameGeneratorInterface::OutputType::kNV12,
+      absl::nullopt);
+  EncodeAndWaitForFrame(NextInputFrame(), &encoded_frame, &codec_specific_info);
+  input_frame_generator_ = test::CreateSquareFrameGenerator(
+      kWidth, kHeight, test::FrameGeneratorInterface::OutputType::kI420,
+      absl::nullopt);
+  EncodeAndWaitForFrame(NextInputFrame(), &encoded_frame, &codec_specific_info);
+
+  EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, encoder_->Release());
+  EXPECT_EQ(WEBRTC_VIDEO_CODEC_UNINITIALIZED,
+            encoder_->Encode(NextInputFrame(), nullptr));
+}
+
 TEST_F(TestVp8Impl, InitDecode) {
   EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK, decoder_->Release());
   EXPECT_EQ(WEBRTC_VIDEO_CODEC_OK,