Skip to content

Commit 544a5bb

Browse files
authored
Merge pull request #187 from JohanMabille/zerod
fixed 0d array initialization
2 parents b5aca15 + f0ce20f commit 544a5bb

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ namespace xt
761761

762762
m_backstrides = backstrides_type(*this);
763763
m_storage = storage_type(reinterpret_cast<pointer>(PyArray_DATA(this->python_array())),
764-
this->get_min_stride() * static_cast<size_type>(PyArray_SIZE(this->python_array())));
764+
this->get_buffer_size());
765765
}
766766

767767
template <class T, layout_type L>

include/xtensor-python/pycontainer.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ namespace xt
122122
const derived_type& derived_cast() const;
123123

124124
PyArrayObject* python_array() const;
125-
size_type get_min_stride() const;
125+
size_type get_buffer_size() const;
126126
};
127127

128128
namespace detail
@@ -297,10 +297,15 @@ namespace xt
297297
}
298298

299299
template <class D>
300-
inline auto pycontainer<D>::get_min_stride() const -> size_type
300+
inline auto pycontainer<D>::get_buffer_size() const -> size_type
301301
{
302302
const size_type& (*min)(const size_type&, const size_type&) = std::min<size_type>;
303-
return std::max(size_type(1), std::accumulate(this->strides().cbegin(), this->strides().cend(), std::numeric_limits<size_type>::max(), min));
303+
size_type min_stride = this->strides().empty() ? size_type(1) :
304+
std::max(size_type(1), std::accumulate(this->strides().cbegin(),
305+
this->strides().cend(),
306+
std::numeric_limits<size_type>::max(),
307+
min));
308+
return min_stride * static_cast<size_type>(PyArray_SIZE(this->python_array()));
304309
}
305310

306311
template <class D>

include/xtensor-python/pytensor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ namespace xt
463463
}
464464

465465
m_storage = storage_type(reinterpret_cast<pointer>(PyArray_DATA(this->python_array())),
466-
this->get_min_stride() * static_cast<size_type>(PyArray_SIZE(this->python_array())));
466+
this->get_buffer_size());
467467
}
468468

469469
template <class T, std::size_t N, layout_type L>

test/test_pyarray.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,4 +249,11 @@ namespace xt
249249
auto v = xt::view(arr, xt::all());
250250
EXPECT_EQ(v(0), 0.);
251251
}
252+
253+
TEST(pyarray, zerod_copy)
254+
{
255+
xt::pyarray<int> arr = 2;
256+
xt::pyarray<int> arr2(arr);
257+
EXPECT_EQ(arr(), arr2());
258+
}
252259
}

0 commit comments

Comments
 (0)