Skip to content

Commit 157292c

Browse files
committed
Add back jsontags
1 parent 524177f commit 157292c

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

nltk/tag/perceptron.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from os.path import join as path_join
1616
from tempfile import gettempdir
1717

18+
from nltk import jsontags
1819
from nltk.data import find, load
1920
from nltk.tag.api import TaggerI
2021

@@ -39,13 +40,16 @@ def lang_jsons(lang="eng"):
3940
TAGGER_JSONS = {lang: lang_jsons(lang) for lang in ["eng", "rus", "xxx"]}
4041

4142

43+
@jsontags.register_tag
4244
class AveragedPerceptron:
4345
"""An averaged perceptron, as implemented by Matthew Honnibal.
4446
4547
See more implementation details here:
4648
https://explosion.ai/blog/part-of-speech-pos-tagger-in-python
4749
"""
4850

51+
json_tag = "nltk.tag.perceptron.AveragedPerceptron"
52+
4953
def __init__(self, weights=None):
5054
# Each feature gets its own weight vector, so weights is a dict-of-dicts
5155
self.weights = weights if weights else {}
@@ -122,7 +126,15 @@ def load(self, path):
122126
with open(path) as fin:
123127
self.weights = json.load(fin)
124128

129+
def encode_json_obj(self):
130+
return self.weights
131+
132+
@classmethod
133+
def decode_json_obj(cls, obj):
134+
return cls(obj)
125135

136+
137+
@jsontags.register_tag
126138
class PerceptronTagger(TaggerI):
127139
"""
128140
Greedy Averaged Perceptron tagger, as implemented by Matthew Honnibal.
@@ -152,6 +164,8 @@ class PerceptronTagger(TaggerI):
152164
[('The', 'DT'), ('red', 'JJ'), ('cat', 'NN')]
153165
"""
154166

167+
json_tag = "nltk.tag.perceptron.PerceptronTagger"
168+
155169
START = ["-START-", "-START2-"]
156170
END = ["-END-", "-END2-"]
157171

@@ -257,7 +271,7 @@ def save_to_json(self, lang="xxx", loc=None):
257271
with open(path_join(loc, jsons["tagdict"]), "w") as fout:
258272
json.dump(self.tagdict, fout)
259273
with open(path_join(loc, jsons["classes"]), "w") as fout:
260-
json.dump(list(self.model.classes), fout)
274+
json.dump(list(self.classes), fout)
261275

262276
def load_from_json(self, lang="eng", loc=None):
263277
# Automatically find path to the tagger if location is not specified.
@@ -269,7 +283,19 @@ def load_from_json(self, lang="eng", loc=None):
269283
with open(loc + jsons["tagdict"]) as fin:
270284
self.tagdict = json.load(fin)
271285
with open(loc + jsons["classes"]) as fin:
272-
self.model.classes = set(json.load(fin))
286+
self.classes = set(json.load(fin))
287+
self.model.classes = self.classes
288+
289+
def encode_json_obj(self):
290+
return self.model.weights, self.tagdict, list(self.classes)
291+
292+
@classmethod
293+
def decode_json_obj(cls, obj):
294+
tagger = cls(load=False)
295+
tagger.model.weights, tagger.tagdict, tagger.classes = obj
296+
tagger.classes = set(tagger.classes)
297+
tagger.model.classes = tagger.classes
298+
return tagger
273299

274300
def normalize(self, word):
275301
"""

0 commit comments

Comments
 (0)