From 166924427ea0cd01c28975cefedc1b5aa1949c7d Mon Sep 17 00:00:00 2001
From: Cameron Gutman <aicommander@gmail.com>
Date: Fri, 11 Aug 2023 22:17:38 -0500
Subject: [PATCH] Use existing H.264 SPS as a template rather than building it
 from scratch

---
 src/cbs.cpp | 136 +++++++++++++++-------------------------------------
 1 file changed, 38 insertions(+), 98 deletions(-)

diff --git a/src/cbs.cpp b/src/cbs.cpp
index c1d9df60..c1578825 100644
--- a/src/cbs.cpp
+++ b/src/cbs.cpp
@@ -87,92 +87,61 @@ namespace cbs {
     return write(cbs_ctx, nal, uh, codec_id);
   }
 
-  util::buffer_t<std::uint8_t>
-  make_sps_h264(const AVCodecContext *ctx) {
-    H264RawSPS sps {};
-
-    // b_per_p == ctx->max_b_frames for h264
-    // desired_b_depth == avoption("b_depth") == 1
-    // max_b_depth == std::min(av_log2(ctx->b_per_p) + 1, desired_b_depth) ==> 1
-    auto max_b_depth = 1;
-    auto dpb_frame = ctx->gop_size == 1 ? 0 : 1 + max_b_depth;
-    auto mb_width = (FFALIGN(ctx->width, 16) / 16) * 16;
-    auto mb_height = (FFALIGN(ctx->height, 16) / 16) * 16;
-
-    sps.nal_unit_header.nal_ref_idc = 3;
-    sps.nal_unit_header.nal_unit_type = H264_NAL_SPS;
-
-    sps.profile_idc = FF_PROFILE_H264_HIGH & 0xFF;
-
-    sps.constraint_set1_flag = 1;
-
-    if (ctx->level != FF_LEVEL_UNKNOWN) {
-      sps.level_idc = ctx->level;
-    }
-    else {
-      auto framerate = ctx->framerate;
-
-      auto level = ff_h264_guess_level(
-        sps.profile_idc,
-        ctx->bit_rate,
-        framerate.num / framerate.den,
-        mb_width,
-        mb_height,
-        dpb_frame);
-
-      if (!level) {
-        BOOST_LOG(error) << "Could not guess h264 level"sv;
-
-        return {};
-      }
-      sps.level_idc = level->level_idc;
+  h264_t
+  make_sps_h264(const AVCodecContext *avctx, const AVPacket *packet) {
+    cbs::ctx_t ctx;
+    if (ff_cbs_init(&ctx, AV_CODEC_ID_H264, nullptr)) {
+      return {};
     }
 
-    sps.seq_parameter_set_id = 0;
-    sps.chroma_format_idc = 1;
+    cbs::frag_t frag;
 
-    sps.log2_max_frame_num_minus4 = 3;  // 4;
-    sps.pic_order_cnt_type = 0;
-    sps.log2_max_pic_order_cnt_lsb_minus4 = 0;  // 4;
+    int err = ff_cbs_read_packet(ctx.get(), &frag, &*packet);
+    if (err < 0) {
+      char err_str[AV_ERROR_MAX_STRING_SIZE] { 0 };
+      BOOST_LOG(error) << "Couldn't read packet: "sv << av_make_error_string(err_str, AV_ERROR_MAX_STRING_SIZE, err);
 
-    sps.max_num_ref_frames = dpb_frame;
-
-    sps.pic_width_in_mbs_minus1 = mb_width / 16 - 1;
-    sps.pic_height_in_map_units_minus1 = mb_height / 16 - 1;
-
-    sps.frame_mbs_only_flag = 1;
-    sps.direct_8x8_inference_flag = 1;
-
-    if (ctx->width != mb_width || ctx->height != mb_height) {
-      sps.frame_cropping_flag = 1;
-      sps.frame_crop_left_offset = 0;
-      sps.frame_crop_top_offset = 0;
-      sps.frame_crop_right_offset = (mb_width - ctx->width) / 2;
-      sps.frame_crop_bottom_offset = (mb_height - ctx->height) / 2;
+      return {};
     }
 
-    sps.vui_parameters_present_flag = 1;
+    auto sps_p = ((CodedBitstreamH264Context *) ctx->priv_data)->active_sps;
 
-    auto &vui = sps.vui;
+    // This is a very large struct that cannot safely be stored on the stack
+    auto sps = std::make_unique<H264RawSPS>(*sps_p);
+
+    if (avctx->refs > 0) {
+      sps->max_num_ref_frames = avctx->refs;
+    }
+
+    sps->vui_parameters_present_flag = 1;
+
+    auto &vui = sps->vui;
+    std::memset(&vui, 0, sizeof(vui));
 
     vui.video_format = 5;
     vui.colour_description_present_flag = 1;
     vui.video_signal_type_present_flag = 1;
-    vui.video_full_range_flag = ctx->color_range == AVCOL_RANGE_JPEG;
-    vui.colour_primaries = ctx->color_primaries;
-    vui.transfer_characteristics = ctx->color_trc;
-    vui.matrix_coefficients = ctx->colorspace;
+    vui.video_full_range_flag = avctx->color_range == AVCOL_RANGE_JPEG;
+    vui.colour_primaries = avctx->color_primaries;
+    vui.transfer_characteristics = avctx->color_trc;
+    vui.matrix_coefficients = avctx->colorspace;
 
     vui.low_delay_hrd_flag = 1 - vui.fixed_frame_rate_flag;
 
     vui.bitstream_restriction_flag = 1;
     vui.motion_vectors_over_pic_boundaries_flag = 1;
-    vui.log2_max_mv_length_horizontal = 15;
-    vui.log2_max_mv_length_vertical = 15;
-    vui.max_num_reorder_frames = max_b_depth;
-    vui.max_dec_frame_buffering = max_b_depth + 1;
+    vui.log2_max_mv_length_horizontal = 16;
+    vui.log2_max_mv_length_vertical = 16;
+    vui.max_num_reorder_frames = 0;
+    vui.max_dec_frame_buffering = sps->max_num_ref_frames;
 
-    return write(sps.nal_unit_header.nal_unit_type, (void *) &sps.nal_unit_header, AV_CODEC_ID_H264);
+    cbs::ctx_t write_ctx;
+    ff_cbs_init(&write_ctx, AV_CODEC_ID_H264, nullptr);
+
+    return h264_t {
+      write(write_ctx, sps->nal_unit_header.nal_unit_type, (void *) &sps->nal_unit_header, AV_CODEC_ID_H264),
+      write(ctx, sps_p->nal_unit_header.nal_unit_type, (void *) &sps_p->nal_unit_header, AV_CODEC_ID_H264)
+    };
   }
 
   hevc_t
@@ -248,35 +217,6 @@ namespace cbs {
     };
   }
 
-  util::buffer_t<std::uint8_t>
-  read_sps_h264(const AVPacket *packet) {
-    cbs::ctx_t ctx;
-    if (ff_cbs_init(&ctx, AV_CODEC_ID_H264, nullptr)) {
-      return {};
-    }
-
-    cbs::frag_t frag;
-
-    int err = ff_cbs_read_packet(ctx.get(), &frag, &*packet);
-    if (err < 0) {
-      char err_str[AV_ERROR_MAX_STRING_SIZE] { 0 };
-      BOOST_LOG(error) << "Couldn't read packet: "sv << av_make_error_string(err_str, AV_ERROR_MAX_STRING_SIZE, err);
-
-      return {};
-    }
-
-    auto h264 = (H264RawNALUnitHeader *) ((CodedBitstreamH264Context *) ctx->priv_data)->active_sps;
-    return write(h264->nal_unit_type, (void *) h264, AV_CODEC_ID_H264);
-  }
-
-  h264_t
-  make_sps_h264(const AVCodecContext *ctx, const AVPacket *packet) {
-    return h264_t {
-      make_sps_h264(ctx),
-      read_sps_h264(packet),
-    };
-  }
-
   bool
   validate_sps(const AVPacket *packet, int codec_id) {
     cbs::ctx_t ctx;