blob: 61e2cb711a51a5c1db31e2295d6166e43d2625f7 [file] [log] [blame]
// Inspired by Clang's clang-format-diff:
//
// https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/clang-format-diff.py
#![deny(warnings)]
#[macro_use]
extern crate tracing;
use serde::{Deserialize, Serialize};
use serde_json as json;
use thiserror::Error;
use tracing_subscriber::EnvFilter;
use std::collections::HashSet;
use std::env;
use std::ffi::OsStr;
use std::io::{self, BufRead};
use std::process;
use regex::Regex;
use clap::{CommandFactory, Parser};
/// The default pattern of files to format.
///
/// We only want to format rust files by default.
const DEFAULT_PATTERN: &str = r".*\.rs";
#[derive(Error, Debug)]
enum FormatDiffError {
#[error("{0}")]
IncorrectOptions(#[from] getopts::Fail),
#[error("{0}")]
IncorrectFilter(#[from] regex::Error),
#[error("{0}")]
IoError(#[from] io::Error),
}
#[derive(Parser, Debug)]
#[clap(
name = "rustfmt-format-diff",
disable_version_flag = true,
next_line_help = true
)]
pub struct Opts {
/// Skip the smallest prefix containing NUMBER slashes
#[clap(
short = 'p',
long = "skip-prefix",
value_name = "NUMBER",
default_value = "0"
)]
skip_prefix: u32,
/// Custom pattern selecting file paths to reformat
#[clap(
short = 'f',
long = "filter",
value_name = "PATTERN",
default_value = DEFAULT_PATTERN
)]
filter: String,
}
fn main() {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_env("RUSTFMT_LOG"))
.init();
let opts = Opts::parse();
if let Err(e) = run(opts) {
println!("{e}");
Opts::command()
.print_help()
.expect("cannot write to stdout");
process::exit(1);
}
}
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
struct Range {
file: String,
range: [u32; 2],
}
fn run(opts: Opts) -> Result<(), FormatDiffError> {
let (files, ranges) = scan_diff(io::stdin(), opts.skip_prefix, &opts.filter)?;
run_rustfmt(&files, &ranges)
}
fn run_rustfmt(files: &HashSet<String>, ranges: &[Range]) -> Result<(), FormatDiffError> {
if files.is_empty() || ranges.is_empty() {
debug!("No files to format found");
return Ok(());
}
let ranges_as_json = json::to_string(ranges).unwrap();
debug!("Files: {:?}", files);
debug!("Ranges: {:?}", ranges);
let rustfmt_var = env::var_os("RUSTFMT");
let rustfmt = match &rustfmt_var {
Some(rustfmt) => rustfmt,
None => OsStr::new("rustfmt"),
};
let exit_status = process::Command::new(rustfmt)
.args(files)
.arg("--file-lines")
.arg(ranges_as_json)
.status()?;
if !exit_status.success() {
return Err(FormatDiffError::IoError(io::Error::new(
io::ErrorKind::Other,
format!("rustfmt failed with {exit_status}"),
)));
}
Ok(())
}
/// Scans a diff from `from`, and returns the set of files found, and the ranges
/// in those files.
fn scan_diff<R>(
from: R,
skip_prefix: u32,
file_filter: &str,
) -> Result<(HashSet<String>, Vec<Range>), FormatDiffError>
where
R: io::Read,
{
let diff_pattern = format!(r"^\+\+\+\s(?:.*?/){{{skip_prefix}}}(\S*)");
let diff_pattern = Regex::new(&diff_pattern).unwrap();
let lines_pattern = Regex::new(r"^@@.*\+(\d+)(,(\d+))?").unwrap();
let file_filter = Regex::new(&format!("^{file_filter}$"))?;
let mut current_file = None;
let mut files = HashSet::new();
let mut ranges = vec![];
for line in io::BufReader::new(from).lines() {
let line = line.unwrap();
if let Some(captures) = diff_pattern.captures(&line) {
current_file = Some(captures.get(1).unwrap().as_str().to_owned());
}
let file = match current_file {
Some(ref f) => &**f,
None => continue,
};
// FIXME(emilio): We could avoid this most of the time if needed, but
// it's not clear it's worth it.
if !file_filter.is_match(file) {
continue;
}
let lines_captures = match lines_pattern.captures(&line) {
Some(captures) => captures,
None => continue,
};
let start_line = lines_captures
.get(1)
.unwrap()
.as_str()
.parse::<u32>()
.unwrap();
let line_count = match lines_captures.get(3) {
Some(line_count) => line_count.as_str().parse::<u32>().unwrap(),
None => 1,
};
if line_count == 0 {
continue;
}
let end_line = start_line + line_count - 1;
files.insert(file.to_owned());
ranges.push(Range {
file: file.to_owned(),
range: [start_line, end_line],
});
}
Ok((files, ranges))
}
#[test]
fn scan_simple_git_diff() {
const DIFF: &str = include_str!("test/bindgen.diff");
let (files, ranges) = scan_diff(DIFF.as_bytes(), 1, r".*\.rs").expect("scan_diff failed?");
assert!(
files.contains("src/ir/traversal.rs"),
"Should've matched the filter"
);
assert!(
!files.contains("tests/headers/anon_enum.hpp"),
"Shouldn't have matched the filter"
);
assert_eq!(
&ranges,
&[
Range {
file: "src/ir/item.rs".to_owned(),
range: [148, 158],
},
Range {
file: "src/ir/item.rs".to_owned(),
range: [160, 170],
},
Range {
file: "src/ir/traversal.rs".to_owned(),
range: [9, 16],
},
Range {
file: "src/ir/traversal.rs".to_owned(),
range: [35, 43],
},
]
);
}
#[cfg(test)]
mod cmd_line_tests {
use super::*;
#[test]
fn default_options() {
let empty: Vec<String> = vec![];
let o = Opts::parse_from(&empty);
assert_eq!(DEFAULT_PATTERN, o.filter);
assert_eq!(0, o.skip_prefix);
}
#[test]
fn good_options() {
let o = Opts::parse_from(&["test", "-p", "10", "-f", r".*\.hs"]);
assert_eq!(r".*\.hs", o.filter);
assert_eq!(10, o.skip_prefix);
}
#[test]
fn unexpected_option() {
assert!(
Opts::command()
.try_get_matches_from(&["test", "unexpected"])
.is_err()
);
}
#[test]
fn unexpected_flag() {
assert!(
Opts::command()
.try_get_matches_from(&["test", "--flag"])
.is_err()
);
}
#[test]
fn overridden_option() {
assert!(
Opts::command()
.try_get_matches_from(&["test", "-p", "10", "-p", "20"])
.is_err()
);
}
#[test]
fn negative_filter() {
assert!(
Opts::command()
.try_get_matches_from(&["test", "-p", "-1"])
.is_err()
);
}
}