Skip to content

Commit 4b4ceff

Browse files
committed
major refactor of Comprehend code to now use Comprehend__Reponse_Utils
1 parent 25fde93 commit 4b4ceff

File tree

6 files changed

+220
-228
lines changed

6 files changed

+220
-228
lines changed
Lines changed: 37 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import List
2+
from botocore.client import BaseClient
23
from osbot_utils.helpers.duration.decorators.capture_duration import capture_duration
34
from osbot_utils.type_safe.type_safe_core.decorators.type_safe import type_safe
4-
from osbot_aws.aws.comprehend.Comprehend__Base import Comprehend__Base
5+
from osbot_utils.type_safe.Type_Safe import Type_Safe
6+
from osbot_aws.aws.comprehend.Comprehend__Response_Utils import Comprehend__Response_Utils
57
from osbot_aws.aws.comprehend.schemas.batch.Schema__Comprehend__Batch__Detect_Sentiment import Schema__Comprehend__Batch__Detect_Sentiment
68
from osbot_aws.aws.comprehend.schemas.batch.Schema__Comprehend__Batch__Detect_Key_Phrases import Schema__Comprehend__Batch__Detect_Key_Phrases
79
from osbot_aws.aws.comprehend.schemas.batch.Schema__Comprehend__Batch__Detect_Entities import Schema__Comprehend__Batch__Detect_Entities
@@ -11,185 +13,70 @@
1113
from osbot_aws.aws.comprehend.schemas.safe_str.Safe_Str__AWS_Comprehend__Text import Safe_Str__Comprehend__Text
1214

1315

14-
class Comprehend__Batch(Comprehend__Base):
1516

16-
17-
# ============================================================================
18-
# BATCH SENTIMENT DETECTION
19-
# ============================================================================
17+
class Comprehend__Batch(Type_Safe):
18+
client : BaseClient
19+
response_utils : Comprehend__Response_Utils
2020

2121
@type_safe
22-
def batch_detect_sentiment(self, text_list : List[Safe_Str__Comprehend__Text] ,
23-
language_code : Enum__Comprehend__Language_Code = Enum__Comprehend__Language_Code.ENGLISH ,
22+
def batch_detect_sentiment(self, text_list : List[Safe_Str__Comprehend__Text] ,
23+
language_code : Enum__Comprehend__Language_Code = Enum__Comprehend__Language_Code.ENGLISH ,
2424
) -> Schema__Comprehend__Batch__Detect_Sentiment:
25-
2625
with capture_duration() as duration:
27-
result = self.client().batch_detect_sentiment(TextList = text_list ,
28-
LanguageCode = language_code.value)
29-
30-
# Process successful results
31-
result_list = []
32-
for item in result.get('ResultList', []):
33-
sentiment_score = item.get('SentimentScore', {})
34-
result_list.append(dict(index = item.get('Index') ,
35-
sentiment = item.get('Sentiment') ,
36-
score = dict(mixed = sentiment_score.get('Mixed' ) ,
37-
neutral = sentiment_score.get('Neutral' ) ,
38-
negative = sentiment_score.get('Negative') ,
39-
positive = sentiment_score.get('Positive'))))
40-
41-
# Process errors
42-
error_list = []
43-
for error in result.get('ErrorList', []):
44-
error_list.append(dict(index = error.get('Index') ,
45-
error_code = error.get('ErrorCode') ,
46-
error_message = error.get('ErrorMessage')))
26+
result = self.client.batch_detect_sentiment(TextList = text_list ,
27+
LanguageCode = language_code.value)
4728

48-
return Schema__Comprehend__Batch__Detect_Sentiment(result_list = result_list ,
49-
error_list = error_list ,
50-
duration = duration.seconds)
51-
52-
# ============================================================================
53-
# BATCH ENTITY DETECTION
54-
# ============================================================================
29+
return Schema__Comprehend__Batch__Detect_Sentiment(result_list = [self.response_utils.process_batch_sentiment_result(item)
30+
for item in result.get('ResultList', [])] ,
31+
error_list = self.response_utils.process_batch_errors(result) ,
32+
duration = duration.seconds)
5533

5634
@type_safe
5735
def batch_detect_entities(self, text_list : List[Safe_Str__Comprehend__Text] ,
5836
language_code : Enum__Comprehend__Language_Code = Enum__Comprehend__Language_Code.ENGLISH ,
5937
) -> Schema__Comprehend__Batch__Detect_Entities:
6038
with capture_duration() as duration:
61-
result = self.client().batch_detect_entities(TextList = text_list ,
62-
LanguageCode = language_code.value)
63-
64-
# Process successful results
65-
result_list = []
66-
for item in result.get('ResultList', []):
67-
entities_list = []
68-
for entity in item.get('Entities', []):
69-
entities_list.append(dict(text = entity.get('Text' ) ,
70-
type = entity.get('Type' ) ,
71-
score = entity.get('Score' ) ,
72-
begin_offset = entity.get('BeginOffset') ,
73-
end_offset = entity.get('EndOffset' )))
74-
75-
result_list.append(dict(index = item.get('Index') ,
76-
entities = entities_list ))
77-
78-
# Process errors
79-
error_list = []
80-
for error in result.get('ErrorList', []):
81-
error_list.append(dict(index = error.get('Index') ,
82-
error_code = error.get('ErrorCode') ,
83-
error_message = error.get('ErrorMessage')))
84-
85-
return Schema__Comprehend__Batch__Detect_Entities(result_list = result_list ,
86-
error_list = error_list ,
87-
duration = duration.seconds)
39+
result = self.client.batch_detect_entities(TextList = text_list ,
40+
LanguageCode = language_code.value)
8841

89-
# ============================================================================
90-
# BATCH KEY PHRASE DETECTION
91-
# ============================================================================
42+
return Schema__Comprehend__Batch__Detect_Entities(result_list = [self.response_utils.process_batch_entities_result(item)
43+
for item in result.get('ResultList', [])] ,
44+
error_list = self.response_utils.process_batch_errors(result) ,
45+
duration = duration.seconds)
9246

9347
@type_safe
9448
def batch_detect_key_phrases(self, text_list : List[Safe_Str__Comprehend__Text] ,
9549
language_code : Enum__Comprehend__Language_Code = Enum__Comprehend__Language_Code.ENGLISH ,
9650
) -> Schema__Comprehend__Batch__Detect_Key_Phrases:
9751
with capture_duration() as duration:
98-
result = self.client().batch_detect_key_phrases(TextList = text_list ,
99-
LanguageCode = language_code.value)
100-
101-
# Process successful results
102-
result_list = []
103-
for item in result.get('ResultList', []):
104-
key_phrases_list = []
105-
for phrase in item.get('KeyPhrases', []):
106-
key_phrases_list.append(dict(text = phrase.get('Text' ) ,
107-
score = phrase.get('Score' ) ,
108-
begin_offset = phrase.get('BeginOffset') ,
109-
end_offset = phrase.get('EndOffset' )))
110-
111-
result_list.append(dict(index = item.get('Index') ,
112-
key_phrases = key_phrases_list ))
113-
114-
# Process errors
115-
error_list = []
116-
for error in result.get('ErrorList', []):
117-
error_list.append(dict(index = error.get('Index') ,
118-
error_code = error.get('ErrorCode') ,
119-
error_message = error.get('ErrorMessage')))
120-
121-
return Schema__Comprehend__Batch__Detect_Key_Phrases(result_list = result_list ,
122-
error_list = error_list ,
123-
duration = duration.seconds)
52+
result = self.client.batch_detect_key_phrases(TextList = text_list ,
53+
LanguageCode = language_code.value)
12454

125-
# ============================================================================
126-
# BATCH DOMINANT LANGUAGE DETECTION
127-
# ============================================================================
55+
return Schema__Comprehend__Batch__Detect_Key_Phrases(result_list = [self.response_utils.process_batch_key_phrases_result(item)
56+
for item in result.get('ResultList', [])] ,
57+
error_list = self.response_utils.process_batch_errors(result) ,
58+
duration = duration.seconds)
12859

12960
@type_safe
13061
def batch_detect_dominant_language(self, text_list : List[Safe_Str__Comprehend__Text]
13162
) -> Schema__Comprehend__Batch__Detect_Dominant_Language:
13263
with capture_duration() as duration:
133-
result = self.client().batch_detect_dominant_language(TextList = text_list)
134-
135-
# Process successful results
136-
result_list = []
137-
for item in result.get('ResultList', []):
138-
languages_list = []
139-
for language in item.get('Languages', []):
140-
languages_list.append(dict(language_code = language.get('LanguageCode') ,
141-
score = language.get('Score' )))
142-
143-
result_list.append(dict(index = item.get('Index') ,
144-
languages = languages_list ))
64+
result = self.client.batch_detect_dominant_language(TextList = text_list)
14565

146-
# Process errors
147-
error_list = []
148-
for error in result.get('ErrorList', []):
149-
error_list.append(dict(index = error.get('Index') ,
150-
error_code = error.get('ErrorCode') ,
151-
error_message = error.get('ErrorMessage')))
152-
153-
return Schema__Comprehend__Batch__Detect_Dominant_Language(result_list = result_list ,
154-
error_list = error_list ,
155-
duration = duration.seconds)
156-
157-
# ============================================================================
158-
# BATCH SYNTAX DETECTION
159-
# ============================================================================
66+
return Schema__Comprehend__Batch__Detect_Dominant_Language(result_list = [self.response_utils.process_batch_dominant_language_result(item)
67+
for item in result.get('ResultList', [])] ,
68+
error_list = self.response_utils.process_batch_errors(result) ,
69+
duration = duration.seconds)
16070

16171
@type_safe
16272
def batch_detect_syntax(self, text_list : List[Safe_Str__Comprehend__Text] ,
16373
language_code : Enum__Comprehend__Language_Code = Enum__Comprehend__Language_Code.ENGLISH ,
16474
) -> Schema__Comprehend__Batch__Detect_Syntax:
165-
16675
with capture_duration() as duration:
167-
result = self.client().batch_detect_syntax(TextList = text_list ,
168-
LanguageCode = language_code.value)
169-
170-
# Process successful results
171-
result_list = []
172-
for item in result.get('ResultList', []):
173-
tokens_list = []
174-
for token in item.get('SyntaxTokens', []):
175-
pos = token.get('PartOfSpeech', {})
176-
tokens_list.append(dict(text = token.get('Text' ) ,
177-
token_id = token.get('TokenId' ) ,
178-
begin_offset = token.get('BeginOffset') ,
179-
end_offset = token.get('EndOffset' ) ,
180-
part_of_speech = dict(tag = pos.get('Tag' ) ,
181-
score = pos.get('Score'))))
182-
183-
result_list.append(dict(index = item.get('Index') ,
184-
syntax_tokens = tokens_list ))
185-
186-
# Process errors
187-
error_list = []
188-
for error in result.get('ErrorList', []):
189-
error_list.append(dict(index = error.get('Index') ,
190-
error_code = error.get('ErrorCode') ,
191-
error_message = error.get('ErrorMessage')))
76+
result = self.client.batch_detect_syntax(TextList = text_list ,
77+
LanguageCode = language_code.value)
19278

193-
return Schema__Comprehend__Batch__Detect_Syntax(result_list = result_list ,
194-
error_list = error_list ,
195-
duration = duration.seconds)
79+
return Schema__Comprehend__Batch__Detect_Syntax(result_list = [self.response_utils.process_batch_syntax_result(item)
80+
for item in result.get('ResultList', [])] ,
81+
error_list = self.response_utils.process_batch_errors(result) ,
82+
duration = duration.seconds)

0 commit comments

Comments
 (0)