blob: 12ab92db347b44109ebe3b46839ba4b23cad18a5 [file] [log] [blame]
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.rkpdapp.unittest;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import android.content.Context;
import android.content.pm.ApplicationInfo;
import android.content.pm.PackageManager;
import android.util.Base64;
import androidx.test.core.app.ApplicationProvider;
import com.android.rkpdapp.GeekResponse;
import com.android.rkpdapp.RkpdException;
import com.android.rkpdapp.interfaces.ServerInterface;
import com.android.rkpdapp.metrics.ProvisioningAttempt;
import com.android.rkpdapp.testutil.FakeRkpServer;
import com.android.rkpdapp.utils.CborUtils;
import com.android.rkpdapp.utils.Settings;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Mockito;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
public class ServerInterfaceTest {
private static final Duration TIME_TO_REFRESH_HOURS = Duration.ofHours(2);
private static Context sContext;
private ServerInterface mServerInterface;
@BeforeClass
public static void init() {
sContext = Mockito.spy(ApplicationProvider.getApplicationContext());
}
@Before
public void setUp() {
Settings.clearPreferences(sContext);
mServerInterface = new ServerInterface(sContext, false);
Utils.mockConnectivityState(sContext, Utils.ConnectivityState.CONNECTED);
}
@After
public void tearDown() {
Settings.clearPreferences(sContext);
Mockito.reset(sContext);
}
@Test
public void testRetryOnServerFailure() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(FakeRkpServer.Response.INTERNAL_ERROR,
FakeRkpServer.Response.INTERNAL_ERROR)) {
Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
Settings.setMaxRequestTime(sContext, 100);
GeekResponse ignored = mServerInterface.fetchGeek(
ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
assertWithMessage("Expected RkpdException.").fail();
} catch (RkpdException e) {
assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.HTTP_SERVER_ERROR);
assertThat(e).hasMessageThat().contains("HTTP error status encountered");
}
}
@Test
public void testFetchGeekRkpDisabled() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_RKP_DISABLED,
FakeRkpServer.Response.INTERNAL_ERROR)) {
Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
GeekResponse response = mServerInterface.fetchGeek(
ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
assertThat(response.numExtraAttestationKeys).isEqualTo(0);
assertThat(response.getChallenge()).isNotNull();
assertThat(response.getGeekChain(2)).isNotNull();
}
}
@Test
public void testFetchGeekRkpEnabled() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
GeekResponse response = mServerInterface.fetchGeek(
ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
assertThat(response.numExtraAttestationKeys).isEqualTo(20);
assertThat(response.getChallenge()).isNotNull();
byte[] challenge = Base64.decode("AAABgEg1zGsBILStY/1VNI7st0AG9x2S/tba+H4=",
Base64.DEFAULT);
assertThat(response.getChallenge()).isEqualTo(challenge);
byte[] ed25519GeekChain = Base64.decode(
"g4RDoQEnoFgqpAEBAycgBiFYIJm57t1e5FL2hcZMYtw+YatXS"
+ "H11NymtdoAy0rPLY1jZWEAeIghLpLekyNdOAw7+uK8UTKc7b6XN3Np5xitk"
+ "/pk5r3bngPpmAIUNB5gqrJFcpyUUSQY0dcqKJ3rZ41pJ6wIDhEOhASegWCqk"
+ "AQEDJyAGIVgg6i+FDp5qDFz3vdn6KDK/2lXpIKJRA8kDkxjOoBUp7NFYQIJr"
+ "x12mNle3x3ESrRzCarMsIyrdFDDLghS2icXTHjG7uFAhSklNupEMbzNNg7xY"
+ "Ky6E28VZD5hh4sHqifLQrgSEQ6EBJ6BYTqUBAQJYIG+S0QRtcdinjojY0VaB"
+ "X5bReIPmMBuH7b8g0Uo7/mouAzgYIAQhWCC2XRxLmoM6nbUVWTehJvsP3+ec"
+ "rAHVpOzIOikAiFglOVhAgLKf0DKenUr+sCXywtIiaEbGILCq6BasZKFFg5vM"
+ "SVQlf6sWBVPwvTWT88a7WU5e+d4hBxSjtqSji4+Clpa6Aw==",
Base64.DEFAULT);
byte[] p256GeekChain = Base64.decode(
"g4RDoQEmoFhNpQECAyYgASFYIPcUituX9MxT79JkEcTjdR9mH6Rx"
+ "DGzP+glGgHSHVPKtIlggXn9b9uzk9hnM/xM3/Q+hyJPbGAZ2xF3m12p3hsMtr49YQC"
+ "+XjkL7vgctlUeFR5NAsB/Um0ekxESp8qEHhxDHn8sR9L+f6Dvg5zRMFfx7w34zBfTR"
+ "NDztAgRgehXgedOK/ySEQ6EBJqBYTaUBAgMmIAEhWCBRgKzPj5aM7A9Q4akbt5CGNI"
+ "vjw6xlAk209jEOCEYyOSJYIFTrlJ3+trTkczolTi8fnZ29+mbBEYvploxD5DD22nar"
+ "WECYOPs0OmXbc5ixJ6IVdPK+BueNIk7d8L/CAXTEtylrJBy12NJm+kTv9TAsBHTt6M"
+ "Zg2s6fVlcndCHT3pOP47jNhEOhASagWHGmAQICWCCDn/j9EBwSn5JBx1uN5E70GROa"
+ "xxttpw6V8mRTXacdwQM4GCABIVggFqRSEmOzhlZQ2N/yoKh9vNlup2hg6oxc8ZPllx"
+ "kNrN4iWCCJvsxsP16wOTSvl7o40RYdocwdZNOMSE74coEbOz4x7lhA+trPLaulMAxz"
+ "xeWrSZJZYET6xPIz5QSybBlk6RzjZDs0hgBlLfXdr6oBya+DyU74WpToZZNR4xgeOY"
+ "CnaUszzQ==",
Base64.DEFAULT);
assertThat(response.getGeekChain(CborUtils.EC_CURVE_25519)).isEqualTo(ed25519GeekChain);
assertThat(response.getGeekChain(CborUtils.EC_CURVE_P256)).isEqualTo(p256GeekChain);
}
}
@Test
public void testFetchKeyAndUpdate() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
mServerInterface.fetchGeekAndUpdate(
ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
assertThat(Settings.getExtraSignedKeysAvailable(sContext)).isEqualTo(20);
assertThat(Settings.getExpiringBy(sContext)).isEqualTo(Duration.ofHours(72));
}
}
@Test
public void testRequestSignedCertUnregistered() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_DEVICE_UNREGISTERED)) {
Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
sContext);
mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
assertWithMessage("Should fail due to unregistered device.").fail();
} catch (RkpdException e) {
assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.DEVICE_NOT_REGISTERED);
}
}
@Test
public void testRequestSignedCertClientError() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_USER_UNAUTHORIZED)) {
Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
Settings.setMaxRequestTime(sContext, 100);
ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
sContext);
mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
assertWithMessage("Should fail due to client error.").fail();
} catch (RkpdException e) {
assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.HTTP_CLIENT_ERROR);
}
}
@Test
public void testRequestSignedCertCborError() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_OK_INVALID_CBOR)) {
Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
sContext);
mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
assertWithMessage("Should fail due to invalid cbor.").fail();
} catch (RkpdException e) {
assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.INTERNAL_ERROR);
assertThat(e).hasMessageThat().isEqualTo("Response failed to parse.");
}
}
@Test
public void testRequestSignedCertValid() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
sContext);
List<byte[]> certChains = mServerInterface.requestSignedCertificates(new byte[0],
new byte[0], metrics);
assertThat(certChains).isEmpty();
assertThat(certChains).isNotNull();
}
}
@Test
public void testDataBudgetEmptyFetchGeekNetworkConnected() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
// Check the data budget in order to initialize a rolling window.
assertThat(Settings.hasErrDataBudget(sContext, null /* curTime */)).isTrue();
Settings.consumeErrDataBudget(sContext, Settings.FAILURE_DATA_USAGE_MAX);
ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
sContext);
mServerInterface.fetchGeek(metrics);
assertWithMessage("Network transaction should not have proceeded.").fail();
} catch (RkpdException e) {
assertThat(e).hasMessageThat().contains("Out of data budget due to repeated errors");
assertThat(e.getErrorCode()).isEqualTo(
RkpdException.ErrorCode.NETWORK_COMMUNICATION_ERROR);
}
}
@Test
public void testNetworkDisconnected() throws Exception {
try (FakeRkpServer server = new FakeRkpServer(
FakeRkpServer.Response.FETCH_EEK_OK,
FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
sContext);
// We are okay in mocking connectivity failure since network check is the first thing
// to happen.
Utils.mockConnectivityState(sContext, Utils.ConnectivityState.DISCONNECTED);
mServerInterface.fetchGeek(metrics);
assertWithMessage("Network transaction should not have proceeded.").fail();
} catch (RkpdException e) {
assertThat(e).hasMessageThat().contains("No network detected");
assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.NO_NETWORK_CONNECTIVITY);
}
}
@Test
public void testReadErrorInvalidContentType() {
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(connection.getContentType()).thenReturn("application/NOPE");
assertThat(ServerInterface.readErrorFromConnection(connection))
.isEqualTo("Unexpected content type from the server: application/NOPE");
}
@Test
public void testReadTextErrorFromErrorStreamNoErrorData() throws Exception {
final String expectedError = "No error data returned by server.";
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(connection.getContentType()).thenReturn("text");
Mockito.when(connection.getInputStream()).thenThrow(new IOException());
Mockito.when(connection.getErrorStream()).thenReturn(null);
assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(expectedError);
}
@Test
public void testReadTextErrorFromErrorStream() throws Exception {
final String error = "Explanation for error goes here.";
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(connection.getContentType()).thenReturn("text");
Mockito.when(connection.getInputStream()).thenThrow(new IOException());
Mockito.when(connection.getErrorStream())
.thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));
assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(error);
}
@Test
public void testReadTextError() throws IOException {
final String error = "This is an error. Oh No.";
final String[] textContentTypes = new String[]{
"text",
"text/ANYTHING",
"text/what-is-this; charset=unknown",
"text/lowercase; charset=utf-8",
"text/uppercase; charset=UTF-8",
"text/yolo; charset=ASCII"
};
for (String contentType : textContentTypes) {
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(connection.getContentType()).thenReturn(contentType);
Mockito.when(connection.getInputStream())
.thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));
assertWithMessage("Failed on content type '" + contentType + "'")
.that(error)
.isEqualTo(ServerInterface.readErrorFromConnection(connection));
}
}
@Test
public void testReadJsonError() throws IOException {
final String error = "Not really JSON.";
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(connection.getContentType()).thenReturn("application/json");
Mockito.when(connection.getInputStream())
.thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));
assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(error);
}
@Test
public void testReadErrorStreamThrowsException() throws IOException {
InputStream stream = Mockito.mock(InputStream.class);
Mockito.when(stream.read(Mockito.any())).thenThrow(new IOException());
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(connection.getContentType()).thenReturn("text");
Mockito.when(connection.getInputStream()).thenReturn(stream);
final String error = ServerInterface.readErrorFromConnection(connection);
assertWithMessage("Error string: '" + error + "'")
.that(error).startsWith("Error reading error string from server: ");
}
@Test
public void testReadErrorEmptyStream() throws IOException {
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(connection.getContentType()).thenReturn("text");
Mockito.when(connection.getInputStream())
.thenReturn(new ByteArrayInputStream(new byte[0]));
assertThat(ServerInterface.readErrorFromConnection(connection))
.isEqualTo("No error data returned by server.");
}
@Test
public void testReadErrorStreamTooLarge() throws IOException {
final StringBuilder sb = new StringBuilder();
for (int i = 0; i < 2048; ++i) {
sb.append(i % 100);
}
final String bigString = sb.toString();
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(connection.getContentType()).thenReturn("text");
Mockito.when(connection.getInputStream())
.thenReturn(new ByteArrayInputStream(bigString.getBytes(StandardCharsets.UTF_8)));
sb.setLength(1024);
assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(sb.toString());
}
@Test
public void testServerConnectionTimeout() {
ServerInterface serverInterface = Mockito.spy(mServerInterface);
Mockito.when(serverInterface.getRegionalProperty()).thenReturn("cn");
assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
ServerInterface.SYNC_CONNECT_TIMEOUT_RETRICTED_MS);
Mockito.when(serverInterface.getRegionalProperty()).thenReturn("cn,us");
assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
ServerInterface.SYNC_CONNECT_TIMEOUT_RETRICTED_MS);
Mockito.when(serverInterface.getRegionalProperty()).thenReturn(null);
assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);
Mockito.when(serverInterface.getRegionalProperty()).thenReturn("");
assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);
Mockito.when(serverInterface.getRegionalProperty()).thenReturn("us");
assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);
}
@Test
public void testConnectionConsent() throws Exception {
String cnGmsFeature = "cn.google.services";
PackageManager mockedPackageManager = Mockito.mock(PackageManager.class);
Context mockedContext = Mockito.mock(Context.class);
ApplicationInfo fakeApplicationInfo = new ApplicationInfo();
Mockito.when(mockedContext.getPackageManager()).thenReturn(mockedPackageManager);
Mockito.when(mockedPackageManager.hasSystemFeature(cnGmsFeature)).thenReturn(true);
Mockito.when(mockedPackageManager.getApplicationInfo(Mockito.any(), Mockito.eq(0)))
.thenReturn(fakeApplicationInfo);
fakeApplicationInfo.enabled = false;
assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isFalse();
fakeApplicationInfo.enabled = true;
assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();
Mockito.when(mockedPackageManager.getApplicationInfo(Mockito.any(), Mockito.eq(0)))
.thenThrow(new PackageManager.NameNotFoundException());
assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isFalse();
Mockito.when(mockedPackageManager.hasSystemFeature(cnGmsFeature)).thenReturn(false);
assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();
fakeApplicationInfo.enabled = false;
assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();
}
}