diff --git a/crates/lib-core/src/utils/analysis/query.rs b/crates/lib-core/src/utils/analysis/query.rs index 583320c96..22eb18c47 100644 --- a/crates/lib-core/src/utils/analysis/query.rs +++ b/crates/lib-core/src/utils/analysis/query.rs @@ -1,6 +1,3 @@ -use std::cell::RefCell; -use std::rc::Rc; - use smol_str::{SmolStr, StrExt, ToSmolStr}; use super::select::SelectStatementColumnsAndTables; @@ -153,31 +150,21 @@ impl Selectable<'_> { } #[derive(Debug, Clone)] -pub struct Query<'me, T> { - pub inner: Rc>>, -} - -#[derive(Debug, Clone)] -pub struct QueryInner<'me, T> { +pub struct Query<'me> { pub query_type: QueryType, pub dialect: &'me Dialect, pub selectables: Vec>, - pub ctes: IndexMap>, - pub parent: Option>, - pub subqueries: Vec>, - pub cte_definition_segment: Option, - pub cte_name_segment: Option, - pub payload: T, + pub ctes: IndexMap, } -impl<'me, T: Clone + Default> Query<'me, T> { +impl<'me> Query<'me> { pub fn crawl_sources( &self, segment: ErasedSegment, - pop: bool, + _pop: bool, lookup_cte: bool, - ) -> Vec> { + ) -> Vec> { let mut acc = Vec::new(); for seg in segment.recursive_crawl( @@ -195,19 +182,15 @@ impl<'me, T: Clone + Default> Query<'me, T> { ) { if seg.is_type(SyntaxKind::TableReference) { let _seg = seg.reference(); - if !_seg.is_qualified() - && lookup_cte - && let Some(cte) = self.lookup_cte(seg.raw().as_ref(), pop) - { - acc.push(Source::Query(cte)); + if !_seg.is_qualified() && lookup_cte { + if let Some(cte) = self.lookup_cte(seg.raw().as_ref()) { + acc.push(Source::Query(cte)); + continue; + } } acc.push(Source::TableReference(seg.raw().clone())); } else { - acc.push(Source::Query(Query::from_segment( - &seg, - self.inner.borrow().dialect, - Some(self.clone()), - ))) + acc.push(Source::Query(Query::from_segment(&seg, self.dialect))) } } @@ -222,54 +205,30 @@ impl<'me, T: Clone + Default> Query<'me, T> { } #[track_caller] - pub fn lookup_cte(&self, name: &str, pop: bool) -> Option> { - let cte = if pop { - self.inner - .borrow_mut() - .ctes - .shift_remove(&name.to_uppercase_smolstr()) - } else { - self.inner - .borrow() - .ctes - .get(&name.to_uppercase_smolstr()) - .cloned() - }; - - cte.or_else(move || { - self.inner - .borrow_mut() - .parent - .as_mut() - .and_then(|it| it.lookup_cte(name, pop)) - }) - } - - fn post_init(&self) { - let this = self.clone(); - - for subquery in &RefCell::borrow(&self.inner).subqueries { - RefCell::borrow_mut(&subquery.inner).parent = this.clone().into(); - } - - for cte in RefCell::borrow(&self.inner).ctes.values().cloned() { - RefCell::borrow_mut(&cte.inner).parent = this.clone().into(); - } + pub fn lookup_cte(&self, name: &str) -> Option> { + let key = name.to_uppercase_smolstr(); + let seg = self.ctes.get(&key)?.clone(); + // Dive into the CTE definition to find the inner query (SELECT/SET/WITH) + let inner = seg.recursive_crawl( + const { &SELECTABLE_TYPES.union(&SUBSELECT_TYPES) }, + true, + &SyntaxSet::EMPTY, + true, + ); + inner.first().map(|qseg| Query::from_segment(qseg, self.dialect)) } } -impl Query<'_, T> { +impl Query<'_> { pub fn children(&self) -> Vec { - self.inner - .borrow() - .ctes - .values() - .chain(self.inner.borrow().subqueries.iter()) - .cloned() - .collect() + let mut acc = Vec::new(); + for selectable in &self.selectables { + acc.extend(Self::extract_subqueries(selectable, self.dialect)); + } + acc } - fn extract_subqueries<'a>(selectable: &Selectable, dialect: &'a Dialect) -> Vec> { + fn extract_subqueries<'a>(selectable: &Selectable, dialect: &'a Dialect) -> Vec> { let mut acc = Vec::new(); for subselect in selectable.selectable.recursive_crawl( @@ -278,16 +237,13 @@ impl Query<'_, T> { &SyntaxSet::EMPTY, false, ) { - acc.push(Query::from_segment(&subselect, dialect, None)); + acc.push(Query::from_segment(&subselect, dialect)); } acc } - pub fn from_root<'a>( - root_segment: &ErasedSegment, - dialect: &'a Dialect, - ) -> Option> { + pub fn from_root<'a>(root_segment: &ErasedSegment, dialect: &'a Dialect) -> Option> { let stmts = root_segment.recursive_crawl( &SELECTABLE_TYPES, true, @@ -296,17 +252,12 @@ impl Query<'_, T> { ); let selectable_segment = stmts.first()?; - Some(Query::from_segment(selectable_segment, dialect, None)) + Some(Query::from_segment(selectable_segment, dialect)) } - pub fn from_segment<'a>( - segment: &ErasedSegment, - dialect: &'a Dialect, - parent: Option>, - ) -> Query<'a, T> { + pub fn from_segment<'a>(segment: &ErasedSegment, dialect: &'a Dialect) -> Query<'a> { let mut selectables = Vec::new(); - let mut subqueries = Vec::new(); - let mut cte_defs: Vec = Vec::new(); + let mut cte_defs: IndexMap = IndexMap::default(); let mut query_type = QueryType::Simple; if segment.is_type(SyntaxKind::SelectStatement) @@ -347,65 +298,22 @@ impl Query<'_, T> { const { &SyntaxSet::single(SyntaxKind::WithCompoundStatement) }, true, ) { - cte_defs.push(seg); + let name_seg = seg.segments()[0].clone(); + let name = name_seg.raw().to_uppercase_smolstr(); + cte_defs.insert(name, seg); } } - for selectable in &selectables { - subqueries.extend(Self::extract_subqueries(selectable, dialect)); - } - - let outer_query = Query { - inner: Rc::new(RefCell::new(QueryInner { - query_type, - dialect, - selectables, - ctes: <_>::default(), - parent, - subqueries, - cte_definition_segment: None, - cte_name_segment: None, - payload: T::default(), - })), - }; - - outer_query.post_init(); - - if cte_defs.is_empty() { - return outer_query; - } - - let mut ctes = IndexMap::default(); - for cte in cte_defs { - let name_seg = cte.segments()[0].clone(); - let name = name_seg.raw().to_uppercase_smolstr(); - - let queries = cte.recursive_crawl( - const { &SELECTABLE_TYPES.union(&SUBSELECT_TYPES) }, - true, - &SyntaxSet::EMPTY, - true, - ); - - if queries.is_empty() { - continue; - }; - - let query = &queries[0]; - let query = Self::from_segment(query, dialect, outer_query.clone().into()); - - RefCell::borrow_mut(&query.inner).cte_definition_segment = cte.into(); - RefCell::borrow_mut(&query.inner).cte_name_segment = name_seg.into(); - - ctes.insert(name, query); + Query { + query_type, + dialect, + selectables, + ctes: cte_defs, } - - RefCell::borrow_mut(&outer_query.inner).ctes = ctes; - outer_query } } -pub enum Source<'a, T> { +pub enum Source<'a> { TableReference(SmolStr), - Query(Query<'a, T>), + Query(Query<'a>), } diff --git a/crates/lib/src/rules/aliasing/al05.rs b/crates/lib/src/rules/aliasing/al05.rs index 6f8befab9..89e287d86 100644 --- a/crates/lib/src/rules/aliasing/al05.rs +++ b/crates/lib/src/rules/aliasing/al05.rs @@ -1,5 +1,3 @@ -use std::cell::RefCell; - use ahash::{AHashMap, AHashSet}; use smol_str::{SmolStr, ToSmolStr}; use sqruff_lib_core::dialects::Dialect; @@ -19,7 +17,7 @@ use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler}; use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups}; #[derive(Default, Clone)] -struct AL05Query { +struct State { aliases: Vec, tbl_refs: Vec, } @@ -86,14 +84,16 @@ FROM foo return Vec::new(); } - let query = Query::from_segment(&context.segment, context.dialect, None); - self.analyze_table_aliases(query.clone(), context.dialect); + let query = Query::from_segment(&context.segment, context.dialect); + let mut state_stack = vec![State::default()]; + self.analyze_table_aliases(query.clone(), context.dialect, &mut state_stack); + let state = state_stack.pop().unwrap(); if context.dialect.name == DialectKind::Redshift { let mut references = AHashSet::default(); let mut aliases = AHashSet::default(); - for alias in &query.inner.borrow().payload.aliases { + for alias in &state.aliases { aliases.insert(alias.ref_str.clone()); if let Some(object_reference) = &alias.object_reference { for seg in object_reference.segments() { @@ -118,17 +118,12 @@ FROM foo } } - for alias in &RefCell::borrow(&query.inner).payload.aliases { + for alias in &state.aliases { if Self::is_alias_required(&alias.from_expression_element, context.dialect.name) { continue; } - if alias.aliased - && !RefCell::borrow(&query.inner) - .payload - .tbl_refs - .contains(&alias.ref_str) - { + if alias.aliased && !state.tbl_refs.contains(&alias.ref_str) { let violation = self.report_unused_alias(alias.clone()); violations.push(violation); } @@ -148,13 +143,14 @@ FROM foo impl RuleAL05 { #[allow(clippy::only_used_in_recursion)] - fn analyze_table_aliases(&self, query: Query, dialect: &Dialect) { - let selectables = std::mem::take(&mut RefCell::borrow_mut(&query.inner).selectables); + fn analyze_table_aliases(&self, query: Query, dialect: &Dialect, state_stack: &mut Vec) { + let selectables = query.selectables.clone(); for selectable in &selectables { if let Some(select_info) = selectable.select_info() { - RefCell::borrow_mut(&query.inner) - .payload + state_stack + .last_mut() + .unwrap() .aliases .extend(select_info.table_aliases); @@ -162,32 +158,25 @@ impl RuleAL05 { for tr in r.extract_possible_references(ObjectReferenceLevel::Table, dialect.name) { - Self::resolve_and_mark_reference(query.clone(), tr.part); + Self::resolve_and_mark_reference(state_stack, tr.part); } } } } - RefCell::borrow_mut(&query.inner).selectables = selectables; - for child in query.children() { - self.analyze_table_aliases(child, dialect); + state_stack.push(State::default()); + self.analyze_table_aliases(child, dialect, state_stack); + state_stack.pop(); } } - fn resolve_and_mark_reference(query: Query, r#ref: String) { - if RefCell::borrow(&query.inner) - .payload - .aliases - .iter() - .any(|it| it.ref_str == r#ref) - { - RefCell::borrow_mut(&query.inner) - .payload - .tbl_refs - .push(r#ref.into()); - } else if let Some(parent) = RefCell::borrow(&query.inner).parent.clone() { - Self::resolve_and_mark_reference(parent, r#ref); + fn resolve_and_mark_reference(state_stack: &mut Vec, r#ref: String) { + for st in state_stack.iter_mut().rev() { + if st.aliases.iter().any(|it| it.ref_str == r#ref) { + st.tbl_refs.push(r#ref.into()); + return; + } } } diff --git a/crates/lib/src/rules/ambiguous/am04.rs b/crates/lib/src/rules/ambiguous/am04.rs index 155074632..237da7c5b 100644 --- a/crates/lib/src/rules/ambiguous/am04.rs +++ b/crates/lib/src/rules/ambiguous/am04.rs @@ -1,7 +1,9 @@ -use ahash::AHashMap; +use ahash::{AHashMap, AHashSet}; +use smol_str::{SmolStr, StrExt}; use sqruff_lib_core::dialects::common::AliasInfo; use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet}; use sqruff_lib_core::parser::segments::ErasedSegment; +use sqruff_lib_core::helpers::IndexMap; use sqruff_lib_core::utils::analysis::query::{Query, Selectable, Source}; use crate::core::config::Value; @@ -18,6 +20,14 @@ const START_TYPES: [SyntaxKind; 3] = [ SyntaxKind::WithCompoundStatement, ]; +// Types used to locate the inner query within a CTE definition, including VALUES. +const INNER_TYPES: [SyntaxKind; 4] = [ + SyntaxKind::WithCompoundStatement, + SyntaxKind::SetExpression, + SyntaxKind::SelectStatement, + SyntaxKind::ValuesClause, +]; + impl Rule for RuleAM04 { fn load_from_config(&self, _config: &AHashMap) -> Result { Ok(RuleAM04.erased()) @@ -72,9 +82,10 @@ SELECT a, b FROM t } fn eval(&self, rule_cx: &RuleContext) -> Vec { - let query = Query::from_segment(&rule_cx.segment, rule_cx.dialect, None); - - let result = self.analyze_result_columns(query); + let query = Query::from_segment(&rule_cx.segment, rule_cx.dialect); + let mut visited = AHashSet::new(); + let env = query.ctes.clone(); + let result = self.analyze_result_columns(query, &mut visited, &env); match result { Ok(_) => { vec![] @@ -94,32 +105,92 @@ SELECT a, b FROM t impl RuleAM04 { /// returns an anchor to the rule - fn analyze_result_columns(&self, query: Query<()>) -> Result<(), ErasedSegment> { - if query.inner.borrow().selectables.is_empty() { + fn analyze_result_columns( + &self, + query: Query, + visited: &mut AHashSet, + env: &IndexMap, + ) -> Result<(), ErasedSegment> { + if query.selectables.is_empty() { return Ok(()); } - let selectables = query.inner.borrow().selectables.clone(); + let selectables = query.selectables.clone(); for selectable in selectables { for wildcard in selectable.wildcard_info() { if !wildcard.tables.is_empty() { for wildcard_table in wildcard.tables { if let Some(alias_info) = selectable.find_alias(&wildcard_table) { - self.handle_alias(&selectable, alias_info, &query)?; + self.handle_alias(&selectable, alias_info, &query, visited, env)?; } else { - let Some(cte) = query.lookup_cte(&wildcard_table, true) else { + let key = wildcard_table.to_uppercase_smolstr(); + if let Some(seg) = env.get(&key) { + if visited.contains(&key) { + return Err(selectable.selectable); + } + // Build inner query from CTE definition + let inner = seg + .recursive_crawl( + const { &SyntaxSet::new(&INNER_TYPES) }, + true, + &SyntaxSet::EMPTY, + true, + ) + .first() + .cloned(); + if let Some(inner) = inner { + let child = Query::from_segment(&inner, query.dialect); + // Merge env with child's own ctes (child overrides) + let mut merged = env.clone(); + merged.extend(child.ctes.clone()); + visited.insert(key.clone()); + let res = self.analyze_result_columns(child, visited, &merged); + visited.remove(&key); + res?; + } else { + return Err(selectable.selectable); + } + } else { return Err(selectable.selectable); - }; - - self.analyze_result_columns(cte)?; + } } } } else { - let selectable = query.inner.borrow().selectables[0].selectable.clone(); + let selectable = query.selectables[0].selectable.clone(); for source in query.crawl_sources(selectable.clone(), false, true) { - if let Source::Query(query) = source { - self.analyze_result_columns(query)?; - return Ok(()); + match source { + Source::Query(q) => { + let mut merged = env.clone(); + merged.extend(q.ctes.clone()); + self.analyze_result_columns(q, visited, &merged)?; + return Ok(()); + } + Source::TableReference(name) => { + let key = name.to_uppercase_smolstr(); + if let Some(seg) = env.get(&key) { + if visited.contains(&key) { + return Err(selectable.clone()); + } + let inner = seg + .recursive_crawl( + const { &SyntaxSet::new(&INNER_TYPES) }, + true, + &SyntaxSet::EMPTY, + true, + ) + .first() + .cloned(); + if let Some(inner) = inner { + let child = Query::from_segment(&inner, query.dialect); + let mut merged = env.clone(); + merged.extend(child.ctes.clone()); + visited.insert(key.clone()); + self.analyze_result_columns(child, visited, &merged)?; + visited.remove(&key); + return Ok(()); + } + } + } } } @@ -135,7 +206,9 @@ impl RuleAM04 { &self, selectable: &Selectable, alias_info: AliasInfo, - query: &Query<'_, ()>, + query: &Query<'_>, + visited: &mut AHashSet, + env: &IndexMap, ) -> Result<(), ErasedSegment> { let select_info_target = query .crawl_sources(alias_info.from_expression_element, false, true) @@ -143,8 +216,41 @@ impl RuleAM04 { .next() .unwrap(); match select_info_target { - Source::TableReference(_) => Err(selectable.selectable.clone()), - Source::Query(query) => self.analyze_result_columns(query), + Source::TableReference(name) => { + let key = name.to_uppercase_smolstr(); + if let Some(seg) = env.get(&key) { + if visited.contains(&key) { + return Err(selectable.selectable.clone()); + } + let inner = seg + .recursive_crawl( + const { &SyntaxSet::new(&INNER_TYPES) }, + true, + &SyntaxSet::EMPTY, + true, + ) + .first() + .cloned(); + if let Some(inner) = inner { + let child = Query::from_segment(&inner, query.dialect); + let mut merged = env.clone(); + merged.extend(child.ctes.clone()); + visited.insert(key.clone()); + let res = self.analyze_result_columns(child, visited, &merged); + visited.remove(&key); + res + } else { + Err(selectable.selectable.clone()) + } + } else { + Err(selectable.selectable.clone()) + } + } + Source::Query(q) => { + let mut merged = env.clone(); + merged.extend(q.ctes.clone()); + self.analyze_result_columns(q, visited, &merged) + } } } } diff --git a/crates/lib/src/rules/ambiguous/am07.rs b/crates/lib/src/rules/ambiguous/am07.rs index 10feb8c89..56988ccc6 100644 --- a/crates/lib/src/rules/ambiguous/am07.rs +++ b/crates/lib/src/rules/ambiguous/am07.rs @@ -1,4 +1,6 @@ +use ahash::AHashSet; use ahash::{AHashMap, HashSet, HashSetExt}; +use smol_str::{SmolStr, StrExt}; use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet}; use sqruff_lib_core::utils::analysis::query::{Query, Selectable, Source, WildcardInfo}; @@ -84,7 +86,7 @@ FROM t } } - let query: Query<()> = Query::from_segment(root, context.dialect, None); + let query: Query = Query::from_segment(root, context.dialect); let (set_segment_select_sizes, resolve_wildcard) = self.get_select_target_counts(query); // if queries had different select target counts and all wildcards had been @@ -117,13 +119,15 @@ impl RuleAM07 { /// can't guarantee that we can always resolve any wildcards (*), so /// we also return a flag to indicate whether any present have been /// fully resolved. - fn get_select_target_counts(&self, query: Query<()>) -> (HashSet, bool) { + fn get_select_target_counts(&self, query: Query) -> (HashSet, bool) { let mut select_target_counts = HashSet::new(); let mut resolved_wildcard = true; + let mut visited = AHashSet::::new(); - let selectables = query.inner.borrow().selectables.clone(); + let selectables = query.selectables.clone(); for selectable in selectables { - let (cnt, res) = self.resolve_selectable(selectable.clone(), query.clone()); + let (cnt, res) = + self.resolve_selectable(selectable.clone(), query.clone(), &mut visited); if !res { resolved_wildcard = false; } @@ -137,7 +141,12 @@ impl RuleAM07 { /// /// The selectable may opr may not have (*) wildcard expressions. If it /// does, we attempt to resolve them. - fn resolve_selectable(&self, selectable: Selectable, root_query: Query<()>) -> (usize, bool) { + fn resolve_selectable( + &self, + selectable: Selectable, + root_query: Query, + visited: &mut AHashSet, + ) -> (usize, bool) { debug_assert!(selectable.select_info().is_some()); let wildcard_info = selectable.wildcard_info(); @@ -155,8 +164,12 @@ impl RuleAM07 { // If the set query contains one or more wildcards, attempt to resolve it to a // list of select targets that can be counted. for wildcard in wildcard_info { - let (_cols, _resolved) = - self.resolve_selectable_wildcard(wildcard, selectable.clone(), root_query.clone()); + let (_cols, _resolved) = self.resolve_selectable_wildcard( + wildcard, + selectable.clone(), + root_query.clone(), + visited, + ); resolved = resolved && _resolved; // Add on the number of columns which the wildcard resolves to. num_cols += _cols; @@ -176,14 +189,14 @@ impl RuleAM07 { /// only called on any subqueries (which may themselves be SELECT, /// WITH or set expressions) found during the resolution of any /// wildcards. - fn resolve_wild_query(&self, query: Query<()>) -> (usize, bool) { + fn resolve_wild_query(&self, query: Query, visited: &mut AHashSet) -> (usize, bool) { // if one of the source queries for a query within the set is a // set expression, just use the first query. If that first query isn't // reflective of the others, that will be caught when that segment // is processed. We'll know if we're in a set based on whether there // is more than one selectable. i.e. Just take the first selectable. - let selectable = query.inner.borrow().selectables[0].clone(); - self.resolve_selectable(selectable, query.clone()) + let selectable = query.selectables[0].clone(); + self.resolve_selectable(selectable, query.clone(), visited) } /// Attempt to resolve a single wildcard (*) within a Selectable. @@ -195,7 +208,8 @@ impl RuleAM07 { &self, wildcard: WildcardInfo, selectable: Selectable, - root_query: Query<()>, + root_query: Query, + visited: &mut AHashSet, ) -> (usize, bool) { let mut resolved = true; @@ -204,7 +218,7 @@ impl RuleAM07 { // Crawl the query looking for the subquery, problem in the FROM. for source in root_query.crawl_sources(selectable.selectable, false, true) { if let Source::Query(query) = source { - return self.resolve_wild_query(query); + return self.resolve_wild_query(query, visited); } } return (0, false); @@ -229,7 +243,7 @@ impl RuleAM07 { cte_name = name; } Source::Query(query) => { - let (_cols, _resolved) = self.resolve_wild_query(query); + let (_cols, _resolved) = self.resolve_wild_query(query, visited); num_columns += _cols; resolved = resolved && _resolved; continue; @@ -237,9 +251,17 @@ impl RuleAM07 { } } - let cte = root_query.lookup_cte(&cte_name, true); + let cte = root_query.lookup_cte(&cte_name); if let Some(cte) = cte { - let (cols, _resolved) = self.resolve_wild_query(cte); + let key = cte_name.to_uppercase_smolstr(); + if visited.contains(&key) { + // cyclic reference, cannot resolve + resolved = false; + continue; + } + visited.insert(key.clone()); + let (cols, _resolved) = self.resolve_wild_query(cte, visited); + visited.remove(&key); num_columns += cols; resolved = resolved && _resolved; } else { diff --git a/crates/lib/src/rules/references/rf01.rs b/crates/lib/src/rules/references/rf01.rs index bc0f0b01b..1f8019d1b 100644 --- a/crates/lib/src/rules/references/rf01.rs +++ b/crates/lib/src/rules/references/rf01.rs @@ -1,5 +1,3 @@ -use std::cell::RefCell; - use ahash::AHashMap; use itertools::Itertools; use smol_str::SmolStr; @@ -19,7 +17,7 @@ use crate::core::rules::reference::object_ref_matches_table; use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups}; #[derive(Debug, Default, Clone)] -struct RF01Query { +struct State { aliases: Vec, standalone_aliases: Vec, } @@ -36,7 +34,7 @@ impl RuleRF01 { r: &ObjectReferenceSegment, tbl_refs: Vec<(ObjectReferencePart, Vec)>, dml_target_table: &[SmolStr], - query: Query, + state_stack: &[State], ) -> Option { let possible_references: Vec<_> = tbl_refs .clone() @@ -45,32 +43,31 @@ impl RuleRF01 { .collect(); let mut targets = vec![]; + for st in state_stack.iter().rev() { + for alias in &st.aliases { + if alias.aliased { + targets.push(vec![alias.ref_str.clone()]); + } - for alias in &RefCell::borrow(&query.inner).payload.aliases { - if alias.aliased { - targets.push(vec![alias.ref_str.clone()]); - } - - if let Some(object_reference) = &alias.object_reference { - let references = object_reference - .reference() - .iter_raw_references() - .into_iter() - .map(|it| it.part.into()) - .collect_vec(); + if let Some(object_reference) = &alias.object_reference { + let references = object_reference + .reference() + .iter_raw_references() + .into_iter() + .map(|it| it.part.into()) + .collect_vec(); - targets.push(references); + targets.push(references); + } } - } - for standalone_alias in &RefCell::borrow(&query.inner).payload.standalone_aliases { - targets.push(vec![standalone_alias.clone()]); + for standalone_alias in &st.standalone_aliases { + targets.push(vec![standalone_alias.clone()]); + } } if !object_ref_matches_table(&possible_references, &targets) { - if let Some(parent) = RefCell::borrow(&query.inner).parent.clone() { - return self.resolve_reference(r, tbl_refs.clone(), dml_target_table, parent); - } else if dml_target_table.is_empty() + if dml_target_table.is_empty() || !object_ref_matches_table(&possible_references, &[dml_target_table.to_vec()]) { return LintResult::new( @@ -120,41 +117,42 @@ impl RuleRF01 { fn analyze_table_references( &self, - query: Query, + query: Query, dml_target_table: &[SmolStr], violations: &mut Vec, + state_stack: &mut Vec, ) { - let selectables = std::mem::take(&mut RefCell::borrow_mut(&query.inner).selectables); + let selectables = query.selectables.clone(); + let mut state = State::default(); for selectable in &selectables { if let Some(select_info) = selectable.select_info() { - RefCell::borrow_mut(&query.inner) - .payload - .aliases - .extend(select_info.table_aliases); - RefCell::borrow_mut(&query.inner) - .payload + state.aliases.extend(select_info.table_aliases); + state .standalone_aliases .extend(select_info.standalone_aliases); for r in select_info.reference_buffer { if !self.should_ignore_reference(&r, selectable) { + let dialect = query.dialect; + state_stack.push(state.clone()); let violation = self.resolve_reference( &r, - self.get_table_refs(&r, RefCell::borrow(&query.inner).dialect), + self.get_table_refs(&r, dialect), dml_target_table, - query.clone(), + state_stack, ); + state_stack.pop(); violations.extend(violation); } } } } - RefCell::borrow_mut(&query.inner).selectables = selectables; - for child in query.children() { - self.analyze_table_references(child, dml_target_table, violations); + state_stack.push(state.clone()); + self.analyze_table_references(child, dml_target_table, violations, state_stack); + state_stack.pop(); } } @@ -233,7 +231,7 @@ FROM foo } fn eval(&self, context: &RuleContext) -> Vec { - let query = Query::from_segment(&context.segment, context.dialect, None); + let query = Query::from_segment(&context.segment, context.dialect); let mut violations = Vec::new(); let tmp; @@ -260,7 +258,8 @@ FROM foo &[] }; - self.analyze_table_references(query, dml_target_table, &mut violations); + let mut state_stack = Vec::new(); + self.analyze_table_references(query, dml_target_table, &mut violations, &mut state_stack); violations } diff --git a/crates/lib/src/rules/references/rf03.rs b/crates/lib/src/rules/references/rf03.rs index e68925eae..73655a380 100644 --- a/crates/lib/src/rules/references/rf03.rs +++ b/crates/lib/src/rules/references/rf03.rs @@ -1,5 +1,3 @@ -use std::cell::RefCell; - use ahash::{AHashMap, AHashSet}; use itertools::Itertools; use smol_str::SmolStr; @@ -28,14 +26,14 @@ impl RuleRF03 { tables: &Tables, single_table_references: &str, is_struct_dialect: bool, - query: Query<()>, + query: Query, _visited: &mut AHashSet, ) -> Vec { #[allow(unused_assignments)] let mut select_info = None; let mut acc = Vec::new(); - let selectables = &RefCell::borrow(&query.inner).selectables; + let selectables = &query.selectables; if !selectables.is_empty() { select_info = selectables[0].select_info(); @@ -47,8 +45,6 @@ impl RuleRF03 { let mut fixable = true; let possible_ref_tables = iter_available_targets(query.clone()); - if let Some(_parent) = &RefCell::borrow(&query.inner).parent {} - if possible_ref_tables.len() > 1 { fixable = false; } @@ -84,8 +80,8 @@ impl RuleRF03 { } } -fn iter_available_targets(query: Query<()>) -> Vec { - RefCell::borrow(&query.inner) +fn iter_available_targets(query: Query) -> Vec { + query .selectables .iter() .flat_map(|selectable| { @@ -367,7 +363,7 @@ FROM foo .unwrap() }); - let query: Query<()> = Query::from_segment(&context.segment, context.dialect, None); + let query: Query = Query::from_segment(&context.segment, context.dialect); let mut visited: AHashSet = AHashSet::new(); let is_struct_dialect = self.dialect_skip().contains(&context.dialect.name); diff --git a/crates/lib/src/rules/structure/st03.rs b/crates/lib/src/rules/structure/st03.rs index 79adbab65..9bc18ea67 100644 --- a/crates/lib/src/rules/structure/st03.rs +++ b/crates/lib/src/rules/structure/st03.rs @@ -1,5 +1,3 @@ -use std::cell::RefCell; - use ahash::AHashMap; use smol_str::StrExt; use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet}; @@ -69,9 +67,9 @@ FROM cte1 fn eval(&self, context: &RuleContext) -> Vec { let mut result = Vec::new(); - let query: Query<'_, ()> = Query::from_root(&context.segment, context.dialect).unwrap(); + let query: Query = Query::from_root(&context.segment, context.dialect).unwrap(); - let mut remaining_ctes: IndexMap<_, _> = RefCell::borrow(&query.inner) + let mut remaining_ctes: IndexMap<_, _> = query .ctes .keys() .map(|it| (it.to_uppercase_smolstr(), it.clone())) @@ -87,14 +85,14 @@ FROM cte1 } for name in remaining_ctes.values() { - let tmp = RefCell::borrow(&query.inner); - let cte = RefCell::borrow(&tmp.ctes[name].inner); + let cte_def = &query.ctes[name]; + let name_seg = cte_def.segments()[0].clone(); result.push(LintResult::new( - cte.cte_name_segment.clone(), + Some(name_seg.clone()), Vec::new(), Some(format!( "Query defines CTE \"{}\" but does not use it.", - cte.cte_name_segment.as_ref().unwrap().raw() + name_seg.raw() )), None, )); diff --git a/crates/lib/src/rules/structure/st05.rs b/crates/lib/src/rules/structure/st05.rs index ec8852c37..1f7f5ca14 100644 --- a/crates/lib/src/rules/structure/st05.rs +++ b/crates/lib/src/rules/structure/st05.rs @@ -40,7 +40,7 @@ fn config_mapping(key: &str) -> SyntaxSet { #[allow(dead_code)] struct NestedSubQuerySummary<'a> { - query: Query<'a, ()>, + query: Query<'a>, selectable: Selectable<'a>, table_alias: AliasInfo, select_source_names: AHashSet, @@ -118,11 +118,11 @@ join c using(x) return Vec::new(); } - let query: Query<'_, ()> = Query::from_segment(&context.segment, context.dialect, None); + let query: Query<'_> = Query::from_segment(&context.segment, context.dialect); let mut ctes = CTEBuilder::default(); - for cte in query.inner.borrow().ctes.values() { - ctes.insert_cte(cte.inner.borrow().cte_definition_segment.clone().unwrap()); + for cte_def in query.ctes.values() { + ctes.insert_cte(cte_def.clone()); } let is_with = segment.all(Some(|it: &ErasedSegment| { @@ -263,7 +263,7 @@ impl RuleST05 { &self, tables: &Tables, dialect: &'a Dialect, - query: Query<'a, ()>, + query: Query<'a>, ctes: &mut CTEBuilder, case_preference: Case, segment_clone_map: &SegmentCloneMap, @@ -290,11 +290,11 @@ impl RuleST05 { ctes.insert_cte(new_cte); - if nsq.query.inner.borrow().selectables.len() != 1 { + if nsq.query.selectables.len() != 1 { continue; } - let select = nsq.query.inner.borrow().selectables[0].clone().selectable; + let select = nsq.query.selectables[0].clone().selectable; let anchor = anchor.recursive_crawl( const { &SyntaxSet::new(&[ @@ -325,7 +325,7 @@ impl RuleST05 { res, nsq.table_alias.from_expression_element, alias_name.clone(), - nsq.query.inner.borrow().selectables[0].clone().selectable, + nsq.query.selectables[0].clone().selectable, )); } @@ -334,17 +334,23 @@ impl RuleST05 { fn nested_subqueries<'a>( &self, - query: Query<'a, ()>, + query: Query<'a>, dialect: &'a Dialect, ) -> Vec> { let mut acc = Vec::new(); let parent_types = config_mapping(&self.forbid_subquery_in); let mut queries = vec![query.clone()]; - queries.extend(query.inner.borrow().ctes.values().cloned()); + queries.extend( + query + .ctes + .values() + .cloned() + .map(|seg| Query::from_segment(&seg, dialect)), + ); for (i, q) in enumerate(queries) { - for selectable in &q.inner.borrow().selectables { + for selectable in &q.selectables { let Some(select_info) = selectable.select_info() else { continue; }; @@ -360,7 +366,7 @@ impl RuleST05 { } let Some(query) = - Query::<()>::from_root(&table_alias.from_expression_element, dialect) + Query::from_root(&table_alias.from_expression_element, dialect) else { continue; }; @@ -378,17 +384,7 @@ impl RuleST05 { } if is_correlated_subquery( - Segments::new( - query - .inner - .borrow() - .selectables - .first() - .unwrap() - .selectable - .clone(), - None, - ), + Segments::new(query.selectables.first().unwrap().selectable.clone(), None), &select_source_names, dialect, ) { diff --git a/crates/sqlinference/src/columns.rs b/crates/sqlinference/src/columns.rs index 63e9c4fcb..05da27bb7 100644 --- a/crates/sqlinference/src/columns.rs +++ b/crates/sqlinference/src/columns.rs @@ -19,8 +19,8 @@ pub fn get_columns_internal( let mut columns: Vec = vec![]; let mut unnamed: Vec = vec![]; - let query: Query<()> = Query::from_root(&ast, parser.dialect()).unwrap(); - let ast = query.inner.borrow().selectables[0].selectable.clone(); + let query: Query = Query::from_root(&ast, parser.dialect()).unwrap(); + let ast = query.selectables[0].selectable.clone(); for segment in ast.recursive_crawl( const { &SyntaxSet::new(&[SyntaxKind::SelectClauseElement]) }, diff --git a/crates/sqlinference/src/infer_tests.rs b/crates/sqlinference/src/infer_tests.rs index f3ca66131..9fa7f7d77 100644 --- a/crates/sqlinference/src/infer_tests.rs +++ b/crates/sqlinference/src/infer_tests.rs @@ -386,7 +386,7 @@ pub fn get_column_with_source( select_statement: &str, ) -> Result { let ast = parse_sql(parser, select_statement); - let query: Query<()> = Query::from_root(&ast, parser.dialect()).unwrap(); + let query: Query = Query::from_root(&ast, parser.dialect()).unwrap(); extract_select(&query) } @@ -409,29 +409,27 @@ type OperatedOn = HashMap; /// statement. The map in the result is from the final column name to the source /// column name and source table name. Also returns an array of unrecognized /// columns. -fn extract_select(query: &Query<'_, ()>) -> Result { - let with_extracted: Option> = - if query.inner.borrow().ctes.is_empty() { - None - } else { - query - .inner - .borrow() - .ctes - .iter() - .rev() - .map(|(name, query)| { - let select = extract_select(query)?; - Ok(Some((name.to_lowercase(), select))) - }) - .collect::>, String>>()? - }; +fn extract_select(query: &Query<'_>) -> Result { + let with_extracted: Option> = if query.ctes.is_empty() { + None + } else { + query + .ctes + .iter() + .rev() + .map(|(name, seg)| { + let q = Query::from_segment(seg, query.dialect); + let select = extract_select(&q)?; + Ok(Some((name.to_lowercase(), select))) + }) + .collect::>, String>>()? + }; - let main_extracted = if let Some(from_clause) = query.inner.borrow().selectables[0] + let main_extracted = if let Some(from_clause) = query.selectables[0] .selectable .child(const { &SyntaxSet::single(SyntaxKind::FromClause) }) { - let has_group_by = query.inner.borrow().selectables[0] + let has_group_by = query.selectables[0] .selectable .child(const { &SyntaxSet::single(SyntaxKind::GroupbyClause) }) .is_some(); @@ -447,7 +445,7 @@ fn extract_select(query: &Query<'_, ()>) -> Result { .next() .unwrap(); - let select_clause = query.inner.borrow().selectables[0] + let select_clause = query.selectables[0] .selectable .child(const { &SyntaxSet::single(SyntaxKind::SelectClause) }) .unwrap(); @@ -459,7 +457,7 @@ fn extract_select(query: &Query<'_, ()>) -> Result { false, ); - let extracted_table = extract_table(&relation, query.inner.borrow().dialect)?; + let extracted_table = extract_table(&relation, query.dialect)?; let mut extracted_tables = vec![extracted_table]; let joins = from_clause.recursive_crawl( @@ -469,7 +467,7 @@ fn extract_select(query: &Query<'_, ()>) -> Result { false, ); if !joins.is_empty() { - let extracted = extract_extracted_from_joins(joins, query.inner.borrow().dialect)?; + let extracted = extract_extracted_from_joins(joins, query.dialect)?; extracted_tables.extend(extracted); } @@ -1020,7 +1018,7 @@ fn extract_table(table_factor: &ErasedSegment, dialect: &Dialect) -> Result