Skip to content

Commit 1e922c5

Browse files
authored
Merge pull request #251 from adriendelsalle/pyarray-init-list
`pyarray` initializers lists work with all layouts
2 parents 062c8c2 + 66b81ae commit 1e922c5

File tree

2 files changed

+53
-16
lines changed

2 files changed

+53
-16
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ namespace xt
224224
storage_type& storage_impl() noexcept;
225225
const storage_type& storage_impl() const noexcept;
226226

227+
layout_type default_dynamic_layout();
228+
227229
friend class xcontainer<pyarray<T, L>>;
228230
friend class pycontainer<pyarray<T, L>>;
229231
};
@@ -254,48 +256,48 @@ namespace xt
254256
inline pyarray<T, L>::pyarray(const value_type& t)
255257
: base_type()
256258
{
257-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
259+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
258260
nested_copy(m_storage.begin(), t);
259261
}
260262

261263
template <class T, layout_type L>
262264
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 1> t)
263265
: base_type()
264266
{
265-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
266-
nested_copy(m_storage.begin(), t);
267+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
268+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
267269
}
268270

269271
template <class T, layout_type L>
270272
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 2> t)
271273
: base_type()
272274
{
273-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
274-
nested_copy(m_storage.begin(), t);
275+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
276+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
275277
}
276278

277279
template <class T, layout_type L>
278280
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 3> t)
279281
: base_type()
280282
{
281-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
282-
nested_copy(m_storage.begin(), t);
283+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
284+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
283285
}
284286

285287
template <class T, layout_type L>
286288
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 4> t)
287289
: base_type()
288290
{
289-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
290-
nested_copy(m_storage.begin(), t);
291+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
292+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
291293
}
292294

293295
template <class T, layout_type L>
294296
inline pyarray<T, L>::pyarray(nested_initializer_list_t<T, 5> t)
295297
: base_type()
296298
{
297-
base_type::resize(xt::shape<shape_type>(t), layout_type::row_major);
298-
nested_copy(m_storage.begin(), t);
299+
base_type::resize(xt::shape<shape_type>(t), default_dynamic_layout());
300+
L == layout_type::row_major ? nested_copy(m_storage.begin(), t) : nested_copy(this->template begin<layout_type::row_major>(), t);
299301
}
300302

301303
template <class T, layout_type L>
@@ -443,7 +445,9 @@ namespace xt
443445
// TODO: prevent intermediary shape allocation
444446
shape_type shape = xtl::forward_sequence<shape_type, decltype(e.derived_cast().shape())>(e.derived_cast().shape());
445447
strides_type strides = xtl::make_sequence<strides_type>(shape.size(), size_type(0));
446-
compute_strides(shape, L, strides);
448+
layout_type layout = default_dynamic_layout();
449+
450+
compute_strides(shape, layout, strides);
447451
init_array(shape, strides);
448452
semantic_base::assign(e);
449453
}
@@ -559,6 +563,12 @@ namespace xt
559563
{
560564
return m_storage;
561565
}
566+
567+
template <class T, layout_type L>
568+
layout_type pyarray<T, L>::default_dynamic_layout()
569+
{
570+
return L == layout_type::dynamic ? layout_type::row_major : L;
571+
}
562572
}
563573

564574
#endif

test/test_pyarray.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,44 @@ namespace xt
3737

3838
TEST(pyarray, initializer_constructor)
3939
{
40-
pyarray<int> t
40+
pyarray<int> r
4141
{{{ 0, 1, 2},
4242
{ 3, 4, 5},
4343
{ 6, 7, 8}},
4444
{{ 9, 10, 11},
4545
{12, 13, 14},
4646
{15, 16, 17}}};
4747

48-
EXPECT_EQ(t.dimension(), 3);
49-
EXPECT_EQ(t(0, 0, 1), 1);
50-
EXPECT_EQ(t.shape()[0], 2);
48+
EXPECT_EQ(r.layout(), xt::layout_type::row_major);
49+
EXPECT_EQ(r.dimension(), 3);
50+
EXPECT_EQ(r(0, 0, 1), 1);
51+
EXPECT_EQ(r.shape()[0], 2);
52+
53+
pyarray<int, xt::layout_type::column_major> c
54+
{{{ 0, 1, 2},
55+
{ 3, 4, 5},
56+
{ 6, 7, 8}},
57+
{{ 9, 10, 11},
58+
{12, 13, 14},
59+
{15, 16, 17}}};
60+
61+
EXPECT_EQ(c.layout(), xt::layout_type::column_major);
62+
EXPECT_EQ(c.dimension(), 3);
63+
EXPECT_EQ(c(0, 0, 1), 1);
64+
EXPECT_EQ(c.shape()[0], 2);
65+
66+
pyarray<int, xt::layout_type::dynamic> d
67+
{{{ 0, 1, 2},
68+
{ 3, 4, 5},
69+
{ 6, 7, 8}},
70+
{{ 9, 10, 11},
71+
{12, 13, 14},
72+
{15, 16, 17}}};
73+
74+
EXPECT_EQ(d.layout(), xt::layout_type::row_major);
75+
EXPECT_EQ(d.dimension(), 3);
76+
EXPECT_EQ(d(0, 0, 1), 1);
77+
EXPECT_EQ(d.shape()[0], 2);
5178
}
5279

5380
TEST(pyarray, shaped_constructor)

0 commit comments

Comments
 (0)