Support changing flag values for device_config flags
+ some refactoring and format changes
Bug: 308215250
Test: local
Change-Id: I5a0449c746c443bbcf9ab9de5a85470543abd4fd
diff --git a/tools/device_flags.py b/tools/device_flags.py
index b6fbe80..22ba87d 100644
--- a/tools/device_flags.py
+++ b/tools/device_flags.py
@@ -18,7 +18,7 @@
import os
import tempfile
-from typing import Any, Dict, Optional
+from typing import Any
from mobly.controllers import android_device
from protos import aconfig_pb2
@@ -27,10 +27,14 @@
_ACONFIG_PB_FILE = 'aconfig_flags.pb'
_DEVICE_CONFIG_GET_CMD = 'device_config get'
+_DEVICE_CONFIG_PUT_CMD = 'device_config put'
_READ_ONLY = aconfig_pb2.flag_permission.READ_ONLY
_ENABLED = aconfig_pb2.flag_state.ENABLED
+_VAL_TRUE = 'true'
+_VAL_FALSE = 'false'
+
class DeviceFlags:
"""Provides access to aconfig and device_config flag values of a device."""
@@ -39,10 +43,10 @@
self._ad = ad
self._aconfig_flags = None
- def get_value(self, namespace: str, name: str) -> Optional[str]:
+ def get_value(self, namespace: str, key: str) -> str | None:
"""Gets the value of the requested flag.
- Flags must be specified by both its namespace and name.
+ Flags must be specified by both its namespace and key.
The method will first look for the flag from the device's
aconfig_flags.pb files, and, if not found or the flag is READ_WRITE,
@@ -52,38 +56,39 @@
Args:
namespace: The namespace of the flag.
- name: The full name of the flag.
- For aconfig flags, it is equivalent to '{package}.{name}' from
- the aconfig proto.
- For device_config flags, it is equivalent to '{KEY}' from the
- "device_config shell" command.
+ key: The full name of the flag. For aconfig flags, it is equivalent
+ to '{package}.{name}' from the aconfig proto. For device_config
+ flags, it is equivalent to '{KEY}' from the `device_config get`
+ command.
Returns:
The flag value as a string.
"""
# Check aconfig
aconfig_val = None
- aconfig_flag = self._get_aconfig_flags().get(
- '%s/%s' % (namespace, name))
+ aconfig_flag = self._get_aconfig_flags().get(f'{namespace}/{key}')
if aconfig_flag is not None:
- aconfig_val = 'true' if aconfig_flag.state == _ENABLED else 'false'
+ aconfig_val = (
+ _VAL_TRUE if aconfig_flag.state == _ENABLED else _VAL_FALSE)
if aconfig_flag.permission == _READ_ONLY:
return aconfig_val
# If missing or READ_WRITE, also check device_config
- device_config_val = self._ad.adb.shell(
- '%s %s %s' % (_DEVICE_CONFIG_GET_CMD, namespace, name)
- ).decode('utf8').strip()
+ device_config_val = (
+ self._ad.adb.shell(f'{_DEVICE_CONFIG_GET_CMD} {namespace} {key}')
+ .decode('utf8')
+ .strip()
+ )
return device_config_val if device_config_val != 'null' else aconfig_val
- def get_bool(self, namespace: str, name: str) -> bool:
+ def get_bool(self, namespace: str, key: str) -> bool:
"""Gets the value of the requested flag as a boolean.
See get_value() for details.
Args:
namespace: The namespace of the flag.
- name: The key of the flag.
+ key: The key of the flag.
Returns:
The flag value as a boolean.
@@ -91,33 +96,84 @@
Raises:
ValueError if the flag value cannot be expressed as a boolean.
"""
- val = self.get_value(namespace, name)
- if val.lower() == 'true':
- return True
- if val.lower() == 'false':
- return False
- raise ValueError('Flag %s/%s is not a boolean (value: %s).'
- % (namespace, name, val))
+ val = self.get_value(namespace, key)
+ if val is not None:
+ if val.lower() == _VAL_TRUE:
+ return True
+ if val.lower() == _VAL_FALSE:
+ return False
+ raise ValueError(
+ f'Flag {namespace}/{key} is not a boolean (value: {val}).')
- def _get_aconfig_flags(self) -> Dict[str, Any]:
+ def _get_aconfig_flags(self) -> dict[str, Any]:
"""Gets the aconfig flags as a dict. Loads from proto if necessary."""
if self._aconfig_flags is None:
- self._load_aconfig_flags()
+ self._aconfig_flags = self._load_aconfig_flags()
return self._aconfig_flags
- def _load_aconfig_flags(self) -> None:
+ def _load_aconfig_flags(self) -> dict[str, Any]:
"""Pull aconfig proto files from device, then load the flag info."""
- self._aconfig_flags = {}
+ aconfig_flags = {}
with tempfile.TemporaryDirectory() as tmp_dir:
for partition in _ACONFIG_PARTITIONS:
device_path = os.path.join(
'/', partition, 'etc', _ACONFIG_PB_FILE)
host_path = os.path.join(
- tmp_dir, '%s_%s' % (partition, _ACONFIG_PB_FILE))
+ tmp_dir, f'{partition}_{_ACONFIG_PB_FILE}')
self._ad.adb.pull([device_path, host_path])
with open(host_path, 'rb') as f:
parsed_flags = aconfig_pb2.parsed_flags.FromString(f.read())
for flag in parsed_flags.parsed_flag:
- full_name = '%s/%s.%s' % (
- flag.namespace, flag.package, flag.name)
- self._aconfig_flags[full_name] = flag
+ full_name = f'{flag.namespace}/{flag.package}.{flag.name}'
+ aconfig_flags[full_name] = flag
+ return aconfig_flags
+
+ def set_value(self, namespace: str, key: str, val: str) -> None:
+ """Sets the value of the requested flag.
+
+ This only supports flags that are set via `adb device_config`.
+
+ Args:
+ namespace: The namespace of the flag.
+ key: The key of the flag.
+ val: The desired value of the flag, in string format.
+ """
+ self._ad.adb.shell(f'{_DEVICE_CONFIG_PUT_CMD} {namespace} {key} {val}')
+
+ def enable(self, namespace: str, key: str) -> None:
+ """Enables the requested flag.
+
+ This only supports flags that are set via `adb device_config`.
+
+ Args:
+ namespace: The namespace of the flag.
+ key: The key of the flag.
+
+ Raises:
+ ValueError if the original flag value cannot be expressed as a
+ boolean.
+ """
+ # If the original value of the flag is not boolean, this will raise a
+ # ValueError.
+ _ = self.get_bool(namespace, key)
+
+ self.set_value(namespace, key, _VAL_TRUE)
+
+ def disable(self, namespace: str, key: str) -> None:
+ """Disables the requested flag.
+
+ This only supports flags that are set via `adb device_config`.
+
+ Args:
+ namespace: The namespace of the flag.
+ key: The key of the flag.
+
+ Raises:
+ ValueError if the original flag value cannot be expressed as a
+ boolean.
+ """
+ # If the original value of the flag is not boolean, this will raise a
+ # ValueError.
+ _ = self.get_bool(namespace, key)
+
+ self.set_value(namespace, key, _VAL_FALSE)
diff --git a/tools/device_flags_test.py b/tools/device_flags_test.py
index 40405a3..3efdf88 100644
--- a/tools/device_flags_test.py
+++ b/tools/device_flags_test.py
@@ -25,53 +25,96 @@
"""Unit tests for DeviceFlags."""
def setUp(self) -> None:
+ super().setUp()
self.ad = mock.MagicMock()
self.device_flags = device_flags.DeviceFlags(self.ad)
self.device_flags._aconfig_flags = {}
def test_get_value_aconfig_flag_missing_use_device_config(self) -> None:
self.ad.adb.shell.return_value = b'foo'
- self.assertEqual(self.device_flags.get_value('sample', 'flag'), 'foo')
+
+ value = self.device_flags.get_value('sample', 'flag')
+
+ self.assertEqual(value, 'foo')
def test_get_value_aconfig_flag_read_write_use_device_config(self) -> None:
sample_flag = aconfig_pb2.parsed_flag()
sample_flag.state = aconfig_pb2.flag_state.ENABLED
sample_flag.permission = aconfig_pb2.flag_permission.READ_WRITE
self.device_flags._aconfig_flags['sample/flag'] = sample_flag
-
self.ad.adb.shell.return_value = b'false'
- self.assertEqual(self.device_flags.get_value('sample', 'flag'), 'false')
+
+ value = self.device_flags.get_value('sample', 'flag')
+
+ self.assertEqual(value, 'false')
def test_get_value_aconfig_flag_read_only_use_aconfig(self) -> None:
sample_flag = aconfig_pb2.parsed_flag()
sample_flag.state = aconfig_pb2.flag_state.ENABLED
sample_flag.permission = aconfig_pb2.flag_permission.READ_ONLY
self.device_flags._aconfig_flags['sample/flag'] = sample_flag
-
self.ad.adb.shell.return_value = b'false'
- self.assertEqual(self.device_flags.get_value('sample', 'flag'), 'true')
+
+ value = self.device_flags.get_value('sample', 'flag')
+
+ self.assertEqual(value, 'true')
def test_get_value_device_config_null_use_aconfig(self) -> None:
sample_flag = aconfig_pb2.parsed_flag()
sample_flag.state = aconfig_pb2.flag_state.ENABLED
sample_flag.permission = aconfig_pb2.flag_permission.READ_WRITE
self.device_flags._aconfig_flags['sample/flag'] = sample_flag
-
self.ad.adb.shell.return_value = b'null'
- self.assertEqual(self.device_flags.get_value('sample', 'flag'), 'true')
- def test_get_bool_with_valid_bool_value(self) -> None:
+ value = self.device_flags.get_value('sample', 'flag')
+
+ self.assertEqual(value, 'true')
+
+ def test_get_bool_with_valid_bool_value_true(self) -> None:
self.ad.adb.shell.return_value = b'true'
- self.assertTrue(self.device_flags.get_bool('sample', 'flag'))
+ value = self.device_flags.get_bool('sample', 'flag')
+
+ self.assertTrue(value)
+
+ def test_get_bool_with_valid_bool_value_false(self) -> None:
self.ad.adb.shell.return_value = b'false'
- self.assertFalse(self.device_flags.get_bool('sample', 'flag'))
+
+ value = self.device_flags.get_bool('sample', 'flag')
+
+ self.assertFalse(value)
def test_get_bool_with_invalid_bool_value(self) -> None:
self.ad.adb.shell.return_value = b'foo'
+
with self.assertRaisesRegex(ValueError, 'not a boolean'):
self.device_flags.get_bool('sample', 'flag')
+ def test_set_value_runs_correct_command(self) -> None:
+ self.device_flags.set_value('sample', 'flag', 'value')
+
+ self.ad.adb.shell.assert_called_with('device_config put sample flag value')
+
+ def test_enable_runs_correct_command(self) -> None:
+ self.ad.adb.shell.return_value = b'true'
+
+ self.device_flags.enable('sample', 'flag')
+
+ self.ad.adb.shell.assert_called_with('device_config put sample flag true')
+
+ def test_disable_runs_correct_command(self) -> None:
+ self.ad.adb.shell.return_value = b'true'
+
+ self.device_flags.disable('sample', 'flag')
+
+ self.ad.adb.shell.assert_called_with('device_config put sample flag false')
+
+ def test_disable_fails_with_non_boolean_original_value(self) -> None:
+ self.ad.adb.shell.return_value = b'foo'
+
+ with self.assertRaisesRegex(ValueError, 'not a boolean'):
+ self.device_flags.disable('sample', 'flag')
+
if __name__ == '__main__':
unittest.main()