From 4bc28d25a3bce122da79aca912f51d4aa2949288 Mon Sep 17 00:00:00 2001 From: liyulingyue <852433440@qq.com> Date: Sun, 1 Dec 2024 23:19:39 +0800 Subject: [PATCH 1/2] add logical to enable pir infer --- paddlespeech/t2s/exps/inference.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 3edc4b63b94..99d147958c1 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -126,6 +127,12 @@ def main(): paddle.set_device(args.device) + # model_suffix + if os.path.exists(args.am + ".json"): + model_suffix = ".json" + else: + model_suffix = ".pdmodel" + # frontend frontend = get_frontend( lang=args.lang, @@ -135,7 +142,7 @@ def main(): # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt, @@ -148,7 +155,7 @@ def main(): # voc_predictor voc_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.voc + ".pdmodel", + model_file=args.voc + model_suffix, params_file=args.voc + ".pdiparams", device=args.device, use_trt=args.use_trt, From 1512b3b06598f8b6e593bf7f4a7ae79d8c77c067 Mon Sep 17 00:00:00 2001 From: liyulingyue <852433440@qq.com> Date: Sun, 1 Dec 2024 23:39:54 +0800 Subject: [PATCH 2/2] add logical to enable pir infer --- paddlespeech/t2s/exps/inference.py | 7 ++----- paddlespeech/t2s/exps/jets/inference.py | 6 +++++- paddlespeech/t2s/exps/vits/inference.py | 6 +++++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 99d147958c1..80ecdb17023 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -127,11 +127,8 @@ def main(): paddle.set_device(args.device) - # model_suffix - if os.path.exists(args.am + ".json"): - model_suffix = ".json" - else: - model_suffix = ".pdmodel" + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" # frontend frontend = get_frontend( diff --git a/paddlespeech/t2s/exps/jets/inference.py b/paddlespeech/t2s/exps/jets/inference.py index 4f6882eda2b..d83510a9403 100644 --- a/paddlespeech/t2s/exps/jets/inference.py +++ b/paddlespeech/t2s/exps/jets/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -96,13 +97,16 @@ def main(): paddle.set_device(args.device) + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" + # frontend frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt, diff --git a/paddlespeech/t2s/exps/vits/inference.py b/paddlespeech/t2s/exps/vits/inference.py index 08c1ac566db..ba9e2e8da4b 100644 --- a/paddlespeech/t2s/exps/vits/inference.py +++ b/paddlespeech/t2s/exps/vits/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -96,13 +97,16 @@ def main(): paddle.set_device(args.device) + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" + # frontend frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt,