Skip to content

Commit

Permalink
Wip
Browse files Browse the repository at this point in the history
Signed-off-by: Geoffroy Jamgotchian <[email protected]>
  • Loading branch information
geofjamg committed Nov 22, 2023
1 parent 97625a7 commit 22a848c
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 81 deletions.
8 changes: 7 additions & 1 deletion src/jniwrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,16 @@ class IntArray : public JniWrapper<jintArray> {
_ptr(nullptr) {
}

IntArray(JNIEnv* env, int length)
: IntArray(env, nullptr, length) {
}

IntArray(JNIEnv* env, int* ptr, int length) :
JniWrapper<jintArray>(env, env->NewIntArray(length)),
_ptr(nullptr) {
_env->SetIntArrayRegion(_obj, 0, length, (const jint*) ptr);
if (ptr) {
_env->SetIntArrayRegion(_obj, 0, length, (const jint*) ptr);
}
}

~IntArray() override {
Expand Down
106 changes: 31 additions & 75 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void ComPowsyblMathSolverNewtonKrylovSolverContext::init(JNIEnv* env) {
_logError = env->GetMethodID(_cls, "logError", "(ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;)V");
_logInfo = env->GetMethodID(_cls, "logInfo", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V");
_updateFunc = env->GetMethodID(_cls, "updateFunc", "([D[D)V");
_updateJac = env->GetMethodID(_cls, "updateJac", "([D)V");
_updateJac = env->GetMethodID(_cls, "updateJac", "([D[I[I[D)V");
}

ComPowsyblMathSolverNewtonKrylovSolverContext::ComPowsyblMathSolverNewtonKrylovSolverContext(JNIEnv* env, jobject obj)
Expand Down Expand Up @@ -69,16 +69,22 @@ void copyToJava(JNIEnv* env, jdoubleArray ja, const std::vector<double>& v) {
std::memcpy(a.get(), v.data(), v.size() * sizeof(double));
}

void ComPowsyblMathSolverNewtonKrylovSolverContext::updateFunc(double* x, double* f, int length) {
DoubleArray jx(_env, x, length);
DoubleArray jf(_env, length);
void ComPowsyblMathSolverNewtonKrylovSolverContext::updateFunc(double* x, double* f, int n) {
DoubleArray jx(_env, x, n);
DoubleArray jf(_env, n);
_env->CallVoidMethod(_obj, _updateFunc, jx.obj(), jf.obj());
std::memcpy(f, jf.get(), length * sizeof(double));
std::memcpy(f, jf.get(), n * sizeof(double));
}

void ComPowsyblMathSolverNewtonKrylovSolverContext::updateJac(double* x, int length) {
DoubleArray jx(_env, x, length);
_env->CallVoidMethod(_obj, _updateJac, jx.obj());
void ComPowsyblMathSolverNewtonKrylovSolverContext::updateJac(double* x, int n, int* ap, int* ai, double* ax, int nnz) {
DoubleArray jx(_env, x, n);
IntArray jap(_env, n + 1);
IntArray jai(_env, nnz);
DoubleArray jax(_env, nnz);
_env->CallVoidMethod(_obj, _updateJac, jx.obj(), jap.obj(), jai.obj(), jax.obj());
std::memcpy(ap, jap.get(), (n + 1) * sizeof(int));
std::memcpy(ai, jai.get(), nnz * sizeof(int));
std::memcpy(ax, jax.get(), nnz * sizeof(double));
}

} // namespace jni
Expand All @@ -97,12 +103,12 @@ class NewtonKrylovSolverContext {
_delegate.logInfo(module, function, message);
}

void updateFunc(double* x, double* f, int length) {
_delegate.updateFunc(x, f, length);
void updateFunc(double* x, double* f, int n) {
_delegate.updateFunc(x, f, n);
}

void updateJac(double* x, int length) {
_delegate.updateJac(x, length);
void updateJac(double* x, int n, int* ap, int* ai, double* ax, int nnz) {
_delegate.updateJac(x, n, ap, ai, ax, nnz);
}

private:
Expand All @@ -113,44 +119,20 @@ static int evalFunc(N_Vector x, N_Vector f, void* user_data) {
NewtonKrylovSolverContext* solverContext = (NewtonKrylovSolverContext*) user_data;
double* xd = N_VGetArrayPointer(x);
double* fd = N_VGetArrayPointer(f);
int length = NV_LENGTH_S(x);
solverContext->updateFunc(xd, fd, length);
int n = NV_LENGTH_S(x);
solverContext->updateFunc(xd, fd, n);
return 0;
}

static int evalJac(N_Vector x, N_Vector f, SUNMatrix j, void* user_data, N_Vector tmp1, N_Vector tmp2) {
NewtonKrylovSolverContext* solverContext = (NewtonKrylovSolverContext*) user_data;
double* xd = N_VGetArrayPointer(x);
int length = NV_LENGTH_S(x);
solverContext->updateJac(xd, length);
double* xData = N_VGetArrayPointer(x);
double v2 = xData[0];
double ph2 = xData[1];

sunindextype* colPtrs = SUNSparseMatrix_IndexPointers(j);
sunindextype* rowVals = SUNSparseMatrix_IndexValues(j);
double* data = SUNSparseMatrix_Data(j);

// p2: 0 = 0.02 + v2 * 0.1 * sin(ph2)
// q2: 0 = 0.01 + v2 * 0.1 (-cos(ph2) + v2)
double dp2dv2 = 0.1 * std::sin(ph2);
double dp2dph2 = v2 * 0.1 * std::cos(ph2);
double dq2dv2 = - 0.1 * cos(ph2) + 2 * v2 * 0.1;
double dq2dph2 = v2 * 0.1 * std::sin(ph2);

SUNMatZero(j);

colPtrs[0] = 0;
colPtrs[1] = 2;
colPtrs[2] = 4;
data[0] = dp2dv2;
data[1] = dp2dph2;
data[2] = dq2dv2;
data[3] = dq2dph2;
rowVals[0] = 0;
rowVals[1] = 1;
rowVals[2] = 0;
rowVals[3] = 1;
int n = NV_LENGTH_S(x);
int* ap = SUNSparseMatrix_IndexPointers(j);
int* ai = SUNSparseMatrix_IndexValues(j);
double* ax = SUNSparseMatrix_Data(j);
int nnz = SM_NNZ_S(j);
solverContext->updateJac(xd, n, ap, ai, ax, nnz);
return 0;
}

Expand All @@ -164,27 +146,7 @@ static void infoHandler(const char* module, const char* function, char* msg, voi
solverContext->logInfo(module, function, msg);
}

SUNMatrix createSparseMatrix(int n, int nnz, int* ap, int* ai, double* ax, SUNContext& sunCtx) {
SUNMatrix m = SUNSparseMatrix(n, n, nnz, CSC_MAT, sunCtx);
// TODO could be optimized by full re-implementing SUNSparseMatrix so that we don't need to free this arrays
free(SM_INDEXPTRS_S(m));
free(SM_INDEXVALS_S(m));
free(SM_DATA_S(m));
SM_INDEXPTRS_S(m) = ap;
SM_INDEXVALS_S(m) = ai;
SM_DATA_S(m) = ax;
return m;
}

void destroySparseMatrix(SUNMatrix& m) {
// destruction of CSC internal structure is done of Java side
SM_INDEXPTRS_S(m) = nullptr;
SM_INDEXVALS_S(m) = nullptr;
SM_DATA_S(m) = nullptr;
SUNMatDestroy(m);
}

void solve(std::vector<double>& xd, int nnz, int* ap, int* ai, double* ax, powsybl::NewtonKrylovSolverContext& solverContext,
void solve(std::vector<double>& xd, int nnz, powsybl::NewtonKrylovSolverContext& solverContext,
int maxIter, bool lineSearch, int level) {
SUNContext sunCtx;
int error = SUNContext_Create(nullptr, &sunCtx);
Expand All @@ -195,7 +157,7 @@ void solve(std::vector<double>& xd, int nnz, int* ap, int* ai, double* ax, powsy
int n = xd.size();
N_Vector x = N_VMake_Serial(n, xd.data(), sunCtx);

SUNMatrix j = createSparseMatrix(n, nnz, ap, ai, ax, sunCtx);
SUNMatrix j = SUNSparseMatrix(n, n, nnz, CSC_MAT, sunCtx);

SUNLinearSolver ls = SUNLinSol_KLU(x, j, sunCtx);
if (!ls) {
Expand Down Expand Up @@ -267,7 +229,7 @@ void solve(std::vector<double>& xd, int nnz, int* ap, int* ai, double* ax, powsy
throw std::runtime_error("SUNLinSolFree_KLU error " + std::to_string(error));
}

destroySparseMatrix(j);
SUNMatDestroy(j);

N_VDestroy_Serial(x);

Expand All @@ -283,20 +245,14 @@ void solve(std::vector<double>& xd, int nnz, int* ap, int* ai, double* ax, powsy
extern "C" {
#endif

JNIEXPORT void JNICALL Java_com_powsybl_math_solver_NewtonKrylovSolver_solve(JNIEnv * env, jobject, jdoubleArray jx,
jintArray j_ap, jintArray j_ai, jdoubleArray j_ax, jobject jSolverContext) {
JNIEXPORT void JNICALL Java_com_powsybl_math_solver_NewtonKrylovSolver_solve(JNIEnv * env, jobject, jdoubleArray jx, jint nnz, jobject jSolverContext) {
try {
std::vector<double> x = powsybl::jni::copyFromJava(env, jx);
int n = x.size();
powsybl::jni::IntArray ap(env, j_ap);
powsybl::jni::IntArray ai(env, j_ai);
powsybl::jni::DoubleArray ax(env, j_ax);
int nnz = ax.length();
powsybl::NewtonKrylovSolverContext solverContext(env, jSolverContext);
int maxIter = 200;
bool lineSearch = false;
int level = 2;
powsybl::solve(x, nnz, ap.get(), ai.get(), ax.get(), solverContext, maxIter, lineSearch, level);
powsybl::solve(x, nnz, solverContext, maxIter, lineSearch, level);
powsybl::jni::copyToJava(env, jx, x);
} catch (const std::exception& e) {
powsybl::jni::throwMatrixException(env, e.what());
Expand Down
4 changes: 2 additions & 2 deletions src/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ class ComPowsyblMathSolverNewtonKrylovSolverContext : public JniWrapper<jobject>

void logInfo(const std::string& module, const std::string& function, const std::string& message);

void updateFunc(double* x, double* f, int length);
void updateFunc(double* x, double* f, int n);

void updateJac(double* x, int length);
void updateJac(double* x, int n, int* ap, int* ai, double* ax, int nnz);

private:
static jclass _cls;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ public class NewtonKrylovSolver {
MathNative.init();
}

public native void solve(double[] x, int[] ap, int[] ai, double[] ax, NewtonKrylovSolverContext context);
public native void solve(double[] x, int nnz, NewtonKrylovSolverContext context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void updateFunc(double[] x, double[] f) {
f[1] = 0.01 + v2 * 0.1 * (-Math.cos(ph2) + v2);
}

public void updateJac(double[] x) {
public void updateJac(double[] x, int[] ap, int[] ai, double[] ax) {
double v2 = x[0];
double ph2 = x[1];
double dp2dv2 = 0.1 * Math.sin(ph2);
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/com/powsybl/mathnative/SolverTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void test() {
int[] ai = new int[nnz];
double[] ax = new double[nnz];
NewtonKrylovSolver solver = new NewtonKrylovSolver();
solver.solve(x, ap, ai, ax, new NewtonKrylovSolverContext(ap, ai, ax));
solver.solve(x, nnz, new NewtonKrylovSolverContext(ap, ai, ax));
assertArrayEquals(new double[] {0.85545, -0.235992}, x, 1e-6);
}
}

0 comments on commit 22a848c

Please sign in to comment.