Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ajdesh2000 authored Aug 10, 2023
1 parent 3e53024 commit 133db5c
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
parser.add_argument(
"--sparse", default=False, type=lambda x: (str(x).lower() == "true")
) #
parser.add_argument("--n_trials_unsup", default=5, type=int) #
parser.add_argument("--n_trials_sup", default=50, type=int) #
parser.add_argument("--n_trials_unsup", default=20, type=int) #
parser.add_argument("--n_trials_sup", default=60, type=int) #
parser.add_argument("--alpha_masks", default="-1", type=str) #
parser.add_argument("--lr_alphas", default=0.001, type=float) #
parser.add_argument("--alpha_activation", default="none", type=str) #
Expand All @@ -40,12 +40,6 @@
# get options
options = vars(parser.parse_args())

if options["algorithm"] == "dgi":
options["augmentation_type"] = "fsgnn"
options["augmentation_quantity"] = 0
options["augmentation_index"] = 0
options["augmentation_all"] = False

if options["algorithm"] == "mvgrl":
options["augmentation_type"] = "diff"
options["augmentation_quantity"] = 0
Expand Down Expand Up @@ -278,4 +272,4 @@ def train_supervised_wrapper(params):
results_dict = {**unsupervised_result_dict, **supervised_result_dict}
print("Results dict", results_dict)

print("Note that these results are for the following data splits:", dataset_splits)
print("Note that these results are for the following data splits:", dataset_splits)

0 comments on commit 133db5c

Please sign in to comment.