Allow enabling support for the lightweight report format without breaking support for non-lightweight tasks
Without this change, the client expects to only handle lightweight tasks if the enable_lightweight_client_report_wire_format flag is enabled. This change adds a separate flag that enables enable_lightweight_client_report_wire_format to be be true without breaking support for non-lightweight tasks.
PiperOrigin-RevId: 638668272
diff --git a/fcp/client/fl_runner.cc b/fcp/client/fl_runner.cc
index 8b151d4..f039b9a 100644
--- a/fcp/client/fl_runner.cc
+++ b/fcp/client/fl_runner.cc
@@ -225,23 +225,51 @@
}
}
- if (flags->enable_lightweight_client_report_wire_format()) {
- // Reads string from federated_compute_checkpoint
- if (plan_result.federated_compute_checkpoint.empty()) {
- return absl::InvalidArgumentError("Empty federate compute checkpoint");
+ if (!flags
+ ->untie_lw_client_report_format_support_from_requiring_lw_report()) {
+ // Old buggy behavior
+ if (flags->enable_lightweight_client_report_wire_format()) {
+ // Reads string from federated_compute_checkpoint
+ if (plan_result.federated_compute_checkpoint.empty()) {
+ return absl::InvalidArgumentError("Empty federate compute checkpoint");
+ }
+ computation_results[kFederatedComputeCheckpoint] =
+ std::move(plan_result.federated_compute_checkpoint);
+ } else {
+ // Name of the TF checkpoint inside the aggregand map in the Checkpoint
+ // protobuf. This field name is ignored by the server.
+ if (!checkpoint_filename.empty()) {
+ FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint,
+ fcp::ReadFileToString(checkpoint_filename));
+ computation_results[std::string(kTensorflowCheckpointAggregand)] =
+ std::move(tf_checkpoint);
+ }
}
- computation_results[kFederatedComputeCheckpoint] =
- std::move(plan_result.federated_compute_checkpoint);
} else {
- // Name of the TF checkpoint inside the aggregand map in the Checkpoint
- // protobuf. This field name is ignored by the server.
- if (!checkpoint_filename.empty()) {
+ if (!plan_result.federated_compute_checkpoint.empty()) {
+ if (flags->enable_lightweight_client_report_wire_format()) {
+ // Task produced a lightweight report, and the feature is enabled.
+ computation_results[kFederatedComputeCheckpoint] =
+ std::move(plan_result.federated_compute_checkpoint);
+ } else {
+ // Task produced a lightweight report, but the feature is disabled.
+ return absl::InternalError(
+ "Lightweight report produced but lightweight report feature is "
+ "disabled");
+ }
+ } else if (!checkpoint_filename.empty()) {
+ // Name of the TF checkpoint inside the aggregand map in the Checkpoint
+ // protobuf. This field name is ignored by the server.
FCP_ASSIGN_OR_RETURN(std::string tf_checkpoint,
fcp::ReadFileToString(checkpoint_filename));
computation_results[std::string(kTensorflowCheckpointAggregand)] =
std::move(tf_checkpoint);
+ } else {
+ // No lightweight report produced, and no TF checkpoint produced. For this
+ // computation, all outputs are aggregated with secagg.
}
}
+
return computation_results;
}
diff --git a/fcp/client/flags.h b/fcp/client/flags.h
index c63aba2..506886c 100644
--- a/fcp/client/flags.h
+++ b/fcp/client/flags.h
@@ -174,6 +174,16 @@
return false;
}
+ // If true, enabling the lightweight client report wire format will be untied
+ // from requiring a lightweight client report. This means that even if
+ // `enable_lightweight_client_report_wire_format` is true, the client is not
+ // required to produce a lightweight client report. This is a bugfix for the
+ // existing implementation.
+ virtual bool untie_lw_client_report_format_support_from_requiring_lw_report()
+ const {
+ return false;
+ }
+
// If true, OpStats logger enables PhaseStats logging.
virtual bool enable_phase_stats_logging() const { return false; }
diff --git a/fcp/client/http/http_federated_protocol.cc b/fcp/client/http/http_federated_protocol.cc
index ef5be61..6371596 100644
--- a/fcp/client/http/http_federated_protocol.cc
+++ b/fcp/client/http/http_federated_protocol.cc
@@ -1176,18 +1176,31 @@
" aggregands have unexpected results size."));
}
auto result = std::move(results.begin()->second);
+ bool untie_lw_client_report_format_support_from_requiring_lw_report =
+ flags_->untie_lw_client_report_format_support_from_requiring_lw_report();
bool enable_lightweight_client_report_wire_format =
flags_->enable_lightweight_client_report_wire_format();
- if (!enable_lightweight_client_report_wire_format &&
- !std::holds_alternative<TFCheckpoint>(result)) {
- return absl::InternalError(absl::StrCat(
- aggregation_type_readable, " aggregands have unexpected format."));
- }
- if (enable_lightweight_client_report_wire_format &&
- !std::holds_alternative<FCCheckpoint>(result)) {
- return absl::InternalError(
- absl::StrCat(aggregation_type_readable,
- " aggregands have unexpected format for FC Wire Format."));
+ if (!untie_lw_client_report_format_support_from_requiring_lw_report) {
+ // Old incorrect behavior.
+ if (!enable_lightweight_client_report_wire_format &&
+ !std::holds_alternative<TFCheckpoint>(result)) {
+ return absl::InternalError(absl::StrCat(
+ aggregation_type_readable, " aggregands have unexpected format."));
+ }
+ if (enable_lightweight_client_report_wire_format &&
+ !std::holds_alternative<FCCheckpoint>(result)) {
+ return absl::InternalError(absl::StrCat(
+ aggregation_type_readable,
+ " aggregands have unexpected format for FC Wire Format."));
+ }
+ } else {
+ if (!enable_lightweight_client_report_wire_format &&
+ std::holds_alternative<FCCheckpoint>(result)) {
+ return absl::InternalError(absl::StrCat(
+ aggregation_type_readable,
+ " computation produced FC Wire Format but this feature is "
+ "not enabled."));
+ }
}
auto start_upload_status = HandleStartDataAggregationUploadOperationResponse(
PerformStartDataUploadRequestAndReportTaskResult(plan_duration,
@@ -1206,12 +1219,25 @@
<< aggregation_type_readable;
std::string result_data;
- if (enable_lightweight_client_report_wire_format) {
- // TODO: b/300128447 - avoid copying serialized checkpoint once http
- // federated protocol supports absl::Cord
- absl::CopyCordToString(std::get<FCCheckpoint>(result), &result_data);
+ if (!untie_lw_client_report_format_support_from_requiring_lw_report) {
+ if (enable_lightweight_client_report_wire_format) {
+ // TODO: b/300128447 - avoid copying serialized checkpoint once http
+ // federated protocol supports absl::Cord
+ absl::CopyCordToString(std::get<FCCheckpoint>(result), &result_data);
+ } else {
+ result_data = std::get<TFCheckpoint>(result);
+ }
} else {
- result_data = std::get<TFCheckpoint>(result);
+ bool should_report_lightweight_client_report_wire_format =
+ enable_lightweight_client_report_wire_format &&
+ std::holds_alternative<FCCheckpoint>(result);
+ if (should_report_lightweight_client_report_wire_format) {
+ // TODO: b/300128447 - avoid copying serialized checkpoint once http
+ // federated protocol supports absl::Cord
+ absl::CopyCordToString(std::get<FCCheckpoint>(result), &result_data);
+ } else {
+ result_data = std::get<TFCheckpoint>(result);
+ }
}
std::string data_to_upload;
diff --git a/fcp/client/http/http_federated_protocol_test.cc b/fcp/client/http/http_federated_protocol_test.cc
index a560adb..4722bc0 100644
--- a/fcp/client/http/http_federated_protocol_test.cc
+++ b/fcp/client/http/http_federated_protocol_test.cc
@@ -3035,7 +3035,7 @@
}
TEST_F(HttpFederatedProtocolTest,
- TestReportCompletedViaSimpleAggSuccessWithFCWireFormat) {
+ TestReportCompletedViaSimpleAggSuccessWithFCWireFormatBugFixDisabled) {
// Issue an eligibility eval checkin first.
ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
// Issue a regular checkin.
@@ -3043,6 +3043,52 @@
// Enables enable_lightweight_client_report_wire_format flag.
EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
.WillRepeatedly(Return(true));
+ // Reporting a lightweight result should still work even if the bugfix is
+ // disabled.
+ EXPECT_CALL(mock_flags_,
+ untie_lw_client_report_format_support_from_requiring_lw_report())
+ .WillRepeatedly(Return(false));
+ // Create a fake lightweight result format with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ absl::Cord checkpoint_cord(checkpoint_str);
+ ComputationResults results;
+ results.emplace("fc_checkpoint", checkpoint_cord);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ ExpectSuccessfulByteStreamUploadRequest(
+ "https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw",
+ checkpoint_str);
+ ExpectSuccessfulSubmitAggregationResultRequest(
+ "https://aggregation.second.uri/v1/aggregations/"
+ "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
+
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedViaSimpleAggSuccessWithFCWireFormatBugFixEnabled) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ // Enables enable_lightweight_client_report_wire_format flag.
+ EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+ .WillRepeatedly(Return(true));
+ // Reporting a lightweight result should still work when the bugfix is
+ // enabled.
+ EXPECT_CALL(mock_flags_,
+ untie_lw_client_report_format_support_from_requiring_lw_report())
+ .WillRepeatedly(Return(true));
// Create a fake checkpoint with 32 'X'.
std::string checkpoint_str(32, 'X');
absl::Cord checkpoint_cord(checkpoint_str);
@@ -3070,6 +3116,142 @@
plan_duration, std::nullopt));
}
+TEST_F(
+ HttpFederatedProtocolTest,
+ TestReportCompletedWithLightweightWireFormatSupportDisabledBugFixEnabled) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ // Disables enable_lightweight_client_report_wire_format flag.
+ EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+ .WillRepeatedly(Return(false));
+ // Bugfix is disabled to test the old behavior.
+ EXPECT_CALL(mock_flags_,
+ untie_lw_client_report_format_support_from_requiring_lw_report())
+ .WillRepeatedly(Return(true));
+ // Create a fake lightweight result format with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ absl::Cord checkpoint_cord(checkpoint_str);
+ ComputationResults results;
+ results.emplace("fc_checkpoint", checkpoint_cord);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ // Should fail because the flag is disabled.
+ EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt),
+ IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(
+ HttpFederatedProtocolTest,
+ TestReportCompletedSuccessWithTfCheckpointButSupportDisabledBugFixEnabled) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ // Disables enable_lightweight_client_report_wire_format flag.
+ EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+ .WillRepeatedly(Return(false));
+ // Enables the bugfix so that the client can report tf checkpoint even if
+ // enable_lightweight_client_report_wire_format is true.
+ EXPECT_CALL(mock_flags_,
+ untie_lw_client_report_format_support_from_requiring_lw_report())
+ .WillRepeatedly(Return(true));
+ // Create a fake tf checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ ExpectSuccessfulByteStreamUploadRequest(
+ "https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw",
+ checkpoint_str);
+ ExpectSuccessfulSubmitAggregationResultRequest(
+ "https://aggregation.second.uri/v1/aggregations/"
+ "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
+
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestReportCompletedFailureWithTfCheckpointSupportEnabledBugFixDisabled) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ // Enables enable_lightweight_client_report_wire_format flag.
+ EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+ .WillRepeatedly(Return(true));
+ // The client should fail to report tf checkpoint since
+ // enable_lightweight_client_report_wire_format is true but the bugfix is
+ // disabled.
+ EXPECT_CALL(mock_flags_,
+ untie_lw_client_report_format_support_from_requiring_lw_report())
+ .WillRepeatedly(Return(false));
+ // Create a fake tf checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ // Should fail because the enable_lightweight_client_report_wire_format is
+ // still disabled.
+ EXPECT_THAT(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt),
+ IsCode(absl::StatusCode::kInternal));
+}
+
+TEST_F(HttpFederatedProtocolTest,
+ TestReportWithTfCheckpointFcWireFormatEnabledAndBugFixEnabled) {
+ // Issue an eligibility eval checkin first.
+ ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
+ // Issue a regular checkin.
+ ASSERT_OK(RunSuccessfulCheckin());
+ // Enables enable_lightweight_client_report_wire_format flag.
+ EXPECT_CALL(mock_flags_, enable_lightweight_client_report_wire_format())
+ .WillRepeatedly(Return(true));
+ // Enables the bugfix so that the client can report tf checkpoint even if
+ // enable_lightweight_client_report_wire_format is true.
+ EXPECT_CALL(mock_flags_,
+ untie_lw_client_report_format_support_from_requiring_lw_report())
+ .WillRepeatedly(Return(true));
+ // Create a fake tf checkpoint with 32 'X'.
+ std::string checkpoint_str(32, 'X');
+ ComputationResults results;
+ results.emplace("tensorflow_checkpoint", checkpoint_str);
+ absl::Duration plan_duration = absl::Minutes(5);
+
+ ExpectSuccessfulReportTaskResultRequest(
+ "https://taskassignment.uri/v1/populations/TEST%2FPOPULATION/"
+ "taskassignments/CLIENT_SESSION_ID:reportresult?%24alt=proto",
+ kAggregationSessionId, kTaskName, plan_duration);
+ ExpectSuccessfulStartAggregationDataUploadRequest(
+ "https://aggregation.uri/v1/aggregations/AGGREGATION_SESSION_ID/"
+ "clients/AUTHORIZATION_TOKEN:startdataupload?%24alt=proto",
+ kResourceName, kByteStreamTargetUri, kSecondStageAggregationTargetUri);
+ ExpectSuccessfulByteStreamUploadRequest(
+ "https://bytestream.uri/upload/v1/media/"
+ "CHECKPOINT_RESOURCE?upload_protocol=raw",
+ checkpoint_str);
+ ExpectSuccessfulSubmitAggregationResultRequest(
+ "https://aggregation.second.uri/v1/aggregations/"
+ "AGGREGATION_SESSION_ID/clients/CLIENT_TOKEN:submit?%24alt=proto");
+
+ EXPECT_OK(federated_protocol_->ReportCompleted(std::move(results),
+ plan_duration, std::nullopt));
+}
+
// TODO(team): Remove this test once client_token is always populated in
// StartAggregationDataUploadResponse.
TEST_F(HttpFederatedProtocolTest,
diff --git a/fcp/client/test_helpers.h b/fcp/client/test_helpers.h
index d5082d9..bc79348 100644
--- a/fcp/client/test_helpers.h
+++ b/fcp/client/test_helpers.h
@@ -668,6 +668,9 @@
MOCK_METHOD(bool, enable_phase_stats_logging, (), (const, override));
MOCK_METHOD(bool, enable_lightweight_client_report_wire_format, (),
(const, override));
+ MOCK_METHOD(bool,
+ untie_lw_client_report_format_support_from_requiring_lw_report,
+ (), (const, override));
MOCK_METHOD(bool, enable_native_example_query_recording, (),
(const, override));
MOCK_METHOD(bool, enable_confidential_aggregation, (), (const, override));