Skip to content

Commit

Permalink
Merge pull request BVLC#2462 from longjon/correct-python-exceptions
Browse files Browse the repository at this point in the history
Handle Python layer exceptions correctly
  • Loading branch information
shelhamer committed Aug 6, 2015
2 parents d4aa5fe + 977023f commit ac6d4b6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 26 deletions.
29 changes: 4 additions & 25 deletions include/caffe/python_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,23 @@ class PythonLayer : public Layer<Dtype> {

virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
self_.attr("setup")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
self_.attr("setup")(bottom, top);
}

virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
self_.attr("reshape")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
self_.attr("reshape")(bottom, top);
}

virtual inline const char* type() const { return "Python"; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
self_.attr("forward")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
self_.attr("forward")(bottom, top);
}
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
try {
self_.attr("backward")(top, propagate_down, bottom);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
self_.attr("backward")(top, propagate_down, bottom);
}

private:
Expand Down
22 changes: 22 additions & 0 deletions python/caffe/test/test_python_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def backward(self, top, propagate_down, bottom):
bottom[0].diff[...] = 10 * top[0].diff


class ExceptionLayer(caffe.Layer):
"""A layer for checking exceptions from Python"""

def setup(self, bottom, top):
raise RuntimeError


def python_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
Expand All @@ -35,6 +42,16 @@ def python_net_file():
return f.name


def exception_net_file():
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write("""name: 'pythonnet' force_backward: true
input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
python_param { module: 'test_python_layer' layer: 'ExceptionLayer' } }
""")
return f.name


class TestPythonLayer(unittest.TestCase):
def setUp(self):
net_file = python_net_file()
Expand Down Expand Up @@ -62,3 +79,8 @@ def test_reshape(self):
for blob in six.itervalues(self.net.blobs):
for d in blob.data.shape:
self.assertEqual(s, d)

def test_exception(self):
net_file = exception_net_file()
self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
os.remove(net_file)
16 changes: 15 additions & 1 deletion tools/caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"

#ifdef WITH_PYTHON_LAYER
#include "boost/python.hpp"
namespace bp = boost::python;
#endif

using caffe::Blob;
using caffe::Caffe;
using caffe::Net;
Expand Down Expand Up @@ -304,7 +309,16 @@ int main(int argc, char** argv) {
// Run tool or show usage.
caffe::GlobalInit(&argc, &argv);
if (argc == 2) {
return GetBrewFunction(caffe::string(argv[1]))();
#ifdef WITH_PYTHON_LAYER
try {
#endif
return GetBrewFunction(caffe::string(argv[1]))();
#ifdef WITH_PYTHON_LAYER
} catch (bp::error_already_set) {
PyErr_Print();
return 1;
}
#endif
} else {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
}
Expand Down

0 comments on commit ac6d4b6

Please sign in to comment.