Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 44 additions & 136 deletions crates/lib-core/src/utils/analysis/query.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use std::cell::RefCell;
use std::rc::Rc;

use smol_str::{SmolStr, StrExt, ToSmolStr};

use super::select::SelectStatementColumnsAndTables;
Expand Down Expand Up @@ -153,31 +150,21 @@ impl Selectable<'_> {
}

#[derive(Debug, Clone)]
pub struct Query<'me, T> {
pub inner: Rc<RefCell<QueryInner<'me, T>>>,
}

#[derive(Debug, Clone)]
pub struct QueryInner<'me, T> {
pub struct Query<'me> {
pub query_type: QueryType,
pub dialect: &'me Dialect,
pub selectables: Vec<Selectable<'me>>,
pub ctes: IndexMap<SmolStr, Query<'me, T>>,
pub parent: Option<Query<'me, T>>,
pub subqueries: Vec<Query<'me, T>>,
pub cte_definition_segment: Option<ErasedSegment>,
pub cte_name_segment: Option<ErasedSegment>,
pub payload: T,
pub ctes: IndexMap<SmolStr, ErasedSegment>,
}

impl<'me, T: Clone + Default> Query<'me, T> {
impl<'me> Query<'me> {
pub fn crawl_sources(
&self,
segment: ErasedSegment,

pop: bool,
_pop: bool,
lookup_cte: bool,
) -> Vec<Source<'me, T>> {
) -> Vec<Source<'me>> {
let mut acc = Vec::new();

for seg in segment.recursive_crawl(
Expand All @@ -195,19 +182,15 @@ impl<'me, T: Clone + Default> Query<'me, T> {
) {
if seg.is_type(SyntaxKind::TableReference) {
let _seg = seg.reference();
if !_seg.is_qualified()
&& lookup_cte
&& let Some(cte) = self.lookup_cte(seg.raw().as_ref(), pop)
{
acc.push(Source::Query(cte));
if !_seg.is_qualified() && lookup_cte {
if let Some(cte) = self.lookup_cte(seg.raw().as_ref()) {
acc.push(Source::Query(cte));
continue;
}
}
acc.push(Source::TableReference(seg.raw().clone()));
} else {
acc.push(Source::Query(Query::from_segment(
&seg,
self.inner.borrow().dialect,
Some(self.clone()),
)))
acc.push(Source::Query(Query::from_segment(&seg, self.dialect)))
}
}

Expand All @@ -222,54 +205,30 @@ impl<'me, T: Clone + Default> Query<'me, T> {
}

#[track_caller]
pub fn lookup_cte(&self, name: &str, pop: bool) -> Option<Query<'me, T>> {
let cte = if pop {
self.inner
.borrow_mut()
.ctes
.shift_remove(&name.to_uppercase_smolstr())
} else {
self.inner
.borrow()
.ctes
.get(&name.to_uppercase_smolstr())
.cloned()
};

cte.or_else(move || {
self.inner
.borrow_mut()
.parent
.as_mut()
.and_then(|it| it.lookup_cte(name, pop))
})
}

fn post_init(&self) {
let this = self.clone();

for subquery in &RefCell::borrow(&self.inner).subqueries {
RefCell::borrow_mut(&subquery.inner).parent = this.clone().into();
}

for cte in RefCell::borrow(&self.inner).ctes.values().cloned() {
RefCell::borrow_mut(&cte.inner).parent = this.clone().into();
}
pub fn lookup_cte(&self, name: &str) -> Option<Query<'me>> {
let key = name.to_uppercase_smolstr();
let seg = self.ctes.get(&key)?.clone();
// Dive into the CTE definition to find the inner query (SELECT/SET/WITH)
let inner = seg.recursive_crawl(
const { &SELECTABLE_TYPES.union(&SUBSELECT_TYPES) },
true,
&SyntaxSet::EMPTY,
true,
);
inner.first().map(|qseg| Query::from_segment(qseg, self.dialect))
}
}

impl<T: Default + Clone> Query<'_, T> {
impl Query<'_> {
pub fn children(&self) -> Vec<Self> {
self.inner
.borrow()
.ctes
.values()
.chain(self.inner.borrow().subqueries.iter())
.cloned()
.collect()
let mut acc = Vec::new();
for selectable in &self.selectables {
acc.extend(Self::extract_subqueries(selectable, self.dialect));
}
acc
}

fn extract_subqueries<'a>(selectable: &Selectable, dialect: &'a Dialect) -> Vec<Query<'a, T>> {
fn extract_subqueries<'a>(selectable: &Selectable, dialect: &'a Dialect) -> Vec<Query<'a>> {
let mut acc = Vec::new();

for subselect in selectable.selectable.recursive_crawl(
Expand All @@ -278,16 +237,13 @@ impl<T: Default + Clone> Query<'_, T> {
&SyntaxSet::EMPTY,
false,
) {
acc.push(Query::from_segment(&subselect, dialect, None));
acc.push(Query::from_segment(&subselect, dialect));
}

acc
}

pub fn from_root<'a>(
root_segment: &ErasedSegment,
dialect: &'a Dialect,
) -> Option<Query<'a, T>> {
pub fn from_root<'a>(root_segment: &ErasedSegment, dialect: &'a Dialect) -> Option<Query<'a>> {
let stmts = root_segment.recursive_crawl(
&SELECTABLE_TYPES,
true,
Expand All @@ -296,17 +252,12 @@ impl<T: Default + Clone> Query<'_, T> {
);
let selectable_segment = stmts.first()?;

Some(Query::from_segment(selectable_segment, dialect, None))
Some(Query::from_segment(selectable_segment, dialect))
}

pub fn from_segment<'a>(
segment: &ErasedSegment,
dialect: &'a Dialect,
parent: Option<Query<'a, T>>,
) -> Query<'a, T> {
pub fn from_segment<'a>(segment: &ErasedSegment, dialect: &'a Dialect) -> Query<'a> {
let mut selectables = Vec::new();
let mut subqueries = Vec::new();
let mut cte_defs: Vec<ErasedSegment> = Vec::new();
let mut cte_defs: IndexMap<SmolStr, ErasedSegment> = IndexMap::default();
let mut query_type = QueryType::Simple;

if segment.is_type(SyntaxKind::SelectStatement)
Expand Down Expand Up @@ -347,65 +298,22 @@ impl<T: Default + Clone> Query<'_, T> {
const { &SyntaxSet::single(SyntaxKind::WithCompoundStatement) },
true,
) {
cte_defs.push(seg);
let name_seg = seg.segments()[0].clone();
let name = name_seg.raw().to_uppercase_smolstr();
cte_defs.insert(name, seg);
}
}

for selectable in &selectables {
subqueries.extend(Self::extract_subqueries(selectable, dialect));
}

let outer_query = Query {
inner: Rc::new(RefCell::new(QueryInner {
query_type,
dialect,
selectables,
ctes: <_>::default(),
parent,
subqueries,
cte_definition_segment: None,
cte_name_segment: None,
payload: T::default(),
})),
};

outer_query.post_init();

if cte_defs.is_empty() {
return outer_query;
}

let mut ctes = IndexMap::default();
for cte in cte_defs {
let name_seg = cte.segments()[0].clone();
let name = name_seg.raw().to_uppercase_smolstr();

let queries = cte.recursive_crawl(
const { &SELECTABLE_TYPES.union(&SUBSELECT_TYPES) },
true,
&SyntaxSet::EMPTY,
true,
);

if queries.is_empty() {
continue;
};

let query = &queries[0];
let query = Self::from_segment(query, dialect, outer_query.clone().into());

RefCell::borrow_mut(&query.inner).cte_definition_segment = cte.into();
RefCell::borrow_mut(&query.inner).cte_name_segment = name_seg.into();

ctes.insert(name, query);
Query {
query_type,
dialect,
selectables,
ctes: cte_defs,
}

RefCell::borrow_mut(&outer_query.inner).ctes = ctes;
outer_query
}
}

pub enum Source<'a, T> {
pub enum Source<'a> {
TableReference(SmolStr),
Query(Query<'a, T>),
Query(Query<'a>),
}
57 changes: 23 additions & 34 deletions crates/lib/src/rules/aliasing/al05.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::cell::RefCell;

use ahash::{AHashMap, AHashSet};
use smol_str::{SmolStr, ToSmolStr};
use sqruff_lib_core::dialects::Dialect;
Expand All @@ -19,7 +17,7 @@ use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
use crate::core::rules::{Erased, ErasedRule, LintResult, Rule, RuleGroups};

#[derive(Default, Clone)]
struct AL05Query {
struct State {
aliases: Vec<AliasInfo>,
tbl_refs: Vec<SmolStr>,
}
Expand Down Expand Up @@ -86,14 +84,16 @@ FROM foo
return Vec::new();
}

let query = Query::from_segment(&context.segment, context.dialect, None);
self.analyze_table_aliases(query.clone(), context.dialect);
let query = Query::from_segment(&context.segment, context.dialect);
let mut state_stack = vec![State::default()];
self.analyze_table_aliases(query.clone(), context.dialect, &mut state_stack);
let state = state_stack.pop().unwrap();

if context.dialect.name == DialectKind::Redshift {
let mut references = AHashSet::default();
let mut aliases = AHashSet::default();

for alias in &query.inner.borrow().payload.aliases {
for alias in &state.aliases {
aliases.insert(alias.ref_str.clone());
if let Some(object_reference) = &alias.object_reference {
for seg in object_reference.segments() {
Expand All @@ -118,17 +118,12 @@ FROM foo
}
}

for alias in &RefCell::borrow(&query.inner).payload.aliases {
for alias in &state.aliases {
if Self::is_alias_required(&alias.from_expression_element, context.dialect.name) {
continue;
}

if alias.aliased
&& !RefCell::borrow(&query.inner)
.payload
.tbl_refs
.contains(&alias.ref_str)
{
if alias.aliased && !state.tbl_refs.contains(&alias.ref_str) {
let violation = self.report_unused_alias(alias.clone());
violations.push(violation);
}
Expand All @@ -148,46 +143,40 @@ FROM foo

impl RuleAL05 {
#[allow(clippy::only_used_in_recursion)]
fn analyze_table_aliases(&self, query: Query<AL05Query>, dialect: &Dialect) {
let selectables = std::mem::take(&mut RefCell::borrow_mut(&query.inner).selectables);
fn analyze_table_aliases(&self, query: Query, dialect: &Dialect, state_stack: &mut Vec<State>) {
let selectables = query.selectables.clone();

for selectable in &selectables {
if let Some(select_info) = selectable.select_info() {
RefCell::borrow_mut(&query.inner)
.payload
state_stack
.last_mut()
.unwrap()
.aliases
.extend(select_info.table_aliases);

for r in select_info.reference_buffer {
for tr in
r.extract_possible_references(ObjectReferenceLevel::Table, dialect.name)
{
Self::resolve_and_mark_reference(query.clone(), tr.part);
Self::resolve_and_mark_reference(state_stack, tr.part);
}
}
}
}

RefCell::borrow_mut(&query.inner).selectables = selectables;

for child in query.children() {
self.analyze_table_aliases(child, dialect);
state_stack.push(State::default());
self.analyze_table_aliases(child, dialect, state_stack);
state_stack.pop();
}
}

fn resolve_and_mark_reference(query: Query<AL05Query>, r#ref: String) {
if RefCell::borrow(&query.inner)
.payload
.aliases
.iter()
.any(|it| it.ref_str == r#ref)
{
RefCell::borrow_mut(&query.inner)
.payload
.tbl_refs
.push(r#ref.into());
} else if let Some(parent) = RefCell::borrow(&query.inner).parent.clone() {
Self::resolve_and_mark_reference(parent, r#ref);
fn resolve_and_mark_reference(state_stack: &mut Vec<State>, r#ref: String) {
for st in state_stack.iter_mut().rev() {
if st.aliases.iter().any(|it| it.ref_str == r#ref) {
st.tbl_refs.push(r#ref.into());
return;
}
}
}

Expand Down
Loading
Loading