forked from sinaptik-ai/pandas-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_databricks.py
132 lines (110 loc) · 4.42 KB
/
test_databricks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import unittest
from unittest.mock import MagicMock, patch
from pandasai_databricks import (
load_from_databricks,
)
class TestDatabricksLoader(unittest.TestCase):
@patch("databricks.sql.connect")
def test_load_from_databricks_with_query(self, MockConnect):
# Mock the connection and cursor
mock_connection = MagicMock()
MockConnect.return_value = mock_connection
mock_cursor = MagicMock()
mock_connection.cursor.return_value = mock_cursor
# Sample data that would be returned by Databricks SQL
mock_cursor.fetchall.return_value = [
(1, "Alice", 100),
(2, "Bob", 200),
]
mock_cursor.description = [("id",), ("name",), ("value",)]
# Test config with a custom SQL query
config = {
"host": "databricks_host",
"http_path": "http_path",
"token": "access_token",
"query": "SELECT * FROM sample_table",
}
# Call the function under test
result = load_from_databricks(config)
# Assertions
MockConnect.assert_called_once_with(
server_hostname="databricks_host",
http_path="http_path",
access_token="access_token",
)
mock_cursor.execute.assert_called_once_with("SELECT * FROM sample_table")
self.assertEqual(result.shape[0], 2) # 2 rows
self.assertEqual(result.shape[1], 3) # 3 columns
self.assertTrue("id" in result.columns)
self.assertTrue("name" in result.columns)
self.assertTrue("value" in result.columns)
@patch("databricks.sql.connect")
def test_load_from_databricks_with_table(self, MockConnect):
# Mock the connection and cursor
mock_connection = MagicMock()
MockConnect.return_value = mock_connection
mock_cursor = MagicMock()
mock_connection.cursor.return_value = mock_cursor
# Sample data returned by Databricks SQL
mock_cursor.fetchall.return_value = [
(1, "Alice", 100),
(2, "Bob", 200),
]
mock_cursor.description = [("id",), ("name",), ("value",)]
# Test config with a table name
config = {
"host": "databricks_host",
"http_path": "http_path",
"token": "access_token",
"database": "test_db",
"table": "sample_table",
}
# Call the function under test
result = load_from_databricks(config)
# Assertions
query = "SELECT * FROM test_db.sample_table"
mock_cursor.execute.assert_called_once_with(query)
self.assertEqual(result.shape[0], 2)
self.assertEqual(result.shape[1], 3)
self.assertTrue("id" in result.columns)
self.assertTrue("name" in result.columns)
self.assertTrue("value" in result.columns)
@patch("databricks.sql.connect")
def test_load_from_databricks_no_query_or_table(self, MockConnect):
# Mock the connection and cursor
mock_connection = MagicMock()
MockConnect.return_value = mock_connection
mock_cursor = MagicMock()
mock_connection.cursor.return_value = mock_cursor
# Test config with neither query nor table
config = {
"host": "databricks_host",
"http_path": "http_path",
"token": "access_token",
}
# Call the function under test and assert that it raises a ValueError
with self.assertRaises(ValueError):
load_from_databricks(config)
@patch("databricks.sql.connect")
def test_load_from_databricks_empty_result(self, MockConnect):
# Mock the connection and cursor
mock_connection = MagicMock()
MockConnect.return_value = mock_connection
mock_cursor = MagicMock()
mock_connection.cursor.return_value = mock_cursor
# Empty result set
mock_cursor.fetchall.return_value = []
mock_cursor.description = [("id",), ("name",), ("value",)]
# Test config with a custom SQL query
config = {
"host": "databricks_host",
"http_path": "http_path",
"token": "access_token",
"query": "SELECT * FROM sample_table",
}
# Call the function under test
result = load_from_databricks(config)
# Assertions
self.assertTrue(result.empty) # Result should be an empty DataFrame
if __name__ == "__main__":
unittest.main()