@@ -718,20 +718,26 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
718
718
SHARPY::Transceiver *tc, int64_t iNDims,
719
719
int64_t *iGShapePtr, int64_t *iOffsPtr,
720
720
void *iDataPtr, int64_t *iDataShapePtr,
721
- int64_t *iDataStridesPtr, int64_t oNDims,
722
- int64_t *oGShapePtr, int64_t *oOffsPtr,
721
+ int64_t *iDataStridesPtr, int64_t *oOffsPtr,
723
722
void *oDataPtr, int64_t *oDataShapePtr,
724
723
int64_t *oDataStridesPtr, int64_t *axesPtr) {
725
724
#ifdef NO_TRANSCEIVER
726
725
initMPIRuntime ();
727
726
tc = SHARPY::getTransceiver ();
728
727
#endif
729
728
if (!iGShapePtr || !iOffsPtr || !iDataPtr || !iDataShapePtr ||
730
- !iDataStridesPtr || !oGShapePtr || !oOffsPtr || !oDataPtr ||
731
- !oDataShapePtr || ! oDataStridesPtr || !tc) {
729
+ !iDataStridesPtr || !oOffsPtr || !oDataPtr || !oDataShapePtr ||
730
+ !oDataStridesPtr || !tc) {
732
731
throw std::invalid_argument (" Fatal: received nullptr in reshape" );
733
732
}
734
733
734
+ std::vector<int64_t > oGShape (iNDims);
735
+ for (int64_t i = 0 ; i < iNDims; ++i) {
736
+ oGShape[i] = iGShapePtr[axesPtr[i]];
737
+ }
738
+ auto *oGShapePtr = oGShape.data ();
739
+ const auto oNDims = iNDims;
740
+
735
741
assert (std::accumulate (&iGShapePtr[0 ], &iGShapePtr[iNDims], 1 ,
736
742
std::multiplies<int64_t >()) ==
737
743
std::accumulate (&oGShapePtr[0 ], &oGShapePtr[oNDims], 1 ,
@@ -817,21 +823,21 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
817
823
});
818
824
819
825
int lastOffset = 0 ;
820
- for (size_t i = 0 ; i < nRanks; i ++) {
821
- sendSizes[i ] = sendRankBuffer[i ].size ();
822
- sendOffsets[i ] = lastOffset;
823
- sendBuffer.insert (sendBuffer.end (), sendRankBuffer[i ].begin (),
824
- sendRankBuffer[i ].end ());
825
- lastOffset += sendSizes[i ];
826
+ for (size_t rank = 0 ; rank < nRanks; rank ++) {
827
+ sendSizes[rank ] = sendRankBuffer[rank ].size ();
828
+ sendOffsets[rank ] = lastOffset;
829
+ sendBuffer.insert (sendBuffer.end (), sendRankBuffer[rank ].begin (),
830
+ sendRankBuffer[rank ].end ());
831
+ lastOffset += sendSizes[rank ];
826
832
}
827
833
828
834
output.localIndices ([&](const id &outputIndex) {
829
835
id inputIndex = outputIndex.permute (axes);
830
836
auto rank = getInputRank (parts, inputIndex[0 ]);
831
837
++receiveSizes[rank];
832
838
});
833
- for (size_t i = 1 ; i < nRanks; i ++) {
834
- receiveOffsets[i ] = receiveOffsets[i - 1 ] + receiveSizes[i - 1 ];
839
+ for (size_t rank = 1 ; rank < nRanks; rank ++) {
840
+ receiveOffsets[rank ] = receiveOffsets[rank - 1 ] + receiveSizes[rank - 1 ];
835
841
}
836
842
}
837
843
@@ -842,7 +848,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
842
848
843
849
{
844
850
std::vector<std::vector<T>> receiveRankBuffer (nRanks);
845
- for (int64_t rank = 0 ; rank < nRanks; ++rank) {
851
+ for (size_t rank = 0 ; rank < nRanks; ++rank) {
846
852
auto &rankBuffer = receiveRankBuffer[rank];
847
853
rankBuffer.insert (
848
854
rankBuffer.end (), receiveBuffer.begin () + receiveOffsets[rank],
@@ -866,12 +872,12 @@ template <typename T>
866
872
WaitHandleBase *
867
873
_idtr_copy_permute (SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
868
874
int64_t iNOffs, void *iLOffsDescr, int64_t iNDims,
869
- void *iDataDescr, int64_t oNSzs , void *oGShapeDescr ,
870
- int64_t oNOffs , void *oLOffsDescr , int64_t oNDims ,
871
- void *oDataDescr, int64_t axesSzs, void * axesDescr) {
875
+ void *iDataDescr, int64_t oNOffs , void *oLOffsDescr ,
876
+ int64_t oNDims , void *oDataDescr , int64_t axesSzs ,
877
+ void *axesDescr) {
872
878
873
- if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oGShapeDescr ||
874
- !oLOffsDescr || ! oDataDescr || !axesDescr) {
879
+ if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oLOffsDescr ||
880
+ !oDataDescr || !axesDescr) {
875
881
throw std::invalid_argument (
876
882
" Fatal error: received nullptr in update_halo." );
877
883
}
@@ -882,15 +888,14 @@ _idtr_copy_permute(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
882
888
MRIdx1d iGShape (iNSzs, iGShapeDescr);
883
889
MRIdx1d iOffs (iNOffs, iLOffsDescr);
884
890
SHARPY::UnrankedMemRefType<T> iData (iNDims, iDataDescr);
885
- MRIdx1d oGShape (oNSzs, oGShapeDescr);
886
891
MRIdx1d oOffs (oNOffs, oLOffsDescr);
887
892
SHARPY::UnrankedMemRefType<T> oData (oNDims, oDataDescr);
888
893
MRIdx1d axes (axesSzs, axesDescr);
889
894
890
- return _idtr_copy_permute<T>(
891
- sharpyType, tc, iNDims, iGShape .data (), iOffs .data (), iData.data (),
892
- iData. sizes (), iData.strides (), oNDims, oGShape .data (), oOffs .data (),
893
- oData. data (), oData.sizes (), oData.strides (), axes.data ());
895
+ return _idtr_copy_permute<T>(sharpyType, tc, iNDims, iGShape. data (),
896
+ iOffs .data (), iData .data (), iData.sizes (),
897
+ iData.strides (), oOffs .data (), oData .data (),
898
+ oData.sizes (), oData.strides (), axes.data ());
894
899
}
895
900
896
901
extern " C" {
@@ -919,12 +924,11 @@ TYPED_COPY_RESHAPE(i1, bool);
919
924
void *_idtr_copy_permute_##_sfx( \
920
925
SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, \
921
926
int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iLDescr, \
922
- int64_t oNSzs, void *oGShapeDescr, int64_t oNOffs, void *oLOffsDescr, \
923
- int64_t oNDims, void *oLDescr, int64_t axesSzs, void *axesDescr) { \
924
- return _idtr_copy_permute<_typ>(tc, iNSzs, iGShapeDescr, iNOffs, \
925
- iLOffsDescr, iNDims, iLDescr, oNSzs, \
926
- oGShapeDescr, oNOffs, oLOffsDescr, oNDims, \
927
- oLDescr, axesSzs, axesDescr); \
927
+ int64_t oNOffs, void *oLOffsDescr, int64_t oNDims, void *oLDescr, \
928
+ int64_t axesSzs, void *axesDescr) { \
929
+ return _idtr_copy_permute<_typ>( \
930
+ tc, iNSzs, iGShapeDescr, iNOffs, iLOffsDescr, iNDims, iLDescr, oNOffs, \
931
+ oLOffsDescr, oNDims, oLDescr, axesSzs, axesDescr); \
928
932
} \
929
933
_Pragma (STRINGIFY(weak _mlir_ciface__idtr_copy_permute_##_sfx = \
930
934
_idtr_copy_permute_##_sfx))
0 commit comments