Skip to content

Commit

Permalink
discard training using pytroch lightening, which cause gpu hangs
Browse files Browse the repository at this point in the history
  • Loading branch information
tiantiaf0627 committed Dec 8, 2021
1 parent 83144df commit d8e43ee
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 179 deletions.
126 changes: 86 additions & 40 deletions train/federated_attribute_attack.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import torch
import torch.multiprocessing
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning import seed_everything
import pytorch_lightning as pl
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from copy import deepcopy


from sklearn.model_selection import train_test_split
from pathlib import Path
import pandas as pd
import numpy as np
import torch.nn as nn
import sys, os, shutil, pickle, argparse, pdb

sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model'))
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'utils'))

from training_tools import EarlyStopping, seed_worker, result_summary
from attack_model import attack_model

EarlyStopping

# some general mapping for this script
gender_dict = {'F': 0, 'M': 1}
leak_layer_dict = {'full': ['w0', 'b0', 'w1', 'b1', 'w2', 'b2'],
Expand All @@ -31,28 +36,45 @@ def __len__(self):
return len(self.dict_keys)

def __getitem__(self, idx):

data_file_str = self.dict_keys[idx]
gender = gender_dict[self.data_dict[data_file_str]['gender']]

tmp_data = (self.data_dict[data_file_str][weight_name] - weight_norm_mean_dict[weight_name]) / (weight_norm_std_dict[weight_name] + 0.00001)
weights = torch.from_numpy(np.ascontiguousarray(tmp_data))
tmp_data = (self.data_dict[data_file_str][bias_name] - weight_norm_mean_dict[bias_name]) / (weight_norm_std_dict[bias_name] + 0.00001)
bias = torch.from_numpy(np.ascontiguousarray(tmp_data))

return weights, bias, gender

class AttackDataModule(pl.LightningDataModule):
def __init__(self, train, val):
super().__init__()
self.train = train
self.val = val

def train_dataloader(self):
return DataLoader(self.train, batch_size=20, num_workers=0, shuffle=True)
def run_one_epoch(model, data_loader, optimizer, scheduler, loss_func, epoch, mode='train'):

model.train() if mode == 'train' else model.eval()
step_outputs = []

for batch_idx, data_batch in enumerate(data_loader):
weights, bias, y = data_batch
weights, bias, y = weights.to(device), bias.to(device), y.to(device)
logits = model(weights.float().unsqueeze(dim=1), bias.float())
loss = loss_func(logits, y)

predictions = np.argmax(logits.detach().cpu().numpy(), axis=1)
pred_list = [predictions[pred_idx] for pred_idx in range(len(predictions))]
truth_list = [y.detach().cpu().numpy()[pred_idx] for pred_idx in range(len(predictions))]
step_outputs.append({'loss': loss.item(), 'pred': pred_list, 'truth': truth_list})

# step the loss back
if mode == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
del data_batch, logits, loss
torch.cuda.empty_cache()
result_dict = result_summary(step_outputs, mode, epoch)

# if validate mode, step the loss
if mode == 'validate':
mean_loss = np.mean(result_dict['loss'])
scheduler.step(mean_loss)
return result_dict

def val_dataloader(self):
return DataLoader(self.val, batch_size=20, num_workers=0, shuffle=False)

if __name__ == '__main__':

Expand All @@ -77,13 +99,16 @@ def val_dataloader(self):
parser.add_argument('--save_dir', default='/media/data/projects/speech-privacy')
args = parser.parse_args()

seed_everything(8, workers=True)
seed_worker(8)
device = torch.device("cuda:"+str(args.device)) if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available(): print('GPU available, use GPU')

model_setting_str = 'local_epoch_'+str(args.local_epochs) if args.model_type == 'fed_avg' else 'local_epoch_1'
model_setting_str += '_dropout_' + str(args.dropout).replace('.', '')
model_setting_str += '_lr_' + str(args.learning_rate)[2:]

torch.cuda.empty_cache()
torch.multiprocessing.set_sharing_strategy('file_system')

# 1. normalization tmp computations
weight_norm_mean_dict, weight_norm_std_dict = {}, {}
Expand All @@ -110,12 +135,11 @@ def val_dataloader(self):
for speaker_id in adv_gradient_dict:
data_key = str(shadow_idx)+'_'+str(epoch)+'_'+speaker_id
gradients = adv_gradient_dict[speaker_id]['gradient']
gender = adv_gradient_dict[speaker_id]['gender']
shadow_training_sample_size += 1

# calculate running stats for computing std and mean
shadow_data_dict[data_key] = {}
shadow_data_dict[data_key]['gender'] = gender
shadow_data_dict[data_key]['gender'] = adv_gradient_dict[speaker_id]['gender']
shadow_data_dict[data_key][weight_name] = gradients[weight_idx]
shadow_data_dict[data_key][bias_name] = gradients[bias_idx]
for layer_name in leak_layer_dict[args.leak_layer]:
Expand All @@ -133,29 +157,53 @@ def val_dataloader(self):
train_key_list, validate_key_list = train_test_split(list(shadow_data_dict.keys()), test_size=0.2, random_state=0)
model = attack_model(args.leak_layer, args.feature_type)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=float(args.model_learning_rate), weight_decay=1e-04, betas=(0.9, 0.98), eps=1e-9)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.2, verbose=True, min_lr=1e-6)

# 2.2 define data loader
dataset_train = WeightDataGenerator(train_key_list, shadow_data_dict)
dataset_valid = WeightDataGenerator(validate_key_list, shadow_data_dict)
data_module = AttackDataModule(dataset_train, dataset_valid)

train_loader = DataLoader(dataset_train, batch_size=20, num_workers=0, shuffle=True)
validation_loader =DataLoader(dataset_valid, batch_size=20, num_workers=0, shuffle=False)

# 2.3 initialize the early_stopping object
early_stopping = EarlyStopping(monitor="val_loss", mode='min', patience=5, stopping_threshold=1e-4, check_finite=True)

early_stopping = EarlyStopping(patience=5, verbose=True)
loss = nn.NLLLoss().to(device)

# 2.4 log saving path
attack_model_result_path = Path(os.path.realpath(__file__)).parents[1].joinpath('results', 'attack', args.leak_layer, args.model_type, args.feature_type, model_setting_str)
log_path = Path.joinpath(attack_model_result_path, 'log_private_' + str(args.dataset))
if log_path.exists(): shutil.rmtree(log_path)
Path.mkdir(log_path, parents=True, exist_ok=True)
mlf_logger = MLFlowLogger(experiment_name="ser", save_dir=str(log_path))

checkpoint_callback = ModelCheckpoint(monitor="val_acc_epoch", mode="max",
dirpath=str(attack_model_result_path),
filename='private_' + str(args.dataset) + '_model')
# 2.5 training using pytorch lighting framework
trainer = pl.Trainer(logger=mlf_logger, gpus=1, callbacks=[checkpoint_callback, early_stopping], max_epochs=50)
trainer.fit(model, data_module)

# 2.5 training attack model
result_dict, best_val_dict = {}, {}
for epoch in range(30):
# perform the training, validate, and test
train_result = run_one_epoch(model, train_loader, optimizer, scheduler, loss, epoch, mode='train')
validate_result = run_one_epoch(model, validation_loader, optimizer, scheduler, loss, epoch, mode='validate')

# save the results for later
result_dict[epoch] = {}
result_dict[epoch]['train'], result_dict[epoch]['validate'] = train_result, validate_result

if len(best_val_dict) == 0: best_val_dict, best_epoch = validate_result, epoch
if validate_result['uar'] > best_val_dict['uar'] and epoch > 10:
best_val_dict, best_epoch = validate_result, epoch
best_model = deepcopy(model.state_dict())

# early_stopping needs the validation loss to check if it has decresed,
# and if it has, it will make a checkpoint of the current model
if epoch > 10: early_stopping(validate_result['loss'], model)

# print(final_acc, best_val_acc, best_epoch)
print('best epoch %d, best final acc %.2f, best final uar %.2f' % (best_epoch, best_val_dict['acc']*100, best_val_dict['uar']*100))
print(best_val_dict['conf'])

if early_stopping.early_stop and epoch > 10:
print("Early stopping")
break

# 3. we evaluate the attacker performance on service provider training
save_result_df = pd.DataFrame()
# 3.1 we perform 5 fold evaluation, since we also train the private data 5 times
Expand All @@ -179,14 +227,12 @@ def val_dataloader(self):
test_data_dict[data_key][bias_name] = gradients[bias_idx]

dataset_test = WeightDataGenerator(list(test_data_dict.keys()), test_data_dict)
dataloader_test = DataLoader(dataset_test, batch_size=20, num_workers=1, shuffle=False)

# model.freeze()
# trainer.test(test_dataloaders=data_module.train_dataloader())
result_dict = trainer.test(dataloaders=dataloader_test, ckpt_path='best')
row_df['acc'], row_df['uar'] = result_dict[0]['test_acc_epoch'], result_dict[0]['test_uar_epoch']
test_loader = DataLoader(dataset_test, batch_size=20, num_workers=0, shuffle=False)
test_result = run_one_epoch(model, test_loader, optimizer, scheduler, loss, best_epoch, mode='test')

row_df['acc'], row_df['uar'] = test_result['acc'], test_result['uar']
save_result_df = pd.concat([save_result_df, row_df])
del dataset_test, dataloader_test
del dataset_test, test_loader

row_df = pd.DataFrame(index=['average'])
row_df['acc'], row_df['uar'] = np.mean(save_result_df['acc']), np.mean(save_result_df['uar'])
Expand Down
Loading

0 comments on commit d8e43ee

Please sign in to comment.