Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions deepethogram/feature_extractor/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,11 @@
"and test dataloaders.",
)

plt.switch_backend("agg")

log = logging.getLogger(__name__)


def feature_extractor_train(cfg: DictConfig) -> nn.Module:
plt.switch_backend("agg")
"""Trains feature extractor models from a configuration.

Parameters
Expand Down
3 changes: 1 addition & 2 deletions deepethogram/flow_generator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@

flow_generators = utils.get_models_from_module(models, get_function=False)

plt.switch_backend("agg")

log = logging.getLogger(__name__)


def flow_generator_train(cfg: DictConfig) -> nn.Module:
plt.switch_backend("agg")
"""Trains flow generator models from a configuration.

Parameters
Expand Down
7 changes: 6 additions & 1 deletion deepethogram/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,12 @@ def add_label_to_project(path_to_labels: Union[str, os.PathLike], path_to_video)
if os.path.isfile(label_dst):
warnings.warn("Label already exists in destination {}, overwriting...".format(label_dst))

df = pd.read_csv(path_to_labels, index_col=0)
df = pd.read_csv(path_to_labels)
# Drop unnamed index column if present (DEG-generated CSVs have one)
first_col = df.columns[0]
if first_col == "" or str(first_col).startswith("Unnamed"):
Comment thread
jbohnslav marked this conversation as resolved.
df = df.drop(columns=[first_col])

if "none" in list(df.columns):
df = df.rename(columns={"none": "background"})
if "background" not in list(df.columns):
Expand Down
3 changes: 1 addition & 2 deletions deepethogram/sequence/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

log = logging.getLogger(__name__)

plt.switch_backend("agg")


def sequence_train(cfg: DictConfig) -> nn.Module:
plt.switch_backend("agg")
"""Trains sequence models from a configuration.

Parameters
Expand Down
75 changes: 75 additions & 0 deletions tests/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,80 @@ def test_add_external_label():
projects.add_label_to_project(labelfile, videofile)


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_add_label_deg_style_csv(tmp_path):
"""Test add_label_to_project with DEG-generated CSV (has unnamed index column)."""
make_project_from_archive()
mousedir = os.path.join(project_path, "DATA", "mouse06")
videofile = os.path.join(mousedir, "mouse06.h5")

# Create a DEG-style CSV with unnamed numeric index
csv_path = tmp_path / "labels_with_index.csv"
csv_path.write_text(
",background,behavior1,behavior2\n"
"0,1,0,0\n"
"1,0,1,0\n"
"2,0,0,1\n"
)

result = projects.add_label_to_project(str(csv_path), videofile)
df = pd.read_csv(result, index_col=0)
assert "background" in df.columns
assert "behavior1" in df.columns
assert "behavior2" in df.columns
assert df.shape[1] == 3 # background + 2 behaviors


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_add_label_external_csv_no_index(tmp_path):
"""Test add_label_to_project with external CSV (no index column, no background).

Regression test for GitHub issue #116: the old code used index_col=0 which
silently ate the first data column when no index column was present.
"""
make_project_from_archive()
mousedir = os.path.join(project_path, "DATA", "mouse06")
videofile = os.path.join(mousedir, "mouse06.h5")

# Create a user-provided CSV without index or background column
csv_path = tmp_path / "labels_no_index.csv"
csv_path.write_text(
"behavior1,behavior2\n"
"0,0\n"
"1,0\n"
"0,1\n"
)

result = projects.add_label_to_project(str(csv_path), videofile)
df = pd.read_csv(result, index_col=0)
assert "background" in df.columns, "background column should be auto-inserted"
assert "behavior1" in df.columns, "behavior1 should NOT be eaten by index_col"
assert "behavior2" in df.columns
assert df.shape[1] == 3 # background + behavior1 + behavior2


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_add_label_external_csv_with_background_no_index(tmp_path):
"""Test external CSV that has background but no index column."""
make_project_from_archive()
mousedir = os.path.join(project_path, "DATA", "mouse06")
videofile = os.path.join(mousedir, "mouse06.h5")

csv_path = tmp_path / "labels_bg_no_index.csv"
csv_path.write_text(
"background,behavior1,behavior2\n"
"1,0,0\n"
"0,1,0\n"
"0,0,1\n"
)

result = projects.add_label_to_project(str(csv_path), videofile)
df = pd.read_csv(result, index_col=0)
assert "background" in df.columns
assert "behavior1" in df.columns
assert "behavior2" in df.columns
assert df.shape[1] == 3


if __name__ == "__main__":
test_add_external_label()
Loading