diff --git a/src/builder.rs b/src/builder.rs index ad0be06..e21057c 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -19,13 +19,19 @@ use std::collections::{HashMap, HashSet}; use crate::capture::{add_capture, Capture}; use crate::query::{NegativeQuery, QueryTree}; use crate::util::parse_number_literal; +use crate::RegexMap; use colored::Colorize; use tree_sitter::{Node, TreeCursor}; /// Translate a parsed and validated input source (specified by `source` and `cursor`) into a `QueryTree`. /// When `is_cpp` is set, C++ specific features are enabled. -pub fn build_query_tree(source: &str, cursor: &mut TreeCursor, is_cpp: bool) -> QueryTree { - _build_query_tree(source, cursor, 0, is_cpp, false, false) +pub fn build_query_tree( + source: &str, + cursor: &mut TreeCursor, + is_cpp: bool, + regex_constraints: Option, +) -> QueryTree { + _build_query_tree(source, cursor, 0, is_cpp, false, false, regex_constraints) } fn _build_query_tree( @@ -35,6 +41,7 @@ fn _build_query_tree( is_cpp: bool, is_multi_pattern: bool, strict_mode: bool, + regex_constraints: Option, ) -> QueryTree { let mut b = QueryBuilder { query_source: source.to_string(), @@ -42,6 +49,10 @@ fn _build_query_tree( negations: Vec::new(), id, cpp: is_cpp, + regex_constraints: match regex_constraints { + Some(r) => r, + None => RegexMap::new(HashMap::new()), + }, }; // Skip the root node if it's a translation_unit. @@ -147,7 +158,7 @@ fn process_captures( Capture::Check(s) => { sexp += &format!(r#"(#eq? @{} "{}")"#, (i + offset).to_string(), s); } - Capture::Variable(var) => { + Capture::Variable(var, _) => { vars.entry(var.clone()) .or_insert_with(Vec::new) .push(i + offset); @@ -165,7 +176,7 @@ fn process_captures( let a = vec[0].to_string(); for capture in vec.iter().skip(1) { let b = capture.to_string(); - sexp += &format!(r#"(#eq? @{} @{})"#, a.to_string(), b.to_string()); + sexp += &format!(r#"(#eq? @{} @{})"#, a, b); } } } @@ -180,6 +191,7 @@ struct QueryBuilder { negations: Vec, // all negative sub queries (not: ) id: usize, // a globally unique ID used for caching results see `query.rs` cpp: bool, // flag to enable C++ support + regex_constraints: RegexMap, } impl QueryBuilder { @@ -245,7 +257,7 @@ impl QueryBuilder { // Anonymous nodes are string constants like "+" or "+=". // We can simply copy them into the query. if !c.node().is_named() { - return format!(r#""{}""#, c.node().kind().to_string()); + return format!(r#""{}""#, c.node().kind()); } let kind = c.node().kind(); @@ -305,6 +317,7 @@ impl QueryBuilder { self.cpp, true, false, // limit strictness to current depth for now + Some(self.regex_constraints.clone()), ))); return "(compound_statement) @".to_string() + &add_capture(&mut self.captures, capture); @@ -353,7 +366,10 @@ impl QueryBuilder { let unquoted = &pattern[1..pattern.len() - 1]; if unquoted.starts_with('$') { - let c = Capture::Variable(unquoted.to_string()); + let c = Capture::Variable( + unquoted.to_string(), + self.regex_constraints.get(unquoted), + ); return format! {"(string_literal) @{}", &add_capture(&mut self.captures, c)}; } } @@ -371,7 +387,7 @@ impl QueryBuilder { let mut result = format!("({}", c.node().kind()); if !c.goto_first_child() { if !c.node().is_named() { - return format!(r#""{}""#, c.node().kind().to_string()); + return format!(r#""{}""#, c.node().kind()); } return result + ")"; } @@ -445,6 +461,7 @@ impl QueryBuilder { self.cpp, false, false, // TODO: should strict mode be supported in NOT queries? + Some(self.regex_constraints.clone()), )), previous_capture_index: before, }); @@ -473,7 +490,7 @@ impl QueryBuilder { }; let capture = if pattern.starts_with('$') { - Capture::Variable(pattern.to_string()) + Capture::Variable(pattern.to_string(), self.regex_constraints.get(pattern)) } else { Capture::Check(pattern.to_string()) }; @@ -481,7 +498,7 @@ impl QueryBuilder { result += " @"; result += &add_capture(&mut self.captures, capture); - return result; + result } // Handle $foo() and _(). Returns None if the call does not need special handling. @@ -518,6 +535,7 @@ impl QueryBuilder { self.cpp, false, strict_mode, + Some(self.regex_constraints.clone()), ))); return Some("_ @".to_string() + &add_capture(&mut self.captures, capture)); } diff --git a/src/capture.rs b/src/capture.rs index e2c8d7a..5705ad6 100644 --- a/src/capture.rs +++ b/src/capture.rs @@ -1,18 +1,19 @@ /* - Copyright 2021 Google LLC +Copyright 2021 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +use regex::Regex; /// We use captures as a way to extend tree-sitter's query mechanism. /// Variable captures correspond to a weggli variable ($foo) and we enforce @@ -23,7 +24,7 @@ #[derive(Debug)] pub enum Capture { Display, - Variable(String), + Variable(String, Option<(bool, Regex)>), Check(String), Number(i128), Subquery(Box), diff --git a/src/lib.rs b/src/lib.rs index 3691e86..b58d861 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ +use std::collections::{hash_map::Keys, HashMap}; + +use regex::Regex; use tree_sitter::{Language, Parser, Query, Tree}; #[macro_use] @@ -60,7 +63,7 @@ fn ts_query(sexpr: &str, cpp: bool) -> tree_sitter::Query { unsafe { tree_sitter_cpp() } }; - match Query::new(language, &sexpr) { + match Query::new(language, sexpr) { Ok(q) => q, Err(e) => { eprintln!( @@ -74,3 +77,25 @@ fn ts_query(sexpr: &str, cpp: bool) -> tree_sitter::Query { } } +/// Map from variable names to a positive/negative regex constraint +/// see --regex +#[derive(Clone)] +pub struct RegexMap(HashMap); + +impl RegexMap { + pub fn new(m: HashMap) -> RegexMap { + RegexMap(m) + } + + pub fn variables(&self) -> Keys { + self.0.keys() + } + + pub fn get(&self, variable: &str) -> Option<(bool, Regex)> { + if let Some((b, r)) = self.0.get(variable) { + Some((*b, r.to_owned())) + } else { + None + } + } +} diff --git a/src/main.rs b/src/main.rs index 7860233..2384e14 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,6 +32,7 @@ use std::{collections::HashSet, fs}; use std::{io::prelude::*, path::PathBuf}; use tree_sitter::Tree; use walkdir::WalkDir; +use weggli::RegexMap; use weggli::builder::build_query_tree; use weggli::query::QueryTree; @@ -51,6 +52,19 @@ fn main() { // Keep track of all variables used in the input pattern(s) let mut variables = HashSet::new(); + // Validate all regular expressions + let regex_constraints = process_regexes(&args.regexes).unwrap_or_else(|e| { + let msg = match e { + RegexError::InvalidArg(s) => format!( + "'{}' is not a valid argument of the form var=regex", + s.red() + ), + RegexError::InvalidRegex(s) => format!("Regex error {}", s), + }; + eprintln!("{}", msg); + std::process::exit(1) + }); + // Normalize all patterns and translate them into QueryTrees // We also extract the identifiers at this point // to use them for file filtering later on. @@ -61,7 +75,7 @@ fn main() { .pattern .iter() .map(|pattern| { - let qt = parse_search_pattern(pattern, args.cpp, args.force_query); + let qt = parse_search_pattern(pattern, args.cpp, args.force_query, ®ex_constraints); let identifiers = qt.identifiers(); variables.extend(qt.variables()); @@ -69,21 +83,12 @@ fn main() { }) .collect(); - // Verify that regular expressions only refer to existing variables - let regex_constraints = process_regexes(&variables, &args.regexes).unwrap_or_else(|e| { - let msg = match e { - RegexError::InvalidArg(s) => format!( - "'{}' is not a valid argument of the form var=regex", - s.red() - ), - RegexError::InvalidVariable(s) => { - format!("'{}' is not a valid query variable", s.red()) - } - RegexError::InvalidRegex(s) => format!("Regex error {}", s), - }; - eprintln!("{}", msg); - std::process::exit(1) - }); + for v in regex_constraints.variables() { + if !variables.contains(v) { + eprintln!("'{}' is not a valid query variable", v.red()); + std::process::exit(1) + } + } // Verify that the --include and --exclude regexes are valid. let helper_regex = |v: &[String]| -> Vec { @@ -156,7 +161,7 @@ fn main() { // on the results. For single query executions, we can // directly print any remaining matches. For multi // query runs we forward them to our next worker function - s.spawn(move |_| execute_queries_worker(ast_rx, results_tx, w, ®ex_constraints, &args)); + s.spawn(move |_| execute_queries_worker(ast_rx, results_tx, w, &args)); if w.len() > 1 { s.spawn(move |_| multi_query_worker(results_rx, w.len(), before, after)); @@ -179,14 +184,19 @@ const VALID_NODE_KINDS: &[&str] = &[ /// We support some basic normalization (adding { } around queries) and store the normalized form /// in `normalized_patterns` to avoid lifetime issues. /// For invalid patterns, validate_query will cause a process exit with a human-readable error message -fn parse_search_pattern(pattern: &String, is_cpp: bool, force_query: bool) -> QueryTree { +fn parse_search_pattern( + pattern: &str, + is_cpp: bool, + force_query: bool, + regex_constraints: &RegexMap, +) -> QueryTree { let mut tree = weggli::parse(pattern, is_cpp); let mut p = pattern; let temp_pattern; // Try to fix missing ';' at the end of a query. - // weggli 'memcpy(a,b,size)' should work. + // weggli 'memcpy(a,b,size)' should work. if tree.root_node().has_error() { if !pattern.ends_with(';') { temp_pattern = format!("{};", &p); @@ -218,9 +228,9 @@ fn parse_search_pattern(pattern: &String, is_cpp: bool, force_query: bool) -> Qu } } - let mut c = validate_query(&tree, &p, force_query); + let mut c = validate_query(&tree, p, force_query); - build_query_tree(&p, &mut c, is_cpp) + build_query_tree(p, &mut c, is_cpp, Some(regex_constraints.clone())) } /// Validates the user supplied search query and quits with an error message in case @@ -301,33 +311,8 @@ fn validate_query<'a>( c } -/// Map from variable names to a positive/negative regex constraint -/// see --regex -struct RegexMap(HashMap); - -impl RegexMap { - /// Returns true if the regex constraints in `self` allow the query result `m` - fn include_match(&self, m: &weggli::result::QueryResult, source: &str) -> bool { - if self.0.is_empty() { - return true; - } - - let mut skip = false; - for (v, r) in &self.0 { - let value = m.value(&v, &source).unwrap(); - skip = if r.1.is_match(value) { r.0 } else { !r.0 }; - - if skip { - break; - } - } - return !skip; - } -} - enum RegexError { InvalidArg(String), - InvalidVariable(String), InvalidRegex(regex::Error), } @@ -337,13 +322,9 @@ impl From for RegexError { } } -/// Validate all passed regexes against the set of query variables. -/// Returns an error if a regex for a non-existing variable is defined, -/// or if an invalid regex is supplied otherwise return a RegexMap -fn process_regexes( - variables: &HashSet, - regexes: &[String], -) -> Result { +/// Validate all passed regexes and compile them. +/// Returns an error if an invalid regex is supplied otherwise return a RegexMap +fn process_regexes(regexes: &[String]) -> Result { let mut result = HashMap::new(); for r in regexes { @@ -362,14 +343,10 @@ fn process_regexes( normalized_var.pop(); // remove ! } - if !variables.contains(&normalized_var) { - return Err(RegexError::InvalidVariable(var.to_string())); - } - let regex = Regex::new(raw_regex)?; result.insert(normalized_var, (negative, regex)); } - Ok(RegexMap(result)) + Ok(RegexMap::new(result)) } /// Recursively iterate through all files under `path` that match an ending listed in `extensions` @@ -415,7 +392,7 @@ struct WorkItem { fn parse_files_worker( files: Vec, sender: Sender<(Arc, Tree, String)>, - work: &Vec, + work: &[WorkItem], is_cpp: bool, ) { files @@ -465,8 +442,7 @@ struct ResultsCtx { fn execute_queries_worker( receiver: Receiver<(Arc, Tree, String)>, results_tx: Sender, - work: &Vec, - constraints: &RegexMap, + work: &[WorkItem], args: &cli::Args, ) { receiver.into_iter().par_bridge().for_each_with( @@ -483,23 +459,14 @@ fn execute_queries_worker( return; } - // Enforce RegEx constraints - let check_constraints = |m: &QueryResult| constraints.include_match(m, &source); - // Enforce --unique let check_unique = |m: &QueryResult| { if args.unique { let mut seen = HashSet::new(); - if !m - .vars + m.vars .keys() .map(|k| m.value(k, &source).unwrap()) .all(|x| seen.insert(x)) - { - false - } else { - true - } } else { true } @@ -541,7 +508,6 @@ fn execute_queries_worker( matches .into_iter() - .filter(check_constraints) .filter(check_unique) .filter(check_limit) .for_each(process_match); @@ -612,7 +578,7 @@ fn reset_signal_pipe_handler() { unsafe { let _ = signal::signal(signal::Signal::SIGPIPE, signal::SigHandler::SigDfl) - .map_err(|e| eprintln!("{}", e.to_string())); + .map_err(|e| eprintln!("{}", e)); } } } diff --git a/src/query.rs b/src/query.rs index aaaa975..9f43b72 100644 --- a/src/query.rs +++ b/src/query.rs @@ -78,7 +78,7 @@ impl<'a> QueryTree { let mut result = HashSet::new(); for c in &self.captures { match c { - Capture::Variable(s) => { + Capture::Variable(s, _) => { result.insert(s.to_string()); } Capture::Subquery(t) => { @@ -89,6 +89,10 @@ impl<'a> QueryTree { } } + for neg in &self.negations { + result.extend(neg.qt.variables()) + } + result } @@ -173,9 +177,9 @@ impl<'a> QueryTree { let negative_results = neg.qt.match_internal(root, source, cache); // check if any of its result are a valid match. - let negative_match = negative_results.into_iter().any(|n| { + negative_results.into_iter().any(|n| { // check if the negative match `m` is consistent with our result - if n.merge(&result, source, false).is_none() { + if n.merge(result, source, false).is_none() { return false; } @@ -196,10 +200,8 @@ impl<'a> QueryTree { } } - return true; - }); - - negative_match + true + }) }); !negative_query_matched @@ -207,7 +209,7 @@ impl<'a> QueryTree { .collect() } - // Process a single tree-sitter match and return all query results + // Process a single tree-sitter match and return all query results // This function is responsible for running all subqueries, // and veriyfing that negations don't match. fn process_match( @@ -237,7 +239,13 @@ impl<'a> QueryTree { } match capture { - Capture::Variable(s) => { + Capture::Variable(s, regex_constraint) => { + if let Some((negative, regex)) = regex_constraint { + let m = regex.is_match(&source[c.node.byte_range()]); + if (m && *negative) || (!m && !*negative) { + return vec![]; + } + } vars.insert(s.clone(), r.len() - 1); } Capture::Subquery(t) => { @@ -305,7 +313,7 @@ impl<'a> QueryTree { .flat_map(move |r| { sub_results .iter() - .filter_map(move |s| r.merge(&s, source, enforce_ordering)) + .filter_map(move |s| r.merge(s, source, enforce_ordering)) }) .collect() } diff --git a/src/result.rs b/src/result.rs index 3884cc4..5ab1db9 100644 --- a/src/result.rs +++ b/src/result.rs @@ -217,7 +217,7 @@ pub fn merge_results( .flat_map(|r| { sub_results .iter() - .filter_map(move |s| r.merge(&s, source, enforce_order)) + .filter_map(move |s| r.merge(s, source, enforce_order)) }) .collect() } diff --git a/src/util.rs b/src/util.rs index 349f599..f74afb8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -45,7 +45,7 @@ pub fn parse_number_literal(input: &str) -> Option { if let Ok(v) = value { if negative { - Some(v * -1) + Some(-v) } else { Some(v) } diff --git a/tests/cli.rs b/tests/cli.rs index 7afcb9b..1f2b9b2 100644 --- a/tests/cli.rs +++ b/tests/cli.rs @@ -166,5 +166,64 @@ fn invalid_utf8() -> Result<(), Box> { cmd.assert().success().stdout(predicate::str::contains( "memcpy")); + Ok(()) +} + +#[test] +fn regex_constraint() -> Result<(), Box> { + let mut cmd = Command::cargo_bin("weggli")?; + + cmd.arg("char $buf[10];") + .arg("./third_party/examples/invalid-utf8.c") + .arg("-Rbuf=buf"); + cmd.assert().success().stdout(predicate::str::contains( + "char buf[10]")); + + let mut cmd = Command::cargo_bin("weggli")?; + + cmd.arg("char $buf[10];") + .arg("./third_party/examples/invalid-utf8.c") + .arg("-Rbuf=foo"); + cmd.assert().success().stdout(predicate::str::is_empty()); + + let mut cmd = Command::cargo_bin("weggli")?; + + cmd.arg("char $buf[10];") + .arg("./third_party/examples/invalid-utf8.c") + .arg("-Rruf=foo"); + cmd.assert().failure() + .stderr(predicate::str::contains("is not a valid query variable")); + + let mut cmd = Command::cargo_bin("weggli")?; + + cmd.arg("char $buf[10];") + .arg("./third_party/examples/invalid-utf8.c") + .arg("-Rbuf!=woof"); + cmd.assert().success().stdout(predicate::str::contains( + "char buf[10]")); + + let mut cmd = Command::cargo_bin("weggli")?; + + cmd.arg("{char buf[10]; not: memcpy($buf, _, _);}") + .arg("./third_party/examples/invalid-utf8.c") + .arg("-Rbuf=woof"); + cmd.assert().success().stdout(predicate::str::contains( + "char buf[10]")); + + let mut cmd = Command::cargo_bin("weggli")?; + + cmd.arg("{char buf[10]; not: memcpy($buf, _, _);}") + .arg("./third_party/examples/invalid-utf8.c") + .arg("-Rbuf=buf"); + cmd.assert().success().stdout(predicate::str::is_empty()); + + let mut cmd = Command::cargo_bin("weggli")?; + + cmd.arg("{char buf[10]; not: memcpy($buf, _, _);}") + .arg("./third_party/examples/invalid-utf8.c") + .arg("-Rbuf!=buf"); + cmd.assert().success().stdout(predicate::str::contains( + "char buf[10]")); + Ok(()) } \ No newline at end of file diff --git a/tests/query.rs b/tests/query.rs index bc6b6ba..9d77db9 100644 --- a/tests/query.rs +++ b/tests/query.rs @@ -28,7 +28,7 @@ fn parse_and_match_helper(needle: &str, source: &str, cpp: bool) -> Vec