From b999102007b5b278482398843cd290535d535af1 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 11 Sep 2024 21:47:24 -0700 Subject: [PATCH] Set a default for num_layers based on the model type - num_layers=2 seems good for the character classifier --- stanza/models/mwt_expander.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index c9e11bdac4..f6274211e4 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -53,7 +53,7 @@ def build_argparse(): parser.add_argument('--hidden_dim', type=int, default=100) parser.add_argument('--emb_dim', type=int, default=50) - parser.add_argument('--num_layers', type=int, default=1) + parser.add_argument('--num_layers', type=int, default=None, help='Number of layers in model encoder. Defaults to 1 for seq2seq, 2 for classifier') parser.add_argument('--emb_dropout', type=float, default=0.5) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--max_dec_len', type=int, default=50) @@ -153,6 +153,12 @@ def train(args): args['vocab_size'] = vocab.size dev_batch = BinaryDataLoader(dev_doc, args['batch_size'], args, vocab=vocab, evaluation=True) + if args['num_layers'] is None: + if args['force_exact_pieces']: + args['num_layers'] = 2 + else: + args['num_layers'] = 1 + # train a dictionary-based MWT expander trainer = Trainer(args=args, vocab=vocab, device=args['device']) logger.info("Training dictionary-based MWT expander...")