Merge "Update google-java-format-diff" into main
diff --git a/common/google-java-format/google-java-format-diff.py b/common/google-java-format/google-java-format-diff.py
index 151ae33..7f52ed1 100755
--- a/common/google-java-format/google-java-format-diff.py
+++ b/common/google-java-format/google-java-format-diff.py
@@ -1,13 +1,13 @@
 #!/usr/bin/env python3
 #
-#===- google-java-format-diff.py - google-java-format Diff Reformatter -----===#
+# ===- google-java-format-diff.py - google-java-format Diff Reformatter -----===#
 #
 #                     The LLVM Compiler Infrastructure
 #
 # This file is distributed under the University of Illinois Open Source
 # License. See LICENSE.TXT for details.
 #
-#===------------------------------------------------------------------------===#
+# ===------------------------------------------------------------------------===#
 
 """
 google-java-format Diff Reformatter
@@ -33,8 +33,34 @@
 import subprocess
 import io
 import sys
+from concurrent.futures import ThreadPoolExecutor,wait,FIRST_EXCEPTION
 from shutil import which
 
+def _apply_format(filename, lines, base_command, args):
+  """Apply format on filename."""
+  if args.i and args.verbose:
+    print('Formatting', filename)
+
+  command = base_command[:]
+  command.extend(lines)
+  command.append(filename)
+  p = subprocess.Popen(command, stdout=subprocess.PIPE,
+                       stderr=None, stdin=subprocess.PIPE)
+  stdout, _ = p.communicate()
+  if p.returncode != 0:
+    sys.exit(p.returncode)
+
+  if not args.i:
+    with open(filename) as f:
+      code = f.readlines()
+    formatted_code = io.StringIO(stdout.decode('utf-8')).readlines()
+    diff = difflib.unified_diff(code, formatted_code,
+                                filename, filename,
+                                '(before formatting)', '(after formatting)')
+    diff_string = ''.join(diff)
+    if len(diff_string) > 0:
+      sys.stdout.write(diff_string)
+
 def main():
   parser = argparse.ArgumentParser(description=
                                    'Reformat changed lines in diff. Without -i '
@@ -75,7 +101,7 @@
   lines_by_file = {}
 
   for line in sys.stdin:
-    match = re.search('^\+\+\+\ (.*?/){%s}(\S*)' % args.p, line)
+    match = re.search(r'^\+\+\+\ (.*?/){%s}(\S*)' % args.p, line)
     if match:
       filename = match.group(2)
     if filename == None:
@@ -88,7 +114,7 @@
       if not re.match('^%s$' % args.iregex, filename, re.IGNORECASE):
         continue
 
-    match = re.search('^@@.*\+(\d+)(,(\d+))?', line)
+    match = re.search(r'^@@.*\+(\d+)(,(\d+))?', line)
     if match:
       start_line = int(match.group(1))
       line_count = 1
@@ -108,39 +134,29 @@
     binary = which('google-java-format') or '/usr/bin/google-java-format'
     base_command = [binary]
 
-  # Reformat files containing changes in place.
-  for filename, lines in lines_by_file.items():
-    if args.i and args.verbose:
-      print('Formatting', filename)
-    command = base_command[:]
-    if args.i:
-      command.append('-i')
-    if args.aosp:
-      command.append('--aosp')
-    if args.skip_sorting_imports:
-      command.append('--skip-sorting-imports')
-    if args.skip_removing_unused_imports:
-      command.append('--skip-removing-unused-imports')
-    if args.skip_javadoc_formatting:
-      command.append('--skip-javadoc-formatting')
-    command.extend(lines)
-    command.append(filename)
-    p = subprocess.Popen(command, stdout=subprocess.PIPE,
-                         stderr=None, stdin=subprocess.PIPE)
-    stdout, stderr = p.communicate()
-    if p.returncode != 0:
-      sys.exit(p.returncode);
+  if args.i:
+    base_command.append('-i')
+  if args.aosp:
+    base_command.append('--aosp')
+  if args.skip_sorting_imports:
+    base_command.append('--skip-sorting-imports')
+  if args.skip_removing_unused_imports:
+    base_command.append('--skip-removing-unused-imports')
+  if args.skip_javadoc_formatting:
+    base_command.append('--skip-javadoc-formatting')
 
-    if not args.i:
-      with open(filename) as f:
-        code = f.readlines()
-      formatted_code = io.StringIO(stdout.decode('utf-8')).readlines()
-      diff = difflib.unified_diff(code, formatted_code,
-                                  filename, filename,
-                                  '(before formatting)', '(after formatting)')
-      diff_string = ''.join(diff)
-      if len(diff_string) > 0:
-        sys.stdout.write(diff_string)
+  with ThreadPoolExecutor() as executor:
+    format_futures = []
+    for filename, lines in lines_by_file.items():
+      format_futures.append(
+          executor.submit(_apply_format, filename, lines, base_command, args)
+      )
+
+    done, _ = wait(format_futures, return_when=FIRST_EXCEPTION)
+    for future in done:
+      if exception := future.exception():
+        executor.shutdown(wait=True, cancel_futures=True)
+        sys.exit(exception.args[0])
 
 if __name__ == '__main__':
   main()