Skip to content

Commit 3be37af

Browse files
committed
substantial mysql tests and sanitize parameters
1 parent db8b9b3 commit 3be37af

File tree

2 files changed

+155
-14
lines changed

2 files changed

+155
-14
lines changed

src/dbc.rs

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
use std::sync::Arc;
22

33
use serde::{Deserialize, Serialize};
4-
54
use serde_json;
65

76
mod mysql;
8-
mod sqlite;
97
mod postgres;
8+
mod sqlite;
109

1110
pub type Error = Box<dyn std::error::Error + Send + Sync>;
1211

@@ -47,15 +46,31 @@ impl Database {
4746
Ok(serde_json::to_vec(&result)?)
4847
}
4948

50-
pub fn execute_query_with_params(&mut self, query: &str, params: &[&str]) -> Result<QueryResult, Error> {
49+
pub fn execute_query_with_params(
50+
&mut self,
51+
query: &str,
52+
params: &[&str],
53+
) -> Result<QueryResult, Error> {
5154
let mut query = query.to_string();
5255
for param in params {
53-
query = query.replace("?", param);
56+
let quoted_param = if let Some(_) = param.strip_prefix('\'') {
57+
// If the parameter already has single quotes, don't add them again.
58+
// This avoids SQL injection vulnerabilities when the parameter contains a quote.
59+
param.to_string()
60+
} else {
61+
// If the parameter doesn't have single quotes, add them.
62+
format!("'{}'", param)
63+
};
64+
query = query.replacen("?", quoted_param.as_str(), 1);
5465
}
5566
self.execute_query(&query)
5667
}
5768

58-
pub fn execute_query_with_params_and_serialize(&mut self, query: &str, params: &[&str]) -> Result<String, Error> {
69+
pub fn execute_query_with_params_and_serialize(
70+
&mut self,
71+
query: &str,
72+
params: &[&str],
73+
) -> Result<String, Error> {
5974
let result = self.execute_query_with_params(query, params)?;
6075
Ok(serde_json::to_string(&result)?)
6176
}
@@ -89,6 +104,31 @@ pub struct Row {
89104
columns: Arc<[Column]>,
90105
}
91106

107+
impl Row {
108+
pub fn new(values: Vec<Value>, columns: Arc<[Column]>) -> Self {
109+
Row { values, columns }
110+
}
111+
112+
pub fn get_value(&self, index: usize) -> Option<&Value> {
113+
self.values.get(index)
114+
}
115+
116+
pub fn get_column(&self, index: usize) -> Option<&Column> {
117+
self.columns.get(index)
118+
}
119+
120+
pub fn get_value_by_name(&self, name: &str) -> Option<&Value> {
121+
self.columns
122+
.iter()
123+
.position(|column| column.name == name)
124+
.and_then(|index| self.values.get(index))
125+
}
126+
127+
pub fn get_column_by_name(&self, name: &str) -> Option<&Column> {
128+
self.columns.iter().find(|column| column.name == name)
129+
}
130+
}
131+
92132
#[derive(Serialize, Deserialize, Debug)]
93133
pub struct QueryResult {
94134
pub rows: Vec<Row>,

tests/mysql_test.rs

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,32 @@
1+
use mysql;
2+
13
use rdbc2::dbc;
24

35
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
46

7+
fn _prepare_mysql_database(url: String) -> Result<(), Error> {
8+
let mut database = dbc::Database::new(url.as_str())?;
9+
10+
let query = "DROP DATABASE IF EXISTS test";
11+
database.execute_query(query)?;
12+
let query = "CREATE DATABASE IF NOT EXISTS test";
13+
database.execute_query(query)?;
14+
let query = "USE test";
15+
database.execute_query(query)?;
16+
let query = "CREATE TABLE IF NOT EXISTS test_table (id INT NOT NULL AUTO_INCREMENT, name VARCHAR(255) NOT NULL, PRIMARY KEY (id))";
17+
database.execute_query(query)?;
18+
19+
Ok(())
20+
}
21+
22+
fn _cleanup_mysql_database(url: String) -> Result<(), Error> {
23+
let mut database = dbc::Database::new(url.as_str())?;
24+
let query = "DROP DATABASE IF EXISTS test";
25+
database.execute_query(query)?;
26+
27+
Ok(())
28+
}
29+
530
fn _get_mysql_connection_url() -> String {
631
if std::env::var("MYSQL_DATABASE_URL").is_ok() {
732
std::env::var("MYSQL_DATABASE_URL").unwrap()
@@ -13,32 +38,108 @@ fn _get_mysql_connection_url() -> String {
1338
#[tokio::test]
1439
async fn test_mysql_simple_query() -> Result<(), Error> {
1540
let url = _get_mysql_connection_url();
41+
_prepare_mysql_database(url.clone())?;
42+
1643
let mut database = dbc::Database::new(url.as_str())?;
17-
let query = "SELECT 1";
18-
let result = database.execute_query(query)?;
19-
assert_eq!(result.rows.len(), 1);
44+
45+
// Use the test database
46+
let use_query = "USE test";
47+
database.execute_query(use_query)?;
48+
49+
// Insert two rows into the test_table
50+
let insert_query = "INSERT INTO test_table (name) VALUES ('test1'), ('test2')";
51+
database.execute_query(insert_query)?;
52+
53+
// Select all rows from test_table
54+
let select_query = "SELECT * FROM test_table";
55+
let result = database.execute_query(select_query)?;
56+
57+
assert_eq!(result.rows.len(), 2);
58+
59+
// Verify the data returned by the query
60+
let first_row = &result.rows[0];
61+
let second_row = &result.rows[1];
62+
assert_eq!(
63+
first_row.get_value_by_name("name"),
64+
Some(&dbc::Value::Bytes("test1".to_owned().into_bytes()))
65+
);
66+
assert_eq!(
67+
second_row.get_value_by_name("name"),
68+
Some(&dbc::Value::Bytes("test2".to_owned().into_bytes()))
69+
);
70+
71+
_cleanup_mysql_database(url.clone())?;
2072

2173
Ok(())
2274
}
2375

2476
#[tokio::test]
2577
async fn test_mysql_query_with_params() -> Result<(), Error> {
2678
let url = _get_mysql_connection_url();
79+
_prepare_mysql_database(url.clone())?;
80+
2781
let mut database = dbc::Database::new(url.as_str())?;
28-
let query = "SELECT ? + ?";
29-
let result = database.execute_query_with_params(query, &["1", "2"])?;
82+
83+
// Use the test database
84+
let use_query = "USE test";
85+
database.execute_query(use_query)?;
86+
87+
// Insert two rows into the test_table
88+
let insert_query = "INSERT INTO test_table (name) VALUES ('test1'), ('test2')";
89+
database.execute_query(insert_query)?;
90+
91+
// Select all rows from test_table
92+
let select_query = "SELECT * FROM test_table WHERE name = ?";
93+
let result = database.execute_query_with_params(select_query, &["test1"])?;
94+
3095
assert_eq!(result.rows.len(), 1);
3196

97+
// Verify the data returned by the query
98+
let first_row = &result.rows[0];
99+
assert_eq!(
100+
first_row.get_value_by_name("name"),
101+
Some(&dbc::Value::Bytes("test1".to_owned().into_bytes()))
102+
);
103+
104+
_cleanup_mysql_database(url.clone())?;
105+
32106
Ok(())
33107
}
34108

35109
#[tokio::test]
36110
async fn test_mysql_query_with_params_and_serialize() -> Result<(), Error> {
37111
let url = _get_mysql_connection_url();
112+
_prepare_mysql_database(url.clone())?;
113+
38114
let mut database = dbc::Database::new(url.as_str())?;
39-
let query = "SELECT ? + ?";
40-
let result = database.execute_query_with_params_and_serialize(query, &["1", "2"])?;
41-
assert_eq!(result, r#"{"rows":[{"values":[{"Bytes":[50]}],"columns":[{"name":"1 + 1","column_type":"LONGLONG"}]}]}"#);
115+
116+
// Use the test database
117+
let use_query = "USE test";
118+
database.execute_query(use_query)?;
119+
120+
// Insert two rows into the test_table
121+
let insert_query = "INSERT INTO test_table (name) VALUES ('test1'), ('test2')";
122+
database.execute_query(insert_query)?;
123+
124+
// Update the test_table to set the name of the first row to "updated"
125+
let update_query = "UPDATE test_table SET name = ? WHERE id = ?";
126+
let result = database.execute_query_with_params(update_query, &["updated", "1"])?;
127+
128+
// Select all rows from test_table
129+
let select_query = "SELECT * FROM test_table WHERE id = ?";
130+
let result = database.execute_query_with_params_and_serialize(select_query, &["1"])?;
131+
let expected_result = r#"{"rows":[{"values":[{"Bytes":[49]},{"Bytes":[117,112,100,97,116,101,100]}],"columns":[{"name":"id","column_type":"LONG"},{"name":"name","column_type":"STRING"}]}]}"#;
132+
assert_eq!(result, expected_result);
133+
134+
// deserialize the result and verify the data
135+
let deserialized_result: dbc::QueryResult = serde_json::from_str(&result)?;
136+
let first_row = &deserialized_result.rows[0];
137+
assert_eq!(
138+
first_row.get_value_by_name("name"),
139+
Some(&dbc::Value::Bytes("updated".to_owned().into_bytes()))
140+
);
141+
142+
_cleanup_mysql_database(url.clone())?;
42143

43144
Ok(())
44-
}
145+
}

0 commit comments

Comments
 (0)