Skip to content

Commit

Permalink
Merge pull request BVLC#3755 from shelhamer/fix-upgrade-proto
Browse files Browse the repository at this point in the history
Fix Upgrade Net Tools
  • Loading branch information
shelhamer committed Mar 1, 2016
2 parents 358b60c + 7eaeb3a commit f561682
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 34 deletions.
59 changes: 34 additions & 25 deletions src/caffe/util/upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
namespace caffe {

bool NetNeedsUpgrade(const NetParameter& net_param) {
return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param);
return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param)
|| NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param);
}

bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
Expand Down Expand Up @@ -655,12 +656,14 @@ void UpgradeNetDataTransformation(NetParameter* net_param) {
}

bool UpgradeV1Net(const NetParameter& v1_net_param, NetParameter* net_param) {
bool is_fully_compatible = true;
if (v1_net_param.layer_size() > 0) {
LOG(ERROR) << "Input NetParameter to be upgraded already specifies 'layer' "
<< "fields; these will be ignored for the upgrade.";
is_fully_compatible = false;
LOG(FATAL) << "Refusing to upgrade inconsistent NetParameter input; "
<< "the definition includes both 'layer' and 'layers' fields. "
<< "The current format defines 'layer' fields with string type like "
<< "layer { type: 'Layer' ... } and not layers { type: LAYER ... }. "
<< "Manually switch the definition to 'layer' format to continue.";
}
bool is_fully_compatible = true;
net_param->CopyFrom(v1_net_param);
net_param->clear_layers();
net_param->clear_layer();
Expand Down Expand Up @@ -952,29 +955,35 @@ bool NetNeedsInputUpgrade(const NetParameter& net_param) {
}

void UpgradeNetInput(NetParameter* net_param) {
LayerParameter* layer_param = net_param->add_layer();
layer_param->set_name("input");
layer_param->set_type("Input");
InputParameter* input_param = layer_param->mutable_input_param();
// Collect inputs and convert to Input layer definitions.
// If the NetParameter holds an input alone, without shape/dim, then
// it's a legacy caffemodel and simply stripping the input field is enough.
bool has_shape = net_param->input_shape_size() > 0;
// Convert input fields into a layer.
for (int i = 0; i < net_param->input_size(); ++i) {
layer_param->add_top(net_param->input(i));
if (has_shape) {
input_param->add_shape()->CopyFrom(net_param->input_shape(i));
} else {
// Turn legacy input dimensions into shape.
BlobShape* shape = input_param->add_shape();
int first_dim = i*4;
int last_dim = first_dim + 4;
for (int j = first_dim; j < last_dim; j++) {
shape->add_dim(net_param->input_dim(j));
bool has_dim = net_param->input_dim_size() > 0;
if (has_shape || has_dim) {
LayerParameter* layer_param = net_param->add_layer();
layer_param->set_name("input");
layer_param->set_type("Input");
InputParameter* input_param = layer_param->mutable_input_param();
// Convert input fields into a layer.
for (int i = 0; i < net_param->input_size(); ++i) {
layer_param->add_top(net_param->input(i));
if (has_shape) {
input_param->add_shape()->CopyFrom(net_param->input_shape(i));
} else {
// Turn legacy input dimensions into shape.
BlobShape* shape = input_param->add_shape();
int first_dim = i*4;
int last_dim = first_dim + 4;
for (int j = first_dim; j < last_dim; j++) {
shape->add_dim(net_param->input_dim(j));
}
}
}
}
// Swap input layer to beginning of net to satisfy layer dependencies.
for (int i = net_param->layer_size() - 1; i > 0; --i) {
net_param->mutable_layer(i-1)->Swap(net_param->mutable_layer(i));
// Swap input layer to beginning of net to satisfy layer dependencies.
for (int i = net_param->layer_size() - 1; i > 0; --i) {
net_param->mutable_layer(i-1)->Swap(net_param->mutable_layer(i));
}
}
// Clear inputs.
net_param->clear_input();
Expand Down
5 changes: 3 additions & 2 deletions tools/upgrade_net_proto_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using std::ofstream;
using namespace caffe; // NOLINT(build/namespaces)

int main(int argc, char** argv) {
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
::google::InitGoogleLogging(argv[0]);
if (argc != 3) {
LOG(ERROR) << "Usage: "
Expand All @@ -39,11 +40,11 @@ int main(int argc, char** argv) {
<< "see details above.";
}
} else {
LOG(ERROR) << "File already in V1 proto format: " << argv[1];
LOG(ERROR) << "File already in latest proto format: " << input_filename;
}

WriteProtoToBinaryFile(net_param, argv[2]);

LOG(ERROR) << "Wrote upgraded NetParameter binary proto to " << argv[2];
LOG(INFO) << "Wrote upgraded NetParameter binary proto to " << argv[2];
return !success;
}
8 changes: 2 additions & 6 deletions tools/upgrade_net_proto_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using std::ofstream;
using namespace caffe; // NOLINT(build/namespaces)

int main(int argc, char** argv) {
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
::google::InitGoogleLogging(argv[0]);
if (argc != 3) {
LOG(ERROR) << "Usage: "
Expand All @@ -31,7 +32,6 @@ int main(int argc, char** argv) {
return 2;
}
bool need_upgrade = NetNeedsUpgrade(net_param);
bool need_data_upgrade = NetNeedsDataUpgrade(net_param);
bool success = true;
if (need_upgrade) {
success = UpgradeNetAsNeeded(input_filename, &net_param);
Expand All @@ -43,13 +43,9 @@ int main(int argc, char** argv) {
LOG(ERROR) << "File already in latest proto format: " << input_filename;
}

if (need_data_upgrade) {
UpgradeNetDataTransformation(&net_param);
}

// Save new format prototxt.
WriteProtoToTextFile(net_param, argv[2]);

LOG(ERROR) << "Wrote upgraded NetParameter text proto to " << argv[2];
LOG(INFO) << "Wrote upgraded NetParameter text proto to " << argv[2];
return !success;
}
3 changes: 2 additions & 1 deletion tools/upgrade_solver_proto_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using std::ofstream;
using namespace caffe; // NOLINT(build/namespaces)

int main(int argc, char** argv) {
FLAGS_alsologtostderr = 1; // Print output to stderr (while still logging)
::google::InitGoogleLogging(argv[0]);
if (argc != 3) {
LOG(ERROR) << "Usage: upgrade_solver_proto_text "
Expand Down Expand Up @@ -45,6 +46,6 @@ int main(int argc, char** argv) {
// Save new format prototxt.
WriteProtoToTextFile(solver_param, argv[2]);

LOG(ERROR) << "Wrote upgraded SolverParameter text proto to " << argv[2];
LOG(INFO) << "Wrote upgraded SolverParameter text proto to " << argv[2];
return !success;
}

0 comments on commit f561682

Please sign in to comment.