Skip to content

Commit 267d929

Browse files
committed
wip
1 parent d06819d commit 267d929

File tree

4 files changed

+37
-31
lines changed

4 files changed

+37
-31
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ include_directories(
152152
${LLVM_INCLUDE_DIRS}
153153
${MLIR_INCLUDE_DIRS}
154154
${IMEX_INCLUDE_DIRS}
155-
"/export/users/yzhao/work/sharpy_ws/builds/debug/tools/Imex/include")
155+
"/localdisk2/yzhao/work/sharpy_ws/builds/debug/tools/Imex/include")
156156

157157
if (CMAKE_SYSTEM_NAME STREQUAL Linux)
158158
target_link_options(_sharpy PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/export.txt")

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ def build_cmake(self, ext):
4343
build_args = ["--config", config]
4444

4545
os.chdir(str(build_temp))
46+
print('!!!!!!!!!!', ["cmake", str(cwd)] + cmake_args)
4647
self.spawn(["cmake", str(cwd)] + cmake_args)
4748
if not self.dry_run:
49+
print('!!!!!!!!!!', ["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"] + build_args)
4850
self.spawn(
4951
["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"]
5052
+ build_args

src/ManipOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ struct DeferredPermuteDims : public Deferred {
219219
jit::DepManager &dm) override {
220220
auto arrayValue = dm.getDependent(builder, Registry::get(_array));
221221

222-
auto axesAttr = builder.getI64ArrayAttr(_axes);
222+
auto axesAttr = builder.getDenseI64ArrayAttr(_axes);
223223

224224
auto aTyp =
225225
::mlir::cast<::imex::ndarray::NDArrayType>(arrayValue.getType());

src/idtr.cpp

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -718,20 +718,26 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
718718
SHARPY::Transceiver *tc, int64_t iNDims,
719719
int64_t *iGShapePtr, int64_t *iOffsPtr,
720720
void *iDataPtr, int64_t *iDataShapePtr,
721-
int64_t *iDataStridesPtr, int64_t oNDims,
722-
int64_t *oGShapePtr, int64_t *oOffsPtr,
721+
int64_t *iDataStridesPtr, int64_t *oOffsPtr,
723722
void *oDataPtr, int64_t *oDataShapePtr,
724723
int64_t *oDataStridesPtr, int64_t *axesPtr) {
725724
#ifdef NO_TRANSCEIVER
726725
initMPIRuntime();
727726
tc = SHARPY::getTransceiver();
728727
#endif
729728
if (!iGShapePtr || !iOffsPtr || !iDataPtr || !iDataShapePtr ||
730-
!iDataStridesPtr || !oGShapePtr || !oOffsPtr || !oDataPtr ||
731-
!oDataShapePtr || !oDataStridesPtr || !tc) {
729+
!iDataStridesPtr || !oOffsPtr || !oDataPtr || !oDataShapePtr ||
730+
!oDataStridesPtr || !tc) {
732731
throw std::invalid_argument("Fatal: received nullptr in reshape");
733732
}
734733

734+
std::vector<int64_t> oGShape(iNDims);
735+
for (int64_t i = 0; i < iNDims; ++i) {
736+
oGShape[i] = iGShapePtr[axesPtr[i]];
737+
}
738+
auto *oGShapePtr = oGShape.data();
739+
const auto oNDims = iNDims;
740+
735741
assert(std::accumulate(&iGShapePtr[0], &iGShapePtr[iNDims], 1,
736742
std::multiplies<int64_t>()) ==
737743
std::accumulate(&oGShapePtr[0], &oGShapePtr[oNDims], 1,
@@ -817,21 +823,21 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
817823
});
818824

819825
int lastOffset = 0;
820-
for (size_t i = 0; i < nRanks; i++) {
821-
sendSizes[i] = sendRankBuffer[i].size();
822-
sendOffsets[i] = lastOffset;
823-
sendBuffer.insert(sendBuffer.end(), sendRankBuffer[i].begin(),
824-
sendRankBuffer[i].end());
825-
lastOffset += sendSizes[i];
826+
for (size_t rank = 0; rank < nRanks; rank++) {
827+
sendSizes[rank] = sendRankBuffer[rank].size();
828+
sendOffsets[rank] = lastOffset;
829+
sendBuffer.insert(sendBuffer.end(), sendRankBuffer[rank].begin(),
830+
sendRankBuffer[rank].end());
831+
lastOffset += sendSizes[rank];
826832
}
827833

828834
output.localIndices([&](const id &outputIndex) {
829835
id inputIndex = outputIndex.permute(axes);
830836
auto rank = getInputRank(parts, inputIndex[0]);
831837
++receiveSizes[rank];
832838
});
833-
for (size_t i = 1; i < nRanks; i++) {
834-
receiveOffsets[i] = receiveOffsets[i - 1] + receiveSizes[i - 1];
839+
for (size_t rank = 1; rank < nRanks; rank++) {
840+
receiveOffsets[rank] = receiveOffsets[rank - 1] + receiveSizes[rank - 1];
835841
}
836842
}
837843

@@ -842,7 +848,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
842848

843849
{
844850
std::vector<std::vector<T>> receiveRankBuffer(nRanks);
845-
for (int64_t rank = 0; rank < nRanks; ++rank) {
851+
for (size_t rank = 0; rank < nRanks; ++rank) {
846852
auto &rankBuffer = receiveRankBuffer[rank];
847853
rankBuffer.insert(
848854
rankBuffer.end(), receiveBuffer.begin() + receiveOffsets[rank],
@@ -866,12 +872,12 @@ template <typename T>
866872
WaitHandleBase *
867873
_idtr_copy_permute(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
868874
int64_t iNOffs, void *iLOffsDescr, int64_t iNDims,
869-
void *iDataDescr, int64_t oNSzs, void *oGShapeDescr,
870-
int64_t oNOffs, void *oLOffsDescr, int64_t oNDims,
871-
void *oDataDescr, int64_t axesSzs, void *axesDescr) {
875+
void *iDataDescr, int64_t oNOffs, void *oLOffsDescr,
876+
int64_t oNDims, void *oDataDescr, int64_t axesSzs,
877+
void *axesDescr) {
872878

873-
if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oGShapeDescr ||
874-
!oLOffsDescr || !oDataDescr || !axesDescr) {
879+
if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oLOffsDescr ||
880+
!oDataDescr || !axesDescr) {
875881
throw std::invalid_argument(
876882
"Fatal error: received nullptr in update_halo.");
877883
}
@@ -882,15 +888,14 @@ _idtr_copy_permute(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
882888
MRIdx1d iGShape(iNSzs, iGShapeDescr);
883889
MRIdx1d iOffs(iNOffs, iLOffsDescr);
884890
SHARPY::UnrankedMemRefType<T> iData(iNDims, iDataDescr);
885-
MRIdx1d oGShape(oNSzs, oGShapeDescr);
886891
MRIdx1d oOffs(oNOffs, oLOffsDescr);
887892
SHARPY::UnrankedMemRefType<T> oData(oNDims, oDataDescr);
888893
MRIdx1d axes(axesSzs, axesDescr);
889894

890-
return _idtr_copy_permute<T>(
891-
sharpyType, tc, iNDims, iGShape.data(), iOffs.data(), iData.data(),
892-
iData.sizes(), iData.strides(), oNDims, oGShape.data(), oOffs.data(),
893-
oData.data(), oData.sizes(), oData.strides(), axes.data());
895+
return _idtr_copy_permute<T>(sharpyType, tc, iNDims, iGShape.data(),
896+
iOffs.data(), iData.data(), iData.sizes(),
897+
iData.strides(), oOffs.data(), oData.data(),
898+
oData.sizes(), oData.strides(), axes.data());
894899
}
895900

896901
extern "C" {
@@ -919,12 +924,11 @@ TYPED_COPY_RESHAPE(i1, bool);
919924
void *_idtr_copy_permute_##_sfx( \
920925
SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, \
921926
int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iLDescr, \
922-
int64_t oNSzs, void *oGShapeDescr, int64_t oNOffs, void *oLOffsDescr, \
923-
int64_t oNDims, void *oLDescr, int64_t axesSzs, void *axesDescr) { \
924-
return _idtr_copy_permute<_typ>(tc, iNSzs, iGShapeDescr, iNOffs, \
925-
iLOffsDescr, iNDims, iLDescr, oNSzs, \
926-
oGShapeDescr, oNOffs, oLOffsDescr, oNDims, \
927-
oLDescr, axesSzs, axesDescr); \
927+
int64_t oNOffs, void *oLOffsDescr, int64_t oNDims, void *oLDescr, \
928+
int64_t axesSzs, void *axesDescr) { \
929+
return _idtr_copy_permute<_typ>( \
930+
tc, iNSzs, iGShapeDescr, iNOffs, iLOffsDescr, iNDims, iLDescr, oNOffs, \
931+
oLOffsDescr, oNDims, oLDescr, axesSzs, axesDescr); \
928932
} \
929933
_Pragma(STRINGIFY(weak _mlir_ciface__idtr_copy_permute_##_sfx = \
930934
_idtr_copy_permute_##_sfx))

0 commit comments

Comments
 (0)