Skip to content

Commit

Permalink
Apply linting
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Apr 4, 2022
1 parent 6d4fc85 commit e4a2a36
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 24 deletions.
7 changes: 2 additions & 5 deletions scripts/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import click
import pandas as pd
from constants import COLLATED_PATH, MELTED_PATH, MODEL_DIRECTORY
from docdata import get_docdata
from pykeen.datasets import dataset_resolver

from constants import COLLATED_PATH, MELTED_PATH, MODEL_DIRECTORY


@click.command()
def collate():
Expand Down Expand Up @@ -54,9 +53,7 @@ def collate():
df,
id_vars=id_vars,
value_vars=[
v
for v in df.columns
if v not in id_vars and v not in ("evaluation", "training")
v for v in df.columns if v not in id_vars and v not in ("evaluation", "training")
],
).sort_values(by=["dataset", "model", "variable"])
melted_df.to_csv(MELTED_PATH, sep="\t", index=False)
Expand Down
9 changes: 2 additions & 7 deletions scripts/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import numpy
import pandas
from constants import COLLATED_PATH, MELTED_PATH
from pykeen.datasets import get_dataset
from pykeen.evaluation.evaluator import get_candidate_set_size
from pykeen.metrics import rank_based_metric_resolver

from constants import COLLATED_PATH, MELTED_PATH

df = pandas.read_csv(COLLATED_PATH, sep="\t")
pairs = {}
for column in df.columns:
Expand Down Expand Up @@ -59,10 +58,6 @@
melted_df = pandas.melt(
df,
id_vars=id_vars,
value_vars=[
v
for v in df.columns
if v not in id_vars and v not in ("evaluation", "training")
],
value_vars=[v for v in df.columns if v not in id_vars and v not in ("evaluation", "training")],
).sort_values(by=["dataset", "model", "variable"])
melted_df.to_csv(MELTED_PATH, sep="\t", index=False)
17 changes: 5 additions & 12 deletions scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import pandas as pd
import scipy.constants
import seaborn as sns

from constants import CHARTS_DIRECTORY, MELTED_PATH
from pykeen.datasets import dataset_resolver

SIGIL = r"\mathcal{T}_{train}"


def _lookup_key(d):
return dataset_resolver.docdata(d, 'statistics', 'training')
return dataset_resolver.docdata(d, "statistics", "training")


MODEL_TITLES = {
Expand All @@ -28,13 +27,11 @@ def _lookup_key(d):
"kinships": "Kinships",
}
DATASET_TITLES = {
key: f"{value} ($|{SIGIL}|={_lookup_key(key):,}$)"
for key, value in DATASET_TITLES.items()
key: f"{value} ($|{SIGIL}|={_lookup_key(key):,}$)" for key, value in DATASET_TITLES.items()
}
# show datasets in increasing order of entity size
DATASET_ORDER = [
v
for _, v in sorted(DATASET_TITLES.items(), key=lambda pair: _lookup_key(pair[0]))
v for _, v in sorted(DATASET_TITLES.items(), key=lambda pair: _lookup_key(pair[0]))
]
ORDER = [
"Original",
Expand Down Expand Up @@ -83,7 +80,7 @@ def _lookup_key(d):
],
"short": ["GMR", "AGMRI", "ZGMR"],
"has_negative_z": True,
}
},
}


Expand All @@ -96,10 +93,7 @@ def main():
for base_metric_key, metadata in METRICS.items():
metrics = metadata["metrics"]
df = melted_df[melted_df["variable"].isin(metrics)].copy()
metric_order = [
f"{order} ({short})"
for order, short in zip(ORDER, metadata["short"])
]
metric_order = [f"{order} ({short})" for order, short in zip(ORDER, metadata["short"])]
df.loc[:, "variable"] = df["variable"].map(dict(zip(metrics, metric_order)))
grid: sns.FacetGrid = sns.catplot(
data=df,
Expand Down Expand Up @@ -136,7 +130,6 @@ def main():
grid.set_xlabels(label="")
grid.set_titles(col_template="{col_name}", size=13)
grid.tight_layout()

chart_path_stub = CHARTS_DIRECTORY.joinpath(f"{base_metric_key}_plot")
grid.savefig(chart_path_stub.with_suffix(".pdf"))
grid.savefig(chart_path_stub.with_suffix(".svg"))
Expand Down
9 changes: 9 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,12 @@ deps =
pandas
seaborn
pykeen

[testenv:lint]
commands =
black . --line-length 100
isort . --profile=black
skip_install = true
deps =
black
isort

0 comments on commit e4a2a36

Please sign in to comment.