Update h264 sps parsers and sps vui rewriter to use BitstreamReader

The new version is subjectivly cleaner
and objectively generates smaller binary size

Bug: None
Change-Id: I8d845f56f13dbc7d34e4d685f735a448c5fe8f06
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/232001
Reviewed-by: Sergey Silkin <ssilkin@webrtc.org>
Reviewed-by: Erik Språng <sprang@webrtc.org>
Commit-Queue: Danil Chapovalov <danilchap@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35011}
diff --git a/common_video/h264/sps_parser.cc b/common_video/h264/sps_parser.cc
index f505928..cfb0f24 100644
--- a/common_video/h264/sps_parser.cc
+++ b/common_video/h264/sps_parser.cc
@@ -14,16 +14,9 @@
 #include <vector>
 
 #include "common_video/h264/h264_common.h"
-#include "rtc_base/bit_buffer.h"
+#include "rtc_base/bitstream_reader.h"
 
 namespace {
-typedef absl::optional<webrtc::SpsParser::SpsState> OptionalSps;
-
-#define RETURN_EMPTY_ON_FAIL(x) \
-  if (!(x)) {                   \
-    return OptionalSps();       \
-  }
-
 constexpr int kScalingDeltaMin = -128;
 constexpr int kScaldingDeltaMax = 127;
 }  // namespace
@@ -42,13 +35,13 @@
 absl::optional<SpsParser::SpsState> SpsParser::ParseSps(const uint8_t* data,
                                                         size_t length) {
   std::vector<uint8_t> unpacked_buffer = H264::ParseRbsp(data, length);
-  rtc::BitBuffer bit_buffer(unpacked_buffer.data(), unpacked_buffer.size());
-  return ParseSpsUpToVui(&bit_buffer);
+  BitstreamReader reader(unpacked_buffer);
+  return ParseSpsUpToVui(reader);
 }
 
 absl::optional<SpsParser::SpsState> SpsParser::ParseSpsUpToVui(
-    rtc::BitBuffer* buffer) {
-  // Now, we need to use a bit buffer to parse through the actual AVC SPS
+    BitstreamReader& reader) {
+  // Now, we need to use a bitstream reader to parse through the actual AVC SPS
   // format. See Section 7.3.2.1.1 ("Sequence parameter set data syntax") of the
   // H.264 standard for a complete description.
   // Since we only care about resolution, we ignore the majority of fields, but
@@ -61,24 +54,18 @@
 
   SpsState sps;
 
-  // The golomb values we have to read, not just consume.
-  uint32_t golomb_ignored;
-
   // chroma_format_idc will be ChromaArrayType if separate_colour_plane_flag is
   // 0. It defaults to 1, when not specified.
   uint32_t chroma_format_idc = 1;
 
   // profile_idc: u(8). We need it to determine if we need to read/skip chroma
   // formats.
-  uint8_t profile_idc;
-  RETURN_EMPTY_ON_FAIL(buffer->ReadUInt8(profile_idc));
+  uint8_t profile_idc = reader.Read<uint8_t>();
   // constraint_set0_flag through constraint_set5_flag + reserved_zero_2bits
-  // 1 bit each for the flags + 2 bits = 8 bits = 1 byte.
-  RETURN_EMPTY_ON_FAIL(buffer->ConsumeBytes(1));
-  // level_idc: u(8)
-  RETURN_EMPTY_ON_FAIL(buffer->ConsumeBytes(1));
+  // 1 bit each for the flags + 2 bits + 8 bits for level_idc = 16 bits.
+  reader.ConsumeBits(16);
   // seq_parameter_set_id: ue(v)
-  RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(sps.id));
+  sps.id = reader.ReadExponentialGolomb();
   sps.separate_colour_plane_flag = 0;
   // See if profile_idc has chroma format information.
   if (profile_idc == 100 || profile_idc == 110 || profile_idc == 122 ||
@@ -86,42 +73,37 @@
       profile_idc == 86 || profile_idc == 118 || profile_idc == 128 ||
       profile_idc == 138 || profile_idc == 139 || profile_idc == 134) {
     // chroma_format_idc: ue(v)
-    RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(chroma_format_idc));
+    chroma_format_idc = reader.ReadExponentialGolomb();
     if (chroma_format_idc == 3) {
       // separate_colour_plane_flag: u(1)
-      RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, sps.separate_colour_plane_flag));
+      sps.separate_colour_plane_flag = reader.ReadBit();
     }
     // bit_depth_luma_minus8: ue(v)
-    RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored));
+    reader.ReadExponentialGolomb();
     // bit_depth_chroma_minus8: ue(v)
-    RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored));
+    reader.ReadExponentialGolomb();
     // qpprime_y_zero_transform_bypass_flag: u(1)
-    RETURN_EMPTY_ON_FAIL(buffer->ConsumeBits(1));
+    reader.ConsumeBits(1);
     // seq_scaling_matrix_present_flag: u(1)
-    uint32_t seq_scaling_matrix_present_flag;
-    RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, seq_scaling_matrix_present_flag));
-    if (seq_scaling_matrix_present_flag) {
+    if (reader.Read<bool>()) {
       // Process the scaling lists just enough to be able to properly
       // skip over them, so we can still read the resolution on streams
       // where this is included.
       int scaling_list_count = (chroma_format_idc == 3 ? 12 : 8);
       for (int i = 0; i < scaling_list_count; ++i) {
         // seq_scaling_list_present_flag[i]  : u(1)
-        uint32_t seq_scaling_list_present_flags;
-        RETURN_EMPTY_ON_FAIL(
-            buffer->ReadBits(1, seq_scaling_list_present_flags));
-        if (seq_scaling_list_present_flags != 0) {
+        if (reader.Read<bool>()) {
           int last_scale = 8;
           int next_scale = 8;
           int size_of_scaling_list = i < 6 ? 16 : 64;
           for (int j = 0; j < size_of_scaling_list; j++) {
             if (next_scale != 0) {
-              int32_t delta_scale;
               // delta_scale: se(v)
-              RETURN_EMPTY_ON_FAIL(
-                  buffer->ReadSignedExponentialGolomb(delta_scale));
-              RETURN_EMPTY_ON_FAIL(delta_scale >= kScalingDeltaMin &&
-                                   delta_scale <= kScaldingDeltaMax);
+              int delta_scale = reader.ReadSignedExponentialGolomb();
+              if (!reader.Ok() || delta_scale < kScalingDeltaMin ||
+                  delta_scale > kScaldingDeltaMax) {
+                return absl::nullopt;
+              }
               next_scale = (last_scale + delta_scale + 256) % 256;
             }
             if (next_scale != 0)
@@ -132,50 +114,49 @@
     }
   }
   // log2_max_frame_num and log2_max_pic_order_cnt_lsb are used with
-  // BitBuffer::ReadBits, which can read at most 32 bits at a time. We also have
-  // to avoid overflow when adding 4 to the on-wire golomb value, e.g., for evil
-  // input data, ReadExponentialGolomb might return 0xfffc.
+  // BitstreamReader::ReadBits, which can read at most 64 bits at a time. We
+  // also have to avoid overflow when adding 4 to the on-wire golomb value,
+  // e.g., for evil input data, ReadExponentialGolomb might return 0xfffc.
   const uint32_t kMaxLog2Minus4 = 32 - 4;
 
   // log2_max_frame_num_minus4: ue(v)
-  uint32_t log2_max_frame_num_minus4;
-  if (!buffer->ReadExponentialGolomb(log2_max_frame_num_minus4) ||
-      log2_max_frame_num_minus4 > kMaxLog2Minus4) {
-    return OptionalSps();
+  uint32_t log2_max_frame_num_minus4 = reader.ReadExponentialGolomb();
+  if (!reader.Ok() || log2_max_frame_num_minus4 > kMaxLog2Minus4) {
+    return absl::nullopt;
   }
   sps.log2_max_frame_num = log2_max_frame_num_minus4 + 4;
 
   // pic_order_cnt_type: ue(v)
-  RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(sps.pic_order_cnt_type));
+  sps.pic_order_cnt_type = reader.ReadExponentialGolomb();
   if (sps.pic_order_cnt_type == 0) {
     // log2_max_pic_order_cnt_lsb_minus4: ue(v)
-    uint32_t log2_max_pic_order_cnt_lsb_minus4;
-    if (!buffer->ReadExponentialGolomb(log2_max_pic_order_cnt_lsb_minus4) ||
-        log2_max_pic_order_cnt_lsb_minus4 > kMaxLog2Minus4) {
-      return OptionalSps();
+    uint32_t log2_max_pic_order_cnt_lsb_minus4 = reader.ReadExponentialGolomb();
+    if (!reader.Ok() || log2_max_pic_order_cnt_lsb_minus4 > kMaxLog2Minus4) {
+      return absl::nullopt;
     }
     sps.log2_max_pic_order_cnt_lsb = log2_max_pic_order_cnt_lsb_minus4 + 4;
   } else if (sps.pic_order_cnt_type == 1) {
     // delta_pic_order_always_zero_flag: u(1)
-    RETURN_EMPTY_ON_FAIL(
-        buffer->ReadBits(1, sps.delta_pic_order_always_zero_flag));
+    sps.delta_pic_order_always_zero_flag = reader.ReadBit();
     // offset_for_non_ref_pic: se(v)
-    RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored));
+    reader.ReadExponentialGolomb();
     // offset_for_top_to_bottom_field: se(v)
-    RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored));
+    reader.ReadExponentialGolomb();
     // num_ref_frames_in_pic_order_cnt_cycle: ue(v)
-    uint32_t num_ref_frames_in_pic_order_cnt_cycle;
-    RETURN_EMPTY_ON_FAIL(
-        buffer->ReadExponentialGolomb(num_ref_frames_in_pic_order_cnt_cycle));
+    uint32_t num_ref_frames_in_pic_order_cnt_cycle =
+        reader.ReadExponentialGolomb();
     for (size_t i = 0; i < num_ref_frames_in_pic_order_cnt_cycle; ++i) {
       // offset_for_ref_frame[i]: se(v)
-      RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(golomb_ignored));
+      reader.ReadExponentialGolomb();
+      if (!reader.Ok()) {
+        return absl::nullopt;
+      }
     }
   }
   // max_num_ref_frames: ue(v)
-  RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(sps.max_num_ref_frames));
+  sps.max_num_ref_frames = reader.ReadExponentialGolomb();
   // gaps_in_frame_num_value_allowed_flag: u(1)
-  RETURN_EMPTY_ON_FAIL(buffer->ConsumeBits(1));
+  reader.ConsumeBits(1);
   //
   // IMPORTANT ONES! Now we're getting to resolution. First we read the pic
   // width/height in macroblocks (16x16), which gives us the base resolution,
@@ -183,48 +164,41 @@
   // to signify resolutions that aren't multiples of 16.
   //
   // pic_width_in_mbs_minus1: ue(v)
-  uint32_t pic_width_in_mbs_minus1;
-  RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(pic_width_in_mbs_minus1));
+  sps.width = 16 * (reader.ReadExponentialGolomb() + 1);
   // pic_height_in_map_units_minus1: ue(v)
-  uint32_t pic_height_in_map_units_minus1;
-  RETURN_EMPTY_ON_FAIL(
-      buffer->ReadExponentialGolomb(pic_height_in_map_units_minus1));
+  uint32_t pic_height_in_map_units_minus1 = reader.ReadExponentialGolomb();
   // frame_mbs_only_flag: u(1)
-  RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, sps.frame_mbs_only_flag));
+  sps.frame_mbs_only_flag = reader.ReadBit();
   if (!sps.frame_mbs_only_flag) {
     // mb_adaptive_frame_field_flag: u(1)
-    RETURN_EMPTY_ON_FAIL(buffer->ConsumeBits(1));
+    reader.ConsumeBits(1);
   }
+  sps.height =
+      16 * (2 - sps.frame_mbs_only_flag) * (pic_height_in_map_units_minus1 + 1);
   // direct_8x8_inference_flag: u(1)
-  RETURN_EMPTY_ON_FAIL(buffer->ConsumeBits(1));
+  reader.ConsumeBits(1);
   //
   // MORE IMPORTANT ONES! Now we're at the frame crop information.
   //
-  // frame_cropping_flag: u(1)
-  uint32_t frame_cropping_flag;
   uint32_t frame_crop_left_offset = 0;
   uint32_t frame_crop_right_offset = 0;
   uint32_t frame_crop_top_offset = 0;
   uint32_t frame_crop_bottom_offset = 0;
-  RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, frame_cropping_flag));
-  if (frame_cropping_flag) {
+  // frame_cropping_flag: u(1)
+  if (reader.Read<bool>()) {
     // frame_crop_{left, right, top, bottom}_offset: ue(v)
-    RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(frame_crop_left_offset));
-    RETURN_EMPTY_ON_FAIL(
-        buffer->ReadExponentialGolomb(frame_crop_right_offset));
-    RETURN_EMPTY_ON_FAIL(buffer->ReadExponentialGolomb(frame_crop_top_offset));
-    RETURN_EMPTY_ON_FAIL(
-        buffer->ReadExponentialGolomb(frame_crop_bottom_offset));
+    frame_crop_left_offset = reader.ReadExponentialGolomb();
+    frame_crop_right_offset = reader.ReadExponentialGolomb();
+    frame_crop_top_offset = reader.ReadExponentialGolomb();
+    frame_crop_bottom_offset = reader.ReadExponentialGolomb();
   }
   // vui_parameters_present_flag: u(1)
-  RETURN_EMPTY_ON_FAIL(buffer->ReadBits(1, sps.vui_params_present));
+  sps.vui_params_present = reader.ReadBit();
 
   // Far enough! We don't use the rest of the SPS.
-
-  // Start with the resolution determined by the pic_width/pic_height fields.
-  sps.width = 16 * (pic_width_in_mbs_minus1 + 1);
-  sps.height =
-      16 * (2 - sps.frame_mbs_only_flag) * (pic_height_in_map_units_minus1 + 1);
+  if (!reader.Ok()) {
+    return absl::nullopt;
+  }
 
   // Figure out the crop units in pixels. That's based on the chroma format's
   // sampling, which is indicated by chroma_format_idc.
@@ -247,7 +221,7 @@
   sps.width -= (frame_crop_left_offset + frame_crop_right_offset);
   sps.height -= (frame_crop_top_offset + frame_crop_bottom_offset);
 
-  return OptionalSps(sps);
+  return sps;
 }
 
 }  // namespace webrtc
diff --git a/common_video/h264/sps_parser.h b/common_video/h264/sps_parser.h
index 76e627d..da328b4 100644
--- a/common_video/h264/sps_parser.h
+++ b/common_video/h264/sps_parser.h
@@ -12,10 +12,7 @@
 #define COMMON_VIDEO_H264_SPS_PARSER_H_
 
 #include "absl/types/optional.h"
-
-namespace rtc {
-class BitBuffer;
-}
+#include "rtc_base/bitstream_reader.h"
 
 namespace webrtc {
 
@@ -46,9 +43,9 @@
   static absl::optional<SpsState> ParseSps(const uint8_t* data, size_t length);
 
  protected:
-  // Parse the SPS state, up till the VUI part, for a bit buffer where RBSP
+  // Parse the SPS state, up till the VUI part, for a buffer where RBSP
   // decoding has already been performed.
-  static absl::optional<SpsState> ParseSpsUpToVui(rtc::BitBuffer* buffer);
+  static absl::optional<SpsState> ParseSpsUpToVui(BitstreamReader& reader);
 };
 
 }  // namespace webrtc
diff --git a/common_video/h264/sps_vui_rewriter.cc b/common_video/h264/sps_vui_rewriter.cc
index 856b012..117e92a 100644
--- a/common_video/h264/sps_vui_rewriter.cc
+++ b/common_video/h264/sps_vui_rewriter.cc
@@ -13,6 +13,7 @@
 
 #include <string.h>
 
+#include <algorithm>
 #include <cstdint>
 #include <vector>
 
@@ -20,9 +21,9 @@
 #include "common_video/h264/h264_common.h"
 #include "common_video/h264/sps_parser.h"
 #include "rtc_base/bit_buffer.h"
+#include "rtc_base/bitstream_reader.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/logging.h"
-#include "rtc_base/numerics/safe_minmax.h"
 #include "system_wrappers/include/metrics.h"
 
 namespace webrtc {
@@ -53,46 +54,55 @@
     }                                                                  \
   } while (0)
 
-#define COPY_UINT8(src, dest, tmp)                   \
-  do {                                               \
-    RETURN_FALSE_ON_FAIL((src)->ReadUInt8(tmp));     \
-    if (dest)                                        \
-      RETURN_FALSE_ON_FAIL((dest)->WriteUInt8(tmp)); \
-  } while (0)
+uint8_t CopyUInt8(BitstreamReader& source, rtc::BitBufferWriter& destination) {
+  uint8_t tmp = source.Read<uint8_t>();
+  if (!destination.WriteUInt8(tmp)) {
+    source.Invalidate();
+  }
+  return tmp;
+}
 
-#define COPY_EXP_GOLOMB(src, dest, tmp)                          \
-  do {                                                           \
-    RETURN_FALSE_ON_FAIL((src)->ReadExponentialGolomb(tmp));     \
-    if (dest)                                                    \
-      RETURN_FALSE_ON_FAIL((dest)->WriteExponentialGolomb(tmp)); \
-  } while (0)
+uint32_t CopyExpGolomb(BitstreamReader& source,
+                       rtc::BitBufferWriter& destination) {
+  uint32_t tmp = source.ReadExponentialGolomb();
+  if (!destination.WriteExponentialGolomb(tmp)) {
+    source.Invalidate();
+  }
+  return tmp;
+}
 
-#define COPY_BITS(src, dest, tmp, bits)                   \
-  do {                                                    \
-    RETURN_FALSE_ON_FAIL((src)->ReadBits(bits, tmp));     \
-    if (dest)                                             \
-      RETURN_FALSE_ON_FAIL((dest)->WriteBits(tmp, bits)); \
-  } while (0)
+uint32_t CopyBits(int bits,
+                  BitstreamReader& source,
+                  rtc::BitBufferWriter& destination) {
+  RTC_DCHECK_GT(bits, 0);
+  RTC_DCHECK_LE(bits, 32);
+  uint64_t tmp = source.ReadBits(bits);
+  if (!destination.WriteBits(tmp, bits)) {
+    source.Invalidate();
+  }
+  return tmp;
+}
 
 bool CopyAndRewriteVui(const SpsParser::SpsState& sps,
-                       rtc::BitBuffer* source,
-                       rtc::BitBufferWriter* destination,
+                       BitstreamReader& source,
+                       rtc::BitBufferWriter& destination,
                        const webrtc::ColorSpace* color_space,
-                       SpsVuiRewriter::ParseResult* out_vui_rewritten);
-bool CopyHrdParameters(rtc::BitBuffer* source,
-                       rtc::BitBufferWriter* destination);
+                       SpsVuiRewriter::ParseResult& out_vui_rewritten);
+
+void CopyHrdParameters(BitstreamReader& source,
+                       rtc::BitBufferWriter& destination);
 bool AddBitstreamRestriction(rtc::BitBufferWriter* destination,
                              uint32_t max_num_ref_frames);
 bool IsDefaultColorSpace(const ColorSpace& color_space);
-bool AddVideoSignalTypeInfo(rtc::BitBufferWriter* destination,
+bool AddVideoSignalTypeInfo(rtc::BitBufferWriter& destination,
                             const ColorSpace& color_space);
 bool CopyOrRewriteVideoSignalTypeInfo(
-    rtc::BitBuffer* source,
-    rtc::BitBufferWriter* destination,
+    BitstreamReader& source,
+    rtc::BitBufferWriter& destination,
     const ColorSpace* color_space,
-    SpsVuiRewriter::ParseResult* out_vui_rewritten);
-bool CopyRemainingBits(rtc::BitBuffer* source,
-                       rtc::BitBufferWriter* destination);
+    SpsVuiRewriter::ParseResult& out_vui_rewritten);
+bool CopyRemainingBits(BitstreamReader& source,
+                       rtc::BitBufferWriter& destination);
 }  // namespace
 
 void SpsVuiRewriter::UpdateStats(ParseResult result, Direction direction) {
@@ -133,23 +143,25 @@
   // Create temporary RBSP decoded buffer of the payload (exlcuding the
   // leading nalu type header byte (the SpsParser uses only the payload).
   std::vector<uint8_t> rbsp_buffer = H264::ParseRbsp(buffer, length);
-  rtc::BitBuffer source_buffer(rbsp_buffer.data(), rbsp_buffer.size());
+  BitstreamReader source_buffer(rbsp_buffer);
   absl::optional<SpsParser::SpsState> sps_state =
-      SpsParser::ParseSpsUpToVui(&source_buffer);
+      SpsParser::ParseSpsUpToVui(source_buffer);
   if (!sps_state)
     return ParseResult::kFailure;
 
   *sps = sps_state;
 
-  // We're going to completely muck up alignment, so we need a BitBuffer to
-  // write with.
+  // We're going to completely muck up alignment, so we need a BitBufferWriter
+  // to write with.
   rtc::Buffer out_buffer(length + kMaxVuiSpsIncrease);
   rtc::BitBufferWriter sps_writer(out_buffer.data(), out_buffer.size());
 
   // Check how far the SpsParser has read, and copy that data in bulk.
-  size_t byte_offset;
-  size_t bit_offset;
-  source_buffer.GetCurrentOffset(&byte_offset, &bit_offset);
+  RTC_DCHECK(source_buffer.Ok());
+  size_t total_bit_offset =
+      rbsp_buffer.size() * 8 - source_buffer.RemainingBitCount();
+  size_t byte_offset = total_bit_offset / 8;
+  size_t bit_offset = total_bit_offset % 8;
   memcpy(out_buffer.data(), rbsp_buffer.data(),
          byte_offset + (bit_offset > 0 ? 1 : 0));  // OK to copy the last bits.
 
@@ -164,8 +176,8 @@
   sps_writer.Seek(byte_offset, bit_offset);
 
   ParseResult vui_updated;
-  if (!CopyAndRewriteVui(*sps_state, &source_buffer, &sps_writer, color_space,
-                         &vui_updated)) {
+  if (!CopyAndRewriteVui(*sps_state, source_buffer, sps_writer, color_space,
+                         vui_updated)) {
     RTC_LOG(LS_ERROR) << "Failed to parse/copy SPS VUI.";
     return ParseResult::kFailure;
   }
@@ -175,7 +187,7 @@
     return vui_updated;
   }
 
-  if (!CopyRemainingBits(&source_buffer, &sps_writer)) {
+  if (!CopyRemainingBits(source_buffer, sps_writer)) {
     RTC_LOG(LS_ERROR) << "Failed to parse/copy SPS VUI.";
     return ParseResult::kFailure;
   }
@@ -271,19 +283,16 @@
 
 namespace {
 bool CopyAndRewriteVui(const SpsParser::SpsState& sps,
-                       rtc::BitBuffer* source,
-                       rtc::BitBufferWriter* destination,
+                       BitstreamReader& source,
+                       rtc::BitBufferWriter& destination,
                        const webrtc::ColorSpace* color_space,
-                       SpsVuiRewriter::ParseResult* out_vui_rewritten) {
-  uint32_t golomb_tmp;
-  uint32_t bits_tmp;
-
-  *out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiOk;
+                       SpsVuiRewriter::ParseResult& out_vui_rewritten) {
+  out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiOk;
 
   //
   // vui_parameters_present_flag: u(1)
   //
-  RETURN_FALSE_ON_FAIL(destination->WriteBits(1, 1));
+  RETURN_FALSE_ON_FAIL(destination.WriteBits(1, 1));
 
   // ********* IMPORTANT! **********
   // Now we're at the VUI, so we want to (1) add it if it isn't present, and
@@ -292,154 +301,140 @@
     // Write a simple VUI with the parameters we want and 0 for all other flags.
 
     // aspect_ratio_info_present_flag, overscan_info_present_flag. Both u(1).
-    RETURN_FALSE_ON_FAIL(destination->WriteBits(0, 2));
+    RETURN_FALSE_ON_FAIL(destination.WriteBits(0, 2));
 
     uint32_t video_signal_type_present_flag =
         (color_space && !IsDefaultColorSpace(*color_space)) ? 1 : 0;
     RETURN_FALSE_ON_FAIL(
-        destination->WriteBits(video_signal_type_present_flag, 1));
+        destination.WriteBits(video_signal_type_present_flag, 1));
     if (video_signal_type_present_flag) {
       RETURN_FALSE_ON_FAIL(AddVideoSignalTypeInfo(destination, *color_space));
     }
     // chroma_loc_info_present_flag, timing_info_present_flag,
     // nal_hrd_parameters_present_flag, vcl_hrd_parameters_present_flag,
     // pic_struct_present_flag, All u(1)
-    RETURN_FALSE_ON_FAIL(destination->WriteBits(0, 5));
+    RETURN_FALSE_ON_FAIL(destination.WriteBits(0, 5));
     // bitstream_restriction_flag: u(1)
-    RETURN_FALSE_ON_FAIL(destination->WriteBits(1, 1));
+    RETURN_FALSE_ON_FAIL(destination.WriteBits(1, 1));
     RETURN_FALSE_ON_FAIL(
-        AddBitstreamRestriction(destination, sps.max_num_ref_frames));
+        AddBitstreamRestriction(&destination, sps.max_num_ref_frames));
 
-    *out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiRewritten;
+    out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiRewritten;
   } else {
     // Parse out the full VUI.
     // aspect_ratio_info_present_flag: u(1)
-    COPY_BITS(source, destination, bits_tmp, 1);
-    if (bits_tmp == 1) {
+    uint32_t aspect_ratio_info_present_flag = CopyBits(1, source, destination);
+    if (aspect_ratio_info_present_flag) {
       // aspect_ratio_idc: u(8)
-      COPY_BITS(source, destination, bits_tmp, 8);
-      if (bits_tmp == 255u) {  // Extended_SAR
+      uint8_t aspect_ratio_idc = CopyUInt8(source, destination);
+      if (aspect_ratio_idc == 255u) {  // Extended_SAR
         // sar_width/sar_height: u(16) each.
-        COPY_BITS(source, destination, bits_tmp, 32);
+        CopyBits(32, source, destination);
       }
     }
     // overscan_info_present_flag: u(1)
-    COPY_BITS(source, destination, bits_tmp, 1);
-    if (bits_tmp == 1) {
+    uint32_t overscan_info_present_flag = CopyBits(1, source, destination);
+    if (overscan_info_present_flag) {
       // overscan_appropriate_flag: u(1)
-      COPY_BITS(source, destination, bits_tmp, 1);
+      CopyBits(1, source, destination);
     }
 
     CopyOrRewriteVideoSignalTypeInfo(source, destination, color_space,
                                      out_vui_rewritten);
 
     // chroma_loc_info_present_flag: u(1)
-    COPY_BITS(source, destination, bits_tmp, 1);
-    if (bits_tmp == 1) {
+    uint32_t chroma_loc_info_present_flag = CopyBits(1, source, destination);
+    if (chroma_loc_info_present_flag == 1) {
       // chroma_sample_loc_type_(top|bottom)_field: ue(v) each.
-      COPY_EXP_GOLOMB(source, destination, golomb_tmp);
-      COPY_EXP_GOLOMB(source, destination, golomb_tmp);
+      CopyExpGolomb(source, destination);
+      CopyExpGolomb(source, destination);
     }
     // timing_info_present_flag: u(1)
-    COPY_BITS(source, destination, bits_tmp, 1);
-    if (bits_tmp == 1) {
+    uint32_t timing_info_present_flag = CopyBits(1, source, destination);
+    if (timing_info_present_flag == 1) {
       // num_units_in_tick, time_scale: u(32) each
-      COPY_BITS(source, destination, bits_tmp, 32);
-      COPY_BITS(source, destination, bits_tmp, 32);
+      CopyBits(32, source, destination);
+      CopyBits(32, source, destination);
       // fixed_frame_rate_flag: u(1)
-      COPY_BITS(source, destination, bits_tmp, 1);
+      CopyBits(1, source, destination);
     }
     // nal_hrd_parameters_present_flag: u(1)
-    uint32_t nal_hrd_parameters_present_flag;
-    COPY_BITS(source, destination, nal_hrd_parameters_present_flag, 1);
+    uint32_t nal_hrd_parameters_present_flag = CopyBits(1, source, destination);
     if (nal_hrd_parameters_present_flag == 1) {
-      RETURN_FALSE_ON_FAIL(CopyHrdParameters(source, destination));
+      CopyHrdParameters(source, destination);
     }
     // vcl_hrd_parameters_present_flag: u(1)
-    uint32_t vcl_hrd_parameters_present_flag;
-    COPY_BITS(source, destination, vcl_hrd_parameters_present_flag, 1);
+    uint32_t vcl_hrd_parameters_present_flag = CopyBits(1, source, destination);
     if (vcl_hrd_parameters_present_flag == 1) {
-      RETURN_FALSE_ON_FAIL(CopyHrdParameters(source, destination));
+      CopyHrdParameters(source, destination);
     }
     if (nal_hrd_parameters_present_flag == 1 ||
         vcl_hrd_parameters_present_flag == 1) {
       // low_delay_hrd_flag: u(1)
-      COPY_BITS(source, destination, bits_tmp, 1);
+      CopyBits(1, source, destination);
     }
     // pic_struct_present_flag: u(1)
-    COPY_BITS(source, destination, bits_tmp, 1);
+    CopyBits(1, source, destination);
 
     // bitstream_restriction_flag: u(1)
-    uint32_t bitstream_restriction_flag;
-    RETURN_FALSE_ON_FAIL(source->ReadBits(1, bitstream_restriction_flag));
-    RETURN_FALSE_ON_FAIL(destination->WriteBits(1, 1));
+    uint32_t bitstream_restriction_flag = source.ReadBit();
+    RETURN_FALSE_ON_FAIL(destination.WriteBits(1, 1));
     if (bitstream_restriction_flag == 0) {
       // We're adding one from scratch.
       RETURN_FALSE_ON_FAIL(
-          AddBitstreamRestriction(destination, sps.max_num_ref_frames));
-      *out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiRewritten;
+          AddBitstreamRestriction(&destination, sps.max_num_ref_frames));
+      out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiRewritten;
     } else {
       // We're replacing.
       // motion_vectors_over_pic_boundaries_flag: u(1)
-      COPY_BITS(source, destination, bits_tmp, 1);
+      CopyBits(1, source, destination);
       // max_bytes_per_pic_denom: ue(v)
-      COPY_EXP_GOLOMB(source, destination, golomb_tmp);
+      CopyExpGolomb(source, destination);
       // max_bits_per_mb_denom: ue(v)
-      COPY_EXP_GOLOMB(source, destination, golomb_tmp);
+      CopyExpGolomb(source, destination);
       // log2_max_mv_length_horizontal: ue(v)
-      COPY_EXP_GOLOMB(source, destination, golomb_tmp);
+      CopyExpGolomb(source, destination);
       // log2_max_mv_length_vertical: ue(v)
-      COPY_EXP_GOLOMB(source, destination, golomb_tmp);
+      CopyExpGolomb(source, destination);
       // ********* IMPORTANT! **********
       // The next two are the ones we need to set to low numbers:
       // max_num_reorder_frames: ue(v)
       // max_dec_frame_buffering: ue(v)
       // However, if they are already set to no greater than the numbers we
       // want, then we don't need to be rewriting.
-      uint32_t max_num_reorder_frames, max_dec_frame_buffering;
+      uint32_t max_num_reorder_frames = source.ReadExponentialGolomb();
+      uint32_t max_dec_frame_buffering = source.ReadExponentialGolomb();
+      RETURN_FALSE_ON_FAIL(destination.WriteExponentialGolomb(0));
       RETURN_FALSE_ON_FAIL(
-          source->ReadExponentialGolomb(max_num_reorder_frames));
-      RETURN_FALSE_ON_FAIL(
-          source->ReadExponentialGolomb(max_dec_frame_buffering));
-      RETURN_FALSE_ON_FAIL(destination->WriteExponentialGolomb(0));
-      RETURN_FALSE_ON_FAIL(
-          destination->WriteExponentialGolomb(sps.max_num_ref_frames));
+          destination.WriteExponentialGolomb(sps.max_num_ref_frames));
       if (max_num_reorder_frames != 0 ||
           max_dec_frame_buffering > sps.max_num_ref_frames) {
-        *out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiRewritten;
+        out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiRewritten;
       }
     }
   }
-  return true;
+  return source.Ok();
 }
 
 // Copies a VUI HRD parameters segment.
-bool CopyHrdParameters(rtc::BitBuffer* source,
-                       rtc::BitBufferWriter* destination) {
-  uint32_t golomb_tmp;
-  uint32_t bits_tmp;
-
+void CopyHrdParameters(BitstreamReader& source,
+                       rtc::BitBufferWriter& destination) {
   // cbp_cnt_minus1: ue(v)
-  uint32_t cbp_cnt_minus1;
-  COPY_EXP_GOLOMB(source, destination, cbp_cnt_minus1);
+  uint32_t cbp_cnt_minus1 = CopyExpGolomb(source, destination);
   // bit_rate_scale and cbp_size_scale: u(4) each
-  COPY_BITS(source, destination, bits_tmp, 8);
-  for (size_t i = 0; i <= cbp_cnt_minus1; ++i) {
+  CopyBits(8, source, destination);
+  for (size_t i = 0; source.Ok() && i <= cbp_cnt_minus1; ++i) {
     // bit_rate_value_minus1 and cbp_size_value_minus1: ue(v) each
-    COPY_EXP_GOLOMB(source, destination, golomb_tmp);
-    COPY_EXP_GOLOMB(source, destination, golomb_tmp);
+    CopyExpGolomb(source, destination);
+    CopyExpGolomb(source, destination);
     // cbr_flag: u(1)
-    COPY_BITS(source, destination, bits_tmp, 1);
+    CopyBits(1, source, destination);
   }
   // initial_cbp_removal_delay_length_minus1: u(5)
-  COPY_BITS(source, destination, bits_tmp, 5);
   // cbp_removal_delay_length_minus1: u(5)
-  COPY_BITS(source, destination, bits_tmp, 5);
   // dbp_output_delay_length_minus1: u(5)
-  COPY_BITS(source, destination, bits_tmp, 5);
   // time_offset_length: u(5)
-  COPY_BITS(source, destination, bits_tmp, 5);
-  return true;
+  CopyBits(5 * 4, source, destination);
 }
 
 // These functions are similar to webrtc::H264SpsParser::Parse, and based on the
@@ -479,51 +474,51 @@
          color_space.matrix() == ColorSpace::MatrixID::kUnspecified;
 }
 
-bool AddVideoSignalTypeInfo(rtc::BitBufferWriter* destination,
+bool AddVideoSignalTypeInfo(rtc::BitBufferWriter& destination,
                             const ColorSpace& color_space) {
   // video_format: u(3).
-  RETURN_FALSE_ON_FAIL(destination->WriteBits(5, 3));  // 5 = Unspecified
+  RETURN_FALSE_ON_FAIL(destination.WriteBits(5, 3));  // 5 = Unspecified
   // video_full_range_flag: u(1)
-  RETURN_FALSE_ON_FAIL(destination->WriteBits(
+  RETURN_FALSE_ON_FAIL(destination.WriteBits(
       color_space.range() == ColorSpace::RangeID::kFull ? 1 : 0, 1));
   // colour_description_present_flag: u(1)
-  RETURN_FALSE_ON_FAIL(destination->WriteBits(1, 1));
+  RETURN_FALSE_ON_FAIL(destination.WriteBits(1, 1));
   // colour_primaries: u(8)
   RETURN_FALSE_ON_FAIL(
-      destination->WriteUInt8(static_cast<uint8_t>(color_space.primaries())));
+      destination.WriteUInt8(static_cast<uint8_t>(color_space.primaries())));
   // transfer_characteristics: u(8)
   RETURN_FALSE_ON_FAIL(
-      destination->WriteUInt8(static_cast<uint8_t>(color_space.transfer())));
+      destination.WriteUInt8(static_cast<uint8_t>(color_space.transfer())));
   // matrix_coefficients: u(8)
   RETURN_FALSE_ON_FAIL(
-      destination->WriteUInt8(static_cast<uint8_t>(color_space.matrix())));
+      destination.WriteUInt8(static_cast<uint8_t>(color_space.matrix())));
   return true;
 }
 
 bool CopyOrRewriteVideoSignalTypeInfo(
-    rtc::BitBuffer* source,
-    rtc::BitBufferWriter* destination,
+    BitstreamReader& source,
+    rtc::BitBufferWriter& destination,
     const ColorSpace* color_space,
-    SpsVuiRewriter::ParseResult* out_vui_rewritten) {
+    SpsVuiRewriter::ParseResult& out_vui_rewritten) {
   // Read.
-  uint32_t video_signal_type_present_flag;
   uint32_t video_format = 5;           // H264 default: unspecified
   uint32_t video_full_range_flag = 0;  // H264 default: limited
   uint32_t colour_description_present_flag = 0;
   uint8_t colour_primaries = 3;          // H264 default: unspecified
   uint8_t transfer_characteristics = 3;  // H264 default: unspecified
   uint8_t matrix_coefficients = 3;       // H264 default: unspecified
-  RETURN_FALSE_ON_FAIL(source->ReadBits(1, video_signal_type_present_flag));
+  uint32_t video_signal_type_present_flag = source.ReadBit();
   if (video_signal_type_present_flag) {
-    RETURN_FALSE_ON_FAIL(source->ReadBits(3, video_format));
-    RETURN_FALSE_ON_FAIL(source->ReadBits(1, video_full_range_flag));
-    RETURN_FALSE_ON_FAIL(source->ReadBits(1, colour_description_present_flag));
+    video_format = source.ReadBits(3);
+    video_full_range_flag = source.ReadBit();
+    colour_description_present_flag = source.ReadBit();
     if (colour_description_present_flag) {
-      RETURN_FALSE_ON_FAIL(source->ReadUInt8(colour_primaries));
-      RETURN_FALSE_ON_FAIL(source->ReadUInt8(transfer_characteristics));
-      RETURN_FALSE_ON_FAIL(source->ReadUInt8(matrix_coefficients));
+      colour_primaries = source.Read<uint8_t>();
+      transfer_characteristics = source.Read<uint8_t>();
+      matrix_coefficients = source.Read<uint8_t>();
     }
   }
+  RETURN_FALSE_ON_FAIL(source.Ok());
 
   // Update.
   uint32_t video_signal_type_present_flag_override =
@@ -564,19 +559,19 @@
 
   // Write.
   RETURN_FALSE_ON_FAIL(
-      destination->WriteBits(video_signal_type_present_flag_override, 1));
+      destination.WriteBits(video_signal_type_present_flag_override, 1));
   if (video_signal_type_present_flag_override) {
-    RETURN_FALSE_ON_FAIL(destination->WriteBits(video_format_override, 3));
+    RETURN_FALSE_ON_FAIL(destination.WriteBits(video_format_override, 3));
     RETURN_FALSE_ON_FAIL(
-        destination->WriteBits(video_full_range_flag_override, 1));
+        destination.WriteBits(video_full_range_flag_override, 1));
     RETURN_FALSE_ON_FAIL(
-        destination->WriteBits(colour_description_present_flag_override, 1));
+        destination.WriteBits(colour_description_present_flag_override, 1));
     if (colour_description_present_flag_override) {
-      RETURN_FALSE_ON_FAIL(destination->WriteUInt8(colour_primaries_override));
+      RETURN_FALSE_ON_FAIL(destination.WriteUInt8(colour_primaries_override));
       RETURN_FALSE_ON_FAIL(
-          destination->WriteUInt8(transfer_characteristics_override));
+          destination.WriteUInt8(transfer_characteristics_override));
       RETURN_FALSE_ON_FAIL(
-          destination->WriteUInt8(matrix_coefficients_override));
+          destination.WriteUInt8(matrix_coefficients_override));
     }
   }
 
@@ -589,27 +584,26 @@
       colour_primaries_override != colour_primaries ||
       transfer_characteristics_override != transfer_characteristics ||
       matrix_coefficients_override != matrix_coefficients) {
-    *out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiRewritten;
+    out_vui_rewritten = SpsVuiRewriter::ParseResult::kVuiRewritten;
   }
 
   return true;
 }
 
-bool CopyRemainingBits(rtc::BitBuffer* source,
-                       rtc::BitBufferWriter* destination) {
-  uint32_t bits_tmp;
+bool CopyRemainingBits(BitstreamReader& source,
+                       rtc::BitBufferWriter& destination) {
   // Try to get at least the destination aligned.
-  if (source->RemainingBitCount() > 0 && source->RemainingBitCount() % 8 != 0) {
-    size_t misaligned_bits = source->RemainingBitCount() % 8;
-    COPY_BITS(source, destination, bits_tmp, misaligned_bits);
+  if (source.RemainingBitCount() > 0 && source.RemainingBitCount() % 8 != 0) {
+    size_t misaligned_bits = source.RemainingBitCount() % 8;
+    CopyBits(misaligned_bits, source, destination);
   }
-  while (source->RemainingBitCount() > 0) {
-    auto count = rtc::SafeMin<size_t>(32u, source->RemainingBitCount());
-    COPY_BITS(source, destination, bits_tmp, count);
+  while (source.RemainingBitCount() > 0) {
+    int count = std::min(32, source.RemainingBitCount());
+    CopyBits(count, source, destination);
   }
   // TODO(noahric): The last byte could be all zeroes now, which we should just
   // strip.
-  return true;
+  return source.Ok();
 }
 
 }  // namespace