Skip to content

Commit

Permalink
TestSnapshot expects .h5 snapshots, explicitly checks history.
Browse files Browse the repository at this point in the history
  • Loading branch information
erictzeng committed Aug 7, 2015
1 parent ada055b commit 5c89c64
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
if (snapshot) {
ostringstream resume_file;
resume_file << snapshot_prefix_ << "/_iter_" << num_iters
<< ".solverstate";
<< ".solverstate.h5";
string resume_filename = resume_file.str();
return resume_filename;
}
Expand Down Expand Up @@ -394,6 +394,18 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
}
}

// Save the solver history
vector<shared_ptr<Blob<Dtype> > > history_copies;
const vector<shared_ptr<Blob<Dtype> > >& orig_history = solver_->history();
history_copies.resize(orig_history.size());
for (int i = 0; i < orig_history.size(); ++i) {
history_copies[i].reset(new Blob<Dtype>());
const bool kReshape = true;
for (int copy_diff = false; copy_diff <= true; ++copy_diff) {
history_copies[i]->CopyFrom(*orig_history[i], copy_diff, kReshape);
}
}

// Run the solver for num_iters iterations and snapshot.
snapshot = true;
string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
Expand All @@ -414,6 +426,17 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
<< "param " << i << " diff differed at dim " << j;
}
}

// Check that history now matches.
const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
for (int i = 0; i < history.size(); ++i) {
for (int j = 0; j < history[i]->count(); ++j) {
EXPECT_EQ(history_copies[i]->cpu_data()[j], history[i]->cpu_data()[j])
<< "history blob " << i << " data differed at dim " << j;
EXPECT_EQ(history_copies[i]->cpu_diff()[j], history[i]->cpu_diff()[j])
<< "history blob " << i << " diff differed at dim " << j;
}
}
}
};

Expand Down

0 comments on commit 5c89c64

Please sign in to comment.