From 942c54bf3d56d41cf0a23662e9b71385a78f88e4 Mon Sep 17 00:00:00 2001 From: Mark Walker Date: Thu, 11 Jan 2024 23:40:59 +0000 Subject: [PATCH] chore: ruff fixes --- explorer/charts.py | 11 +++++++---- explorer/exporters.py | 10 +++++----- explorer/models.py | 5 ++++- explorer/tests/test_models.py | 2 +- explorer/tests/test_utils.py | 24 ++++++++++++------------ explorer/tests/test_views.py | 2 +- 6 files changed, 30 insertions(+), 24 deletions(-) diff --git a/explorer/charts.py b/explorer/charts.py index 140fdaf6..a9362ded 100644 --- a/explorer/charts.py +++ b/explorer/charts.py @@ -11,9 +11,10 @@ import matplotlib.pyplot as plt import seaborn as sns from matplotlib.figure import Figure - except ImportError: + except ImportError as err: raise ImproperlyConfigured( - "If `EXPLORER_CHARTS_ENABLED` is enabled, `matplotlib` and `seaborn` must be installed.") + "If `EXPLORER_CHARTS_ENABLED` is enabled, `matplotlib` and `seaborn` must be installed." + ) from err from .models import QueryResult @@ -59,8 +60,10 @@ def get_line_chart(result: QueryResult) -> Optional[str]: """ if len(result.data) < 1: return None - numeric_columns = [c for c in range(1, len(result.data[0])) - if all([isinstance(col[c], (int, float)) or col[c] is None for col in result.data])] + numeric_columns = [ + c for c in range(1, len(result.data[0])) + if all([isinstance(col[c], (int, float)) or col[c] is None for col in result.data]) + ] if len(numeric_columns) < 1: return None labels = [row[0] for row in result.data] diff --git a/explorer/exporters.py b/explorer/exporters.py index 0e884039..22d1bc5b 100644 --- a/explorer/exporters.py +++ b/explorer/exporters.py @@ -13,7 +13,7 @@ def get_exporter_class(format): - class_str = dict(getattr(app_settings, "EXPLORER_DATA_EXPORTERS"))[format] + class_str = dict(app_settings.EXPLORER_DATA_EXPORTERS)[format] return import_string(class_str) @@ -61,7 +61,7 @@ def _get_output(self, res, **kwargs): writer = csv.writer(csv_data, delimiter=delim) writer.writerow(res.headers) for row in res.data: - writer.writerow([s for s in row]) + writer.writerow(row) return csv_data @@ -115,10 +115,10 @@ def _get_output(self, res, **kwargs): # xlsxwriter can't handle timezone-aware datetimes or # UUIDs, so we help out here and just cast it to a # string - if isinstance(data, datetime) or isinstance(data, uuid.UUID): + if isinstance(data, (datetime, uuid.UUID)): data = str(data) # JSON and Array fields - if isinstance(data, dict) or isinstance(data, list): + if isinstance(data, (dict, list)): data = json.dumps(data) ws.write(row, col, data) col += 1 @@ -129,7 +129,7 @@ def _get_output(self, res, **kwargs): return output def _format_title(self): - # XLSX writer wont allow sheet names > 31 characters or that + # XLSX writer won't allow sheet names > 31 characters or that # contain invalid characters # https://github.com/jmcnamara/XlsxWriter/blob/master/xlsxwriter/ # test/workbook/test_check_sheetname.py diff --git a/explorer/models.py b/explorer/models.py index 0d9f61ea..f8472131 100644 --- a/explorer/models.py +++ b/explorer/models.py @@ -374,7 +374,10 @@ def __init__(self, header, col): ColumnStat("Avg", lambda x: float(sum(x)) / float(len(x))), ColumnStat("Min", min), ColumnStat("Max", max), - ColumnStat("NUL", lambda x: int(sum(map(lambda y: 1 if y is None else 0, x))), 0, True) + ColumnStat( + "NUL", + lambda x: int(sum(map(lambda y: 1 if y is None else 0, x))), 0, True + ) ] without_nulls = list(map(lambda x: 0 if x is None else x, col)) diff --git a/explorer/tests/test_models.py b/explorer/tests/test_models.py index aac01b7b..b2ffebb7 100644 --- a/explorer/tests/test_models.py +++ b/explorer/tests/test_models.py @@ -66,7 +66,7 @@ def test_get_run_count(self): q = SimpleQueryFactory() self.assertEqual(q.get_run_count(), 0) expected = 4 - for i in range(0, expected): + for _ in range(0, expected): q.log() self.assertEqual(q.get_run_count(), expected) diff --git a/explorer/tests/test_utils.py b/explorer/tests/test_utils.py index 46722509..1b2ebf8b 100644 --- a/explorer/tests/test_utils.py +++ b/explorer/tests/test_utils.py @@ -40,7 +40,7 @@ def test_select_containing_drop_in_word(self): self.assertTrue(passes_blacklist(sql)[0]) def test_select_with_case(self): - sql = '''SELECT ProductNumber, Name, "Price Range" = + sql = """SELECT ProductNumber, Name, "Price Range" = CASE WHEN ListPrice = 0 THEN 'Mfg item - not for resale' WHEN ListPrice < 50 THEN 'Under $50' @@ -50,18 +50,18 @@ def test_select_with_case(self): END FROM Production.Product ORDER BY ProductNumber ; - ''' + """ passes, words = passes_blacklist(sql) self.assertTrue(passes) def test_select_with_subselect(self): - sql = '''SELECT a.studentid, a.name, b.total_marks + sql = """SELECT a.studentid, a.name, b.total_marks FROM student a, marks b WHERE a.studentid = b.studentid AND b.total_marks > (SELECT total_marks FROM marks WHERE studentid = 'V002'); - ''' + """ passes, words = passes_blacklist(sql) self.assertTrue(passes) @@ -88,13 +88,13 @@ def test_dml_insert(self): self.assertFalse(passes) def test_dml_merge(self): - sql = '''MERGE INTO wines w + sql = """MERGE INTO wines w USING (VALUES('Chateau Lafite 2003', '24')) v ON v.column1 = w.winename WHEN NOT MATCHED INSERT VALUES(v.column1, v.column2) WHEN MATCHED - UPDATE SET stock = stock + v.column2;''' + UPDATE SET stock = stock + v.column2;""" passes, words = passes_blacklist(sql) self.assertFalse(passes) @@ -119,9 +119,9 @@ def test_dml_start(self): self.assertFalse(passes) def test_dml_update(self): - sql = '''UPDATE accounts SET (contact_first_name, contact_last_name) = + sql = """UPDATE accounts SET (contact_first_name, contact_last_name) = (SELECT first_name, last_name FROM employees - WHERE employees.id = accounts.sales_person);''' + WHERE employees.id = accounts.sales_person);""" passes, words = passes_blacklist(sql) self.assertFalse(passes) @@ -131,24 +131,24 @@ def test_dml_upsert(self): self.assertFalse(passes) def test_ddl_alter(self): - sql = '''ALTER TABLE foo + sql = """ALTER TABLE foo ALTER COLUMN foo_timestamp DROP DEFAULT, ALTER COLUMN foo_timestamp TYPE timestamp with time zone USING timestamp with time zone 'epoch' + foo_timestamp * interval '1 second', - ALTER COLUMN foo_timestamp SET DEFAULT now();''' + ALTER COLUMN foo_timestamp SET DEFAULT now();""" passes, words = passes_blacklist(sql) self.assertFalse(passes) def test_ddl_create(self): - sql = '''CREATE TABLE Persons ( + sql = """CREATE TABLE Persons ( PersonID int, LastName varchar(255), FirstName varchar(255), Address varchar(255), City varchar(255) ); - ''' + """ passes, words = passes_blacklist(sql) self.assertFalse(passes) diff --git a/explorer/tests/test_views.py b/explorer/tests/test_views.py index 2bddda8b..df8574dc 100644 --- a/explorer/tests/test_views.py +++ b/explorer/tests/test_views.py @@ -65,7 +65,7 @@ def test_permissions_show_only_allowed_queries(self): def test_run_count(self): q = SimpleQueryFactory(title="foo - bar1") - for i in range(0, 4): + for _ in range(0, 4): q.log() resp = self.client.get(reverse("explorer_index")) self.assertContains(resp, "4")