Skip to content

Commit

Permalink
chore: ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
marksweb committed Jan 11, 2024
1 parent 79b63d9 commit 942c54b
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 24 deletions.
11 changes: 7 additions & 4 deletions explorer/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.figure import Figure
except ImportError:
except ImportError as err:
raise ImproperlyConfigured(
"If `EXPLORER_CHARTS_ENABLED` is enabled, `matplotlib` and `seaborn` must be installed.")
"If `EXPLORER_CHARTS_ENABLED` is enabled, `matplotlib` and `seaborn` must be installed."
) from err

from .models import QueryResult

Expand Down Expand Up @@ -59,8 +60,10 @@ def get_line_chart(result: QueryResult) -> Optional[str]:
"""
if len(result.data) < 1:
return None
numeric_columns = [c for c in range(1, len(result.data[0]))
if all([isinstance(col[c], (int, float)) or col[c] is None for col in result.data])]
numeric_columns = [
c for c in range(1, len(result.data[0]))
if all([isinstance(col[c], (int, float)) or col[c] is None for col in result.data])
]
if len(numeric_columns) < 1:
return None
labels = [row[0] for row in result.data]
Expand Down
10 changes: 5 additions & 5 deletions explorer/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def get_exporter_class(format):
class_str = dict(getattr(app_settings, "EXPLORER_DATA_EXPORTERS"))[format]
class_str = dict(app_settings.EXPLORER_DATA_EXPORTERS)[format]
return import_string(class_str)


Expand Down Expand Up @@ -61,7 +61,7 @@ def _get_output(self, res, **kwargs):
writer = csv.writer(csv_data, delimiter=delim)
writer.writerow(res.headers)
for row in res.data:
writer.writerow([s for s in row])
writer.writerow(row)
return csv_data


Expand Down Expand Up @@ -115,10 +115,10 @@ def _get_output(self, res, **kwargs):
# xlsxwriter can't handle timezone-aware datetimes or
# UUIDs, so we help out here and just cast it to a
# string
if isinstance(data, datetime) or isinstance(data, uuid.UUID):
if isinstance(data, (datetime, uuid.UUID)):
data = str(data)
# JSON and Array fields
if isinstance(data, dict) or isinstance(data, list):
if isinstance(data, (dict, list)):
data = json.dumps(data)
ws.write(row, col, data)
col += 1
Expand All @@ -129,7 +129,7 @@ def _get_output(self, res, **kwargs):
return output

def _format_title(self):
# XLSX writer wont allow sheet names > 31 characters or that
# XLSX writer won't allow sheet names > 31 characters or that
# contain invalid characters
# https://github.com/jmcnamara/XlsxWriter/blob/master/xlsxwriter/
# test/workbook/test_check_sheetname.py
Expand Down
5 changes: 4 additions & 1 deletion explorer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ def __init__(self, header, col):
ColumnStat("Avg", lambda x: float(sum(x)) / float(len(x))),
ColumnStat("Min", min),
ColumnStat("Max", max),
ColumnStat("NUL", lambda x: int(sum(map(lambda y: 1 if y is None else 0, x))), 0, True)
ColumnStat(
"NUL",
lambda x: int(sum(map(lambda y: 1 if y is None else 0, x))), 0, True
)
]
without_nulls = list(map(lambda x: 0 if x is None else x, col))

Expand Down
2 changes: 1 addition & 1 deletion explorer/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_get_run_count(self):
q = SimpleQueryFactory()
self.assertEqual(q.get_run_count(), 0)
expected = 4
for i in range(0, expected):
for _ in range(0, expected):
q.log()
self.assertEqual(q.get_run_count(), expected)

Expand Down
24 changes: 12 additions & 12 deletions explorer/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_select_containing_drop_in_word(self):
self.assertTrue(passes_blacklist(sql)[0])

def test_select_with_case(self):
sql = '''SELECT ProductNumber, Name, "Price Range" =
sql = """SELECT ProductNumber, Name, "Price Range" =
CASE
WHEN ListPrice = 0 THEN 'Mfg item - not for resale'
WHEN ListPrice < 50 THEN 'Under $50'
Expand All @@ -50,18 +50,18 @@ def test_select_with_case(self):
END
FROM Production.Product
ORDER BY ProductNumber ;
'''
"""
passes, words = passes_blacklist(sql)
self.assertTrue(passes)

def test_select_with_subselect(self):
sql = '''SELECT a.studentid, a.name, b.total_marks
sql = """SELECT a.studentid, a.name, b.total_marks
FROM student a, marks b
WHERE a.studentid = b.studentid AND b.total_marks >
(SELECT total_marks
FROM marks
WHERE studentid = 'V002');
'''
"""
passes, words = passes_blacklist(sql)
self.assertTrue(passes)

Expand All @@ -88,13 +88,13 @@ def test_dml_insert(self):
self.assertFalse(passes)

def test_dml_merge(self):
sql = '''MERGE INTO wines w
sql = """MERGE INTO wines w
USING (VALUES('Chateau Lafite 2003', '24')) v
ON v.column1 = w.winename
WHEN NOT MATCHED
INSERT VALUES(v.column1, v.column2)
WHEN MATCHED
UPDATE SET stock = stock + v.column2;'''
UPDATE SET stock = stock + v.column2;"""
passes, words = passes_blacklist(sql)
self.assertFalse(passes)

Expand All @@ -119,9 +119,9 @@ def test_dml_start(self):
self.assertFalse(passes)

def test_dml_update(self):
sql = '''UPDATE accounts SET (contact_first_name, contact_last_name) =
sql = """UPDATE accounts SET (contact_first_name, contact_last_name) =
(SELECT first_name, last_name FROM employees
WHERE employees.id = accounts.sales_person);'''
WHERE employees.id = accounts.sales_person);"""
passes, words = passes_blacklist(sql)
self.assertFalse(passes)

Expand All @@ -131,24 +131,24 @@ def test_dml_upsert(self):
self.assertFalse(passes)

def test_ddl_alter(self):
sql = '''ALTER TABLE foo
sql = """ALTER TABLE foo
ALTER COLUMN foo_timestamp DROP DEFAULT,
ALTER COLUMN foo_timestamp TYPE timestamp with time zone
USING
timestamp with time zone 'epoch' + foo_timestamp * interval '1 second',
ALTER COLUMN foo_timestamp SET DEFAULT now();'''
ALTER COLUMN foo_timestamp SET DEFAULT now();"""
passes, words = passes_blacklist(sql)
self.assertFalse(passes)

def test_ddl_create(self):
sql = '''CREATE TABLE Persons (
sql = """CREATE TABLE Persons (
PersonID int,
LastName varchar(255),
FirstName varchar(255),
Address varchar(255),
City varchar(255)
);
'''
"""
passes, words = passes_blacklist(sql)
self.assertFalse(passes)

Expand Down
2 changes: 1 addition & 1 deletion explorer/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_permissions_show_only_allowed_queries(self):

def test_run_count(self):
q = SimpleQueryFactory(title="foo - bar1")
for i in range(0, 4):
for _ in range(0, 4):
q.log()
resp = self.client.get(reverse("explorer_index"))
self.assertContains(resp, "<td>4</td>")
Expand Down

0 comments on commit 942c54b

Please sign in to comment.