Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions cellsem_agent/graphs/cxg_annotate/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# CxG experiments

We have two experiments related with CxG annotation:
- **Old one**: `cellsem_agent/graphs/cxg_annotate/cxg_annotate_graph.py`
This experiments uses a set of manually downloaded papers related with gut, retina etc.
- **New one**: `cellsem_agent/graphs/cxg_annotate/cxg_annotate_graph_v2.py`
This is the new experiment. An example input file is here: [cellsem_agent/graphs/cxg_annotate/resources/ac8619d0-4fff-4296-913a-819d1e361ba0_cxg_dataset_unique.tsv](cellsem_agent/graphs/cxg_annotate/resources/ac8619d0-4fff-4296-913a-819d1e361ba0_cxg_dataset_unique.tsv)

## New experiment workflow

`cxg_annotate_graph_v2.py` has a main function that runs the whole experiment graph. Experiment graph is as follows:

```mermaid
---
title: validation_graph
---
stateDiagram-v2
PrepareData --> GetFullNames
GetFullNames --> GetGroundings
GetGroundings --> [*]
```

- **PrepareData**: Prepares the data for the experiment. It reads the input TSV file, converts into the desired format and stores them into `ctx.state` which is a shared state among the workflow tasks. This step also downloads the required publications.
- **GetFullNames**: Uses ChatGPT to get the full names of the cell types based on the publication full text.
- **GetGroundings**: Uses Annotator Agent (`cellsem_agent/agents/annotator/annotator_agent.py`) to get the ontology groundings for the cell types.

Script has a `main` function that runs the whole experiment graph. You can run the script directly.

An environment file is needed at the project root folder named `.env` with the following variables (`cellsem-agent/.env`):
```
OPENAI_API_KEY=
```

### Outputs:

Then the pipeline is run, two main outputs are generated:
- `cellsem_agent/graphs/cxg_annotate/resources/groundings.tsv`: The main annotation results. `grounding_cl_id` and `grounding_cl_label` are found by the agent, `cl_id` and `cl_label` are the truth values from the input file.
- `cellsem_agent/graphs/cxg_annotate/resources/cell_type_annotations_un_filtered.tsv`: A not important intermediate file that contains all the cell type annotations provided by the agent. Agent uses fullname and abbreviation to find the groundings. This file contains all the groundings found by the agent in case you want to optimize the prioritization logic. Currently Full name has the higher priority and it is returned as the first grounding to be used.

### Statistics:

A manuel step to calculate the metrics is needed after the experiment is run. Metrics script is here: `cellsem_agent/graphs/nlm_annotate/grounding_statistics.py`. Update script to point to the correct `groundings.tsv` file and run it.
This scripts prints something like this:

```
Truth table: TP=19, FP=13, FN=0, TN=0
Precision: 0.594
Recall: 1.000
F1 score: 0.745
```

### Running in test mode:

If you set `IS_TEST_MODE=True`, the experiment will run in test mode. In this mode, only a small subset of data (`TEST_ANNOTATIONS_COUNT=4`) is processed to allow for quick testing and debugging. This is useful for development and troubleshooting.

Set `IS_TEST_MODE=False`, to run the full experiment.

### Beware of caching:

The experiment uses caching to store intermediate results and avoid redundant computations and avoid expensive ChatGPT calls. If you make changes to the code or input data, you may need to clear the cache to ensure that the experiment runs with the latest information.

Here are the cache directories used in the experiment:
- `cellsem_agent/graphs/cxg_annotate/resources/publications`: Publications downloaded in the `PrepareData` step is stored here in format: `DOI_10_1038_s41586-018-0698-6.txt`
- `cellsem_agent/graphs/cxg_annotate/resources/expansions`: Caching of the `GetFullNames` step. Example cache file name: `DOI_10_1038_s41586-018-0698-6_batch_0.json`
- `cellsem_agent/graphs/cxg_annotate/resources/cache`: Caching of the `GetGroundings` step. Example cache file name: `groundings_batch_0.json`

Delete these folders as needed to clear the cache and run a fresh but $$$ experiment. The folders should be automatically created when script is run.
318 changes: 318 additions & 0 deletions cellsem_agent/graphs/cxg_annotate/cxg_annotate_graph_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
import asyncio
import os.path
import json
import pandas as pd

from dotenv import load_dotenv
from pydantic_graph import BaseNode, End, Graph, GraphRunContext

from cellsem_agent.agents.annotator.annotator_agent import annotator_agent
from cellsem_agent.agents.paper_celltype.paper_celltype_agent import celltype_agent, CellTypeEntry
from cellsem_agent.agents.annotator.annotator_agent import TextAnnotation
from cellsem_agent.utils.pubmed_utils import get_doi_text

from dataclasses import dataclass
import logfire
import logging

cxg_annotate_logger = logging.getLogger(__name__)
cxg_annotate_logger.setLevel(logging.INFO)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console.setFormatter(formatter)
cxg_annotate_logger.addHandler(console)

cxg_annotate_logger.propagate = True
logfire.configure()

ANNOTATIONS_BATCH_SIZE = 5

IS_TEST_MODE = False
TEST_ANNOTATIONS_COUNT = 4 # Number of annotations to process in test mode

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
RESOURCES_DIR = os.path.join(CURRENT_DIR, "resources")
PUBLICATIONS_DIR = os.path.join(RESOURCES_DIR, "publications")
EXPANSIONS_DIR = os.path.join(RESOURCES_DIR, "expansions")


@dataclass
class Dataset:
name: str
publication_file_name: str
supplementary_file_name: str
data_file_name: str


@dataclass
class State:
articles: set[str]
annotations: list[dict]
article_to_annotations: dict[str, dict]
paper_expansion: dict[str, CellTypeEntry]
is_test_mode: bool = IS_TEST_MODE


@dataclass
class GetGroundings(BaseNode[State, None, str]):

async def run(self, ctx: GraphRunContext[State]) -> End:
annotations = ctx.state.annotations
cxg_annotate_logger.info(f"Total annotations to process: {len(annotations)}")

for annotation in annotations:
if 'enrichment' not in annotation:
annotation['enrichment'] = CellTypeEntry(
name=annotation['annotation_text'],
full_name="",
paper_synonyms="",
tissue_context=""
)
print(f"Warning: No enrichment found for annotation '{annotation['annotation_text']}', using blank entry.")
# delete tissue_context of all enrichments
annotation['enrichment'].tissue_context = ""

# Sort annotations by article_id_doi, then annotation_text
annotations.sort(key=lambda annot: (annot.get('article_id_doi') or "", annot.get('annotation_text') or ""))

cache_dir = os.path.join(RESOURCES_DIR, "cache")
os.makedirs(cache_dir, exist_ok=True)

batch_size = 4
all_groundings = []
for i in range(0, len(annotations), batch_size):
batch_index = i // batch_size
batch = annotations[i:i + batch_size]
batch_cache_path = os.path.join(cache_dir, f"groundings_batch_{batch_index}.json")

if os.path.exists(batch_cache_path):
print(f"Loading cached results for batch {batch_index}")
with open(batch_cache_path, "r") as f:
batch_groundings = [TextAnnotation(**entry) for entry in json.load(f)]
else:
print("Processing batch: ", i // batch_size + 1, " of ",
(len(annotations) + batch_size - 1) // batch_size)
expansions_json = json.dumps([annotation['enrichment'].model_dump() for annotation in batch], indent=2)
agent_response = await annotator_agent.run(expansions_json)
batch_groundings = agent_response.output.annotations
with open(batch_cache_path, "w") as f:
json.dump([entry.model_dump() for entry in batch_groundings], f, indent=2)

all_groundings.extend(batch_groundings)
# update batch annotations with grounding results
for annotation in batch:
# convert enrichment to json to make df mode readable
annotation['enrichment'] = annotation['enrichment'].model_dump()
if "grounding_cl_id" not in annotation:
related_groundings = [gr for gr in batch_groundings if
gr.input_name == annotation['annotation_text']]
if related_groundings:
valid_grounding = next(
(g for g in related_groundings if "NO MATCH" not in g.cl_id), None)
if valid_grounding:
grounding_to_use = valid_grounding
else:
grounding_to_use = related_groundings[0]
annotation['grounding_cl_id'] = grounding_to_use.cl_id
annotation['grounding_cl_label'] = grounding_to_use.cl_label
else:
annotation['grounding_cl_id'] = ""
annotation['grounding_cl_label'] = ""


data = [entry.model_dump() for entry in all_groundings]
df = pd.DataFrame(data)
df.to_csv(os.path.join(RESOURCES_DIR, "cell_type_annotations_un_filtered.tsv"), sep='\t', index=False)

# print annotations that has groundings as tsv (annotation_text, cl_id, grounding_cl_id, grounding_cl_label, article_id_doi)
df = pd.DataFrame(annotations)
df_filtered = df[df['grounding_cl_id'].notna()]
df_filtered['result'] = df_filtered['cl_id'].eq(df_filtered['grounding_cl_id']).map(
{True: 'TRUE', False: 'FALSE'})
df_filtered.to_csv(os.path.join(RESOURCES_DIR, "groundings.tsv"), sep='\t', index=False)

return End("Report generated and saved to individual dataset folders.")


@dataclass
class GetFullNames(BaseNode[State, None, str]):

async def run(self, ctx: GraphRunContext[State]) -> GetGroundings:
print("Running GetFullNames node")
if not os.path.exists(EXPANSIONS_DIR):
os.makedirs(EXPANSIONS_DIR)
article_to_annotations = ctx.state.article_to_annotations
articles = sorted(str(a) if a is not None else "" for a in set(article_to_annotations.keys()))
index = 1
for article_pmc in articles:
print(f"Processing article: {article_pmc} - {index}/{len(articles)}")
index += 1
# get all annotations for this article
article_annotations = article_to_annotations[article_pmc]

for batch_index in range(0, len(article_annotations), ANNOTATIONS_BATCH_SIZE):
batch = article_annotations[batch_index:batch_index + ANNOTATIONS_BATCH_SIZE]
dataset_cache = os.path.join(EXPANSIONS_DIR,
f"{normalise_file_name(article_pmc)}_batch_{batch_index // ANNOTATIONS_BATCH_SIZE}.json")
cc_labels = [{"cc.label": ann['annotation_text']} for ann in batch]

if not os.path.exists(dataset_cache):
full_text_path = os.path.join(EXPANSIONS_DIR, f"{normalise_file_name(article_pmc)}.txt")
if os.path.exists(full_text_path):
with open(full_text_path, 'r', encoding='utf-8') as f:
paper_full_text = f.read()

prompt_instructions = f"""
You are tasked with extracting cell type information from the provided academic paper content,
and the provided JSON data.

The JSON contains cell type annotations (cc.label column) from single-cell transcriptomic data.

Based on the following JSON data and academic paper content, generate a list of structured
cell type entries. Each entry must follow the `CellTypeEntry` schema.

--- JSON List Input Data:
{json.dumps(cc_labels, indent=2)}

--- Academic Paper Content (extracted from PDF):
{paper_full_text}

--- COLUMN DEFINITIONS AND LOGIC:
- `name`: The exact `cc.label` from the input JSON.
- `full_name`: Use the following logic:
1. If the full label (e.g., "SI_TA") is defined directly in the paper, use the exact definition.
2. If not, check if individual parts (e.g., prefixes, suffixes) are defined and reconstruct/assemble the `full_name` from the parts found (e.g., for "SI_TA", assemble "small intestine transit amplifying cell" if paper defines "SI" as "small intestine" and "TA" as "transit amplifying cell").
3. If the label begins with a defined prefix abbreviation (e.g., "RGC"), expand the prefix and append the remaining label (e.g., "RGC10" becomes "retinal ganglion cell 10").
4. If only one part is defined, use just that part.
5. If no parts are defined, leave this field blank.
- `paper_synonyms`: Use only synonyms mentioned in the paper using:
- Abbreviation lists
- Abbreviation definitions (e.g., "follicle-associated epithelium (FAE)")
- Patterns like “also known as”, “termed”, “referred to as”
- Include all found; separate with semicolons (;)
- `tissue_context`: Exact quoted tissue(s) or anatomical terms from the paper where the cell type was identified.

Process all `cc.label` entries from the JSON data automatically.
Do not ask for confirmation.
Provide the output as a JSON array of `CellTypeEntry` objects.
"""
agent_response = await celltype_agent.run(prompt_instructions)

for entry in agent_response.output.cell_type_annotations:
print(
f"Name: {entry.name}, Full Name: {entry.full_name}, Synonyms: {entry.paper_synonyms}, Tissue Context: {entry.tissue_context}")
# add entry to the related article_annotations
for ann in article_annotations:
if ann['annotation_text'] == entry.name:
ann['enrichment'] = entry
break

# ctx.state.paper_expansion[article_pmc] = agent_response.output.cell_type_annotations
expansions = agent_response.output.cell_type_annotations
print(f"Saving results to cache for article: {article_pmc}")
with open(dataset_cache, 'w') as cache_file:
json.dump(
[entry.model_dump() for entry in expansions],
cache_file, indent=2)
else:
print(f"Error: Full text file not found for article for name expansion: {article_pmc}")
else:
print(f"Using cached data for article: {article_pmc}")
with open(dataset_cache, 'r') as cache_file:
cached_data = json.load(cache_file)
for cached_entry in cached_data:
for ann in article_annotations:
if ann['annotation_text'] == cached_entry["name"]:
ann['enrichment'] = CellTypeEntry(**cached_entry)
print("Using cached enrichment data for annotation:", ann['annotation_text'])
break
# ctx.state.paper_expansion[article_pmc] = [CellTypeEntry(**entry) for entry in cached_data]
return GetGroundings()

@dataclass
class PrepareData(BaseNode[State, None, str]):

async def run(self, ctx: GraphRunContext[State]) -> GetFullNames:
print("Running PrepareData node")
annotations, article_to_annotations = load_cxg_annotations()

if ctx.state.is_test_mode:
# only process a few annotations in test mode
annotations = list(annotations)[:TEST_ANNOTATIONS_COUNT]
# filter article_to_annotations to only include those in annotations
article_to_annotations = {k: v for k, v in article_to_annotations.items() if k in
{ann['article_id_doi'] for ann in annotations}}

unique_dois = set(article_to_annotations.keys())
print(f"Unique DOISs to download: {len(unique_dois)}")
articles = download_publication_texts(unique_dois)
print(f"Downloaded articles: {len(articles)}")

ctx.state.articles = articles
ctx.state.annotations = annotations
ctx.state.article_to_annotations = article_to_annotations

return GetFullNames()

def load_cxg_annotations():
tsv_path = os.path.join(os.getcwd(),"resources", "ac8619d0-4fff-4296-913a-819d1e361ba0_cxg_dataset_unique.tsv")
df = pd.read_csv(tsv_path, sep='\t')

annotations = []
article_to_annotations = {}

for _, row in df.iterrows():
paper_doi = str(row['reference']).replace("https://doi.org/", "DOI:")
annotation = {
'annotation_text': row['author_cell_type'],
'cl_id': row['CL_ID'],
'cl_label': row['CL_label'],
'article_id_doi': paper_doi
}
annotations.append(annotation)
article_to_annotations.setdefault(paper_doi, []).append(annotation)

return annotations, article_to_annotations

def download_publication_texts(dois, publications_dir=PUBLICATIONS_DIR):
"""
Download full text for each DOI using get_doi_text and save to publications_dir/pmc_id.txt.
Skips download if file already exists. Creates publications_dir if needed.
Args:
dois (Iterable[str]): Set or list of PMC IDs.
publications_dir (str): Directory to save text files.
"""
if not os.path.exists(publications_dir):
os.makedirs(publications_dir)
articles = set()
for doi in dois:
if doi:
file_path = os.path.join(publications_dir, f"{normalise_file_name(doi)}.txt")
if os.path.exists(file_path):
articles.add(doi)
continue
text = get_doi_text(doi)
if text:
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)
articles.add(doi)
else:
print(f"Error: No full-text found for ID {doi}")
return articles

def normalise_file_name(doi: str) -> str:
return doi.replace("/", "_").replace(":", "_").replace(".", "_")

async def main():
state = State(set(), list(), dict(), dict(), is_test_mode=IS_TEST_MODE)
validation_graph = Graph(nodes=(PrepareData, GetFullNames, GetGroundings))
result = await validation_graph.run(PrepareData(), state=state)
print(result.output)
# print(validation_graph.mermaid_code())


if __name__ == "__main__":
load_dotenv()
print(os.environ.get("OPENAI_API_KEY"))
asyncio.run(main())
Loading
Loading