Skip to content

Commit 97a8c55

Browse files
authored
Merge pull request #8 from ajschmidt8/fix-path
Fix dependency file paths
2 parents 54eb3c7 + 48c9005 commit 97a8c55

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

src/rapids_dependency_file_generator/rapids_dependency_file_generator.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def get_filename(file_type, file_prefix, matrix_combo):
9696
return f"{prefix}{file_prefix}_{suffix}".replace(".", "") + file_ext
9797

9898

99-
def get_output_path(file_type, file_config):
99+
def get_output_path(file_type, config_file_path, file_config):
100100
output_path = "."
101101
if file_type == str(GeneratorTypes.CONDA):
102102
output_path = file_config.get("conda_dir", default_conda_dir)
103103
if file_type == str(GeneratorTypes.REQUIREMENTS):
104104
output_path = file_config.get("requirements_dir", default_requirements_dir)
105-
return output_path
105+
return os.path.join(os.path.dirname(config_file_path), output_path)
106106

107107

108108
def should_use_specific_entry(matrix_combo, specific_entry_matrix):
@@ -112,8 +112,8 @@ def should_use_specific_entry(matrix_combo, specific_entry_matrix):
112112
return True
113113

114114

115-
def main(config_file, files):
116-
with open(config_file, "r") as f:
115+
def main(config_file_path, files=None):
116+
with open(config_file_path) as f:
117117
parsed_config = yaml.load(f, Loader=yaml.FullLoader)
118118

119119
channels = parsed_config.get("channels", default_channels) or default_channels
@@ -156,7 +156,7 @@ def main(config_file, files):
156156
make_dependency_file_factory = lambda output_path: make_dependency_file(
157157
file_type,
158158
full_file_name,
159-
config_file,
159+
config_file_path,
160160
output_path,
161161
channels,
162162
deduped_deps,
@@ -167,7 +167,12 @@ def main(config_file, files):
167167
contents = make_dependency_file_factory(output_path)
168168
print(contents)
169169
else:
170-
output_path = get_output_path(file_type, file_config)
170+
output_path = get_output_path(
171+
file_type, config_file_path, file_config
172+
)
171173
contents = make_dependency_file_factory(output_path)
172-
with open(os.path.join(output_path, full_file_name), "w") as f:
174+
with open(
175+
os.path.join(output_path, full_file_name),
176+
"w",
177+
) as f:
173178
f.write(contents)

0 commit comments

Comments
 (0)