Support changing flag values for device_config flags

+ some refactoring and format changes

Bug: 308215250
Test: local
Change-Id: I5a0449c746c443bbcf9ab9de5a85470543abd4fd
diff --git a/tools/device_flags.py b/tools/device_flags.py
index b6fbe80..22ba87d 100644
--- a/tools/device_flags.py
+++ b/tools/device_flags.py
@@ -18,7 +18,7 @@
 
 import os
 import tempfile
-from typing import Any, Dict, Optional
+from typing import Any
 
 from mobly.controllers import android_device
 from protos import aconfig_pb2
@@ -27,10 +27,14 @@
 _ACONFIG_PB_FILE = 'aconfig_flags.pb'
 
 _DEVICE_CONFIG_GET_CMD = 'device_config get'
+_DEVICE_CONFIG_PUT_CMD = 'device_config put'
 
 _READ_ONLY = aconfig_pb2.flag_permission.READ_ONLY
 _ENABLED = aconfig_pb2.flag_state.ENABLED
 
+_VAL_TRUE = 'true'
+_VAL_FALSE = 'false'
+
 
 class DeviceFlags:
     """Provides access to aconfig and device_config flag values of a device."""
@@ -39,10 +43,10 @@
         self._ad = ad
         self._aconfig_flags = None
 
-    def get_value(self, namespace: str, name: str) -> Optional[str]:
+    def get_value(self, namespace: str, key: str) -> str | None:
         """Gets the value of the requested flag.
 
-        Flags must be specified by both its namespace and name.
+        Flags must be specified by both its namespace and key.
 
         The method will first look for the flag from the device's
         aconfig_flags.pb files, and, if not found or the flag is READ_WRITE,
@@ -52,38 +56,39 @@
 
         Args:
             namespace: The namespace of the flag.
-            name: The full name of the flag.
-                For aconfig flags, it is equivalent to '{package}.{name}' from
-                    the aconfig proto.
-                For device_config flags, it is equivalent to '{KEY}' from the
-                    "device_config shell" command.
+            key: The full name of the flag. For aconfig flags, it is equivalent
+              to '{package}.{name}' from the aconfig proto. For device_config
+              flags, it is equivalent to '{KEY}' from the `device_config get`
+              command.
 
         Returns:
             The flag value as a string.
         """
         # Check aconfig
         aconfig_val = None
-        aconfig_flag = self._get_aconfig_flags().get(
-            '%s/%s' % (namespace, name))
+        aconfig_flag = self._get_aconfig_flags().get(f'{namespace}/{key}')
         if aconfig_flag is not None:
-            aconfig_val = 'true' if aconfig_flag.state == _ENABLED else 'false'
+            aconfig_val = (
+                _VAL_TRUE if aconfig_flag.state == _ENABLED else _VAL_FALSE)
             if aconfig_flag.permission == _READ_ONLY:
                 return aconfig_val
 
         # If missing or READ_WRITE, also check device_config
-        device_config_val = self._ad.adb.shell(
-            '%s %s %s' % (_DEVICE_CONFIG_GET_CMD, namespace, name)
-        ).decode('utf8').strip()
+        device_config_val = (
+            self._ad.adb.shell(f'{_DEVICE_CONFIG_GET_CMD} {namespace} {key}')
+            .decode('utf8')
+            .strip()
+        )
         return device_config_val if device_config_val != 'null' else aconfig_val
 
-    def get_bool(self, namespace: str, name: str) -> bool:
+    def get_bool(self, namespace: str, key: str) -> bool:
         """Gets the value of the requested flag as a boolean.
 
         See get_value() for details.
 
         Args:
             namespace: The namespace of the flag.
-            name: The key of the flag.
+            key: The key of the flag.
 
         Returns:
             The flag value as a boolean.
@@ -91,33 +96,84 @@
         Raises:
             ValueError if the flag value cannot be expressed as a boolean.
         """
-        val = self.get_value(namespace, name)
-        if val.lower() == 'true':
-            return True
-        if val.lower() == 'false':
-            return False
-        raise ValueError('Flag %s/%s is not a boolean (value: %s).'
-                         % (namespace, name, val))
+        val = self.get_value(namespace, key)
+        if val is not None:
+            if val.lower() == _VAL_TRUE:
+                return True
+            if val.lower() == _VAL_FALSE:
+                return False
+        raise ValueError(
+            f'Flag {namespace}/{key} is not a boolean (value: {val}).')
 
-    def _get_aconfig_flags(self) -> Dict[str, Any]:
+    def _get_aconfig_flags(self) -> dict[str, Any]:
         """Gets the aconfig flags as a dict. Loads from proto if necessary."""
         if self._aconfig_flags is None:
-            self._load_aconfig_flags()
+            self._aconfig_flags = self._load_aconfig_flags()
         return self._aconfig_flags
 
-    def _load_aconfig_flags(self) -> None:
+    def _load_aconfig_flags(self) -> dict[str, Any]:
         """Pull aconfig proto files from device, then load the flag info."""
-        self._aconfig_flags = {}
+        aconfig_flags = {}
         with tempfile.TemporaryDirectory() as tmp_dir:
             for partition in _ACONFIG_PARTITIONS:
                 device_path = os.path.join(
                     '/', partition, 'etc', _ACONFIG_PB_FILE)
                 host_path = os.path.join(
-                    tmp_dir, '%s_%s' % (partition, _ACONFIG_PB_FILE))
+                    tmp_dir, f'{partition}_{_ACONFIG_PB_FILE}')
                 self._ad.adb.pull([device_path, host_path])
                 with open(host_path, 'rb') as f:
                     parsed_flags = aconfig_pb2.parsed_flags.FromString(f.read())
                 for flag in parsed_flags.parsed_flag:
-                    full_name = '%s/%s.%s' % (
-                        flag.namespace, flag.package, flag.name)
-                    self._aconfig_flags[full_name] = flag
+                    full_name = f'{flag.namespace}/{flag.package}.{flag.name}'
+                    aconfig_flags[full_name] = flag
+        return aconfig_flags
+
+    def set_value(self, namespace: str, key: str, val: str) -> None:
+        """Sets the value of the requested flag.
+
+        This only supports flags that are set via `adb device_config`.
+
+        Args:
+            namespace: The namespace of the flag.
+            key: The key of the flag.
+            val: The desired value of the flag, in string format.
+        """
+        self._ad.adb.shell(f'{_DEVICE_CONFIG_PUT_CMD} {namespace} {key} {val}')
+
+    def enable(self, namespace: str, key: str) -> None:
+        """Enables the requested flag.
+
+        This only supports flags that are set via `adb device_config`.
+
+        Args:
+            namespace: The namespace of the flag.
+            key: The key of the flag.
+
+        Raises:
+            ValueError if the original flag value cannot be expressed as a
+              boolean.
+        """
+        # If the original value of the flag is not boolean, this will raise a
+        # ValueError.
+        _ = self.get_bool(namespace, key)
+
+        self.set_value(namespace, key, _VAL_TRUE)
+
+    def disable(self, namespace: str, key: str) -> None:
+        """Disables the requested flag.
+
+        This only supports flags that are set via `adb device_config`.
+
+        Args:
+            namespace: The namespace of the flag.
+            key: The key of the flag.
+
+        Raises:
+            ValueError if the original flag value cannot be expressed as a
+              boolean.
+        """
+        # If the original value of the flag is not boolean, this will raise a
+        # ValueError.
+        _ = self.get_bool(namespace, key)
+
+        self.set_value(namespace, key, _VAL_FALSE)
diff --git a/tools/device_flags_test.py b/tools/device_flags_test.py
index 40405a3..3efdf88 100644
--- a/tools/device_flags_test.py
+++ b/tools/device_flags_test.py
@@ -25,53 +25,96 @@
     """Unit tests for DeviceFlags."""
 
     def setUp(self) -> None:
+        super().setUp()
         self.ad = mock.MagicMock()
         self.device_flags = device_flags.DeviceFlags(self.ad)
         self.device_flags._aconfig_flags = {}
 
     def test_get_value_aconfig_flag_missing_use_device_config(self) -> None:
         self.ad.adb.shell.return_value = b'foo'
-        self.assertEqual(self.device_flags.get_value('sample', 'flag'), 'foo')
+
+        value = self.device_flags.get_value('sample', 'flag')
+
+        self.assertEqual(value, 'foo')
 
     def test_get_value_aconfig_flag_read_write_use_device_config(self) -> None:
         sample_flag = aconfig_pb2.parsed_flag()
         sample_flag.state = aconfig_pb2.flag_state.ENABLED
         sample_flag.permission = aconfig_pb2.flag_permission.READ_WRITE
         self.device_flags._aconfig_flags['sample/flag'] = sample_flag
-
         self.ad.adb.shell.return_value = b'false'
-        self.assertEqual(self.device_flags.get_value('sample', 'flag'), 'false')
+
+        value = self.device_flags.get_value('sample', 'flag')
+
+        self.assertEqual(value, 'false')
 
     def test_get_value_aconfig_flag_read_only_use_aconfig(self) -> None:
         sample_flag = aconfig_pb2.parsed_flag()
         sample_flag.state = aconfig_pb2.flag_state.ENABLED
         sample_flag.permission = aconfig_pb2.flag_permission.READ_ONLY
         self.device_flags._aconfig_flags['sample/flag'] = sample_flag
-
         self.ad.adb.shell.return_value = b'false'
-        self.assertEqual(self.device_flags.get_value('sample', 'flag'), 'true')
+
+        value = self.device_flags.get_value('sample', 'flag')
+
+        self.assertEqual(value, 'true')
 
     def test_get_value_device_config_null_use_aconfig(self) -> None:
         sample_flag = aconfig_pb2.parsed_flag()
         sample_flag.state = aconfig_pb2.flag_state.ENABLED
         sample_flag.permission = aconfig_pb2.flag_permission.READ_WRITE
         self.device_flags._aconfig_flags['sample/flag'] = sample_flag
-
         self.ad.adb.shell.return_value = b'null'
-        self.assertEqual(self.device_flags.get_value('sample', 'flag'), 'true')
 
-    def test_get_bool_with_valid_bool_value(self) -> None:
+        value = self.device_flags.get_value('sample', 'flag')
+
+        self.assertEqual(value, 'true')
+
+    def test_get_bool_with_valid_bool_value_true(self) -> None:
         self.ad.adb.shell.return_value = b'true'
-        self.assertTrue(self.device_flags.get_bool('sample', 'flag'))
 
+        value = self.device_flags.get_bool('sample', 'flag')
+
+        self.assertTrue(value)
+
+    def test_get_bool_with_valid_bool_value_false(self) -> None:
         self.ad.adb.shell.return_value = b'false'
-        self.assertFalse(self.device_flags.get_bool('sample', 'flag'))
+
+        value = self.device_flags.get_bool('sample', 'flag')
+
+        self.assertFalse(value)
 
     def test_get_bool_with_invalid_bool_value(self) -> None:
         self.ad.adb.shell.return_value = b'foo'
+
         with self.assertRaisesRegex(ValueError, 'not a boolean'):
             self.device_flags.get_bool('sample', 'flag')
 
+    def test_set_value_runs_correct_command(self) -> None:
+        self.device_flags.set_value('sample', 'flag', 'value')
+
+        self.ad.adb.shell.assert_called_with('device_config put sample flag value')
+
+    def test_enable_runs_correct_command(self) -> None:
+        self.ad.adb.shell.return_value = b'true'
+
+        self.device_flags.enable('sample', 'flag')
+
+        self.ad.adb.shell.assert_called_with('device_config put sample flag true')
+
+    def test_disable_runs_correct_command(self) -> None:
+        self.ad.adb.shell.return_value = b'true'
+
+        self.device_flags.disable('sample', 'flag')
+
+        self.ad.adb.shell.assert_called_with('device_config put sample flag false')
+
+    def test_disable_fails_with_non_boolean_original_value(self) -> None:
+        self.ad.adb.shell.return_value = b'foo'
+
+        with self.assertRaisesRegex(ValueError, 'not a boolean'):
+            self.device_flags.disable('sample', 'flag')
+
 
 if __name__ == '__main__':
     unittest.main()