Snap for 10453563 from 26c3fa06288d0c97a3117f839a3360f3fce845f6 to mainline-tzdata5-release

Change-Id: Ieb9c08f513ddd732456c8d48915b4a174145500d
diff --git a/src/com/google/android/iwlan/IwlanDataService.java b/src/com/google/android/iwlan/IwlanDataService.java
index 2f3ed2b..04b6789 100644
--- a/src/com/google/android/iwlan/IwlanDataService.java
+++ b/src/com/google/android/iwlan/IwlanDataService.java
@@ -81,6 +81,7 @@
 import java.util.List;
 import java.util.LongSummaryStatistics;
 import java.util.Map;
+import java.util.Objects;
 import java.util.concurrent.ConcurrentHashMap;
 
 public class IwlanDataService extends DataService {
@@ -124,13 +125,6 @@
 
     private static Transport sDefaultDataTransport = Transport.UNSPECIFIED_NETWORK;
 
-    enum LinkProtocolType {
-        UNKNOWN,
-        IPV4,
-        IPV6,
-        IPV4V6
-    }
-
     // TODO: see if network monitor callback impl can be shared between dataservice and
     // networkservice
     // This callback runs in the same thread as IwlanDataServiceHandler
@@ -171,12 +165,12 @@
                 @NonNull Network network, @NonNull LinkProperties linkProperties) {
             Log.d(TAG, "onLinkPropertiesChanged: " + linkProperties);
 
-            if (!sNetwork.equals(network)) {
+            if (!network.equals(sNetwork)) {
                 Log.d(TAG, "Ignore LinkProperties changes for unused Network.");
                 return;
             }
 
-            if (!sLinkProperties.equals(linkProperties)) {
+            if (!linkProperties.equals(sLinkProperties)) {
                 for (IwlanDataServiceProvider dp : sIwlanDataServiceProviders.values()) {
                     dp.dnsPrefetchCheck();
                     sLinkProperties = linkProperties;
@@ -601,6 +595,7 @@
             events.add(IwlanEventListener.CARRIER_CONFIG_UNKNOWN_CARRIER_EVENT);
             events.add(IwlanEventListener.WIFI_CALLING_ENABLE_EVENT);
             events.add(IwlanEventListener.WIFI_CALLING_DISABLE_EVENT);
+            events.add(IwlanEventListener.CROSS_SIM_CALLING_ENABLE_EVENT);
             events.add(IwlanEventListener.CELLINFO_CHANGED_EVENT);
             events.add(IwlanEventListener.CALL_STATE_CHANGED_EVENT);
             IwlanEventListener.getInstance(mContext, slotIndex)
@@ -962,14 +957,14 @@
         }
 
         private void updateNetwork(
-                @NonNull Network network, @Nullable LinkProperties linkProperties) {
+                @Nullable Network network, @Nullable LinkProperties linkProperties) {
             if (mIwlanDataService.isNetworkConnected(
                     isActiveDataOnOtherSub(getSlotIndex()),
                     IwlanHelper.isCrossSimCallingEnabled(mContext, getSlotIndex()))) {
                 getTunnelManager().updateNetwork(network, linkProperties);
             }
 
-            if (network.equals(sNetwork)) {
+            if (Objects.equals(network, sNetwork)) {
                 return;
             }
             for (Map.Entry<String, TunnelState> entry : mTunnelStateForApn.entrySet()) {
@@ -1364,6 +1359,12 @@
                     iwlanDataServiceProvider.mWfcEnabled = false;
                     break;
 
+                case IwlanEventListener.CROSS_SIM_CALLING_ENABLE_EVENT:
+                    iwlanDataServiceProvider =
+                            (IwlanDataServiceProvider) getDataServiceProvider(msg.arg1);
+                    iwlanDataServiceProvider.updateNetwork(sNetwork, sLinkProperties);
+                    break;
+
                 case IwlanEventListener.CELLINFO_CHANGED_EVENT:
                     List<CellInfo> cellInfolist = (List<CellInfo>) msg.obj;
                     iwlanDataServiceProvider =
@@ -2045,6 +2046,8 @@
                 return "WIFI_CALLING_ENABLE_EVENT";
             case IwlanEventListener.WIFI_CALLING_DISABLE_EVENT:
                 return "WIFI_CALLING_DISABLE_EVENT";
+            case IwlanEventListener.CROSS_SIM_CALLING_ENABLE_EVENT:
+                return "CROSS_SIM_CALLING_ENABLE_EVENT";
             case IwlanEventListener.CELLINFO_CHANGED_EVENT:
                 return "CELLINFO_CHANGED_EVENT";
             case EVENT_TUNNEL_OPENED_METRICS:
diff --git a/src/com/google/android/iwlan/IwlanHelper.java b/src/com/google/android/iwlan/IwlanHelper.java
index efda5b7..ade7189 100644
--- a/src/com/google/android/iwlan/IwlanHelper.java
+++ b/src/com/google/android/iwlan/IwlanHelper.java
@@ -110,30 +110,8 @@
         return info;
     }
 
-    public static List<InetAddress> getAddressesForNetwork(Network network, Context context) {
-        ConnectivityManager connectivityManager =
-                context.getSystemService(ConnectivityManager.class);
-        List<InetAddress> gatewayList = new ArrayList<>();
-        if (network != null) {
-            LinkProperties linkProperties = connectivityManager.getLinkProperties(network);
-            if (linkProperties != null) {
-                for (LinkAddress linkAddr : linkProperties.getLinkAddresses()) {
-                    InetAddress inetAddr = linkAddr.getAddress();
-                    // skip linklocal and loopback addresses
-                    if (!inetAddr.isLoopbackAddress() && !inetAddr.isLinkLocalAddress()) {
-                        gatewayList.add(inetAddr);
-                    }
-                }
-                if (linkProperties.getNat64Prefix() != null) {
-                    mNat64Prefix = linkProperties.getNat64Prefix();
-                }
-            }
-        }
-        return gatewayList;
-    }
-
-    public static List<InetAddress> getStackedAddressesForNetwork(
-            Network network, Context context) {
+    // Retrieves all IP addresses for this Network, including stacked IPv4 link addresses.
+    public static List<InetAddress> getAllAddressesForNetwork(Network network, Context context) {
         ConnectivityManager connectivityManager =
                 context.getSystemService(ConnectivityManager.class);
         List<InetAddress> gatewayList = new ArrayList<>();
@@ -142,10 +120,14 @@
             if (linkProperties != null) {
                 for (LinkAddress linkAddr : linkProperties.getAllLinkAddresses()) {
                     InetAddress inetAddr = linkAddr.getAddress();
-                    if ((inetAddr instanceof Inet4Address)) {
+                    // skip linklocal and loopback addresses
+                    if (!inetAddr.isLoopbackAddress() && !inetAddr.isLinkLocalAddress()) {
                         gatewayList.add(inetAddr);
                     }
                 }
+                if (linkProperties.getNat64Prefix() != null) {
+                    mNat64Prefix = linkProperties.getNat64Prefix();
+                }
             }
         }
         return gatewayList;
@@ -162,22 +144,24 @@
     }
 
     public static boolean hasIpv6Address(List<InetAddress> localAddresses) {
-        for (InetAddress address : localAddresses) {
-            if (address instanceof Inet6Address) {
-                return true;
+        if (localAddresses != null) {
+            for (InetAddress address : localAddresses) {
+                if (address instanceof Inet6Address) {
+                    return true;
+                }
             }
         }
-
         return false;
     }
 
     public static boolean hasIpv4Address(List<InetAddress> localAddresses) {
-        for (InetAddress address : localAddresses) {
-            if (address instanceof Inet4Address) {
-                return true;
+        if (localAddresses != null) {
+            for (InetAddress address : localAddresses) {
+                if (address instanceof Inet4Address) {
+                    return true;
+                }
             }
         }
-
         return false;
     }
 
diff --git a/src/com/google/android/iwlan/epdg/EpdgSelector.java b/src/com/google/android/iwlan/epdg/EpdgSelector.java
index 52392fd..4534b81 100644
--- a/src/com/google/android/iwlan/epdg/EpdgSelector.java
+++ b/src/com/google/android/iwlan/epdg/EpdgSelector.java
@@ -84,10 +84,12 @@
     // IWLAN applies an internal timeout of 6 seconds, slightly longer than the default timeout
     private static final long DNS_RESOLVER_TIMEOUT_DURATION_SEC = 6L;
 
-    private static final long PARALLEL_DNS_RESOLVER_TIMEOUT_DURATION_SEC = 20L;
+    private static final long PARALLEL_STATIC_RESOLUTION_TIMEOUT_DURATION_SEC = 6L;
+    private static final long PARALLEL_PLMN_RESOLUTION_TIMEOUT_DURATION_SEC = 20L;
     private static final int NUM_EPDG_SELECTION_EXECUTORS = 2; // 1 each for normal selection, SOS.
     private static final int MAX_EPDG_SELECTION_THREADS = 2; // 1 each for prefetch, tunnel bringup.
     private static final int MAX_DNS_RESOLVER_THREADS = 25; // Do not expect > 25 FQDNs per carrier.
+    private static final String NO_DOMAIN = "NO_DOMAIN";
 
     BlockingQueue<Runnable> dnsResolutionQueue =
             new ArrayBlockingQueue<>(
@@ -206,7 +208,7 @@
     }
 
     private CompletableFuture<Map.Entry<String, List<InetAddress>>> submitDnsResolverQuery(
-            String domainName, Network network, Executor executor) {
+            String domainName, Network network, int queryType, Executor executor) {
         CompletableFuture<Map.Entry<String, List<InetAddress>>> result = new CompletableFuture();
 
         final DnsResolver.Callback<List<InetAddress>> cb =
@@ -234,7 +236,7 @@
                     }
                 };
         DnsResolver.getInstance()
-                .query(network, domainName, DnsResolver.FLAG_EMPTY, executor, null, cb);
+                .query(network, domainName, queryType, DnsResolver.FLAG_EMPTY, executor, null, cb);
         return result;
     }
 
@@ -280,6 +282,22 @@
                                 .collect(Collectors.<T>toList()));
     }
 
+    @VisibleForTesting
+    protected boolean hasIpv4Address(Network network) {
+        return IwlanHelper.hasIpv4Address(IwlanHelper.getAllAddressesForNetwork(network, mContext));
+    }
+
+    @VisibleForTesting
+    protected boolean hasIpv6Address(Network network) {
+        return IwlanHelper.hasIpv6Address(IwlanHelper.getAllAddressesForNetwork(network, mContext));
+    }
+
+    private void printParallelDnsResult(Map<String, List<InetAddress>> domainNameToIpAddresses) {
+        Log.d(TAG, "Parallel DNS resolution result:");
+        for (String domain : domainNameToIpAddresses.keySet()) {
+            Log.d(TAG, domain + ": " + domainNameToIpAddresses.get(domain));
+        }
+    }
     /**
      * Returns a list of unique IP addresses corresponding to the given domain names, in the same
      * order of the input. Runs DNS resolution across parallel threads.
@@ -287,27 +305,46 @@
      * @param domainNames Domain names for which DNS resolution needs to be performed.
      * @param filter Selects for IPv4, IPv6 (or both) addresses from the resulting DNS records
      * @param network {@link Network} Network on which to run the DNS query.
+     * @param timeout timeout in seconds.
      * @return List of unique IP addresses corresponding to the domainNames.
      */
     private LinkedHashMap<String, List<InetAddress>> getIP(
-            List<String> domainNames, int filter, Network network) {
+            List<String> domainNames, int filter, Network network, long timeout) {
         // LinkedHashMap preserves insertion order (and hence priority) of domain names passed in.
         LinkedHashMap<String, List<InetAddress>> domainNameToIpAddr = new LinkedHashMap<>();
 
         List<CompletableFuture<Map.Entry<String, List<InetAddress>>>> futuresList =
                 new ArrayList<>();
         for (String domainName : domainNames) {
+            if (InetAddresses.isNumericAddress(domainName)) {
+                Log.d(TAG, domainName + " is a numeric IP address!");
+                InetAddress inetAddr = InetAddresses.parseNumericAddress(domainName);
+                domainNameToIpAddr.put(NO_DOMAIN, new ArrayList<>(List.of(inetAddr)));
+                continue;
+            }
+
             domainNameToIpAddr.put(domainName, new ArrayList<>());
-            futuresList.add(submitDnsResolverQuery(domainName, network, mDnsResolutionExecutor));
+            // Dispatches separate IPv4 and IPv6 queries to avoid being blocked on either result.
+            if (hasIpv4Address(network)) {
+                futuresList.add(
+                        submitDnsResolverQuery(
+                                domainName, network, DnsResolver.TYPE_A, mDnsResolutionExecutor));
+            }
+            if (hasIpv6Address(network)) {
+                futuresList.add(
+                        submitDnsResolverQuery(
+                                domainName,
+                                network,
+                                DnsResolver.TYPE_AAAA,
+                                mDnsResolutionExecutor));
+            }
         }
         CompletableFuture<List<Map.Entry<String, List<InetAddress>>>> allFuturesResult =
                 allOf(futuresList);
 
         List<Map.Entry<String, List<InetAddress>>> resultList = null;
         try {
-            resultList =
-                    allFuturesResult.get(
-                            PARALLEL_DNS_RESOLVER_TIMEOUT_DURATION_SEC, TimeUnit.SECONDS);
+            resultList = allFuturesResult.get(timeout, TimeUnit.SECONDS);
         } catch (ExecutionException e) {
             Log.e(TAG, "Cause of ExecutionException: ", e.getCause());
         } catch (InterruptedException e) {
@@ -320,8 +357,17 @@
                 Log.w(TAG, "No IP addresses in parallel DNS query!");
             } else {
                 for (Map.Entry<String, List<InetAddress>> entry : resultList) {
-                    domainNameToIpAddr.put(
-                            entry.getKey(), v4v6ProtocolFilter(entry.getValue(), filter));
+                    String resultDomainName = entry.getKey();
+                    List<InetAddress> resultIpAddr = v4v6ProtocolFilter(entry.getValue(), filter);
+
+                    if (!domainNameToIpAddr.containsKey(resultDomainName)) {
+                        Log.w(
+                                TAG,
+                                "Unexpected domain name in DnsResolver result: "
+                                        + resultDomainName);
+                        continue;
+                    }
+                    domainNameToIpAddr.get(resultDomainName).addAll(resultIpAddr);
                 }
             }
         }
@@ -345,12 +391,8 @@
         Log.d(TAG, "Input domainName : " + domainName);
 
         if (InetAddresses.isNumericAddress(domainName)) {
-            try {
-                Log.d(TAG, domainName + " is a numeric ip address");
-                ipList.add(InetAddress.getByName(domainName));
-            } catch (UnknownHostException e) {
-                Log.e(TAG, "Exception when resolving domainName : " + domainName + ".", e);
-            }
+            Log.d(TAG, domainName + " is a numeric IP address!");
+            ipList.add(InetAddresses.parseNumericAddress(domainName));
         } else {
             try {
                 CompletableFuture<List<InetAddress>> result = new CompletableFuture();
@@ -570,9 +612,14 @@
         }
 
         Log.d(TAG, "Static Domain Names: " + Arrays.toString(domainNames));
-        for (String domainName : domainNames) {
-            getIP(domainName, filter, validIpList, network);
-        }
+        LinkedHashMap<String, List<InetAddress>> domainNameToIpAddr =
+                getIP(
+                        Arrays.asList(domainNames),
+                        filter,
+                        network,
+                        PARALLEL_STATIC_RESOLUTION_TIMEOUT_DURATION_SEC);
+        printParallelDnsResult(domainNameToIpAddr);
+        domainNameToIpAddr.values().forEach(validIpList::addAll);
     }
 
     private String[] getDomainNames(String key) {
@@ -647,7 +694,8 @@
         }
 
         LinkedHashMap<String, List<InetAddress>> domainNameToIpAddr =
-                getIP(domainNames, filter, network);
+                getIP(domainNames, filter, network, PARALLEL_PLMN_RESOLUTION_TIMEOUT_DURATION_SEC);
+        printParallelDnsResult(domainNameToIpAddr);
         domainNameToIpAddr.values().forEach(validIpList::addAll);
         return domainNameToIpAddr;
     }
diff --git a/src/com/google/android/iwlan/epdg/EpdgTunnelManager.java b/src/com/google/android/iwlan/epdg/EpdgTunnelManager.java
index bfe4f95..b8b70e4 100644
--- a/src/com/google/android/iwlan/epdg/EpdgTunnelManager.java
+++ b/src/com/google/android/iwlan/epdg/EpdgTunnelManager.java
@@ -2283,7 +2283,7 @@
 
     @VisibleForTesting
     List<InetAddress> getAddressForNetwork(Network network, Context context) {
-        return IwlanHelper.getAddressesForNetwork(network, context);
+        return IwlanHelper.getAllAddressesForNetwork(network, context);
     }
 
     @VisibleForTesting
diff --git a/test/com/google/android/iwlan/IwlanDataServiceTest.java b/test/com/google/android/iwlan/IwlanDataServiceTest.java
index 98fd3eb..2f8f2c3 100644
--- a/test/com/google/android/iwlan/IwlanDataServiceTest.java
+++ b/test/com/google/android/iwlan/IwlanDataServiceTest.java
@@ -17,6 +17,7 @@
 package com.google.android.iwlan;
 
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession;
@@ -105,6 +106,7 @@
 public class IwlanDataServiceTest {
     private static final int DEFAULT_SLOT_INDEX = 0;
     private static final int DEFAULT_SUB_INDEX = 0;
+    private static final int INVALID_SUB_INDEX = -1;
     private static final int LINK_MTU = 1280;
     private static final String TEST_APN_NAME = "ims";
     private static final String IP_ADDRESS = "192.0.2.1";
@@ -291,11 +293,11 @@
     }
 
     private void onSystemDefaultNetworkConnected(
-            Network network, LinkProperties linkProperties, int transportType) {
+            Network network, LinkProperties linkProperties, int transportType, int subId) {
         NetworkCapabilities nc =
                 prepareNetworkCapabilitiesForTest(
                         transportType,
-                        DEFAULT_SUB_INDEX /* unused if transportType is TRANSPORT_WIFI */,
+                        subId /* unused if transportType is TRANSPORT_WIFI */,
                         false /* isVcn */);
         NetworkCallback networkMonitorCallback = getNetworkMonitorCallback();
         networkMonitorCallback.onCapabilitiesChanged(network, nc);
@@ -305,7 +307,8 @@
 
     private void onSystemDefaultNetworkConnected(int transportType) {
         Network newNetwork = createMockNetwork(mLinkProperties);
-        onSystemDefaultNetworkConnected(newNetwork, mLinkProperties, transportType);
+        onSystemDefaultNetworkConnected(
+                newNetwork, mLinkProperties, transportType, DEFAULT_SUB_INDEX);
     }
 
     private void onSystemDefaultNetworkLost() {
@@ -315,7 +318,7 @@
     }
 
     @Test
-    public void testWifionConnected() {
+    public void testWifiOnConnected() {
         onSystemDefaultNetworkConnected(TRANSPORT_WIFI);
         assertTrue(
                 mIwlanDataService.isNetworkConnected(
@@ -337,22 +340,23 @@
     }
 
     @Test
-    public void testWifiOnReConnected() {
+    public void testWifiOnReconnected() {
         Network newNetwork = createMockNetwork(mLinkProperties);
-        onSystemDefaultNetworkConnected(newNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                newNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
         verify(mMockEpdgTunnelManager, times(1)).updateNetwork(eq(newNetwork), eq(mLinkProperties));
 
         onSystemDefaultNetworkLost();
-
-        newNetwork = createMockNetwork(mLinkProperties);
-        onSystemDefaultNetworkConnected(newNetwork, mLinkProperties, TRANSPORT_WIFI);
-        verify(mMockEpdgTunnelManager, times(1)).updateNetwork(eq(newNetwork), eq(mLinkProperties));
+        onSystemDefaultNetworkConnected(
+                newNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
+        verify(mMockEpdgTunnelManager, times(2)).updateNetwork(eq(newNetwork), eq(mLinkProperties));
     }
 
     @Test
     public void testOnLinkPropertiesChangedForConnectedNetwork() {
         NetworkCallback networkCallback = getNetworkMonitorCallback();
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         clearInvocations(mMockEpdgTunnelManager);
 
@@ -368,7 +372,8 @@
     @Test
     public void testOnLinkPropertiesChangedForNonConnectedNetwork() {
         NetworkCallback networkCallback = getNetworkMonitorCallback();
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         clearInvocations(mMockEpdgTunnelManager);
 
@@ -387,7 +392,8 @@
         NetworkCallback networkCallback = getNetworkMonitorCallback();
         mLinkProperties.setLinkAddresses(
                 new ArrayList<>(Collections.singletonList(mMockIPv6LinkAddress)));
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         clearInvocations(mMockEpdgTunnelManager);
 
@@ -412,7 +418,8 @@
         DataProfile dp = buildImsDataProfile();
 
         NetworkCallback networkCallback = getNetworkMonitorCallback();
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         clearInvocations(mMockEpdgTunnelManager);
 
@@ -436,7 +443,7 @@
     }
 
     @Test
-    public void testNetworkNotConnectedWithCellularAndCrossSimDisabled()
+    public void testNetworkNotConnectedWithCellularOnSameSubAndCrossSimEnabled()
             throws InterruptedException {
         NetworkCapabilities nc =
                 prepareNetworkCapabilitiesForTest(
@@ -453,7 +460,8 @@
     }
 
     @Test
-    public void testCrossSimNetworkConnectedWithTelephonyNetwork() throws InterruptedException {
+    public void testCrossSimNetworkConnectedWithCellularOnDifferentSub()
+            throws InterruptedException {
         NetworkCapabilities nc =
                 prepareNetworkCapabilitiesForTest(
                         TRANSPORT_CELLULAR, DEFAULT_SUB_INDEX + 1, false /* isVcn */);
@@ -469,7 +477,8 @@
     }
 
     @Test
-    public void testCrossSimNetworkConnectedWithVcn() throws InterruptedException {
+    public void testCrossSimNetworkConnectedWithVcnCellularOnDifferentSub()
+            throws InterruptedException {
         NetworkCapabilities nc =
                 prepareNetworkCapabilitiesForTest(
                         TRANSPORT_CELLULAR, DEFAULT_SUB_INDEX + 1, true /* isVcn */);
@@ -485,6 +494,74 @@
     }
 
     @Test
+    public void testOnCrossSimCallingEnable_doNotUpdateTunnelManagerIfCellularDataOnSameSub()
+            throws Exception {
+        when(mMockImsMmTelManager.isCrossSimCallingEnabled()).thenReturn(true);
+
+        Network newNetwork = createMockNetwork(mLinkProperties);
+        onSystemDefaultNetworkConnected(
+                newNetwork, mLinkProperties, TRANSPORT_CELLULAR, DEFAULT_SUB_INDEX);
+
+        mIwlanDataService
+                .mIwlanDataServiceHandler
+                .obtainMessage(
+                        IwlanEventListener.CROSS_SIM_CALLING_ENABLE_EVENT,
+                        DEFAULT_SLOT_INDEX,
+                        0 /* unused */)
+                .sendToTarget();
+        mTestLooper.dispatchAll();
+        verify(mMockEpdgTunnelManager, never())
+                .updateNetwork(eq(newNetwork), any(LinkProperties.class));
+    }
+
+    @Test
+    public void testOnCrossSimCallingEnable_updateTunnelManagerIfCellularDataOnDifferentSub()
+            throws Exception {
+        when(mMockImsMmTelManager.isCrossSimCallingEnabled()).thenReturn(true);
+
+        Network newNetwork = createMockNetwork(mLinkProperties);
+        onSystemDefaultNetworkConnected(
+                newNetwork, mLinkProperties, TRANSPORT_CELLULAR, DEFAULT_SUB_INDEX + 1);
+        verify(mMockEpdgTunnelManager, times(1)).updateNetwork(eq(newNetwork), eq(mLinkProperties));
+
+        mIwlanDataService
+                .mIwlanDataServiceHandler
+                .obtainMessage(
+                        IwlanEventListener.CROSS_SIM_CALLING_ENABLE_EVENT,
+                        DEFAULT_SLOT_INDEX,
+                        0 /* unused */)
+                .sendToTarget();
+        mTestLooper.dispatchAll();
+        verify(mMockEpdgTunnelManager, times(2)).updateNetwork(eq(newNetwork), eq(mLinkProperties));
+    }
+
+    @Test
+    public void testOnCrossSimCallingEnable_doNotUpdateTunnelManagerIfNoNetwork() throws Exception {
+        when(mMockImsMmTelManager.isCrossSimCallingEnabled()).thenReturn(true);
+        onSystemDefaultNetworkLost();
+
+        mIwlanDataService
+                .mIwlanDataServiceHandler
+                .obtainMessage(
+                        IwlanEventListener.CROSS_SIM_CALLING_ENABLE_EVENT,
+                        DEFAULT_SLOT_INDEX,
+                        0 /* unused */)
+                .sendToTarget();
+        mTestLooper.dispatchAll();
+        verify(mMockEpdgTunnelManager, never())
+                .updateNetwork(any(Network.class), any(LinkProperties.class));
+    }
+
+    @Test
+    public void testOnEthernetConnection_doNotUpdateTunnelManager() throws Exception {
+        Network newNetwork = createMockNetwork(mLinkProperties);
+        onSystemDefaultNetworkConnected(
+                newNetwork, mLinkProperties, TRANSPORT_ETHERNET, DEFAULT_SUB_INDEX);
+        verify(mMockEpdgTunnelManager, never())
+                .updateNetwork(eq(newNetwork), any(LinkProperties.class));
+    }
+
+    @Test
     public void testAddDuplicateDataServiceProviderThrows() throws Exception {
         when(mMockIwlanDataServiceProvider.getSlotIndex()).thenReturn(DEFAULT_SLOT_INDEX);
         assertThrows(
@@ -640,7 +717,8 @@
         DataProfile dp = buildImsDataProfile();
 
         /* Wifi is connected */
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         mSpyIwlanDataServiceProvider.setupDataCall(
                 AccessNetworkType.IWLAN, /* AccessNetworkType */
@@ -678,7 +756,8 @@
         DataProfile dp = buildImsDataProfile();
 
         /* Wifi is connected */
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         mSpyIwlanDataServiceProvider.setupDataCall(
                 AccessNetworkType.IWLAN, /* AccessNetworkType */
@@ -1131,7 +1210,8 @@
     public void testDnsPrefetching() throws Exception {
         NetworkCallback networkCallback = getNetworkMonitorCallback();
         /* Wifi is connected */
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
         networkCallback.onLinkPropertiesChanged(mMockNetwork, mLinkProperties);
 
         mIwlanDataService
@@ -1340,7 +1420,8 @@
     public void testIwlanTunnelStatsFailureCounts() {
         DataProfile dp = buildImsDataProfile();
 
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         when(ErrorPolicyManager.getInstance(eq(mMockContext), eq(DEFAULT_SLOT_INDEX)))
                 .thenReturn(mMockErrorPolicyManager);
@@ -1366,7 +1447,8 @@
         when(mMockErrorPolicyManager.getDataFailCause(eq(TEST_APN_NAME)))
                 .thenReturn(DataFailCause.ERROR_UNSPECIFIED);
 
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         long count = 3L;
         for (int i = 0; i < count; i++) {
@@ -1387,7 +1469,8 @@
         when(calendar.getTime()).thenAnswer(i -> new Date(mMockedCalendarTime));
 
         mSpyIwlanDataServiceProvider.setCalendar(calendar);
-        onSystemDefaultNetworkConnected(mMockNetwork, mLinkProperties, TRANSPORT_WIFI);
+        onSystemDefaultNetworkConnected(
+                mMockNetwork, mLinkProperties, TRANSPORT_WIFI, INVALID_SUB_INDEX);
 
         LongSummaryStatistics tunnelSetupSuccessStats = new LongSummaryStatistics();
         LongSummaryStatistics tunnelUpStats = new LongSummaryStatistics();
diff --git a/test/com/google/android/iwlan/epdg/EpdgSelectorTest.java b/test/com/google/android/iwlan/epdg/EpdgSelectorTest.java
index b2d514c..3f23dab 100644
--- a/test/com/google/android/iwlan/epdg/EpdgSelectorTest.java
+++ b/test/com/google/android/iwlan/epdg/EpdgSelectorTest.java
@@ -124,7 +124,7 @@
 
         when(ErrorPolicyManager.getInstance(mMockContext, DEFAULT_SLOT_INDEX))
                 .thenReturn(mMockErrorPolicyManager);
-        mEpdgSelector = new EpdgSelector(mMockContext, DEFAULT_SLOT_INDEX);
+        mEpdgSelector = spy(new EpdgSelector(mMockContext, DEFAULT_SLOT_INDEX));
 
         when(mMockContext.getSystemService(eq(SubscriptionManager.class)))
                 .thenReturn(mMockSubscriptionManager);
@@ -163,8 +163,6 @@
                 .thenReturn(mMockCarrierConfigManager);
         when(mMockCarrierConfigManager.getConfigForSubId(anyInt())).thenReturn(mTestBundle);
 
-        lenient().when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
-
         mFakeDns = new FakeDns();
         mFakeDns.startMocking();
     }
@@ -177,6 +175,10 @@
 
     @Test
     public void testStaticMethodPass() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         // Set DnsResolver query mock
         final String testStaticAddress = "epdg.epc.mnc088.mcc888.pub.3gppnetwork.org";
         mFakeDns.setAnswer(testStaticAddress, new String[] {TEST_IP_ADDRESS}, TYPE_A);
@@ -193,12 +195,32 @@
 
         InetAddress expectedAddress = InetAddress.getByName(TEST_IP_ADDRESS);
 
-        assertEquals(testInetAddresses.size(), 1);
-        assertEquals(testInetAddresses.get(0), expectedAddress);
+        assertEquals(1, testInetAddresses.size());
+        assertEquals(expectedAddress, testInetAddresses.get(0));
+    }
+
+    @Test
+    public void testStaticMethodDirectIpAddress_noDnsResolution() throws Exception {
+        mTestBundle.putIntArray(
+                CarrierConfigManager.Iwlan.KEY_EPDG_ADDRESS_PRIORITY_INT_ARRAY,
+                new int[] {CarrierConfigManager.Iwlan.EPDG_ADDRESS_STATIC});
+        // Carrier config directly contains the ePDG IP address.
+        mTestBundle.putString(
+                CarrierConfigManager.Iwlan.KEY_EPDG_STATIC_ADDRESS_STRING, TEST_IP_ADDRESS);
+
+        ArrayList<InetAddress> testInetAddresses =
+                getValidatedServerListWithDefaultParams(false /*isEmergency*/);
+
+        assertEquals(1, testInetAddresses.size());
+        assertEquals(InetAddresses.parseNumericAddress(TEST_IP_ADDRESS), testInetAddresses.get(0));
     }
 
     @Test
     public void testRoamStaticMethodPass() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         // Set DnsResolver query mock
         final String testRoamStaticAddress = "epdg.epc.mnc088.mcc888.pub.3gppnetwork.org";
         mFakeDns.setAnswer(testRoamStaticAddress, new String[] {TEST_IP_ADDRESS}, TYPE_A);
@@ -216,8 +238,8 @@
 
         InetAddress expectedAddress = InetAddress.getByName(TEST_IP_ADDRESS);
 
-        assertEquals(testInetAddresses.size(), 1);
-        assertEquals(testInetAddresses.get(0), expectedAddress);
+        assertEquals(1, testInetAddresses.size());
+        assertEquals(expectedAddress, testInetAddresses.get(0));
     }
 
     @Test
@@ -232,6 +254,10 @@
 
     @Test
     public void testPlmnResolutionMethodWithNoPlmnInCarrierConfig() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         // setUp() fills default values for mcc-mnc
         String expectedFqdnFromImsi = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
         String expectedFqdnFromEhplmn = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
@@ -242,12 +268,16 @@
         ArrayList<InetAddress> testInetAddresses =
                 getValidatedServerListWithDefaultParams(false /*isEmergency*/);
 
-        assertEquals(testInetAddresses.size(), 2);
+        assertEquals(2, testInetAddresses.size());
         assertTrue(testInetAddresses.contains(InetAddress.getByName(TEST_IP_ADDRESS_1)));
         assertTrue(testInetAddresses.contains(InetAddress.getByName(TEST_IP_ADDRESS_2)));
     }
 
     private void testPlmnResolutionMethod(boolean isEmergency) throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         String expectedFqdnFromImsi = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
         String expectedFqdnFromRplmn = "epdg.epc.mnc121.mcc311.pub.3gppnetwork.org";
         String expectedFqdnFromEhplmn = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
@@ -295,6 +325,10 @@
 
     @Test
     public void testPlmnResolutionMethodWithDuplicatedImsiAndEhplmn() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         String fqdnFromEhplmn1 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         String fqdnFromEhplmn2AndImsi = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
         String fqdnFromEhplmn3 = "epdg.epc.mnc122.mcc300.pub.3gppnetwork.org";
@@ -330,6 +364,10 @@
 
     @Test
     public void testPlmnResolutionMethodWithInvalidLengthPlmns() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         when(mMockSubscriptionInfo.getMccString()).thenReturn("31");
         when(mMockSubscriptionInfo.getMncString()).thenReturn("12");
 
@@ -355,6 +393,10 @@
 
     @Test
     public void testPlmnResolutionMethodWithInvalidCharacterPlmns() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         when(mMockSubscriptionInfo.getMccString()).thenReturn("a b");
         when(mMockSubscriptionInfo.getMncString()).thenReturn("!@#");
 
@@ -381,6 +423,10 @@
 
     @Test
     public void testPlmnResolutionMethodWithEmptyPlmns() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         when(mMockSubscriptionInfo.getMccString()).thenReturn(null);
         when(mMockSubscriptionInfo.getMncString()).thenReturn(null);
 
@@ -405,6 +451,10 @@
 
     @Test
     public void testPlmnResolutionMethodWithFirstEhplmn() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         String fqdnFromEhplmn1 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         String fqdnFromEhplmn2 = "epdg.epc.mnc121.mcc300.pub.3gppnetwork.org";
         String fqdnFromEhplmn3 = "epdg.epc.mnc122.mcc300.pub.3gppnetwork.org";
@@ -434,6 +484,10 @@
 
     @Test
     public void testPlmnResolutionMethodWithRplmn() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         String fqdnFromRplmn = "epdg.epc.mnc122.mcc300.pub.3gppnetwork.org";
         String fqdnFromEhplmn1 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         String fqdnFromEhplmn2 = "epdg.epc.mnc121.mcc300.pub.3gppnetwork.org";
@@ -464,7 +518,11 @@
 
     @Test
     public void testCarrierConfigStaticAddressList() throws Exception {
-        // Set Network.getAllByName mock
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
+        // Set DnsResolver query mock
         final String addr1 = "epdg.epc.mnc480.mcc310.pub.3gppnetwork.org";
         final String addr2 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         final String addr3 = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
@@ -484,10 +542,10 @@
         ArrayList<InetAddress> testInetAddresses =
                 getValidatedServerListWithDefaultParams(false /*isEmergency*/);
 
-        assertEquals(testInetAddresses.size(), 3);
-        assertEquals(testInetAddresses.get(0), InetAddress.getByName(TEST_IP_ADDRESS_1));
-        assertEquals(testInetAddresses.get(1), InetAddress.getByName(TEST_IP_ADDRESS_2));
-        assertEquals(testInetAddresses.get(2), InetAddress.getByName(TEST_IP_ADDRESS));
+        assertEquals(3, testInetAddresses.size());
+        assertEquals(InetAddress.getByName(TEST_IP_ADDRESS_1), testInetAddresses.get(0));
+        assertEquals(InetAddress.getByName(TEST_IP_ADDRESS_2), testInetAddresses.get(1));
+        assertEquals(InetAddress.getByName(TEST_IP_ADDRESS), testInetAddresses.get(2));
     }
 
     private ArrayList<InetAddress> getValidatedServerListWithDefaultParams(boolean isEmergency)
@@ -567,7 +625,7 @@
         ArrayList<InetAddress> testInetAddresses =
                 getValidatedServerListWithDefaultParams(false /* isEmergency */);
 
-        assertEquals(testInetAddresses.size(), 2);
+        assertEquals(2, testInetAddresses.size());
         assertTrue(testInetAddresses.contains(InetAddress.getByName(TEST_IP_ADDRESS)));
         assertTrue(testInetAddresses.contains(InetAddress.getByName(TEST_IPV6_ADDRESS)));
     }
@@ -588,6 +646,8 @@
     }
 
     private void testCellularResolutionMethod(boolean isEmergency) throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+
         int testMcc = 311;
         int testMnc = 120;
         String testMccString = "311";
@@ -639,10 +699,10 @@
         ArrayList<InetAddress> testInetAddresses =
                 getValidatedServerListWithDefaultParams(isEmergency);
 
-        assertEquals(testInetAddresses.size(), 3);
-        assertEquals(testInetAddresses.get(0), InetAddress.getByName(TEST_IP_ADDRESS));
-        assertEquals(testInetAddresses.get(1), InetAddress.getByName(TEST_IP_ADDRESS_1));
-        assertEquals(testInetAddresses.get(2), InetAddress.getByName(TEST_IP_ADDRESS_2));
+        assertEquals(3, testInetAddresses.size());
+        assertEquals(InetAddress.getByName(TEST_IP_ADDRESS), testInetAddresses.get(0));
+        assertEquals(InetAddress.getByName(TEST_IP_ADDRESS_1), testInetAddresses.get(1));
+        assertEquals(InetAddress.getByName(TEST_IP_ADDRESS_2), testInetAddresses.get(2));
     }
 
     private void setAnswerForCellularMethod(boolean isEmergency, int mcc, int mnc)
@@ -683,6 +743,10 @@
 
     @Test
     public void testGetValidatedServerListIpv4Preferred() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         final String addr1 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         final String addr2 = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
         final String testStaticAddress = addr1 + "," + addr2;
@@ -710,6 +774,10 @@
 
     @Test
     public void testGetValidatedServerListIpv6Preferred() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         final String addr1 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         final String addr2 = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
         final String testStaticAddress = addr1 + "," + addr2;
@@ -737,6 +805,10 @@
 
     @Test
     public void testGetValidatedServerListIpv4Only() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         final String addr1 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         final String addr2 = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
         final String testStaticAddress = addr1 + "," + addr2;
@@ -763,6 +835,10 @@
 
     @Test
     public void testGetValidatedServerListIpv4OnlyCongestion() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         when(mMockErrorPolicyManager.getMostRecentDataFailCause())
                 .thenReturn(DataFailCause.IWLAN_CONGESTION);
         when(mMockErrorPolicyManager.getCurrentFqdnIndex(anyInt())).thenReturn(0);
@@ -794,6 +870,10 @@
 
     @Test
     public void testGetValidatedServerListIpv6Only() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         final String addr1 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         final String addr2 = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
         final String testStaticAddress = addr1 + "," + addr2;
@@ -820,6 +900,10 @@
 
     @Test
     public void testGetValidatedServerListSystemPreferred() throws Exception {
+        when(DnsResolver.getInstance()).thenReturn(mMockDnsResolver);
+        doReturn(true).when(mEpdgSelector).hasIpv4Address(mMockNetwork);
+        doReturn(true).when(mEpdgSelector).hasIpv6Address(mMockNetwork);
+
         final String addr1 = "epdg.epc.mnc120.mcc300.pub.3gppnetwork.org";
         final String addr2 = "epdg.epc.mnc120.mcc311.pub.3gppnetwork.org";
         final String addr3 = "epdg.epc.mnc120.mcc312.pub.3gppnetwork.org";
@@ -873,20 +957,20 @@
             }
         }
 
-        private final ArrayList<DnsEntry> mAnswers = new ArrayList<DnsEntry>();
+        private final List<DnsEntry> mAnswers = new ArrayList<>();
 
         /** Clears all DNS entries. */
         private synchronized void clearAll() {
             mAnswers.clear();
         }
 
-        /** Returns the answer for a given name and type on the given mock network. */
-        private synchronized List<InetAddress> getAnswer(Object mock, String hostname, int type) {
+        /** Returns the answer for a given name and type. */
+        private synchronized List<InetAddress> getAnswer(String hostname, int type) {
             return mAnswers.stream()
                     .filter(e -> e.matches(hostname, type))
                     .map(answer -> answer.mAddresses)
                     .findFirst()
-                    .orElse(null);
+                    .orElse(List.of());
         }
 
         /** Sets the answer for a given name and type. */
@@ -907,10 +991,20 @@
         }
 
         // Regardless of the type, depends on what the responses contained in the network.
-        private List<InetAddress> queryAllTypes(Object mock, String hostname) {
+        private List<InetAddress> queryIpv4(String hostname) {
+            return getAnswer(hostname, TYPE_A);
+        }
+
+        // Regardless of the type, depends on what the responses contained in the network.
+        private List<InetAddress> queryIpv6(String hostname) {
+            return getAnswer(hostname, TYPE_AAAA);
+        }
+
+        // Regardless of the type, depends on what the responses contained in the network.
+        private List<InetAddress> queryAllTypes(String hostname) {
             List<InetAddress> answer = new ArrayList<>();
-            addAllIfNotNull(answer, getAnswer(mock, hostname, TYPE_A));
-            addAllIfNotNull(answer, getAnswer(mock, hostname, TYPE_AAAA));
+            answer.addAll(queryIpv4(hostname));
+            answer.addAll(queryIpv6(hostname));
             return answer;
         }
 
@@ -922,32 +1016,55 @@
 
         /** Starts mocking DNS queries. */
         private void startMocking() throws UnknownHostException {
+            // 5-arg DnsResolver.query()
             doAnswer(
                             invocation -> {
                                 return mockQuery(
                                         invocation,
                                         1 /* posHostname */,
+                                        -1 /* posType */,
                                         3 /* posExecutor */,
-                                        5 /* posCallback */,
-                                        -1 /* posType */);
+                                        5 /* posCallback */);
                             })
                     .when(mMockDnsResolver)
-                    .query(any(), any(), anyInt(), any(), any(), any());
+                    .query(any(), anyString(), anyInt(), any(), any(), any());
+
+            // 6-arg DnsResolver.query() with explicit query type (IPv4 or v6).
+            doAnswer(
+                            invocation -> {
+                                return mockQuery(
+                                        invocation,
+                                        1 /* posHostname */,
+                                        2 /* posType */,
+                                        4 /* posExecutor */,
+                                        6 /* posCallback */);
+                            })
+                    .when(mMockDnsResolver)
+                    .query(any(), anyString(), anyInt(), anyInt(), any(), any(), any());
         }
 
         // Mocking queries on DnsResolver#query.
         private Answer mockQuery(
                 InvocationOnMock invocation,
                 int posHostname,
+                int posType,
                 int posExecutor,
-                int posCallback,
-                int posType) {
-            String hostname = (String) invocation.getArgument(posHostname);
-            Executor executor = (Executor) invocation.getArgument(posExecutor);
+                int posCallback) {
+            String hostname = invocation.getArgument(posHostname);
+            Executor executor = invocation.getArgument(posExecutor);
             DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(posCallback);
             List<InetAddress> answer;
 
-            answer = queryAllTypes(invocation.getMock(), hostname);
+            switch (posType) {
+                case TYPE_A:
+                    answer = queryIpv4(hostname);
+                    break;
+                case TYPE_AAAA:
+                    answer = queryIpv6(hostname);
+                    break;
+                default:
+                    answer = queryAllTypes(hostname);
+            }
 
             if (answer != null && answer.size() > 0) {
                 new Handler(Looper.getMainLooper())