Skip to content

Commit

Permalink
Merge pull request #144 from calpoly-csai/listprops
Browse files Browse the repository at this point in the history
QA enhancements
  • Loading branch information
cameron-toy authored May 17, 2020
2 parents cc91cbd + 94eaada commit d329427
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 33 deletions.
109 changes: 80 additions & 29 deletions QA.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from Entity.Profs import Profs
from Entity.Clubs import Clubs
from Entity.Sections import Sections
from Entity.ProfessorSectionView import ProfessorSectionView
from database_wrapper import NimbusMySQLAlchemy
import itertools

Expand All @@ -21,6 +22,7 @@
"COURSE": Courses,
"SECRET_HIDEOUT": Locations,
"SECTION": Sections,
"PROF_SECTION": ProfessorSectionView
}


Expand Down Expand Up @@ -80,9 +82,15 @@ def string_sub(a_format):
return functools.partial(_string_sub, a_format)


def _get_property(prop: str, extracted_info: Extracted_Vars, db: NimbusMySQLAlchemy):
def _get_property(prop: str,
table: str,
extracted_info: Extracted_Vars,
db: NimbusMySQLAlchemy):
ent_string = extracted_info["normalized entity"]
ent = tag_lookup[extracted_info["tag"]]
if table is None:
ent = tag_lookup[extracted_info["tag"]]
else:
ent = tag_lookup[table]
try:
value = db.get_property_from_entity(
prop=prop, entity=ent, identifier=ent_string
Expand All @@ -93,8 +101,35 @@ def _get_property(prop: str, extracted_info: Extracted_Vars, db: NimbusMySQLAlch
return {f"db_{prop}": value}


def get_property(prop: str):
return functools.partial(_get_property, prop)
def get_property(prop: str, table: str = None):
return functools.partial(_get_property, prop, table)


def _get_property_list(prop: str,
joiner: str,
table: str,
extracted_info: Extracted_Vars,
db: NimbusMySQLAlchemy):
ent_string = extracted_info["normalized entity"]
if table is None:
ent = tag_lookup[extracted_info["tag"]]
else:
ent = tag_lookup[table]

try:
values = db._get_property_from_entity(
prop=prop, entity=ent, identifier=ent_string
)
except IndexError:
return {f"db_{prop}": None}
else:
exact_matches = get_all_exact_matches(values)
return {f"db_{prop}": _grammatical_join(exact_matches)}

def get_property_list(prop: str,
joiner: str,
table: str = None):
return functools.partial(_get_property_list, prop, joiner, table)


def _generic_answer_formatter(
Expand Down Expand Up @@ -207,39 +242,55 @@ def chain_db_access(fns: List[DB_Query]) -> DB_Query:
return functools.partial(_chain_db_access, fns)


def get_all_exact_matches(matches):
exact = matches[-1][1]
exact_matches = []
for match in reversed(matches):
if match[1] == exact:
exact_matches.append(match[2])
return exact_matches


def generate_qa_pairs(qa_pairs: Tuple[str, str], db: NimbusMySQLAlchemy):
text_in_brackets = r"\[[^\[\]]*\]"
qa_objs = []
for pair in qa_pairs:
q = pair[0]
a = pair[1]
q, a = pair
db_access_fns = []
# Find all bracketed tokens ([PROF], [COURSE..units], etc)
matches = re.findall(text_in_brackets, a)
ents = []
for match in matches:
# If match is a property
if ".." in match:
ent, prop = match[1:-1].split("..", 1)
db_access_fns.append(get_property(prop))
# "db" prefix is used to disambiguate database and entity data
# in _string_sub and _generic_answer_formatter. See above.
a = a.replace(match, "{db_" + prop + "}")
# If match is an entity
tokens = a.split()
for i, token in enumerate(tokens):
# I get errors if I don't cast token to a string here, even though str.split() should
# return a list of strings
match = re.match(r"\[(.*?)\]", str(token))
if not match:
continue
else:
ents.append(match)
if len(ents) == 1:
a = a.replace(ents[0], "{ex}")
else:
for ent in ents:
# "ex" prefix is added for the same reason as above.
# Not necessary for current _string_sub function, but useful
# for when we extract multiple variables
a = a.replace(ent, "{ex_" + ent[1:-1] + "}")
subtokens = match.group(1).split("..")
# Match is an entity
if len(subtokens) == 1:
tokens[i] = "{ex}"
# Match is a single-item property
elif len(subtokens) == 2:
ent, prop = subtokens
db_access_fns.append(get_property(prop))
tokens[i] = "{db_" + prop + "}"
elif len(subtokens) == 3:
ent, prop, third = subtokens
if third in tag_lookup:
# third is a table name
db_access_fns.append(get_property(prop, third))
else:
# third is the string used to join the last two of a list of items
db_access_fns.append(get_property_list(prop, third))
tokens[i] = "{db_" + prop + "}"
elif len(subtokens) == 4:
ent, prop, table, joiner = subtokens
db_access_fns.append(get_property_list(prop, joiner, table))
tokens[i] = "{db_" + prop + "}"

o = QA(
q_format=q,
db_query=chain_db_access(db_access_fns),
format_answer=string_sub(a),
format_answer=string_sub(" ".join(tokens)),
db=db,
)
qa_objs.append(o)
Expand Down
13 changes: 11 additions & 2 deletions database_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def partial_fuzzy_match(self, tag_value, identifier):
def full_fuzzy_match(self, tag_value, identifier):
return fuzz.ratio(tag_value, identifier)

def get_property_from_entity(
def _get_property_from_entity(
self,
prop: str,
entity: UNION_ENTITIES,
Expand Down Expand Up @@ -555,7 +555,16 @@ def get_property_from_entity(
return None

sorted_results = sorted(results, key=lambda pair: pair[0])
return sorted_results[-1][2]
return sorted_results

def get_property_from_entity(self,
prop: str,
entity: UNION_ENTITIES,
identifier: str,
tag_column_map: dict = default_tag_column_dict):

props = self._get_property_from_entity(prop, entity, identifier, tag_column_map)
return props[-1][2]

def get_course_properties(
self, department: str, course_num: Union[str, int]
Expand Down
6 changes: 4 additions & 2 deletions nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def predict_question(self, question):
ve = VariableExtractor()
db = NimbusMySQLAlchemy()
qa_pairs = db.get_all_answerable_pairs()
qa_pairs.append(("What sections does [PROF] teach?",
"[PROF] teaches [PROF..section_name..PROF_SECTION..and]."))
qa_dict = create_qa_mapping(generate_qa_pairs(qa_pairs, db))
extracted = ve.extract_variables("How do I zoom Dr. Khosmood?")
print(qa_dict["How do I zoom [PROF]?"].answer(extracted))
extracted = ve.extract_variables("What sections does Khosmood teach?")
print(qa_dict["What sections does [PROF] teach?"].answer(extracted))

0 comments on commit d329427

Please sign in to comment.