Skip to content

Commit

Permalink
fix(matrix): allow weird matrix format at the import
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Jan 20, 2025
1 parent d14454a commit de23945
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
30 changes: 16 additions & 14 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import base64
import collections
import contextlib
import csv
import http
import io
import logging
Expand All @@ -25,6 +24,7 @@
from uuid import uuid4

import numpy as np
import numpy.typing as npt
import pandas as pd
from antares.study.version import StudyVersion
from fastapi import HTTPException, UploadFile
Expand Down Expand Up @@ -197,6 +197,20 @@ def get_disk_usage(path: t.Union[str, Path]) -> int:
return total_size


def _imports_matrix_from_bytes(data: bytes) -> npt.NDArray[np.float64]:
"""Tries to convert bytes to a numpy array when importing a matrix"""
str_data = data.decode("utf-8")
if not str_data:
return np.zeros(shape=(0, 0))
for delimiter in [",", ";", "\t"]:
with contextlib.suppress(Exception):
df = pd.read_csv(io.BytesIO(data), delimiter=delimiter, header=None).replace(",", ".", regex=True)
df = df.dropna(axis=1, how="all") # We want to remove columns full of NaN at the import
matrix = df.to_numpy(dtype=np.float64)
return matrix
raise ValueError("Could not import the matrix")


def _get_path_inside_user_folder(
path: str, exception_class: t.Type[t.Union[FolderCreationNotAllowed, ResourceDeletionNotAllowed]]
) -> str:
Expand Down Expand Up @@ -1591,19 +1605,7 @@ def _create_edit_study_command(
elif isinstance(tree_node, InputSeriesMatrix):
if isinstance(data, bytes):
# noinspection PyTypeChecker
str_data = data.decode("utf-8")
if not str_data:
matrix = np.zeros(shape=(0, 0))
else:
size_to_check = min(len(str_data), 64) # sniff a chunk only to speed up the code
try:
delimiter = csv.Sniffer().sniff(str_data[:size_to_check], delimiters=r"[,;\t]").delimiter
except csv.Error:
# Can happen with data with only one column. In this case, we don't care about the delimiter.
delimiter = "\t"
df = pd.read_csv(io.BytesIO(data), delimiter=delimiter, header=None).replace(",", ".", regex=True)
df = df.dropna(axis=1, how="all") # We want to remove columns full of NaN at the import
matrix = df.to_numpy(dtype=np.float64)
matrix = _imports_matrix_from_bytes(data)
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
return ReplaceMatrix(
target=url, matrix=matrix.tolist(), command_context=context, study_version=study_version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,9 @@ def test_get_study_data(self, client: TestClient, user_access_token: str, intern
b"1;1;1;1;1\r1;1;1;1;1",
b"0,000000;0,000000;0,000000;0,000000\n0,000000;0,000000;0,000000;0,000000",
b"1;2;3;;;\n4;5;6;;;\n",
b"1;1;1;1\r\n1;1;1;1\r\n1;1;1;1\r\n1;1;1;1\r\n1;1;1;1\r\n1;1;1;1\r\n1;1;1;1\r\n1;1;1;1\r\n",
],
["\t", "\t", ",", "\t", ";", ";", ";", ";"],
["\t", "\t", ",", "\t", ";", ";", ";", ";", ";"],
):
res = client.put(raw_url, params={"path": matrix_path}, files={"file": io.BytesIO(content)})
assert res.status_code == 204, res.json()
Expand Down

0 comments on commit de23945

Please sign in to comment.