Skip to content

Commit a3229c3

Browse files
authored
SNOW-1963078 Port _upload / _download / _upload_stream / _download_st… (#2198)
1 parent 26cbdf9 commit a3229c3

File tree

5 files changed

+319
-0
lines changed

5 files changed

+319
-0
lines changed

src/snowflake/connector/connection.py

+5
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
PYTHON_VERSION,
8383
SNOWFLAKE_CONNECTOR_VERSION,
8484
)
85+
from .direct_file_operation_utils import FileOperationParser, StreamDownloader
8586
from .errorcode import (
8687
ER_CONNECTION_IS_CLOSED,
8788
ER_FAILED_PROCESSING_PYFORMAT,
@@ -500,6 +501,10 @@ def __init__(
500501
# check SNOW-1218851 for long term improvement plan to refactor ocsp code
501502
atexit.register(self._close_at_exit)
502503

504+
# Set up the file operation parser and stream downloader.
505+
self._file_operation_parser = FileOperationParser(self)
506+
self._stream_downloader = StreamDownloader(self)
507+
503508
# Deprecated
504509
@property
505510
def insecure_mode(self) -> bool:

src/snowflake/connector/cursor.py

+147
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from ._utils import _TrackedQueryCancellationTimer
4343
from .bind_upload_agent import BindUploadAgent, BindUploadError
4444
from .constants import (
45+
CMD_TYPE_DOWNLOAD,
46+
CMD_TYPE_UPLOAD,
4547
FIELD_NAME_TO_ID,
4648
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
4749
FileTransferType,
@@ -1730,6 +1732,151 @@ def get_result_batches(self) -> list[ResultBatch] | None:
17301732
)
17311733
return self._result_set.batches
17321734

1735+
def _download(
1736+
self,
1737+
stage_location: str,
1738+
target_directory: str,
1739+
options: dict[str, Any],
1740+
_do_reset: bool = True,
1741+
) -> None:
1742+
"""Downloads from the stage location to the target directory.
1743+
1744+
Args:
1745+
stage_location (str): The location of the stage to download from.
1746+
target_directory (str): The destination directory to download into.
1747+
options (dict[str, Any]): The download options.
1748+
_do_reset (bool, optional): Whether to reset the cursor before
1749+
downloading, by default we will reset the cursor.
1750+
"""
1751+
from .file_transfer_agent import SnowflakeFileTransferAgent
1752+
1753+
if _do_reset:
1754+
self.reset()
1755+
1756+
# Interpret the file operation.
1757+
ret = self.connection._file_operation_parser.parse_file_operation(
1758+
stage_location=stage_location,
1759+
local_file_name=None,
1760+
target_directory=target_directory,
1761+
command_type=CMD_TYPE_DOWNLOAD,
1762+
options=options,
1763+
)
1764+
1765+
# Execute the file operation based on the interpretation above.
1766+
file_transfer_agent = SnowflakeFileTransferAgent(
1767+
self,
1768+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1769+
ret,
1770+
)
1771+
file_transfer_agent.execute()
1772+
self._init_result_and_meta(file_transfer_agent.result())
1773+
1774+
def _upload(
1775+
self,
1776+
local_file_name: str,
1777+
stage_location: str,
1778+
options: dict[str, Any],
1779+
_do_reset: bool = True,
1780+
) -> None:
1781+
"""Uploads the local file to the stage location.
1782+
1783+
Args:
1784+
local_file_name (str): The local file to be uploaded.
1785+
stage_location (str): The stage location to upload the local file to.
1786+
options (dict[str, Any]): The upload options.
1787+
_do_reset (bool, optional): Whether to reset the cursor before
1788+
uploading, by default we will reset the cursor.
1789+
"""
1790+
from .file_transfer_agent import SnowflakeFileTransferAgent
1791+
1792+
if _do_reset:
1793+
self.reset()
1794+
1795+
# Interpret the file operation.
1796+
ret = self.connection._file_operation_parser.parse_file_operation(
1797+
stage_location=stage_location,
1798+
local_file_name=local_file_name,
1799+
target_directory=None,
1800+
command_type=CMD_TYPE_UPLOAD,
1801+
options=options,
1802+
)
1803+
1804+
# Execute the file operation based on the interpretation above.
1805+
file_transfer_agent = SnowflakeFileTransferAgent(
1806+
self,
1807+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1808+
ret,
1809+
)
1810+
file_transfer_agent.execute()
1811+
self._init_result_and_meta(file_transfer_agent.result())
1812+
1813+
def _download_stream(
1814+
self, stage_location: str, decompress: bool = False
1815+
) -> IO[bytes]:
1816+
"""Downloads from the stage location as a stream.
1817+
1818+
Args:
1819+
stage_location (str): The location of the stage to download from.
1820+
decompress (bool, optional): Whether to decompress the file, by
1821+
default we do not decompress.
1822+
1823+
Returns:
1824+
IO[bytes]: A stream to read from.
1825+
"""
1826+
# Interpret the file operation.
1827+
ret = self.connection._file_operation_parser.parse_file_operation(
1828+
stage_location=stage_location,
1829+
local_file_name=None,
1830+
target_directory=None,
1831+
command_type=CMD_TYPE_DOWNLOAD,
1832+
options=None,
1833+
has_source_from_stream=True,
1834+
)
1835+
1836+
# Set up stream downloading based on the interpretation and return the stream for reading.
1837+
return self.connection._stream_downloader.download_as_stream(ret, decompress)
1838+
1839+
def _upload_stream(
1840+
self,
1841+
input_stream: IO[bytes],
1842+
stage_location: str,
1843+
options: dict[str, Any],
1844+
_do_reset: bool = True,
1845+
) -> None:
1846+
"""Uploads content in the input stream to the stage location.
1847+
1848+
Args:
1849+
input_stream (IO[bytes]): A stream to read from.
1850+
stage_location (str): The location of the stage to upload to.
1851+
options (dict[str, Any]): The upload options.
1852+
_do_reset (bool, optional): Whether to reset the cursor before
1853+
uploading, by default we will reset the cursor.
1854+
"""
1855+
from .file_transfer_agent import SnowflakeFileTransferAgent
1856+
1857+
if _do_reset:
1858+
self.reset()
1859+
1860+
# Interpret the file operation.
1861+
ret = self.connection._file_operation_parser.parse_file_operation(
1862+
stage_location=stage_location,
1863+
local_file_name=None,
1864+
target_directory=None,
1865+
command_type=CMD_TYPE_UPLOAD,
1866+
options=options,
1867+
has_source_from_stream=input_stream,
1868+
)
1869+
1870+
# Execute the file operation based on the interpretation above.
1871+
file_transfer_agent = SnowflakeFileTransferAgent(
1872+
self,
1873+
"", # empty command because it is triggered by directly calling this util not by a SQL query
1874+
ret,
1875+
source_from_stream=input_stream,
1876+
)
1877+
file_transfer_agent.execute()
1878+
self._init_result_and_meta(file_transfer_agent.result())
1879+
17331880

17341881
class DictCursor(SnowflakeCursor):
17351882
"""Cursor returning results in a dictionary."""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
from abc import ABC, abstractmethod
8+
9+
10+
class FileOperationParserBase(ABC):
11+
"""The interface of internal utility functions for file operation parsing."""
12+
13+
@abstractmethod
14+
def __init__(self, connection):
15+
pass
16+
17+
@abstractmethod
18+
def parse_file_operation(
19+
self,
20+
stage_location,
21+
local_file_name,
22+
target_directory,
23+
command_type,
24+
options,
25+
has_source_from_stream=False,
26+
):
27+
"""Converts the file operation details into a SQL and returns the SQL parsing result."""
28+
pass
29+
30+
31+
class StreamDownloaderBase(ABC):
32+
"""The interface of internal utility functions for stream downloading of file."""
33+
34+
@abstractmethod
35+
def __init__(self, connection):
36+
pass
37+
38+
@abstractmethod
39+
def download_as_stream(self, ret, decompress=False):
40+
pass
41+
42+
43+
class FileOperationParser(FileOperationParserBase):
44+
def __init__(self, connection):
45+
pass
46+
47+
def parse_file_operation(
48+
self,
49+
stage_location,
50+
local_file_name,
51+
target_directory,
52+
command_type,
53+
options,
54+
has_source_from_stream=False,
55+
):
56+
raise NotImplementedError("parse_file_operation is not yet supported")
57+
58+
59+
class StreamDownloader(StreamDownloaderBase):
60+
def __init__(self, connection):
61+
pass
62+
63+
def download_as_stream(self, ret, decompress=False):
64+
raise NotImplementedError("download_as_stream is not yet supported")

test/integ/test_connection.py

+9
Original file line numberDiff line numberDiff line change
@@ -1597,3 +1597,12 @@ def test_no_auth_connection_negative_case():
15971597
# connection is not able to run any query
15981598
with pytest.raises(DatabaseError, match="Connection is closed"):
15991599
conn.execute_string("select 1")
1600+
1601+
1602+
# _file_operation_parser and _stream_downloader are newly introduced and
1603+
# therefore should not be tested on old drivers.
1604+
@pytest.mark.skipolddriver
1605+
def test_file_utils_sanity_check():
1606+
conn = create_connection("default")
1607+
assert hasattr(conn._file_operation_parser, "parse_file_operation")
1608+
assert hasattr(conn._stream_downloader, "download_as_stream")

test/unit/test_cursor.py

+94
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import time
8+
from unittest import TestCase
89
from unittest.mock import MagicMock, patch
910

1011
import pytest
@@ -99,3 +100,96 @@ def mock_cmd_query(*args, **kwargs):
99100

100101
# query cancel request should be sent upon timeout
101102
assert mockCancelQuery.called
103+
104+
105+
# The _upload/_download/_upload_stream/_download_stream are newly introduced
106+
# and therefore should not be tested in old drivers.
107+
@pytest.mark.skipolddriver
108+
class TestUploadDownloadMethods(TestCase):
109+
"""Test the _upload/_download/_upload_stream/_download_stream methods."""
110+
111+
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
112+
def test_download(self, MockFileTransferAgent):
113+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
114+
MockFileTransferAgent
115+
)
116+
117+
# Call _download method
118+
cursor._download("@st", "/tmp/test.txt", {})
119+
120+
# In the process of _download execution, we expect these methods to be called
121+
# - parse_file_operation in connection._file_operation_parser
122+
# - execute in SnowflakeFileTransferAgent
123+
# And we do not expect this method to be involved
124+
# - download_as_stream of connection._stream_downloader
125+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
126+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
127+
mock_file_transfer_agent_instance.execute.assert_called_once()
128+
129+
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
130+
def test_upload(self, MockFileTransferAgent):
131+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
132+
MockFileTransferAgent
133+
)
134+
135+
# Call _upload method
136+
cursor._upload("/tmp/test.txt", "@st", {})
137+
138+
# In the process of _upload execution, we expect these methods to be called
139+
# - parse_file_operation in connection._file_operation_parser
140+
# - execute in SnowflakeFileTransferAgent
141+
# And we do not expect this method to be involved
142+
# - download_as_stream of connection._stream_downloader
143+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
144+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
145+
mock_file_transfer_agent_instance.execute.assert_called_once()
146+
147+
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
148+
def test_download_stream(self, MockFileTransferAgent):
149+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
150+
MockFileTransferAgent
151+
)
152+
153+
# Call _download_stream method
154+
cursor._download_stream("@st/test.txt", decompress=True)
155+
156+
# In the process of _download_stream execution, we expect these methods to be called
157+
# - parse_file_operation in connection._file_operation_parser
158+
# - download_as_stream of connection._stream_downloader
159+
# And we do not expect this method to be involved
160+
# - execute in SnowflakeFileTransferAgent
161+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
162+
fake_conn._stream_downloader.download_as_stream.assert_called_once()
163+
mock_file_transfer_agent_instance.execute.assert_not_called()
164+
165+
@patch("snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent")
166+
def test_upload_stream(self, MockFileTransferAgent):
167+
cursor, fake_conn, mock_file_transfer_agent_instance = self._setup_mocks(
168+
MockFileTransferAgent
169+
)
170+
171+
# Call _upload_stream method
172+
fd = MagicMock()
173+
cursor._upload_stream(fd, "@st/test.txt", {})
174+
175+
# In the process of _upload_stream execution, we expect these methods to be called
176+
# - parse_file_operation in connection._file_operation_parser
177+
# - execute in SnowflakeFileTransferAgent
178+
# And we do not expect this method to be involved
179+
# - download_as_stream of connection._stream_downloader
180+
fake_conn._file_operation_parser.parse_file_operation.assert_called_once()
181+
fake_conn._stream_downloader.download_as_stream.assert_not_called()
182+
mock_file_transfer_agent_instance.execute.assert_called_once()
183+
184+
def _setup_mocks(self, MockFileTransferAgent):
185+
mock_file_transfer_agent_instance = MockFileTransferAgent.return_value
186+
mock_file_transfer_agent_instance.execute.return_value = None
187+
188+
fake_conn = FakeConnection()
189+
fake_conn._file_operation_parser = MagicMock()
190+
fake_conn._stream_downloader = MagicMock()
191+
192+
cursor = SnowflakeCursor(fake_conn)
193+
cursor.reset = MagicMock()
194+
cursor._init_result_and_meta = MagicMock()
195+
return cursor, fake_conn, mock_file_transfer_agent_instance

0 commit comments

Comments
 (0)