Skip to content

Commit e46af78

Browse files
authored
sdl (#103)
* more overflow checks, more specific exceptions, dead code removal * testing invalid arange
1 parent 0449e3f commit e46af78

25 files changed

+169
-292
lines changed

src/Creator.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct DeferredFull : public Deferred {
6666
} else if constexpr (std::is_integral_v<T>) {
6767
return ::imex::createInt(loc, builder, val._int, sizeof(T) * 8);
6868
}
69-
throw std::runtime_error("Unsupported dtype in dispatch");
69+
throw std::invalid_argument("Unsupported dtype in dispatch");
7070
return {};
7171
};
7272
};
@@ -131,7 +131,14 @@ struct DeferredArange : public Deferred {
131131
{static_cast<shape_type::value_type>(
132132
(end - start + step + (step < 0 ? 1 : -1)) / step)},
133133
device, team),
134-
_start(start), _end(end), _step(step) {}
134+
_start(start), _end(end), _step(step) {
135+
if (_start > _end && _step > -1ul) {
136+
throw std::invalid_argument("start > end and step > -1 in arange");
137+
}
138+
if (_start < _end && _step < 1) {
139+
throw std::invalid_argument("start < end and step < 1 in arange");
140+
}
141+
}
135142

136143
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
137144
jit::DepManager &dm) override {
@@ -255,7 +262,7 @@ std::pair<FutureArray *, bool> Creator::mk_future(const py::object &b,
255262
} else if (py::isinstance<py::float_>(b) || py::isinstance<py::int_>(b)) {
256263
return {Creator::full({}, b, dtype, device, team), true};
257264
}
258-
throw std::runtime_error(
265+
throw std::invalid_argument(
259266
"Invalid right operand to elementwise binary operation");
260267
};
261268

src/Deferred.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Deferred::future_type Deferred::get_future() {
5252
Deferred::future_type defer_array(Runable::ptr_type &&_d, bool is_global) {
5353
Deferred *d = dynamic_cast<Deferred *>(_d.get());
5454
if (!d)
55-
throw std::runtime_error("Expected Deferred Array promise");
55+
throw std::invalid_argument("Expected Deferred Array promise");
5656
if (is_global) {
5757
_dist(d);
5858
if (d->guid() == Registry::NOGUID) {

src/EWBinOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ static ::imex::ndarray::EWBinOpId sharpy2mlir(const EWBinOpId bop) {
7979
case __RXOR__:
8080
return ::imex::ndarray::BITWISE_XOR;
8181
default:
82-
throw std::runtime_error("Unknown/invalid elementwise binary operation");
82+
throw std::invalid_argument("Unknown/invalid elementwise binary operation");
8383
}
8484
}
8585

src/EWUnyOp.cpp

Lines changed: 1 addition & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -14,113 +14,6 @@
1414
#include <imex/Dialect/Dist/IR/DistOps.h>
1515

1616
namespace SHARPY {
17-
#if 0
18-
namespace x {
19-
20-
class EWUnyOp
21-
{
22-
public:
23-
using ptr_type = DNDArrayBaseX::ptr_type;
24-
25-
template<typename T>
26-
static ptr_type op(EWUnyOpId uop, const std::shared_ptr<DNDArrayX<T>> & a_ptr)
27-
{
28-
const auto & ax = a_ptr->xarray();
29-
if(a_ptr->is_sliced()) {
30-
const auto & av = xt::strided_view(ax, a_ptr->lslice());
31-
return do_op(uop, av, a_ptr);
32-
}
33-
return do_op(uop, ax, a_ptr);
34-
}
35-
36-
#pragma GCC diagnostic ignored "-Wswitch"
37-
template<typename T1, typename T>
38-
static ptr_type do_op(EWUnyOpId uop, const T1 & a, const std::shared_ptr<DNDArrayX<T>> & a_ptr)
39-
{
40-
switch(uop) {
41-
case __ABS__:
42-
case ABS:
43-
return operatorx<T>::mk_tx_(a_ptr, xt::abs(a));
44-
case ACOS:
45-
return operatorx<T>::mk_tx_(a_ptr, xt::acos(a));
46-
case ACOSH:
47-
return operatorx<T>::mk_tx_(a_ptr, xt::acosh(a));
48-
case ASIN:
49-
return operatorx<T>::mk_tx_(a_ptr, xt::asin(a));
50-
case ASINH:
51-
return operatorx<T>::mk_tx_(a_ptr, xt::asinh(a));
52-
case ATAN:
53-
return operatorx<T>::mk_tx_(a_ptr, xt::atan(a));
54-
case ATANH:
55-
return operatorx<T>::mk_tx_(a_ptr, xt::atanh(a));
56-
case CEIL:
57-
return operatorx<T>::mk_tx_(a_ptr, xt::ceil(a));
58-
case COS:
59-
return operatorx<T>::mk_tx_(a_ptr, xt::cos(a));
60-
case COSH:
61-
return operatorx<T>::mk_tx_(a_ptr, xt::cosh(a));
62-
case EXP:
63-
return operatorx<T>::mk_tx_(a_ptr, xt::exp(a));
64-
case EXPM1:
65-
return operatorx<T>::mk_tx_(a_ptr, xt::expm1(a));
66-
case FLOOR:
67-
return operatorx<T>::mk_tx_(a_ptr, xt::floor(a));
68-
case ISFINITE:
69-
return operatorx<T>::mk_tx_(a_ptr, xt::isfinite(a));
70-
case ISINF:
71-
return operatorx<T>::mk_tx_(a_ptr, xt::isinf(a));
72-
case ISNAN:
73-
return operatorx<T>::mk_tx_(a_ptr, xt::isnan(a));
74-
case LOG:
75-
return operatorx<T>::mk_tx_(a_ptr, xt::log(a));
76-
case LOG1P:
77-
return operatorx<T>::mk_tx_(a_ptr, xt::log1p(a));
78-
case LOG2:
79-
return operatorx<T>::mk_tx_(a_ptr, xt::log2(a));
80-
case LOG10:
81-
return operatorx<T>::mk_tx_(a_ptr, xt::log10(a));
82-
case ROUND:
83-
return operatorx<T>::mk_tx_(a_ptr, xt::round(a));
84-
case SIGN:
85-
return operatorx<T>::mk_tx_(a_ptr, xt::sign(a));
86-
case SIN:
87-
return operatorx<T>::mk_tx_(a_ptr, xt::sin(a));
88-
case SINH:
89-
return operatorx<T>::mk_tx_(a_ptr, xt::sinh(a));
90-
case SQUARE:
91-
return operatorx<T>::mk_tx_(a_ptr, xt::square(a));
92-
case SQRT:
93-
return operatorx<T>::mk_tx_(a_ptr, xt::sqrt(a));
94-
case TAN:
95-
return operatorx<T>::mk_tx_(a_ptr, xt::tan(a));
96-
case TANH:
97-
return operatorx<T>::mk_tx_(a_ptr, xt::tanh(a));
98-
case TRUNC:
99-
return operatorx<T>::mk_tx_(a_ptr, xt::trunc(a));
100-
case ERF:
101-
return operatorx<T>::mk_tx_(a_ptr, xt::erf(a));
102-
case __NEG__:
103-
case NEGATIVE:
104-
case __POS__:
105-
case POSITIVE:
106-
case LOGICAL_NOT:
107-
// FIXME
108-
throw std::runtime_error("Unary operation not implemented");
109-
}
110-
if constexpr (std::is_integral<T>::value) {
111-
switch(uop) {
112-
case __INVERT__:
113-
case BITWISE_INVERT:
114-
throw std::runtime_error("Unary operation not implemented");
115-
}
116-
}
117-
throw std::runtime_error("Unknown/invalid elementwise unary operation");
118-
}
119-
#pragma GCC diagnostic pop
120-
121-
};
122-
} //namespace x
123-
#endif // if 0
12417

12518
// convert id of our unary op to id of imex::ndarray unary op
12619
static ::imex::ndarray::EWUnyOpId sharpy(const EWUnyOpId uop) {
@@ -195,7 +88,7 @@ static ::imex::ndarray::EWUnyOpId sharpy(const EWUnyOpId uop) {
19588
case POSITIVE:
19689
return ::imex::ndarray::POSITIVE;
19790
default:
198-
throw std::runtime_error("Unknown/invalid elementwise unary operation");
91+
throw std::invalid_argument("Unknown/invalid elementwise unary operation");
19992
}
20093
}
20194

src/IEWBinOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ static ::imex::ndarray::EWBinOpId sharpy2mlir(const IEWBinOpId bop) {
4949
case __IXOR__:
5050
return ::imex::ndarray::BITWISE_XOR;
5151
default:
52-
throw std::runtime_error(
52+
throw std::invalid_argument(
5353
"Unknown/invalid inplace elementwise binary operation");
5454
}
5555
}

src/IO.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct DeferredFromLocal : public Deferred {
5050
return FLOAT64;
5151
};
5252
};
53-
throw std::runtime_error("Unsupported dtype");
53+
throw std::invalid_argument("Unsupported dtype");
5454
}
5555

5656
void run() override {

src/MPIMediator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void send_to_workers(const Runable *dfrd, bool self, MPI_Comm comm) {
114114
}
115115
ser.adapter().flush();
116116
if (ser.adapter().writtenBytesCount() > INT_MAX) {
117-
throw std::runtime_error("Message too large for MPI (int count).");
117+
throw std::out_of_range("Message too large for MPI (int count).");
118118
}
119119
int cnt = static_cast<int>(ser.adapter().writtenBytesCount());
120120

src/ManipOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct DeferredAsType : public Deferred {
107107
auto mlirElType = ::imex::ndarray::toMLIR(builder, ndDType);
108108
auto arType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
109109
if (!arType) {
110-
throw std::runtime_error(
110+
throw std::invalid_argument(
111111
"Encountered unexpected ndarray type in astype.");
112112
}
113113
auto outType = arType.cloneWith(std::nullopt, mlirElType);
@@ -158,7 +158,7 @@ struct DeferredToDevice : public Deferred {
158158

159159
auto srcType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
160160
if (!srcType) {
161-
throw std::runtime_error(
161+
throw std::invalid_argument(
162162
"Encountered unexpected ndarray type in to_device.");
163163
}
164164
// copy envs, drop gpu env (if any)

src/ReduceOp.cpp

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,70 +16,6 @@
1616

1717
namespace SHARPY {
1818

19-
#if 0
20-
namespace x {
21-
22-
class ReduceOp
23-
{
24-
public:
25-
using ptr_type = DNDArrayBaseX::ptr_type;
26-
27-
template<typename X>
28-
static ptr_type dist_reduce(ReduceOpId rop, const PVSlice & slice, const dim_vec_type & dims, X && x)
29-
{
30-
xt::xarray<typename X::value_type> a = x;
31-
auto new_shape = reduce_shape(slice.shape(), dims);
32-
rank_type owner = NOOWNER;
33-
if(slice.need_reduce(dims)) {
34-
auto len = VPROD(new_shape);
35-
getTransceiver()->reduce_all(a.data(), DTYPE<typename X::value_type>::value, len, rop);
36-
if(len == 1) return operatorx<typename X::value_type>::mk_tx(a.data()[0], REPLICATED);
37-
owner = REPLICATED;
38-
}
39-
return operatorx<typename X::value_type>::mk_tx(std::move(new_shape), a, owner);
40-
}
41-
42-
template<typename T>
43-
static ptr_type op(ReduceOpId rop, const dim_vec_type & dims, const std::shared_ptr<DNDArrayX<T>> & a_ptr)
44-
{
45-
const auto & ax = a_ptr->xarray();
46-
if(a_ptr->is_sliced()) {
47-
const auto & av = xt::strided_view(ax, a_ptr->lslice());
48-
return do_op(rop, dims, av, a_ptr);
49-
}
50-
return do_op(rop, dims, ax, a_ptr);
51-
}
52-
53-
#pragma GCC diagnostic ignored "-Wswitch"
54-
template<typename T1, typename T>
55-
static ptr_type do_op(ReduceOpId rop, const dim_vec_type & dims, const T1 & a, const std::shared_ptr<DNDArrayX<T>> & a_ptr)
56-
{
57-
switch(rop) {
58-
case MEAN:
59-
return dist_reduce(rop, a_ptr->slice(), dims, xt::mean(a, dims));
60-
case PROD:
61-
return dist_reduce(rop, a_ptr->slice(), dims, xt::prod(a, dims));
62-
case SUM:
63-
return dist_reduce(rop, a_ptr->slice(), dims, xt::sum(a, dims));
64-
case STD:
65-
return dist_reduce(rop, a_ptr->slice(), dims, xt::stddev(a, dims));
66-
case VAR:
67-
return dist_reduce(rop, a_ptr->slice(), dims, xt::variance(a, dims));
68-
case MAX:
69-
return dist_reduce(rop, a_ptr->slice(), dims, xt::amax(a, dims));
70-
case MIN:
71-
return dist_reduce(rop, a_ptr->slice(), dims, xt::amin(a, dims));
72-
default:
73-
throw std::runtime_error("Unknown reduction operation");
74-
}
75-
}
76-
77-
#pragma GCC diagnostic pop
78-
79-
};
80-
} // namespace x
81-
#endif // if 0
82-
8319
// convert id of our reduction op to id of imex::ndarray reduction op
8420
static ::imex::ndarray::ReduceOpId sharpy2mlir(const ReduceOpId rop) {
8521
switch (rop) {
@@ -98,7 +34,7 @@ static ::imex::ndarray::ReduceOpId sharpy2mlir(const ReduceOpId rop) {
9834
case MIN:
9935
return ::imex::ndarray::MIN;
10036
default:
101-
throw std::runtime_error("Unknown reduction operation");
37+
throw std::invalid_argument("Unknown reduction operation");
10238
}
10339
}
10440

src/Service.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct DeferredService : public DeferredT<Service::service_promise_type,
3939
set_value(true);
4040
break;
4141
default:
42-
throw(std::runtime_error(
42+
throw(std::invalid_argument(
4343
"Execution of unkown service operation requested."));
4444
}
4545
}
@@ -57,7 +57,7 @@ struct DeferredService : public DeferredT<Service::service_promise_type,
5757
case RUN:
5858
return true;
5959
default:
60-
throw(std::runtime_error(
60+
throw(std::invalid_argument(
6161
"MLIR generation for unkown service operation requested."));
6262
}
6363

@@ -84,7 +84,7 @@ struct DeferredReplicate : public Deferred {
8484
const auto a = std::move(Registry::get(_a).get());
8585
auto ary = dynamic_cast<NDArray *>(a.get());
8686
if (!ary) {
87-
throw std::runtime_error("Expected NDArray in replicate.");
87+
throw std::invalid_argument("Expected NDArray in replicate.");
8888
}
8989
ary->replicate();
9090
set_value(a);

src/SetGetItem.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ py::handle wrap(NDArray::ptr_type tnsr, const py::handle &handle) {
5252
auto tmp_shp = tnsr->local_shape();
5353
auto tmp_str = tnsr->local_strides();
5454
auto nd = tnsr->ndims();
55-
auto eSz = sizeof_dtype(tnsr->dtype());
55+
int64_t eSz = sizeof_dtype(tnsr->dtype());
5656
std::vector<ssize_t> strides(nd);
5757
for (auto i = 0; i < nd; ++i) {
5858
strides[i] = eSz * tmp_str[i];
59+
if (strides[i] / tmp_str[i] != eSz) {
60+
throw std::overflow_error("Fatal: Integer overflow.");
61+
}
5962
}
6063

6164
return dispatch<wrap_array>(tnsr->dtype(),
@@ -85,7 +88,7 @@ struct DeferredGetLocals
8588
auto aa = std::move(Registry::get(_a).get());
8689
auto a_ptr = std::dynamic_pointer_cast<NDArray>(aa);
8790
if (!a_ptr) {
88-
throw std::runtime_error("Expected NDArray in getlocals.");
91+
throw std::invalid_argument("Expected NDArray in getlocals.");
8992
}
9093
auto res = wrap(a_ptr, _handle);
9194
auto tpl = py::make_tuple(py::reinterpret_steal<py::object>(res));
@@ -122,7 +125,7 @@ struct DeferredGather
122125
auto aa = std::move(Registry::get(_a).get());
123126
auto a_ptr = std::dynamic_pointer_cast<NDArray>(aa);
124127
if (!a_ptr) {
125-
throw std::runtime_error("Expected NDArray in gather.");
128+
throw std::invalid_argument("Expected NDArray in gather.");
126129
}
127130
auto trscvr = a_ptr->transceiver();
128131
auto myrank = trscvr ? trscvr->rank() : 0;
@@ -209,7 +212,7 @@ struct DeferredMap : public Deferred {
209212
auto aa = std::move(Registry::get(_a).get());
210213
auto a_ptr = std::dynamic_pointer_cast<NDArray>(aa);
211214
if (!a_ptr) {
212-
throw std::runtime_error("Expected NDArray in map.");
215+
throw std::invalid_argument("Expected NDArray in map.");
213216
}
214217
auto nd = a_ptr->ndims();
215218
auto lOffs = a_ptr->local_offsets();
@@ -222,6 +225,9 @@ struct DeferredMap : public Deferred {
222225
[&](const std::vector<int64_t> &idx, auto *elPtr) {
223226
for (auto i = 0; i < nd; ++i) {
224227
gIdx[i] = lOffs.empty() ? idx[i] : idx[i] + lOffs[i];
228+
if (gIdx[i] < idx[i]) {
229+
throw std::overflow_error("Fatal: Integer overflow in map.");
230+
}
225231
}
226232
auto pyIdx = _make_tuple(gIdx);
227233
*elPtr =

0 commit comments

Comments
 (0)