From b4fcb66cb89e4dcd609e2f5ac97f669a3c124a40 Mon Sep 17 00:00:00 2001 From: devzhk Date: Mon, 10 Jul 2023 16:25:50 -0700 Subject: [PATCH] add data --- download_data.py | 1 + inference.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/download_data.py b/download_data.py index 7ea92dd..b223181 100644 --- a/download_data.py +++ b/download_data.py @@ -12,6 +12,7 @@ 'NS-Re100Part0': 'https://hkzdata.s3.us-west-2.amazonaws.com/PINO/data/NS_fine_Re100_T128_part0.npy', 'burgers': 'https://hkzdata.s3.us-west-2.amazonaws.com/PINO/data/burgers_pino.mat', 'NS-Re500_T300_id0': 'https://hkzdata.s3.us-west-2.amazonaws.com/PINO/NS-Re500_T300_id0.npy', + 'NS-Re500_T300_id0-shuffle': 'https://hkzdata.s3.us-west-2.amazonaws.com/PINO/data/NS-Re500_T300_id0-shuffle.npy', 'darcy-train': 'https://hkzdata.s3.us-west-2.amazonaws.com/PINO/data/piececonst_r421_N1024_smooth1.mat', 'darcy-test': 'https://hkzdata.s3.us-west-2.amazonaws.com/PINO/data/piececonst_r421_N1024_smooth2.mat', 'cavity': 'https://hkzdata.s3.us-west-2.amazonaws.com/PINO/data/cavity.mat', diff --git a/inference.py b/inference.py index 71b6016..32908d0 100644 --- a/inference.py +++ b/inference.py @@ -1,7 +1,10 @@ +''' +This code generates the prediction on one instance. +Both the ground truth and the prediction are saved in a .pt file. +''' import os import yaml from argparse import ArgumentParser -import random import torch from torch.utils.data import DataLoader @@ -17,7 +20,8 @@ def get_pred(args): with open(args.config, 'r') as stream: config = yaml.load(stream, yaml.FullLoader) - + basedir = os.path.join('exp', config['log']['logdir']) + save_path = os.path.join(basedir, 'predictions','prediction.pt') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # prepare data @@ -55,7 +59,7 @@ def get_pred(args): torch.save({ 'truth': u.cpu(), 'pred': out.cpu(), - }, 're500-1_8s-800-fno-50k-prediction.pt') + }, save_path) break