Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions spacy_llm/tasks/entity_linker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
) = self._find_entity_candidates(docs)
# Reset shard-wise candidate info. Will be set for each shard individually in _get_prompt_data(). We cannot
# update it here, as we don't know yet how the shards will look like.
self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._ents_cands_by_shard = [[]] * len(self._ents_cands_by_doc)
self._has_ent_cands_by_shard = [[]] * len(self._ents_cands_by_doc)
self._n_shards = None
return [
EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i])
Expand Down Expand Up @@ -141,8 +141,8 @@ def _get_prompt_data(
# shards. In this case we have to reset task state as well.
if n_shards != self._n_shards:
self._n_shards = n_shards
self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._ents_cands_by_shard = [[]] * len(self._ents_cands_by_doc)
self._has_ent_cands_by_shard = [[]] * len(self._ents_cands_by_doc)

# It's not ideal that we have to run candidate selection again here - but due to (1) us wanting to know whether
# all entities have candidates before sharding and, more importantly, (2) some entities maybe being split up in
Expand Down
25 changes: 25 additions & 0 deletions spacy_llm/tests/tasks/test_entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,28 @@ def test_init_with_code():
nlp.add_pipe("llm_entitylinker")
with pytest.raises(ValueError, match="candidate_selector has to be provided"):
nlp.initialize()


def test_entity_linker_on_splitted_chunks(zeroshot_cfg_string, tmp_path):
config = Config().from_str(
zeroshot_cfg_string,
overrides={
"paths.el_nlp": str(tmp_path),
"paths.el_kb": str(tmp_path / "entity_linker" / "kb"),
"paths.el_desc": str(tmp_path / "desc.csv"),
},
)
build_el_pipeline(nlp_path=tmp_path, desc_path=tmp_path / "desc.csv")
nlp = assemble_from_config(config)
nlp_ner = spacy.load("en_core_web_md")
docs = [nlp_ner(text) for text in [
'Alice goes to Boston to see the Boston Celtics game.',
'Alice goes to New York to see the New York Knicks game.',
'I went to see Boston in concert yesterday',
'Thibeau Courtois plays for the Red Devils in New York',
]]
docs = [doc for doc in nlp.pipe(docs, batch_size=50)]
data = [[(ent.text, ent.label_, ent.kb_id_) for ent in doc.ents] for doc in docs]
assert len(docs) == 4
assert docs[0].ents[1].text == 'Boston'
assert docs[0].ents[1].kb_id_ == 'Q100'