diff --git a/csrc/ops.hip b/csrc/ops.hip index 157e84629..5c0688b91 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -576,6 +576,7 @@ template int igemmlt(hipblasLtHandl if (returnedAlgoCount == 0) { has_error = 1; + printf("Error: Matmul Algo Heurisitic didn't return algorithms\n"); } else { @@ -614,18 +615,26 @@ template int igemmlt(hipblasLtHandl heuristicResult, &returnedAlgoCount)); - if(!SCALE_ROWS) + if (returnedAlgoCount == 0) { - float alpha = 1.0f, beta = 0.0f; - - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + has_error = 1; + printf("Error: Matmul Algo Heurisitic didn't return algorithms\n"); } else { - //has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, hipblasLt_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); - float beta = 0.0f; - - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + if(!SCALE_ROWS) + { + float alpha = 1.0f, beta = 0.0f; + + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + } + else + { + //has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, hipblasLt_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + float beta = 0.0f; + + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + } } } @@ -635,7 +644,7 @@ template int igemmlt(hipblasLtHandl if (Adesc) has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(Adesc)); if (matmulDesc) has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc)); if(has_error == 1) - printf("error detected"); + printf("error detected\n"); return has_error; #endif // NO_HIPBLASLT