[WGC] Wait longer for the first frame of a stream

There seems to have circumstances where the time to obtain the first
frame of a frame pool takes longer than expected, which results in
screen capture failures due to too many "kFrameDropped" errors. This CL
leverages the FrameArrived event to avoid polling the frame pool until
the first frame is available.

The new histograms will be added in
https://chromium-review.googlesource.com/c/chromium/src/+/6821716

Bug: chromium:433569403
Change-Id: Ib6f77331508772412a84eb5c63399e716b7d0049
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/402945
Reviewed-by: Ilya Nikolaevskiy <ilnik@webrtc.org>
Commit-Queue: Gabriel Brito <gabrielbrito@microsoft.com>
Reviewed-by: Alexander Cooper <alcooper@chromium.org>
Cr-Commit-Position: refs/heads/main@{#45336}
diff --git a/modules/desktop_capture/BUILD.gn b/modules/desktop_capture/BUILD.gn
index 1ba95f7..a3110d1 100644
--- a/modules/desktop_capture/BUILD.gn
+++ b/modules/desktop_capture/BUILD.gn
@@ -592,6 +592,7 @@
     ]
     libs += [ "dwmapi.lib" ]
     deps += [
+      "../../api/units:time_delta",
       "../../rtc_base:rtc_event",
       "../../rtc_base:threading",
       "../../rtc_base/win:hstring",
diff --git a/modules/desktop_capture/win/wgc_capture_session.cc b/modules/desktop_capture/win/wgc_capture_session.cc
index dc744d4..01ec96b 100644
--- a/modules/desktop_capture/win/wgc_capture_session.cc
+++ b/modules/desktop_capture/win/wgc_capture_session.cc
@@ -8,8 +8,6 @@
  *  be found in the AUTHORS file in the root of the source tree.
  */
 
-#include "modules/desktop_capture/win/wgc_capture_session.h"
-
 #include <DispatcherQueue.h>
 #include <windows.graphics.capture.interop.h>
 #include <windows.graphics.directX.direct3d11.interop.h>
@@ -21,15 +19,19 @@
 #include <memory>
 #include <utility>
 
+#include "api/make_ref_counted.h"
 #include "api/sequence_checker.h"
+#include "api/units/time_delta.h"
 #include "modules/desktop_capture/desktop_capture_options.h"
 #include "modules/desktop_capture/desktop_frame.h"
 #include "modules/desktop_capture/desktop_geometry.h"
 #include "modules/desktop_capture/shared_desktop_frame.h"
 #include "modules/desktop_capture/win/screen_capture_utils.h"
+#include "modules/desktop_capture/win/wgc_capture_session.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
 #include "rtc_base/thread.h"
+#include "rtc_base/time_utils.h"
 #include "rtc_base/win/create_direct3d_device.h"
 #include "rtc_base/win/get_activation_factory.h"
 #include "rtc_base/win/windows_version.h"
@@ -46,6 +48,10 @@
 constexpr auto kPixelFormat = ABI::Windows::Graphics::DirectX::
     DirectXPixelFormat::DirectXPixelFormat_B8G8R8A8UIntNormalized;
 
+// We must wait a little longer for the first frame to avoid failing the
+// capture when there is a longer startup time.
+constexpr int kFirstFrameTimeoutMs = 5000;
+
 // These values are persisted to logs. Entries should not be renumbered and
 // numeric values should never be reused.
 enum class StartCaptureResult {
@@ -56,7 +62,7 @@
   kD3dDelayLoadFailed = 4,
   kD3dDeviceCreationFailed = 5,
   kFramePoolActivationFailed = 6,
-  // kFramePoolCastFailed = 7, (deprecated)
+  kFramePoolCastFailed = 7,
   // kGetItemSizeFailed = 8, (deprecated)
   kCreateFramePoolFailed = 9,
   kCreateCaptureSessionFailed = 10,
@@ -80,7 +86,17 @@
   kResizeMappedTextureFailed = 10,
   kRecreateFramePoolFailed = 11,
   kFramePoolEmpty = 12,
-  kMaxValue = kFramePoolEmpty
+  kWaitForFirstFrameFailed = 13,
+  kMaxValue = kWaitForFirstFrameFailed
+};
+
+enum class WaitForFirstFrameResult {
+  kSuccess = 0,
+  kTryGetNextFrameFailed = 1,
+  kAddFrameArrivedCallbackFailed = 2,
+  kWaitingTimedOut = 3,
+  kRemoveFrameArrivedCallbackFailed = 4,
+  kMaxValue = kRemoveFrameArrivedCallbackFailed
 };
 
 void RecordStartCaptureResult(StartCaptureResult error) {
@@ -95,6 +111,19 @@
       static_cast<int>(error), static_cast<int>(GetFrameResult::kMaxValue));
 }
 
+void RecordGetFirstFrameTime(int64_t elapsed_time_ms) {
+  RTC_HISTOGRAM_COUNTS(
+      "WebRTC.DesktopCapture.Win.WgcCaptureSessionTimeToFirstFrame",
+      elapsed_time_ms, /*min=*/1, /*max=*/5000, /*bucket_count=*/100);
+}
+
+void RecordWaitForFirstFrameResult(WaitForFirstFrameResult error) {
+  RTC_HISTOGRAM_ENUMERATION(
+      "WebRTC.DesktopCapture.Win.WgcCaptureSessionWaitForFirstFrameResult",
+      static_cast<int>(error),
+      static_cast<int>(WaitForFirstFrameResult::kMaxValue));
+}
+
 bool SizeHasChanged(ABI::Windows::Graphics::SizeInt32 size_new,
                     ABI::Windows::Graphics::SizeInt32 size_old) {
   return (size_new.Height != size_old.Height ||
@@ -107,6 +136,23 @@
 
 }  // namespace
 
+WgcCaptureSession::RefCountedEvent::RefCountedEvent(bool manual_reset,
+                                                    bool initially_signaled)
+    : Event(manual_reset, initially_signaled) {}
+
+WgcCaptureSession::RefCountedEvent::~RefCountedEvent() = default;
+
+WgcCaptureSession::AgileFrameArrivedHandler::AgileFrameArrivedHandler(
+    scoped_refptr<RefCountedEvent> event)
+    : frame_arrived_event_(event) {}
+
+IFACEMETHODIMP WgcCaptureSession::AgileFrameArrivedHandler::Invoke(
+    ABI::Windows::Graphics::Capture::IDirect3D11CaptureFramePool* sender,
+    IInspectable* args) {
+  frame_arrived_event_->Set();
+  return S_OK;
+}
+
 WgcCaptureSession::WgcCaptureSession(intptr_t source_id,
                                      ComPtr<ID3D11Device> d3d11_device,
                                      ComPtr<WGC::IGraphicsCaptureItem> item,
@@ -119,7 +165,7 @@
 }
 
 WgcCaptureSession::~WgcCaptureSession() {
-  RemoveEventHandler();
+  RemoveEventHandlers();
 }
 
 HRESULT WgcCaptureSession::StartCapture(const DesktopCaptureOptions& options) {
@@ -178,8 +224,20 @@
     return hr;
   }
 
-  hr = frame_pool_statics->Create(direct3d_device_.Get(), kPixelFormat,
-                                  kNumBuffers, size_, &frame_pool_);
+  // Cast to FramePoolStatics2 so we can use CreateFreeThreaded and avoid the
+  // need to have a DispatcherQueue. Sometimes, the time to obtain the first
+  // frame ever in a stream can take longer. To avoid timeouts,
+  // CreateFreeThreaded is needed so that the frame processing done by WGC can
+  // happen on a different thread while the main thread is waiting for it.
+  ComPtr<WGC::IDirect3D11CaptureFramePoolStatics2> frame_pool_statics2;
+  hr = frame_pool_statics->QueryInterface(IID_PPV_ARGS(&frame_pool_statics2));
+  if (FAILED(hr)) {
+    RecordStartCaptureResult(StartCaptureResult::kFramePoolCastFailed);
+    return hr;
+  }
+
+  hr = frame_pool_statics2->CreateFreeThreaded(
+      direct3d_device_.Get(), kPixelFormat, kNumBuffers, size_, &frame_pool_);
   if (FAILED(hr)) {
     RecordStartCaptureResult(StartCaptureResult::kCreateFramePoolFailed);
     return hr;
@@ -227,7 +285,58 @@
   return hr;
 }
 
+bool WgcCaptureSession::WaitForFirstFrame() {
+  RTC_CHECK(!has_first_frame_arrived_);
+
+  ComPtr<WGC::IDirect3D11CaptureFrame> capture_frame = nullptr;
+  // Flush the `frame_pool_` buffers so that we can receive the most recent
+  // frames.
+  for (int i = 0; i < kNumBuffers; ++i) {
+    HRESULT hr = frame_pool_->TryGetNextFrame(&capture_frame);
+    if (FAILED(hr)) {
+      RTC_LOG(LS_ERROR) << "TryGetNextFrame failed: " << hr;
+      RecordWaitForFirstFrameResult(
+          WaitForFirstFrameResult::kTryGetNextFrameFailed);
+      return false;
+    }
+  }
+
+  if (FAILED(AddFrameArrivedEventHandler())) {
+    RecordWaitForFirstFrameResult(
+        WaitForFirstFrameResult::kAddFrameArrivedCallbackFailed);
+    return false;
+  }
+
+  RTC_CHECK(has_first_frame_arrived_event_);
+  int64_t first_frame_event_wait_start = TimeMillis();
+  // Only start the frame polling once the first frame becomes available.
+  if (!has_first_frame_arrived_event_->Wait(
+          TimeDelta::Millis(kFirstFrameTimeoutMs))) {
+    RecordGetFirstFrameTime(kFirstFrameTimeoutMs);
+    RecordWaitForFirstFrameResult(WaitForFirstFrameResult::kWaitingTimedOut);
+    RTC_LOG(LS_ERROR) << "Timed out after waiting " << kFirstFrameTimeoutMs
+                      << " ms for the first frame.";
+    return false;
+  }
+
+  RecordGetFirstFrameTime(TimeMillis() - first_frame_event_wait_start);
+  RecordWaitForFirstFrameResult(WaitForFirstFrameResult::kSuccess);
+  has_first_frame_arrived_ = true;
+  RemoveFrameArrivedEventHandler();
+  return true;
+}
+
 void WgcCaptureSession::EnsureFrame() {
+  // We need to wait for the first frame because it might take some extra time
+  // for the `frame_pool_` to be populated and capture may fail because of too
+  // many `kFrameDropped` errors.
+  if (!has_first_frame_arrived_) {
+    if (!WaitForFirstFrame()) {
+      RecordGetFrameResult(GetFrameResult::kWaitForFirstFrameFailed);
+      return;
+    }
+  }
+
   // Try to process the captured frame and copy it to the `queue_`.
   HRESULT hr = ProcessFrame();
   if (SUCCEEDED(hr)) {
@@ -283,8 +392,16 @@
   // if we know that the source will not be capturable. This can happen e.g.
   // when captured window is minimized and if EnsureFrame() was called in this
   // state a large amount of kFrameDropped errors would be logged.
-  if (source_should_be_capturable)
+  if (source_should_be_capturable) {
     EnsureFrame();
+  } else {
+    // If the source is not capturable, we must reset `has_first_frame_arrived_`
+    // so that the next time the source becomes capturable we can wait for the
+    // first frame again.
+    if (has_first_frame_arrived_) {
+      has_first_frame_arrived_ = false;
+    }
+  }
 
   // Return a NULL frame and false as `result` if we still don't have a valid
   // frame. This will lead to a DesktopCapturer::Result::ERROR_PERMANENT being
@@ -583,7 +700,7 @@
         // Mark resized frames as damaged.
         damage_region_.SetRect(DesktopRect::MakeSize(current_frame->size()));
       }
-    } else{
+    } else {
       // Mark a `damage_region_` even if there is no previous frame. This
       // condition does not create any increased overhead but is useful while
       // using FullScreenWindowDetector, where it would create a new
@@ -607,7 +724,7 @@
   RTC_LOG(LS_INFO) << "Capture target has been closed.";
   item_closed_ = true;
 
-  RemoveEventHandler();
+  RemoveItemClosedEventHandler();
 
   // Do not attempt to free resources in the OnItemClosed handler, as this
   // causes a race where we try to delete the item that is calling us. Removing
@@ -617,16 +734,53 @@
   return S_OK;
 }
 
-void WgcCaptureSession::RemoveEventHandler() {
+void WgcCaptureSession::RemoveEventHandlers() {
+  RemoveItemClosedEventHandler();
+  RemoveFrameArrivedEventHandler();
+}
+
+void WgcCaptureSession::RemoveItemClosedEventHandler() {
   HRESULT hr;
   if (item_ && item_closed_token_) {
     hr = item_->remove_Closed(*item_closed_token_);
     item_closed_token_.reset();
-    if (FAILED(hr))
+    if (FAILED(hr)) {
       RTC_LOG(LS_WARNING) << "Failed to remove Closed event handler: " << hr;
+    }
   }
 }
 
+void WgcCaptureSession::RemoveFrameArrivedEventHandler() {
+  RTC_DCHECK(frame_pool_);
+  if (frame_arrived_token_) {
+    HRESULT hr = frame_pool_->remove_FrameArrived(*frame_arrived_token_);
+    frame_arrived_token_.reset();
+    has_first_frame_arrived_event_ = nullptr;
+    if (FAILED(hr)) {
+      RTC_LOG(LS_WARNING) << "Failed to remove FrameArrived event handler: "
+                          << hr;
+    }
+  }
+}
+
+HRESULT WgcCaptureSession::AddFrameArrivedEventHandler() {
+  RTC_DCHECK(frame_pool_);
+  HRESULT hr = E_FAIL;
+  frame_arrived_token_ = std::make_unique<EventRegistrationToken>();
+  has_first_frame_arrived_event_ = make_ref_counted<RefCountedEvent>(
+      /*manual_reset=*/true, /*initially_signaled=*/false);
+  auto frame_arrived_handler = Microsoft::WRL::Make<AgileFrameArrivedHandler>(
+      has_first_frame_arrived_event_);
+  hr = frame_pool_->add_FrameArrived(frame_arrived_handler.Get(),
+                                     frame_arrived_token_.get());
+  if (FAILED(hr)) {
+    RTC_LOG(LS_WARNING) << "Failed to add FrameArrived event handler: " << hr;
+    frame_arrived_token_.reset();
+    has_first_frame_arrived_event_ = nullptr;
+  }
+  return hr;
+}
+
 bool WgcCaptureSession::FrameContentCanBeCompared() {
   DesktopFrame* current_frame = queue_.current_frame();
   DesktopFrame* previous_frame = queue_.previous_frame();
diff --git a/modules/desktop_capture/win/wgc_capture_session.h b/modules/desktop_capture/win/wgc_capture_session.h
index b62c352..0e5fbd3 100644
--- a/modules/desktop_capture/win/wgc_capture_session.h
+++ b/modules/desktop_capture/win/wgc_capture_session.h
@@ -16,17 +16,20 @@
 #include <windows.graphics.capture.h>
 #include <windows.graphics.h>
 #include <wrl/client.h>
+#include <wrl/implements.h>
 
 #include <cstdint>
 #include <memory>
 #include <optional>
 
+#include "api/scoped_refptr.h"
 #include "api/sequence_checker.h"
 #include "modules/desktop_capture/desktop_capture_options.h"
 #include "modules/desktop_capture/desktop_frame.h"
 #include "modules/desktop_capture/desktop_region.h"
 #include "modules/desktop_capture/screen_capture_frame_queue.h"
 #include "modules/desktop_capture/shared_desktop_frame.h"
+#include "rtc_base/event.h"
 
 namespace webrtc {
 
@@ -68,6 +71,41 @@
   static constexpr int kNumBuffers = 2;
 
  private:
+  class RefCountedEvent : public RefCountedNonVirtual<RefCountedEvent>,
+                          public Event {
+   public:
+    RefCountedEvent(bool manual_reset, bool initially_signaled);
+
+   private:
+    friend class RefCountedNonVirtual<RefCountedEvent>;
+    ~RefCountedEvent();
+  };
+
+  // Handles the arrival of new frames in the Direct3D11CaptureFramePool.
+  // Whenever `Direct3D11CaptureFramePool.FrameArrived` is called,
+  // `AgileFrameArrivedHandler::Invoke` will also be called. This class needs to
+  // implement the IAgileObject interface so that we can create a WGC frame pool
+  // with `Direct3D11CaptureFramePool::CreateFreeThreaded` and be able to call
+  // `Invoke` on a thread different from the one that created this class'
+  // instance. See more:
+  class AgileFrameArrivedHandler
+      : public Microsoft::WRL::RuntimeClass<
+            Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
+            ABI::Windows::Foundation::ITypedEventHandler<
+                ABI::Windows::Graphics::Capture::Direct3D11CaptureFramePool*,
+                IInspectable*>,
+            IAgileObject> {
+   public:
+    AgileFrameArrivedHandler(scoped_refptr<RefCountedEvent> event);
+
+    IFACEMETHODIMP Invoke(
+        ABI::Windows::Graphics::Capture::IDirect3D11CaptureFramePool* sender,
+        IInspectable* args) override;
+
+   private:
+    scoped_refptr<RefCountedEvent> frame_arrived_event_;
+  };
+
   // Initializes `mapped_texture_` with the properties of the `src_texture`,
   // overrides the values of some necessary properties like the
   // D3D11_CPU_ACCESS_READ flag. Also has optional parameters for what size
@@ -83,6 +121,13 @@
       ABI::Windows::Graphics::Capture::IGraphicsCaptureItem* sender,
       IInspectable* event_args);
 
+  // Waits for the first frame to arrive in the `frame_pool_`. We should wait
+  // for a frame if either this is the first frame ever obtained from the
+  // `frame_pool_` or if this is the first frame obtained after a capture
+  // interruption - e.g. when a captured window is brought back after being
+  // minimized.
+  bool WaitForFirstFrame();
+
   // Wraps calls to ProcessFrame and deals with the uniqe start-up phase
   // ensuring that we always have one captured frame available.
   void EnsureFrame();
@@ -90,13 +135,17 @@
   // Process the captured frame and copy it to the `queue_`.
   HRESULT ProcessFrame();
 
-  void RemoveEventHandler();
+  void RemoveEventHandlers();
+  void RemoveItemClosedEventHandler();
+  void RemoveFrameArrivedEventHandler();
+  HRESULT AddFrameArrivedEventHandler();
 
   bool FrameContentCanBeCompared();
 
   bool allow_zero_hertz() const { return allow_zero_hertz_; }
 
   std::unique_ptr<EventRegistrationToken> item_closed_token_;
+  std::unique_ptr<EventRegistrationToken> frame_arrived_token_;
 
   // A Direct3D11 Device provided by the caller. We use this to create an
   // IDirect3DDevice, and also to create textures that will hold the image data.
@@ -166,6 +215,24 @@
   // screen.
   bool is_window_source_;
 
+  // To be shared between `WgcCaptureSession` and `AgileFrameHandler`.
+  // AgileFrameHandler will set this event in a WGC working thread and
+  // `WgcCaptureSession` will check its state in desktopCaptureThread. This is
+  // necessary to avoid race conditions where the desktopCaptureThread preempts
+  // the WGC worker thread and destroys the `WgcCaptureSession` while a new
+  // frame is being processed In this situation, the `AgileFrameHandler` would
+  // end accessing invalid memory, which was previously owned by
+  // `WgcCaptureSession`.
+  //
+  // Will be signaled when the first frame is available in the `frame_pool_` and
+  // should not reset for the lifetime of `WgcCaptureSession`.
+  scoped_refptr<RefCountedEvent> has_first_frame_arrived_event_;
+
+  // Records if the first frame arrived in a stream arrived. Will be reset if a
+  // source becomes momentarilly non-capturable - e.g. a window that gets
+  // minimized.
+  bool has_first_frame_arrived_ = false;
+
   SequenceChecker sequence_checker_;
 };