blob: a0d15ffc9229c6e29006cb9284569ca90110abab [file] [log] [blame]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.platform.test.flag.junit;
import static org.junit.Assume.assumeFalse;
import android.platform.test.flag.util.FlagReadException;
import android.platform.test.flag.util.FlagSetException;
import com.google.common.base.CaseFormat;
import com.google.common.collect.Sets;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
/** A {@link TestRule} that helps to set flag values in unit test. */
public final class SetFlagsRule implements TestRule {
private static final String FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME = "FakeFeatureFlagsImpl";
private static final String FEATURE_FLAGS_CLASS_NAME = "FeatureFlags";
private static final String FEATURE_FLAGS_FIELD_NAME = "FEATURE_FLAGS";
private static final String FLAGS_CLASS_NAME = "Flags";
private static final String FLAG_CONSTANT_PREFIX = "FLAG_";
private static final String SET_FLAG_METHOD_NAME = "setFlag";
private static final String RESET_ALL_METHOD_NAME = "resetAll";
private static final String IS_FLAG_READ_ONLY_OPTIMIZED_METHOD_NAME = "isFlagReadOnlyOptimized";
// Store instances for entire life of a SetFlagsRule instance
private final Map<Class<?>, Object> mFlagsClassToFakeFlagsImpl = new HashMap<>();
private final Map<Class<?>, Object> mFlagsClassToRealFlagsImpl = new HashMap<>();
// Store classes that are currently mutated by this rule
private final Set<Class<?>> mMutatedFlagsClasses = new HashSet<>();
// Any flags added to this list cannot be set imperatively (i.e. with enableFlags/disableFlags)
private final Set<String> mLockedFlagNames = new HashSet<>();
// TODO(322377082): remove repackage prefix list
private static final String[] REPACKAGE_PREFIX_LIST =
new String[] {
"", "com.android.internal.hidden_from_bootclasspath.",
};
private final Map<String, Set<String>> mPackageToRepackage = new HashMap<>();
private boolean mIsInitWithDefault = false;
private FlagsParameterization mFlagsParameterization;
private boolean mIsRuleEvaluating = false;
/**
* Enable default value for flags
*
* <p>Once this method is called the flag value in the same Flag class will be set as the same
* value of current release configuration.
*
* <p>This methods should be called before calling enable/disableFlags
*
* @deprecated Prefer {@link #SetFlagsRule(SetFlagsRule.DefaultInitValueType.DEVICE_DEFAULT)}
*/
@Deprecated
public void initAllFlagsToReleaseConfigDefault() {
if (!mIsInitWithDefault) {
// If you've already set any flags, it's too late to change the defaults.
ensureFlagsAreUnset();
}
mIsInitWithDefault = true;
}
public enum DefaultInitValueType {
/**
* Initialize flag value as null
*
* <p>Flag value need to be set before using
*/
NULL_DEFAULT,
/**
* Initialize flag value with the default value from the device
*
* <p>If flag value is not overridden by adb, then the default value is from the release
* configuration when the test is built.
*/
DEVICE_DEFAULT,
}
public SetFlagsRule() {
this(DefaultInitValueType.DEVICE_DEFAULT);
}
public SetFlagsRule(DefaultInitValueType defaultType) {
this(defaultType, null);
}
public SetFlagsRule(
DefaultInitValueType defaultType,
@Nullable FlagsParameterization flagsParameterization) {
switch (defaultType) {
case DEVICE_DEFAULT:
mIsInitWithDefault = true;
break;
default:
break;
}
mFlagsParameterization = flagsParameterization;
if (flagsParameterization != null) {
mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet());
}
}
/**
* Set the FlagsParameterization to be used during this test. This cannot be used to override a
* previous call, and cannot be called once the rule has been evaluated.
*/
public void setFlagsParameterization(@Nonnull FlagsParameterization flagsParameterization) {
Objects.requireNonNull(flagsParameterization, "FlagsParameterization cannot be cleared");
if (mFlagsParameterization != null) {
throw new AssertionError("FlagsParameterization cannot be overridden");
}
if (mIsRuleEvaluating) {
throw new AssertionError("Cannot set FlagsParameterization once the rule is running");
}
ensureFlagsAreUnset();
mFlagsParameterization = flagsParameterization;
mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet());
}
/**
* Enables the given flags.
*
* @param fullFlagNames The name of the flags in the flag class with the format
* {packageName}.{flagName}
*/
public void enableFlags(String... fullFlagNames) {
if (!mIsRuleEvaluating) {
throw new IllegalStateException("Not allowed to set flags outside test and setup code");
}
for (String fullFlagName : fullFlagNames) {
if (mLockedFlagNames.contains(fullFlagName)) {
throw new FlagSetException(fullFlagName, "Not allowed to change locked flags");
}
setFlagValue(fullFlagName, true);
}
}
/**
* Disables the given flags.
*
* @param fullFlagNames The name of the flags in the flag class with the format
* {packageName}.{flagName}
*/
public void disableFlags(String... fullFlagNames) {
if (!mIsRuleEvaluating) {
throw new IllegalStateException("Not allowed to set flags outside test and setup code");
}
for (String fullFlagName : fullFlagNames) {
if (mLockedFlagNames.contains(fullFlagName)) {
throw new FlagSetException(fullFlagName, "Not allowed to change locked flags");
}
setFlagValue(fullFlagName, false);
}
}
private void ensureFlagsAreUnset() {
if (!mFlagsClassToFakeFlagsImpl.isEmpty()) {
throw new IllegalStateException("Some flags were set before the rule was initialized");
}
}
@Override
public Statement apply(Statement base, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
Throwable throwable = null;
try {
AnnotationsRetriever.FlagAnnotations flagAnnotations =
AnnotationsRetriever.getFlagAnnotations(description);
assertAnnotationsMatchParameterization(flagAnnotations, mFlagsParameterization);
flagAnnotations.assumeAllSetFlagsMatchParameterization(mFlagsParameterization);
if (mFlagsParameterization != null) {
ensureFlagsAreUnset();
for (Map.Entry<String, Boolean> pair :
mFlagsParameterization.mOverrides.entrySet()) {
setFlagValue(pair.getKey(), pair.getValue());
}
}
for (Map.Entry<String, Boolean> pair :
flagAnnotations.mSetFlagValues.entrySet()) {
setFlagValue(pair.getKey(), pair.getValue());
}
mLockedFlagNames.addAll(flagAnnotations.mRequiredFlagValues.keySet());
mLockedFlagNames.addAll(flagAnnotations.mSetFlagValues.keySet());
mIsRuleEvaluating = true;
base.evaluate();
} catch (Throwable t) {
throwable = t;
} finally {
mIsRuleEvaluating = false;
try {
resetFlags();
} catch (Throwable t) {
if (throwable != null) {
t.addSuppressed(throwable);
}
throwable = t;
}
}
if (throwable != null) throw throwable;
}
};
}
private static void assertAnnotationsMatchParameterization(
AnnotationsRetriever.FlagAnnotations flagAnnotations,
FlagsParameterization parameterization) {
if (parameterization == null) return;
Set<String> parameterizedFlags = parameterization.mOverrides.keySet();
Set<String> requiredFlags = flagAnnotations.mRequiredFlagValues.keySet();
// Assert that NO Annotation-Required flag is in the parameterization
Set<String> parameterizedAndRequiredFlags =
Sets.intersection(parameterizedFlags, requiredFlags);
if (!parameterizedAndRequiredFlags.isEmpty()) {
throw new AssertionError(
"The following flags have required values (per @RequiresFlagsEnabled or"
+ " @RequiresFlagsDisabled) but they are part of the"
+ " FlagParameterization: "
+ parameterizedAndRequiredFlags);
}
}
private void setFlagValue(String fullFlagName, boolean value) {
if (!fullFlagName.contains(".")) {
throw new FlagSetException(
fullFlagName, "Flag name is not the expected format {packgeName}.{flagName}.");
}
// Get all packages containing Flags referencing the same fullFlagName.
Set<String> packageSet = getPackagesContainsFlag(fullFlagName);
for (String packageName : packageSet) {
setFlagValue(Flag.createFlag(fullFlagName, packageName), value);
}
}
private Set<String> getPackagesContainsFlag(String fullFlagName) {
Flag flag = Flag.createFlag(fullFlagName);
String packageName = flag.packageName();
Set<String> packageSet = mPackageToRepackage.getOrDefault(packageName, new HashSet<>());
if (!packageSet.isEmpty()) {
return packageSet;
}
for (String prefix : REPACKAGE_PREFIX_LIST) {
String repackagedName = String.format("%s%s", prefix, packageName);
String flagClassName = String.format("%s.%s", repackagedName, FLAGS_CLASS_NAME);
try {
Class.forName(flagClassName, false, this.getClass().getClassLoader());
packageSet.add(repackagedName);
} catch (ClassNotFoundException e) {
// Skip if the class is not found
// An error will be thrown if no package containing flags referencing
// the passed in flag
}
}
mPackageToRepackage.put(packageName, packageSet);
if (packageSet.isEmpty()) {
throw new FlagSetException(
fullFlagName,
"Cannot find package containing Flags class referencing to this flag.");
}
return packageSet;
}
private void setFlagValue(Flag flag, boolean value) {
Object fakeFlagsImplInstance = null;
Class<?> flagsClass = getFlagClassFromFlag(flag);
fakeFlagsImplInstance = getOrCreateFakeFlagsImp(flagsClass);
if (!mMutatedFlagsClasses.contains(flagsClass)) {
// Replace FeatureFlags in Flags class with FakeFeatureFlagsImpl
replaceFlagsImpl(flagsClass, fakeFlagsImplInstance);
mMutatedFlagsClasses.add(flagsClass);
}
// If the test is trying to set the flag value on a read_only flag in an optimized build
// skip this test, since it is not a valid testing case
// The reason for skipping instead of throwning error here is all read_write flag will be
// change to read_only in the final release configuration. Thus the test could be executed
// in other release configuration cases
boolean isOptimized = verifyFlagReadOnlyAndOptimized(fakeFlagsImplInstance, flag);
assumeFalse(
String.format(
"Flag %s is read_only, and the code is optimized. "
+ " The flag value should not be modified on this build"
+ " Skip this test.",
flag.fullFlagName()),
isOptimized);
// Set desired flag value in the FakeFeatureFlagsImpl
setFlagValueInFakeFeatureFlagsImpl(fakeFlagsImplInstance, flag, value);
}
private Class<?> getFlagClassFromFlag(Flag flag) {
String className = flag.flagsClassName();
Class<?> flagsClass = null;
try {
flagsClass = Class.forName(className);
} catch (ClassNotFoundException e) {
throw new FlagSetException(
flag.fullFlagName(),
String.format(
"Can not load the Flags class %s to set its values. Please check the "
+ "flag name and ensure that the aconfig auto generated "
+ "library is in the dependency.",
className),
e);
}
return flagsClass;
}
private boolean getFlagValue(Object featureFlagsImpl, Flag flag) {
// Must be consistent with method name in aconfig auto generated code.
String methodName = getFlagMethodName(flag);
String fullFlagName = flag.fullFlagName();
try {
Object result =
featureFlagsImpl.getClass().getMethod(methodName).invoke(featureFlagsImpl);
if (result instanceof Boolean) {
return (Boolean) result;
}
throw new FlagReadException(
fullFlagName,
String.format(
"Flag type is %s, not boolean", result.getClass().getSimpleName()));
} catch (NoSuchMethodException e) {
throw new FlagReadException(
fullFlagName,
String.format(
"No method %s in the Flags class %s to read the flag value. Please"
+ " check the flag name.",
methodName, featureFlagsImpl.getClass().getName()),
e);
} catch (ReflectiveOperationException e) {
throw new FlagReadException(
fullFlagName,
String.format(
"Fail to get value of flag %s from instance %s",
fullFlagName, featureFlagsImpl.getClass().getName()),
e);
}
}
private String getFlagMethodName(Flag flag) {
return CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, flag.simpleFlagName());
}
private void setFlagValueInFakeFeatureFlagsImpl(
Object fakeFeatureFlagsImpl, Flag flag, boolean value) {
String fullFlagName = flag.fullFlagName();
try {
fakeFeatureFlagsImpl
.getClass()
.getMethod(SET_FLAG_METHOD_NAME, String.class, boolean.class)
.invoke(fakeFeatureFlagsImpl, fullFlagName, value);
} catch (NoSuchMethodException e) {
throw new FlagSetException(
fullFlagName,
String.format(
"Flag implementation %s is not fake implementation",
fakeFeatureFlagsImpl.getClass().getName()),
e);
} catch (ReflectiveOperationException e) {
throw new FlagSetException(fullFlagName, e);
}
}
private boolean verifyFlagReadOnlyAndOptimized(Object fakeFeatureFlagsImpl, Flag flag) {
String fullFlagName = flag.fullFlagName();
try {
boolean result =
(Boolean)
fakeFeatureFlagsImpl
.getClass()
.getMethod(
IS_FLAG_READ_ONLY_OPTIMIZED_METHOD_NAME, String.class)
.invoke(fakeFeatureFlagsImpl, fullFlagName);
return result;
} catch (NoSuchMethodException e) {
// If the flag is generated under exported mode, then it doesn't have this method
if (fakeFeatureFlagsImpl
.getClass()
.getSimpleName()
.equals(FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME)) {
return false;
}
throw new FlagSetException(
fullFlagName,
String.format(
"Cannot check whether flag is optimized. "
+ "Flag implementation %s is not fake implementation",
fakeFeatureFlagsImpl.getClass().getName()),
e);
} catch (ReflectiveOperationException e) {
throw new FlagSetException(fullFlagName, e);
}
}
@Nonnull
private Object getOrCreateFakeFlagsImp(Class<?> flagsClass) {
Object fakeFlagsImplInstance = mFlagsClassToFakeFlagsImpl.get(flagsClass);
if (fakeFlagsImplInstance != null) {
return fakeFlagsImplInstance;
}
String packageName = flagsClass.getPackageName();
String fakeClassName = String.format("%s.%s", packageName, FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME);
String interfaceName = String.format("%s.%s", packageName, FEATURE_FLAGS_CLASS_NAME);
Field featureFlagsField = getFeatureFlagsField(flagsClass);
featureFlagsField.setAccessible(true);
Object realFlagsImplInstance = null;
try {
realFlagsImplInstance = featureFlagsField.get(null);
} catch (IllegalAccessException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot get FeatureFlagsImpl from Flags class %s.",
flagsClass.getName()),
e);
}
mFlagsClassToRealFlagsImpl.put(flagsClass, realFlagsImplInstance);
try {
Class<?> flagImplClass = Class.forName(fakeClassName);
Class<?> flagInterface = Class.forName(interfaceName);
fakeFlagsImplInstance = flagImplClass
.getConstructor(flagInterface)
.newInstance(mIsInitWithDefault ? realFlagsImplInstance : null);
} catch (ReflectiveOperationException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot create FakeFeatureFlagsImpl in Flags class %s.",
flagsClass.getName()),
e);
}
mFlagsClassToFakeFlagsImpl.put(flagsClass, fakeFlagsImplInstance);
return fakeFlagsImplInstance;
}
private void replaceFlagsImpl(Class<?> flagsClass, Object flagsImplInstance) {
Field featureFlagsField = getFeatureFlagsField(flagsClass);
featureFlagsField.setAccessible(true);
try {
featureFlagsField.set(null, flagsImplInstance);
} catch (IllegalAccessException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot replace FeatureFlagsImpl to %s.",
flagsImplInstance.getClass().getName()),
e);
}
}
private Field getFeatureFlagsField(Class<?> flagsClass) {
Field featureFlagsField = null;
try {
featureFlagsField = flagsClass.getDeclaredField(FEATURE_FLAGS_FIELD_NAME);
} catch (ReflectiveOperationException e) {
throw new UnsupportedOperationException(
String.format(
"Cannot store FeatureFlagsImpl in Flag %s.", flagsClass.getName()),
e);
}
return featureFlagsField;
}
private void resetFlags() {
String flagsClassName = null;
try {
for (Class<?> flagsClass : mMutatedFlagsClasses) {
flagsClassName = flagsClass.getName();
Object fakeFlagsImplInstance = mFlagsClassToFakeFlagsImpl.get(flagsClass);
Object flagsImplInstance = mFlagsClassToRealFlagsImpl.get(flagsClass);
// Replace FeatureFlags in Flags class with real FeatureFlagsImpl
replaceFlagsImpl(flagsClass, flagsImplInstance);
fakeFlagsImplInstance
.getClass()
.getMethod(RESET_ALL_METHOD_NAME)
.invoke(fakeFlagsImplInstance);
}
mMutatedFlagsClasses.clear();
} catch (Exception e) {
throw new FlagSetException(flagsClassName, e);
}
}
private static class Flag {
private static final String PACKAGE_NAME_SIMPLE_NAME_SEPARATOR = ".";
private final String mFullFlagName;
private final String mPackageName;
private final String mSimpleFlagName;
public static Flag createFlag(String fullFlagName) {
int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR);
String packageName = fullFlagName.substring(0, index);
return createFlag(fullFlagName, packageName);
}
public static Flag createFlag(String fullFlagName, String packageName) {
if (!fullFlagName.contains(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR)
|| !packageName.contains(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR)) {
throw new IllegalArgumentException(
String.format(
"Flag %s is invalid. The format should be {packageName}"
+ ".{simpleFlagName}",
fullFlagName));
}
int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR);
String simpleFlagName = fullFlagName.substring(index + 1);
return new Flag(fullFlagName, packageName, simpleFlagName);
}
private Flag(String fullFlagName, String packageName, String simpleFlagName) {
this.mFullFlagName = fullFlagName;
this.mPackageName = packageName;
this.mSimpleFlagName = simpleFlagName;
}
public String fullFlagName() {
return mFullFlagName;
}
public String packageName() {
return mPackageName;
}
public String simpleFlagName() {
return mSimpleFlagName;
}
public String flagsClassName() {
return String.format("%s.%s", mPackageName, FLAGS_CLASS_NAME);
}
}
}