diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 3edc4b63b94..80ecdb17023 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -126,6 +127,9 @@ def main(): paddle.set_device(args.device) + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" + # frontend frontend = get_frontend( lang=args.lang, @@ -135,7 +139,7 @@ def main(): # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt, @@ -148,7 +152,7 @@ def main(): # voc_predictor voc_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.voc + ".pdmodel", + model_file=args.voc + model_suffix, params_file=args.voc + ".pdiparams", device=args.device, use_trt=args.use_trt, diff --git a/paddlespeech/t2s/exps/jets/inference.py b/paddlespeech/t2s/exps/jets/inference.py index 4f6882eda2b..d83510a9403 100644 --- a/paddlespeech/t2s/exps/jets/inference.py +++ b/paddlespeech/t2s/exps/jets/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -96,13 +97,16 @@ def main(): paddle.set_device(args.device) + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" + # frontend frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt, diff --git a/paddlespeech/t2s/exps/vits/inference.py b/paddlespeech/t2s/exps/vits/inference.py index 08c1ac566db..ba9e2e8da4b 100644 --- a/paddlespeech/t2s/exps/vits/inference.py +++ b/paddlespeech/t2s/exps/vits/inference.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from pathlib import Path import paddle @@ -96,13 +97,16 @@ def main(): paddle.set_device(args.device) + # set model_suffix + model_suffix = ".json" if os.path.exists(args.am + ".json") else ".pdmodel" + # frontend frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) # am_predictor am_predictor = get_predictor( model_dir=args.inference_dir, - model_file=args.am + ".pdmodel", + model_file=args.am + model_suffix, params_file=args.am + ".pdiparams", device=args.device, use_trt=args.use_trt,