Allow enabling support for the lightweight report format without breaking support for non-lightweight tasks

Without this change, the client expects to only handle lightweight tasks if the enable_lightweight_client_report_wire_format flag is enabled. This change adds a separate flag that enables enable_lightweight_client_report_wire_format to be be true without breaking support for non-lightweight tasks.

PiperOrigin-RevId: 638668272
diff --git a/fcp/client/fl_runner.cc b/fcp/client/fl_runner.cc
index 8b151d4..f039b9a 100644
--- a/fcp/client/fl_runner.cc
+++ b/fcp/client/fl_runner.cc
@@ -225,23 +225,51 @@
     }
   }
 
-  if (flags->enable_lightweight_client_report_wire_format()) {
-    // Reads string from federated_compute_checkpoint
-    if (plan_result.federated_compute_checkpoint.empty()) {
-      return absl::InvalidArgumentError("Empty federate compute checkpoint");
+  if (!flags
+           ->untie_lw_client_report_format_support_from_requiring_lw_report()) {
+    // Old buggy behavior
+    if (flags->enable_lightweight_client_report_wire_format()) {
+      // Reads string from federated_compute_checkpoint
+      if (plan_result.federated_compute_checkpoint.empty()) {
+        return absl::InvalidArgumentError("Empty federate compute checkpoint");
+      }
+      computation_results[kFederatedComputeCheckpoint] =
+          std::move(plan_result.federated_compute_checkpoint);
+    } else {
+      // Name of the TF checkpoint inside the aggregand map in the Checkpoint
+      // protobuf. This field name is ignored by the server.
+      if (!checkpoint_filename.empty()) {
+        FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint,
+                             fcp::ReadFileToString(checkpoint_filename));
+        computation_results[std::string(kTensorflowCheckpointAggregand)] =
+            std::move(tf_checkpoint);
+      }
     }
-    computation_results[kFederatedComputeCheckpoint] =
-        std::move(plan_result.federated_compute_checkpoint);
   } else {
-    // Name of the TF checkpoint inside the aggregand map in the Checkpoint
-    // protobuf. This field name is ignored by the server.
-    if (!checkpoint_filename.empty()) {
+    if (!plan_result.federated_compute_checkpoint.empty()) {
+      if (flags->enable_lightweight_client_report_wire_format()) {
+        // Task produced a lightweight report, and the feature is enabled.
+        computation_results[kFederatedComputeCheckpoint] =
+            std::move(plan_result.federated_compute_checkpoint);
+      } else {
+        // Task produced a lightweight report, but the feature is disabled.
+        return absl::InternalError(
+            "Lightweight report produced but lightweight report feature is "
+            "disabled");
+      }
+    } else if (!checkpoint_filename.empty()) {
+      // Name of the TF checkpoint inside the aggregand map in the Checkpoint
+      // protobuf. This field name is ignored by the server.
       FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint,
                            fcp::ReadFileToString(checkpoint_filename));
       computation_results[std::string(kTensorflowCheckpointAggregand)] =
           std::move(tf_checkpoint);
+    } else {
+      // No lightweight report produced, and no TF checkpoint produced. For this
+      // computation, all outputs are aggregated with secagg.
     }
   }
+
   return computation_results;
 }
 
diff --git a/fcp/client/flags.h b/fcp/client/flags.h
index c63aba2..506886c 100644
--- a/fcp/client/flags.h
+++ b/fcp/client/flags.h
@@ -174,6 +174,16 @@
     return false;
   }
 
+  // If true, enabling the lightweight client report wire format will be untied
+  // from requiring a lightweight client report. This means that even if
+  // `enable_lightweight_client_report_wire_format` is true, the client is not
+  // required to produce a lightweight client report. This is a bugfix for the
+  // existing implementation.
+  virtual bool untie_lw_client_report_format_support_from_requiring_lw_report()
+      const {
+    return false;
+  }
+
   // If true, OpStats logger enables PhaseStats logging.
   virtual bool enable_phase_stats_logging() const { return false; }
 
diff --git a/fcp/client/http/http_federated_protocol.cc b/fcp/client/http/http_federated_protocol.cc
index ef5be61..6371596 100644
--- a/fcp/client/http/http_federated_protocol.cc
+++ b/fcp/client/http/http_federated_protocol.cc
@@ -1176,18 +1176,31 @@
                      " aggregands have unexpected results size."));
   }
   auto result = std::move(results.begin()->second);
+  bool untie_lw_client_report_format_support_from_requiring_lw_report =
+      flags_->untie_lw_client_report_format_support_from_requiring_lw_report();
   bool enable_lightweight_client_report_wire_format =
       flags_->enable_lightweight_client_report_wire_format();
-  if (!enable_lightweight_client_report_wire_format &&
-      !std::holds_alternative<TFCheckpoint>(result)) {
-    return absl::InternalError(absl::StrCat(
-        aggregation_type_readable, " aggregands have unexpected format."));
-  }
-  if (enable_lightweight_client_report_wire_format &&
-      !std::holds_alternative<FCCheckpoint>(result)) {
-    return absl::InternalError(
-        absl::StrCat(aggregation_type_readable,
-                     " aggregands have unexpected format for FC Wire Format."));
+  if (!untie_lw_client_report_format_support_from_requiring_lw_report) {
+    // Old incorrect behavior.
+    if (!enable_lightweight_client_report_wire_format &&
+        !std::holds_alternative<TFCheckpoint>(result)) {
+      return absl::InternalError(absl::StrCat(
+          aggregation_type_readable, " aggregands have unexpected format."));
+    }
+    if (enable_lightweight_client_report_wire_format &&
+        !std::holds_alternative<FCCheckpoint>(result)) {
+      return absl::InternalError(absl::StrCat(
+          aggregation_type_readable,
+          " aggregands have unexpected format for FC Wire Format."));
+    }
+  } else {
+    if (!enable_lightweight_client_report_wire_format &&
+        std::holds_alternative<FCCheckpoint>(result)) {
+      return absl::InternalError(absl::StrCat(
+          aggregation_type_readable,
+          " computation produced FC Wire Format but this feature is "
+          "not enabled."));
+    }
   }
   auto start_upload_status = HandleStartDataAggregationUploadOperationResponse(
       PerformStartDataUploadRequestAndReportTaskResult(plan_duration,
@@ -1206,12 +1219,25 @@
       << aggregation_type_readable;
 
   std::string result_data;
-  if (enable_lightweight_client_report_wire_format) {
-    // TODO: b/300128447 - avoid copying serialized checkpoint once http
-    // federated protocol supports absl::Cord
-    absl::CopyCordToString(std::get<FCCheckpoint>(result), &result_data);
+  if (!untie_lw_client_report_format_support_from_requiring_lw_report) {
+    if (enable_lightweight_client_report_wire_format) {
+      // TODO: b/300128447 - avoid copying serialized checkpoint once http
+      // federated protocol supports absl::Cord
+      absl::CopyCordToString(std::get<FCCheckpoint>(result), &result_data);
+    } else {
+      result_data = std::get<TFCheckpoint>(result);
+    }
   } else {
-    result_data = std::get<TFCheckpoint>(result);
+    bool should_report_lightweight_client_report_wire_format =
+        enable_lightweight_client_report_wire_format &&
+        std::holds_alternative<FCCheckpoint>(result);
+    if (should_report_lightweight_client_report_wire_format) {
+      // TODO: b/300128447 - avoid copying serialized checkpoint once http
+      // federated protocol supports absl::Cord
+      absl::CopyCordToString(std::get<FCCheckpoint>(result), &result_data);
+    } else {
+      result_data = std::get<TFCheckpoint>(result);
+    }
   }
 
   std::string data_to_upload;
diff --git a/fcp/client/http/http_federated_protocol_test.cc b/fcp/client/http/http_federated_protocol_test.cc
index a560adb..4722bc0 100644
--- a/fcp/client/http/http_federated_protocol_test.cc
+++ b/fcp/client/http/http_federated_protocol_test.cc
@@ -3035,7 +3035,7 @@
 }
 
 TEST_F(HttpFederatedProtocolTest,
-       TestReportCompletedViaSimpleAggSuccessWithFCWireFormat) {
+       TestReportCompletedViaSimpleAggSuccessWithFCWireFormatBugFixDisabled) {
   // Issue an eligibility eval checkin first.
   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
   // Issue a regular checkin.
@@ -3043,6 +3043,52 @@
   // Enables enable_lightweight_client_report_wire_format flag.
   EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
       .WillRepeatedly(Return(true));
+  // Reporting a lightweight result should still work even if the bugfix is
+  // disabled.
+  EXPECT_CALL(mock_flags_,
+              untie_lw_client_report_format_support_from_requiring_lw_report())
+      .WillRepeatedly(Return(false));
+  // Create a fake lightweight result format with 32 'X'.
+  std::string checkpoint_str(32, 'X');
+  absl::Cord checkpoint_cord(checkpoint_str);
+  ComputationResults results;
+  results.emplace("fc_checkpoint", checkpoint_cord);
+  absl::Duration plan_duration = absl::Minutes(5);
+
+  ExpectSuccessfulReportTaskResultRequest(
+      "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+      "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+      kAggregationSessionId, kTaskName, plan_duration);
+  ExpectSuccessfulStartAggregationDataUploadRequest(
+      "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+      "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+      kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+  ExpectSuccessfulByteStreamUploadRequest(
+      "https://bytestream.uri/upload/v1/media/"
+      "CHECKPOINT_RESOURCE?upload_protocol=raw",
+      checkpoint_str);
+  ExpectSuccessfulSubmitAggregationResultRequest(
+      "https://aggregation.second.uri/v1/aggregations/"
+      "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
+
+  EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+                                                 plan_duration, std::nullopt));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+       TestReportCompletedViaSimpleAggSuccessWithFCWireFormatBugFixEnabled) {
+  // Issue an eligibility eval checkin first.
+  ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+  // Issue a regular checkin.
+  ASSERT_OK(RunSuccessfulCheckin());
+  // Enables enable_lightweight_client_report_wire_format flag.
+  EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+      .WillRepeatedly(Return(true));
+  // Reporting a lightweight result should still work when the bugfix is
+  // enabled.
+  EXPECT_CALL(mock_flags_,
+              untie_lw_client_report_format_support_from_requiring_lw_report())
+      .WillRepeatedly(Return(true));
   // Create a fake checkpoint with 32 'X'.
   std::string checkpoint_str(32, 'X');
   absl::Cord checkpoint_cord(checkpoint_str);
@@ -3070,6 +3116,142 @@
                                                  plan_duration, std::nullopt));
 }
 
+TEST_F(
+    HttpFederatedProtocolTest,
+    TestReportCompletedWithLightweightWireFormatSupportDisabledBugFixEnabled) {
+  // Issue an eligibility eval checkin first.
+  ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+  // Issue a regular checkin.
+  ASSERT_OK(RunSuccessfulCheckin());
+  // Disables enable_lightweight_client_report_wire_format flag.
+  EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+      .WillRepeatedly(Return(false));
+  // Bugfix is disabled to test the old behavior.
+  EXPECT_CALL(mock_flags_,
+              untie_lw_client_report_format_support_from_requiring_lw_report())
+      .WillRepeatedly(Return(true));
+  // Create a fake lightweight result format with 32 'X'.
+  std::string checkpoint_str(32, 'X');
+  absl::Cord checkpoint_cord(checkpoint_str);
+  ComputationResults results;
+  results.emplace("fc_checkpoint", checkpoint_cord);
+  absl::Duration plan_duration = absl::Minutes(5);
+
+  // Should fail because the flag is disabled.
+  EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
+                                                   plan_duration, std::nullopt),
+              IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(
+    HttpFederatedProtocolTest,
+    TestReportCompletedSuccessWithTfCheckpointButSupportDisabledBugFixEnabled) {
+  // Issue an eligibility eval checkin first.
+  ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+  // Issue a regular checkin.
+  ASSERT_OK(RunSuccessfulCheckin());
+  // Disables enable_lightweight_client_report_wire_format flag.
+  EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+      .WillRepeatedly(Return(false));
+  // Enables the bugfix so that the client can report tf checkpoint even if
+  // enable_lightweight_client_report_wire_format is true.
+  EXPECT_CALL(mock_flags_,
+              untie_lw_client_report_format_support_from_requiring_lw_report())
+      .WillRepeatedly(Return(true));
+  // Create a fake tf checkpoint with 32 'X'.
+  std::string checkpoint_str(32, 'X');
+  ComputationResults results;
+  results.emplace("tensorflow_checkpoint", checkpoint_str);
+  absl::Duration plan_duration = absl::Minutes(5);
+
+  ExpectSuccessfulReportTaskResultRequest(
+      "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+      "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+      kAggregationSessionId, kTaskName, plan_duration);
+  ExpectSuccessfulStartAggregationDataUploadRequest(
+      "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+      "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+      kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+  ExpectSuccessfulByteStreamUploadRequest(
+      "https://bytestream.uri/upload/v1/media/"
+      "CHECKPOINT_RESOURCE?upload_protocol=raw",
+      checkpoint_str);
+  ExpectSuccessfulSubmitAggregationResultRequest(
+      "https://aggregation.second.uri/v1/aggregations/"
+      "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
+
+  EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+                                                 plan_duration, std::nullopt));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+       TestReportCompletedFailureWithTfCheckpointSupportEnabledBugFixDisabled) {
+  // Issue an eligibility eval checkin first.
+  ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+  // Issue a regular checkin.
+  ASSERT_OK(RunSuccessfulCheckin());
+  // Enables enable_lightweight_client_report_wire_format flag.
+  EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+      .WillRepeatedly(Return(true));
+  // The client should fail to report tf checkpoint since
+  // enable_lightweight_client_report_wire_format is true but the bugfix is
+  // disabled.
+  EXPECT_CALL(mock_flags_,
+              untie_lw_client_report_format_support_from_requiring_lw_report())
+      .WillRepeatedly(Return(false));
+  // Create a fake tf checkpoint with 32 'X'.
+  std::string checkpoint_str(32, 'X');
+  ComputationResults results;
+  results.emplace("tensorflow_checkpoint", checkpoint_str);
+  absl::Duration plan_duration = absl::Minutes(5);
+
+  // Should fail because the enable_lightweight_client_report_wire_format is
+  // still disabled.
+  EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
+                                                   plan_duration, std::nullopt),
+              IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+       TestReportWithTfCheckpointFcWireFormatEnabledAndBugFixEnabled) {
+  // Issue an eligibility eval checkin first.
+  ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+  // Issue a regular checkin.
+  ASSERT_OK(RunSuccessfulCheckin());
+  // Enables enable_lightweight_client_report_wire_format flag.
+  EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+      .WillRepeatedly(Return(true));
+  // Enables the bugfix so that the client can report tf checkpoint even if
+  // enable_lightweight_client_report_wire_format is true.
+  EXPECT_CALL(mock_flags_,
+              untie_lw_client_report_format_support_from_requiring_lw_report())
+      .WillRepeatedly(Return(true));
+  // Create a fake tf checkpoint with 32 'X'.
+  std::string checkpoint_str(32, 'X');
+  ComputationResults results;
+  results.emplace("tensorflow_checkpoint", checkpoint_str);
+  absl::Duration plan_duration = absl::Minutes(5);
+
+  ExpectSuccessfulReportTaskResultRequest(
+      "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+      "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+      kAggregationSessionId, kTaskName, plan_duration);
+  ExpectSuccessfulStartAggregationDataUploadRequest(
+      "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+      "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+      kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+  ExpectSuccessfulByteStreamUploadRequest(
+      "https://bytestream.uri/upload/v1/media/"
+      "CHECKPOINT_RESOURCE?upload_protocol=raw",
+      checkpoint_str);
+  ExpectSuccessfulSubmitAggregationResultRequest(
+      "https://aggregation.second.uri/v1/aggregations/"
+      "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
+
+  EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+                                                 plan_duration, std::nullopt));
+}
+
 // TODO(team): Remove this test once client_token is always populated in
 // StartAggregationDataUploadResponse.
 TEST_F(HttpFederatedProtocolTest,
diff --git a/fcp/client/test_helpers.h b/fcp/client/test_helpers.h
index d5082d9..bc79348 100644
--- a/fcp/client/test_helpers.h
+++ b/fcp/client/test_helpers.h
@@ -668,6 +668,9 @@
   MOCK_METHOD(bool, enable_phase_stats_logging, (), (const, override));
   MOCK_METHOD(bool, enable_lightweight_client_report_wire_format, (),
               (const, override));
+  MOCK_METHOD(bool,
+              untie_lw_client_report_format_support_from_requiring_lw_report,
+              (), (const, override));
   MOCK_METHOD(bool, enable_native_example_query_recording, (),
               (const, override));
   MOCK_METHOD(bool, enable_confidential_aggregation, (), (const, override));