Skip to content

Commit 3e812b6

Browse files
committed
refactor: add origin information to Column
1 parent 25c755b commit 3e812b6

File tree

11 files changed

+243
-7
lines changed

11 files changed

+243
-7
lines changed

sqlx-core/src/column.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::database::Database;
22
use crate::error::Error;
33

44
use std::fmt::Debug;
5+
use std::sync::Arc;
56

67
pub trait Column: 'static + Send + Sync + Debug {
78
type Database: Database<Column = Self>;
@@ -20,6 +21,59 @@ pub trait Column: 'static + Send + Sync + Debug {
2021

2122
/// Gets the type information for the column.
2223
fn type_info(&self) -> &<Self::Database as Database>::TypeInfo;
24+
25+
/// If this column comes from a table, return the table and original column name.
26+
///
27+
/// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression
28+
/// or else the source table could not be determined.
29+
///
30+
/// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information,
31+
/// or has not overridden this method.
32+
// This method returns an owned value instead of a reference,
33+
// to give the implementor more flexibility.
34+
fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown }
35+
}
36+
37+
/// A [`Column`] that originates from a table.
38+
#[derive(Debug, Clone)]
39+
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
40+
pub struct TableColumn {
41+
/// The name of the table (optionally schema-qualified) that the column comes from.
42+
pub table: Arc<str>,
43+
/// The original name of the column.
44+
pub name: Arc<str>,
45+
}
46+
47+
/// The possible statuses for our knowledge of the origin of a [`Column`].
48+
#[derive(Debug, Clone, Default)]
49+
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
50+
pub enum ColumnOrigin {
51+
/// The column is known to originate from a table.
52+
///
53+
/// Included is the table name and original column name.
54+
Table(TableColumn),
55+
/// The column originates from an expression, or else its origin could not be determined.
56+
Expression,
57+
/// The database driver does not know the column origin at this time.
58+
///
59+
/// This may happen if:
60+
/// * The connection is in the middle of executing a query,
61+
/// and cannot query the catalog to fetch this information.
62+
/// * The connection does not have access to the database catalog.
63+
/// * The implementation of [`Column`] did not override [`Column::origin()`].
64+
#[default]
65+
Unknown,
66+
}
67+
68+
impl ColumnOrigin {
69+
/// Returns the true column origin, if known.
70+
pub fn table_column(&self) -> Option<&TableColumn> {
71+
if let Self::Table(table_column) = self {
72+
Some(table_column)
73+
} else {
74+
None
75+
}
76+
}
2377
}
2478

2579
/// A type that can be used to index into a [`Row`] or [`Statement`].

sqlx-mysql/src/column.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ pub struct MySqlColumn {
1010
pub(crate) name: UStr,
1111
pub(crate) type_info: MySqlTypeInfo,
1212

13+
#[cfg_attr(feature = "offline", serde(default))]
14+
pub(crate) origin: ColumnOrigin,
15+
1316
#[cfg_attr(feature = "offline", serde(skip))]
1417
pub(crate) flags: Option<ColumnFlags>,
1518
}
@@ -28,4 +31,8 @@ impl Column for MySqlColumn {
2831
fn type_info(&self) -> &MySqlTypeInfo {
2932
&self.type_info
3033
}
34+
35+
fn origin(&self) -> ColumnOrigin {
36+
self.origin.clone()
37+
}
3138
}

sqlx-mysql/src/connection/executor.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use futures_core::stream::BoxStream;
2323
use futures_core::Stream;
2424
use futures_util::{pin_mut, TryStreamExt};
2525
use std::{borrow::Cow, sync::Arc};
26+
use sqlx_core::column::{ColumnOrigin, TableColumn};
2627

2728
impl MySqlConnection {
2829
async fn prepare_statement<'c>(
@@ -382,18 +383,38 @@ async fn recv_result_columns(
382383
fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MySqlColumn, Error> {
383384
// if the alias is empty, use the alias
384385
// only then use the name
386+
let column_name = def.name()?;
387+
385388
let name = match (def.name()?, def.alias()?) {
386389
(_, alias) if !alias.is_empty() => UStr::new(alias),
387390
(name, _) => UStr::new(name),
388391
};
389392

393+
let table = def.table()?;
394+
395+
let origin = if table.is_empty() {
396+
ColumnOrigin::Expression
397+
} else {
398+
let schema = def.schema()?;
399+
400+
ColumnOrigin::Table(TableColumn {
401+
table: if !schema.is_empty() {
402+
format!("{schema}.{table}").into()
403+
} else {
404+
table.into()
405+
},
406+
name: column_name.into(),
407+
})
408+
};
409+
390410
let type_info = MySqlTypeInfo::from_column(def);
391411

392412
Ok(MySqlColumn {
393413
name,
394414
type_info,
395415
ordinal,
396416
flags: Some(def.flags),
417+
origin,
397418
})
398419
}
399420

sqlx-mysql/src/protocol/text/column.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::str::from_utf8;
1+
use std::str;
22

33
use bitflags::bitflags;
44
use bytes::{Buf, Bytes};
@@ -104,11 +104,9 @@ pub enum ColumnType {
104104
pub(crate) struct ColumnDefinition {
105105
#[allow(unused)]
106106
catalog: Bytes,
107-
#[allow(unused)]
108107
schema: Bytes,
109108
#[allow(unused)]
110109
table_alias: Bytes,
111-
#[allow(unused)]
112110
table: Bytes,
113111
alias: Bytes,
114112
name: Bytes,
@@ -125,12 +123,20 @@ impl ColumnDefinition {
125123
// NOTE: strings in-protocol are transmitted according to the client character set
126124
// as this is UTF-8, all these strings should be UTF-8
127125

126+
pub(crate) fn schema(&self) -> Result<&str, Error> {
127+
str::from_utf8(&self.schema).map_err(Error::protocol)
128+
}
129+
130+
pub(crate) fn table(&self) -> Result<&str, Error> {
131+
str::from_utf8(&self.table).map_err(Error::protocol)
132+
}
133+
128134
pub(crate) fn name(&self) -> Result<&str, Error> {
129-
from_utf8(&self.name).map_err(Error::protocol)
135+
str::from_utf8(&self.name).map_err(Error::protocol)
130136
}
131137

132138
pub(crate) fn alias(&self) -> Result<&str, Error> {
133-
from_utf8(&self.alias).map_err(Error::protocol)
139+
str::from_utf8(&self.alias).map_err(Error::protocol)
134140
}
135141
}
136142

sqlx-postgres/src/column.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@ use crate::ext::ustr::UStr;
22
use crate::{PgTypeInfo, Postgres};
33

44
pub(crate) use sqlx_core::column::{Column, ColumnIndex};
5+
use sqlx_core::column::ColumnOrigin;
56

67
#[derive(Debug, Clone)]
78
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
89
pub struct PgColumn {
910
pub(crate) ordinal: usize,
1011
pub(crate) name: UStr,
1112
pub(crate) type_info: PgTypeInfo,
13+
14+
#[cfg_attr(feature = "offline", serde(default))]
15+
pub(crate) origin: ColumnOrigin,
16+
1217
#[cfg_attr(feature = "offline", serde(skip))]
1318
pub(crate) relation_id: Option<crate::types::Oid>,
1419
#[cfg_attr(feature = "offline", serde(skip))]
@@ -51,4 +56,8 @@ impl Column for PgColumn {
5156
fn type_info(&self) -> &PgTypeInfo {
5257
&self.type_info
5358
}
59+
60+
fn origin(&self) -> ColumnOrigin {
61+
self.origin.clone()
62+
}
5463
}

sqlx-postgres/src/connection/describe.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::btree_map;
12
use crate::error::Error;
23
use crate::ext::ustr::UStr;
34
use crate::io::StatementId;
@@ -14,6 +15,9 @@ use futures_core::future::BoxFuture;
1415
use smallvec::SmallVec;
1516
use sqlx_core::query_builder::QueryBuilder;
1617
use std::sync::Arc;
18+
use sqlx_core::column::{ColumnOrigin, TableColumn};
19+
use sqlx_core::hash_map;
20+
use crate::connection::TableColumns;
1721

1822
/// Describes the type of the `pg_type.typtype` column
1923
///
@@ -122,13 +126,20 @@ impl PgConnection {
122126
let type_info = self
123127
.maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
124128
.await?;
129+
130+
let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) {
131+
self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await?
132+
} else {
133+
ColumnOrigin::Expression
134+
};
125135

126136
let column = PgColumn {
127137
ordinal: index,
128138
name: name.clone(),
129139
type_info,
130140
relation_id: field.relation_id,
131141
relation_attribute_no: field.relation_attribute_no,
142+
origin,
132143
};
133144

134145
columns.push(column);
@@ -188,6 +199,54 @@ impl PgConnection {
188199
Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
189200
}
190201
}
202+
203+
async fn maybe_fetch_column_origin(
204+
&mut self,
205+
relation_id: Oid,
206+
attribute_no: i16,
207+
should_fetch: bool,
208+
) -> Result<ColumnOrigin, Error> {
209+
let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) {
210+
hash_map::Entry::Occupied(table_columns) => {
211+
table_columns.into_mut()
212+
},
213+
hash_map::Entry::Vacant(vacant) => {
214+
if !should_fetch { return Ok(ColumnOrigin::Unknown); }
215+
216+
let table_name: String = query_scalar("SELECT $1::oid::regclass::text")
217+
.bind(relation_id)
218+
.fetch_one(&mut *self)
219+
.await?;
220+
221+
vacant.insert(TableColumns {
222+
table_name: table_name.into(),
223+
columns: Default::default(),
224+
})
225+
}
226+
};
227+
228+
let column_name = match table_columns.columns.entry(attribute_no) {
229+
btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()),
230+
btree_map::Entry::Vacant(vacant) => {
231+
if !should_fetch { return Ok(ColumnOrigin::Unknown); }
232+
233+
let column_name: String = query_scalar(
234+
"SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2"
235+
)
236+
.bind(relation_id)
237+
.bind(attribute_no)
238+
.fetch_one(&mut *self)
239+
.await?;
240+
241+
Arc::clone(vacant.insert(column_name.into()))
242+
}
243+
};
244+
245+
Ok(ColumnOrigin::Table(TableColumn {
246+
table: table_columns.table_name.clone(),
247+
name: column_name
248+
}))
249+
}
191250

192251
fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
193252
Box::pin(async move {

sqlx-postgres/src/connection/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::BTreeMap;
12
use std::fmt::{self, Debug, Formatter};
23
use std::sync::Arc;
34

@@ -57,6 +58,7 @@ pub struct PgConnection {
5758
cache_type_info: HashMap<Oid, PgTypeInfo>,
5859
cache_type_oid: HashMap<UStr, Oid>,
5960
cache_elem_type_to_array: HashMap<Oid, Oid>,
61+
cache_table_to_column_names: HashMap<Oid, TableColumns>,
6062

6163
// number of ReadyForQuery messages that we are currently expecting
6264
pub(crate) pending_ready_for_query_count: usize,
@@ -68,6 +70,12 @@ pub struct PgConnection {
6870
log_settings: LogSettings,
6971
}
7072

73+
pub(crate) struct TableColumns {
74+
table_name: Arc<str>,
75+
/// Attribute number -> name.
76+
columns: BTreeMap<i16, Arc<str>>,
77+
}
78+
7179
impl PgConnection {
7280
/// the version number of the server in `libpq` format
7381
pub fn server_version_num(&self) -> Option<u32> {

sqlx-sqlite/src/column.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ pub struct SqliteColumn {
99
pub(crate) name: UStr,
1010
pub(crate) ordinal: usize,
1111
pub(crate) type_info: SqliteTypeInfo,
12+
13+
#[cfg_attr(feature = "offline", serde(default))]
14+
pub(crate) origin: ColumnOrigin
1215
}
1316

1417
impl Column for SqliteColumn {
@@ -25,4 +28,8 @@ impl Column for SqliteColumn {
2528
fn type_info(&self) -> &SqliteTypeInfo {
2629
&self.type_info
2730
}
31+
32+
fn origin(&self) -> ColumnOrigin {
33+
self.origin.clone()
34+
}
2835
}

sqlx-sqlite/src/connection/describe.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result<Descri
4949

5050
for col in 0..num {
5151
let name = stmt.handle.column_name(col).to_owned();
52+
53+
let origin = stmt.handle.column_origin(col);
5254

5355
let type_info = if let Some(ty) = stmt.handle.column_decltype(col) {
5456
ty
@@ -82,6 +84,7 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result<Descri
8284
name: name.into(),
8385
type_info,
8486
ordinal: col,
87+
origin,
8588
});
8689
}
8790
}

0 commit comments

Comments
 (0)