diff --git a/datasette/app.py b/datasette/app.py index 0f417ec958..466a5c4812 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1225,10 +1225,10 @@ def _prepare_connection(self, conn, database): for db_name, db in self.databases.items(): if count >= SQLITE_LIMIT_ATTACHED or db.is_memory: continue - sql = 'ATTACH DATABASE "file:{path}?{qs}" AS [{name}];'.format( + sql = 'ATTACH DATABASE "file:{path}?{qs}" AS {name};'.format( path=db.path, qs="mode=ro" if db.is_mutable else "immutable=1", - name=db_name, + name=escape_sqlite(db_name), ) conn.execute(sql) count += 1 diff --git a/datasette/database.py b/datasette/database.py index 8b824462a3..9781bd8ffc 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -17,6 +17,7 @@ detect_fts, detect_primary_keys, detect_spatialite, + escape_sqlite, get_all_foreign_keys, get_outbound_foreign_keys, md5_not_usedforsecurity, @@ -470,7 +471,7 @@ async def table_counts(self, limit=10): try: table_count = ( await self.execute( - f"select count(*) from (select * from [{table}] limit {self.count_limit + 1})", + f"select count(*) from (select * from {escape_sqlite(table)} limit {self.count_limit + 1})", custom_time_limit=limit, ) ).rows[0][0] diff --git a/datasette/facets.py b/datasette/facets.py index bc4b69049a..a957e69e31 100644 --- a/datasette/facets.py +++ b/datasette/facets.py @@ -85,7 +85,7 @@ def __init__( self.database = database # For foreign key expansion. Can be None for e.g. canned SQL queries: self.table = table - self.sql = sql or f"select * from [{table}]" + self.sql = sql or f"select * from {escape_sqlite(table)}" self.params = params or [] self.table_config = table_config # row_count can be None, in which case we calculate it ourselves: diff --git a/datasette/filters.py b/datasette/filters.py index 95cc5f3748..8ebf27a514 100644 --- a/datasette/filters.py +++ b/datasette/filters.py @@ -206,10 +206,16 @@ def where_clause(self, table, column, value, param_counter): if self.numeric and converted.isdigit(): converted = int(converted) if self.no_argument: - kwargs = {"c": column} + kwargs = {"c": column, "c_escaped": escape_sqlite(column)} converted = None else: - kwargs = {"c": column, "p": f"p{param_counter}", "t": table} + kwargs = { + "c": column, + "c_escaped": escape_sqlite(column), + "p": f"p{param_counter}", + "t": table, + "t_escaped": escape_sqlite(table), + } return self.sql_template.format(**kwargs), converted def human_clause(self, column, value): @@ -322,13 +328,13 @@ class Filters: TemplatedFilter( "arraycontains", "array contains", - """:{p} in (select value from json_each([{t}].[{c}]))""", + """:{p} in (select value from json_each({t_escaped}.{c_escaped}))""", '{c} contains "{v}"', ), TemplatedFilter( "arraynotcontains", "array does not contain", - """:{p} not in (select value from json_each([{t}].[{c}]))""", + """:{p} not in (select value from json_each({t_escaped}.{c_escaped}))""", '{c} does not contain "{v}"', ), ] diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index 1fea992ed9..c93869411b 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -405,7 +405,7 @@ def escape_sqlite(s): if _boring_keyword_re.match(s) and (s.lower() not in reserved_words): return s else: - return f"[{s}]" + return '"{}"'.format(s.replace('"', '""')) def make_dockerfile( @@ -583,7 +583,7 @@ def detect_primary_keys(conn, table): def get_outbound_foreign_keys(conn, table): - infos = conn.execute(f"PRAGMA foreign_key_list([{table}])").fetchall() + infos = conn.execute(f"PRAGMA foreign_key_list({escape_sqlite(table)})").fetchall() fks = [] for info in infos: if info is not None: diff --git a/datasette/utils/internal_db.py b/datasette/utils/internal_db.py index df1499283c..ca137df877 100644 --- a/datasette/utils/internal_db.py +++ b/datasette/utils/internal_db.py @@ -1,5 +1,5 @@ import textwrap -from datasette.utils import table_column_details +from datasette.utils import escape_sqlite, table_column_details async def init_internal_db(db): @@ -168,7 +168,7 @@ def collect_info(conn): for column in columns ) foreign_keys = conn.execute( - f"PRAGMA foreign_key_list([{table_name}])" + f"PRAGMA foreign_key_list({escape_sqlite(table_name)})" ).fetchall() foreign_keys_to_insert.extend( { @@ -177,7 +177,9 @@ def collect_info(conn): } for foreign_key in foreign_keys ) - indexes = conn.execute(f"PRAGMA index_list([{table_name}])").fetchall() + indexes = conn.execute( + f"PRAGMA index_list({escape_sqlite(table_name)})" + ).fetchall() indexes_to_insert.extend( { **{"database_name": database_name, "table_name": table_name}, diff --git a/datasette/views/table.py b/datasette/views/table.py index 5643858d4c..803e58a8e1 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -610,8 +610,10 @@ def insert_or_upsert_rows(conn): ) args = list(itertools.chain.from_iterable(row_pk_values_for_later)) fetched_rows = await db.execute( - "select {}* from [{}] where {}".format( - "rowid, " if pks == ["rowid"] else "", table_name, where_clause + "select {}* from {} where {}".format( + "rowid, " if pks == ["rowid"] else "", + escape_sqlite(table_name), + where_clause, ), args, ) @@ -822,7 +824,11 @@ async def post(self, request): "database": database_name, "table": table_name, "row_count": ( - await db.execute("select count(*) from [{}]".format(table_name)) + await db.execute( + "select count(*) from {}".format( + escape_sqlite(table_name) + ) + ) ).single_value(), "message": 'Pass "confirm": true to confirm', }, @@ -2091,10 +2097,13 @@ async def _next_value_and_url( except IndexError: # sort/sort_desc column missing from SELECT - look up value by PK instead prefix_where_clause = " and ".join( - "[{}] = :pk{}".format(pk, i) for i, pk in enumerate(pks) + "{} = :pk{}".format(escape_sqlite(pk), i) + for i, pk in enumerate(pks) ) - prefix_lookup_sql = "select [{}] from [{}] where {}".format( - sort or sort_desc, table_name, prefix_where_clause + prefix_lookup_sql = "select {} from {} where {}".format( + escape_sqlite(sort or sort_desc), + escape_sqlite(table_name), + prefix_where_clause, ) prefix = ( await db.execute( diff --git a/tests/test_api.py b/tests/test_api.py index 3676c1fb8d..ac254d70d3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,5 +1,6 @@ from datasette.app import Datasette from datasette.plugins import DEFAULT_PLUGINS +from datasette.utils import tilde_encode from datasette.utils.sqlite import sqlite_version from datasette.version import __version__ from .fixtures import make_app_client, EXPECTED_PLUGINS @@ -379,6 +380,42 @@ async def test_row_strange_table_name(ds_client): assert response.json()["rows"] == [{"pk": "3", "content": "hey"}] +@pytest.mark.asyncio +async def test_table_name_with_closing_bracket_does_not_inject(): + malicious_name = 'users] UNION SELECT password FROM users--' + ds = Datasette() + db = ds.add_memory_database("fixtures") + await db.execute_write("CREATE TABLE users (id INTEGER PRIMARY KEY, password TEXT)") + await db.execute_write( + "INSERT INTO users (password) VALUES ('super_secret_password')" + ) + await db.execute_write( + f'CREATE TABLE "{malicious_name}" (id INTEGER PRIMARY KEY, content TEXT)' + ) + await db.execute_write( + f'INSERT INTO "{malicious_name}" (content) VALUES (\'ok\')' + ) + response = await ds.client.get( + f"/fixtures/{tilde_encode(malicious_name)}.json?_extra=count&_facet=content&_shape=objects" + ) + assert response.status_code == 200 + data = response.json() + assert data["count"] == 1 + assert data["rows"] == [{"id": 1, "content": "ok"}] + assert data["facet_results"]["results"]["content"]["results"] == [ + { + "value": "ok", + "label": "ok", + "count": 1, + "toggle_url": ( + f"http://localhost/fixtures/{tilde_encode(malicious_name)}.json" + "?_extra=count&_facet=content&_shape=objects&content=ok" + ), + "selected": False, + } + ] + + @pytest.mark.asyncio async def test_row_foreign_key_tables(ds_client): response = await ds_client.get( diff --git a/tests/test_crossdb.py b/tests/test_crossdb.py index 11e53224ef..e47a217d1c 100644 --- a/tests/test_crossdb.py +++ b/tests/test_crossdb.py @@ -1,5 +1,7 @@ +from datasette.app import Datasette from datasette.cli import cli from click.testing import CliRunner +import pytest import urllib import sqlite3 @@ -75,3 +77,22 @@ def test_crossdb_attached_database_list_display( '
  • extra database - 100")).first() diff --git a/tests/test_utils.py b/tests/test_utils.py index 3fcb623ee7..909826f87f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -226,6 +226,31 @@ def test_detect_fts_different_table_names(table): conn.close() +def test_escape_sqlite_prevents_closing_bracket_sql_injection(): + conn = utils.sqlite3.connect(":memory:") + conn.execute("CREATE TABLE users (id INTEGER, password TEXT)") + conn.execute("INSERT INTO users VALUES (1, 'super_secret_password')") + malicious_name = 'users] UNION SELECT password FROM users--' + conn.execute(f'CREATE TABLE "{malicious_name}" (id INTEGER)') + escaped = utils.escape_sqlite(malicious_name) + results = conn.execute(f"select count(*) from {escaped}").fetchall() + assert results == [(0,)] + conn.close() + + +@pytest.mark.parametrize( + "value,expected", + ( + ("simple", "simple"), + ("select", '"select"'), + ('has"quote', '"has""quote"'), + ("has]bracket", '"has]bracket"'), + ), +) +def test_escape_sqlite(value, expected): + assert utils.escape_sqlite(value) == expected + + @pytest.mark.parametrize( "url,expected", [