-
Notifications
You must be signed in to change notification settings - Fork 94
Add DIPGSafetyEnv for Medical AI Safety Research` #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 23 commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
a5e98b8
dipg safety
surfiniaburger 723ef99
Fix: Correct StepResult import in DIPG safety client
surfiniaburger e847824
DEBUG: Add print statement to client parser
surfiniaburger 05490a0
FIX: Handle double-nested observation in client parser
surfiniaburger 919833c
Merge pull request #1 from surfiniaburger/dipg-research
surfiniaburger f4073ad
FIX: Create robust client parser for reset/step inconsistency
surfiniaburger 7eb6a04
Merge pull request #2 from surfiniaburger/dipg-research
surfiniaburger cf389a2
Fix: Create robust client parser for server responses
surfiniaburger 7820e02
Merge pull request #3 from surfiniaburger/dipg-research
surfiniaburger ebb121f
cla
surfiniaburger 8187400
include test and readme
surfiniaburger 7ce1e12
default to an empty one if obs_data is None
surfiniaburger 1f3c5a7
Update src/envs/dipg_safety_env/README.md
surfiniaburger 569e902
Merge pull request #4 from surfiniaburger/dipg-research
surfiniaburger d8e7008
Feat: Implement code review feedback
surfiniaburger e10ded5
Update src/envs/dipg_safety_env/server/test_dipg_safety_env.py
surfiniaburger a41dd49
Merge pull request #5 from surfiniaburger/dipg-research
surfiniaburger 05568dd
correction
surfiniaburger 2584a01
Merge pull request #6 from surfiniaburger/dipg-research
surfiniaburger 0f09799
Feat: Add configurable timeout to DIPGSafetyEnv client
surfiniaburger b047ea2
Merge pull request #7 from surfiniaburger/dipg-research
surfiniaburger b4111db
Fix(client): Correctly pass timeout parameter to parent class
surfiniaburger 1ff4e49
Merge pull request #8 from surfiniaburger/dipg-research
surfiniaburger 48a16af
Architectural Improvements
surfiniaburger c2755ee
Merge pull request #9 from surfiniaburger/dipg-research
surfiniaburger 4820ea5
add channels to env
surfiniaburger 56ff6e8
Merge pull request #10 from surfiniaburger/dipg-research
surfiniaburger 885132b
update notebook
surfiniaburger 1ea027d
Merge pull request #11 from surfiniaburger/dipg-research
surfiniaburger 7ec0c8d
dipg-notebook
surfiniaburger 8085892
Merge pull request #12 from surfiniaburger/dipg-research
surfiniaburger 6d934c0
improve reset method
surfiniaburger d1cf785
Merge pull request #13 from surfiniaburger/dipg-research
surfiniaburger 7670637
use simulation for now
surfiniaburger 13eb147
Merge pull request #14 from surfiniaburger/dipg-research
surfiniaburger 5ea1c52
set max timeout
surfiniaburger 2392407
Merge pull request #15 from surfiniaburger/dipg-research
surfiniaburger 907d1e3
include all data
surfiniaburger d2715ae
Merge pull request #16 from surfiniaburger/dipg-research
surfiniaburger af7a0f7
pending bug fix
surfiniaburger 0aeaab9
Merge pull request #17 from surfiniaburger/dipg-research
surfiniaburger 26e8a12
revert change
surfiniaburger fdb22b5
use vanilla reset
surfiniaburger 2922991
Merge pull request #18 from surfiniaburger/dipg-research
surfiniaburger 4fdee22
revert vanilla
surfiniaburger 84b696c
Merge pull request #19 from surfiniaburger/dipg-research
surfiniaburger eb8bb9f
update fast-api create app
surfiniaburger aaa8dba
feat(dipg_safety_env): Improve test coverage and fix bugs
surfiniaburger 908a147
clean up
surfiniaburger 4eecb68
Merge pull request #20 from surfiniaburger/dipg-research
surfiniaburger a0500e5
log actions
surfiniaburger b2f48f3
Merge pull request #21 from surfiniaburger/dipg-research
surfiniaburger ba81311
use print
surfiniaburger 3a8d4b4
notebook and demo link
surfiniaburger 5463970
update
surfiniaburger a977421
Merge pull request #22 from surfiniaburger/dipg-research
surfiniaburger 2037ccb
re-add logger
surfiniaburger e63a1fa
removed output
surfiniaburger 03b804e
Merge pull request #23 from surfiniaburger/dipg-research
surfiniaburger File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # DIPG Safety Environment (DIPGSafetyEnv) | ||
|
|
||
| ## Overview | ||
|
|
||
| The `DIPGSafetyEnv` is a custom environment built on the OpenEnv framework for Reinforcement Learning research in high-stakes AI safety. It was developed to address a critical use case: ensuring the reliability and safety of a Large Language Model (LLM) agent operating in the medical domain of **Diffuse Intrinsic Pontine Glioma (DIPG)**, a universally fatal pediatric brain tumor. | ||
|
|
||
| In this context, an AI's failure is not an option. The environment's primary purpose is to train and rigorously evaluate an agent's ability to: | ||
| 1. Base its answers *only* on the verified clinical context provided. | ||
| 2. Correctly identify and report conflicting information from different sources. | ||
| 3. Safely abstain from answering when the context is insufficient. | ||
| 4. Strictly avoid hallucinating facts or providing unsafe, unsupported information. | ||
|
|
||
| ## Features | ||
|
|
||
| The environment server contains a suite of safety-critical reward functions that score an agent's response based on the following behaviors: | ||
|
|
||
| * **Conflict Identification:** Rewards the agent for correctly stating that provided sources are contradictory. | ||
| * **Knowledge Abstention:** Rewards the agent for recognizing when a question cannot be answered from the given text and explicitly saying so. | ||
| * **Format Adherence:** Positively or negatively scores the response based on its adherence to a required structured output format. | ||
| * **Hallucination Penalty:** Heavily penalizes the agent for generating any information that is not supported by the provided context. | ||
|
|
||
| ## Getting Started: How to Use the Environment | ||
|
|
||
| The `DIPGSafetyEnv` follows a standard client-server model. | ||
|
|
||
| ### 1. Running the Server | ||
|
|
||
| The server requires the custom synthetic dataset (`harmonic_reasoner_dataset_structured.jsonl`) to be present in its directory. The easiest way to run the server is as a background process from a script or notebook. | ||
|
|
||
| ```python | ||
| import subprocess | ||
| import sys | ||
| import os | ||
| import time | ||
|
|
||
| # Ensure the dataset file is in the server's execution directory (`src`) first. | ||
| # !mv /path/to/your/harmonic_reasoner_dataset_structured.jsonl ./src/ | ||
|
|
||
| port = "8009" | ||
| localhost = f"http://localhost:{port}" | ||
|
|
||
| # Start the server process from the 'src' directory of the OpenEnv project | ||
| openenv_process = subprocess.Popen( | ||
| [sys.executable, "-m", "uvicorn", "envs.dipg_safety_env.server.app:app", "--host", "0.0.0.0", "--port", port], | ||
| env={**os.environ, "PYTHONPATH": "./src"}, | ||
| cwd="./src", | ||
| ) | ||
|
|
||
| # Wait for the server to initialize | ||
| time.sleep(15) | ||
| print("Server process started.") | ||
| ``` | ||
|
|
||
| ### 2. Interacting from the Client | ||
|
|
||
| Once the server is running, an agent can interact with it using the `DIPGSafetyEnv` client. | ||
|
|
||
| ```python | ||
| from envs.dipg_safety_env.client import DIPGSafetyEnv | ||
| from envs.dipg_safety_env.models import DIPGAction | ||
|
|
||
| # Connect to the running server | ||
| env = DIPGSafetyEnv(base_url="http://localhost:8009", timeout=60) | ||
|
|
||
| # Start a new episode and get the first challenge | ||
| # The 'obs' object will contain a medical context and a question. | ||
| obs = env.reset() | ||
| print(f"Question: {obs.question}") | ||
|
|
||
| # The agent processes the observation and generates a response | ||
| agent_response_text = "Based on the provided context, the information is conflicting." | ||
|
|
||
| # Send the response (as an Action) to the environment to be scored | ||
| action = DIPGAction(llm_response=agent_response_text) | ||
| result = env.step(action) | ||
|
|
||
| # The result contains the reward and a flag indicating the episode is done | ||
| print(f"Reward: {result.reward}") | ||
| print(f"Done: {result.done}") | ||
| ``` | ||
|
|
||
| ## Running Tests | ||
|
|
||
| The environment includes a suite of unit tests to ensure its core logic is working correctly. These tests verify that the environment can be reset, that actions are processed, and that the internal state is managed properly. | ||
|
|
||
| ### Prerequisites | ||
|
|
||
| You must have `pytest` installed: | ||
| ```bash | ||
| pip install pytest | ||
| ``` | ||
|
|
||
| ### How to Run | ||
|
|
||
| From the **root directory** of the `OpenEnv` project, run the following command: | ||
|
|
||
| ```bash | ||
| pytest src/envs/dipg_safety_env/server/test_dipg_safety_env.py | ||
| ``` | ||
|
|
||
| A successful run will show an output indicating that all tests passed, for example: | ||
| ``` | ||
| ============================= test session starts ============================== | ||
| ... | ||
| collected 4 items | ||
|
|
||
| src/envs/dipg_safety_env/server/test_dipg_safety_env.py .... [100%] | ||
|
|
||
| ============================== 4 passed in 0.50s =============================== | ||
| ``` | ||
|
|
||
| ## Core Components | ||
|
|
||
| * **`models.py`**: Defines the data structures for interaction: | ||
| * `DIPGObservation`: Contains the `context` and `question` served to the agent. | ||
| * `DIPGAction`: Contains the `llm_response` generated by the agent. | ||
| * **`server/dipg_environment.py`**: The core of the environment. It loads the dataset, serves challenges via `reset()`, and calculates rewards via `step()`. | ||
| * **`client.py`**: The "remote control" that allows a Python script to communicate with the server over HTTP, handling all the JSON serialization and parsing. | ||
| * **`server/test_dipg_safety_env.py`**: Unit tests for verifying the environment's functionality. |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # src/envs/dipg_safety_env/client.py | ||
| """ | ||
| Client implementation for the custom DIPGSafetyEnv. | ||
|
|
||
| This file defines the `DIPGSafetyEnv` class, which acts as the "remote control" | ||
| for the environment server. Its primary job is to handle the HTTP communication: | ||
| 1. It takes Python objects (like an Action) from the agent's code. | ||
| 2. It converts them into JSON to send to the server. | ||
| 3. It receives JSON responses from the server. | ||
| 4. It parses that JSON back into useful Python objects (like Observations and Rewards). | ||
| """ | ||
|
|
||
| from core.http_env_client import HTTPEnvClient, StepResult | ||
| from .models import DIPGAction, DIPGObservation, DIPGState | ||
|
|
||
|
|
||
| class DIPGSafetyEnv(HTTPEnvClient[DIPGAction, DIPGObservation]): | ||
| """ | ||
| Client for interacting with the `DIPGSafetyEnv` server. | ||
|
|
||
| This class inherits from the base `HTTPEnvClient` and is specialized to handle | ||
| the specific data types of our environment: `DIPGAction` and `DIPGObservation`. | ||
| """ | ||
|
|
||
| def __init__(self, base_url: str, timeout: float = 60.0): | ||
| """ | ||
| Initializes the client. | ||
|
|
||
| Args: | ||
| base_url: The URL of the running environment server. | ||
| timeout: The number of seconds to wait for a server response. | ||
| """ | ||
| # This correctly calls the parent initializer with the expected | ||
| # 'request_timeout_s' keyword argument. | ||
| super().__init__(base_url=base_url, request_timeout_s=timeout) | ||
| # ---------------------------------------- | ||
|
|
||
| def _step_payload(self, action: DIPGAction) -> dict: | ||
| """ | ||
| Formats the `DIPGAction` object into a JSON-serializable dictionary. | ||
|
|
||
| This dictionary becomes the body of the HTTP POST request sent to the | ||
| server's `/step` endpoint. | ||
|
|
||
| Args: | ||
| action: The `DIPGAction` object containing the model's response. | ||
|
|
||
| Returns: | ||
| A dictionary to be sent as the JSON request body. | ||
| """ | ||
| return {"llm_response": action.llm_response} | ||
|
|
||
| def _parse_result(self, payload: dict) -> StepResult[DIPGObservation]: | ||
| """ | ||
| Parses the JSON payload from the server into a `StepResult`, | ||
| robustly handling inconsistencies and potential missing data. | ||
|
|
||
| This method is designed to be crash-proof and handles three key scenarios: | ||
| 1. The single-nested 'observation' dictionary from the `/reset` endpoint. | ||
| 2. The double-nested 'observation' dictionary from the `/step` endpoint. | ||
| 3. A payload where the 'observation' key might be missing entirely. | ||
|
|
||
| Args: | ||
| payload: The raw dictionary parsed from the server's JSON response. | ||
|
|
||
| Returns: | ||
| A structured `StepResult` object. | ||
| """ | ||
| # Safely get the top-level 'observation' object. It could be a dict or None. | ||
| obs_data = payload.get("observation") | ||
|
|
||
| # Check if the object is a dictionary and contains the nested 'observation' key. | ||
| # This identifies the double-nested structure from the /step endpoint. | ||
| if isinstance(obs_data, dict) and "observation" in obs_data: | ||
| # If so, go one level deeper to get the actual data payload. | ||
| actual_obs_data = obs_data.get("observation") | ||
| else: | ||
| # Otherwise, it's either the single-nested structure from /reset or None. | ||
| actual_obs_data = obs_data if isinstance(obs_data, dict) else {} | ||
|
|
||
| # To prevent crashes, ensure `actual_obs_data` is a dictionary before | ||
| # we try to access keys from it. If it was None, it becomes an empty dict. | ||
| if not isinstance(actual_obs_data, dict): | ||
| actual_obs_data = {} | ||
|
|
||
| # Construct the DIPGObservation object safely. | ||
| # Using .get() with a default value ("") prevents a KeyError if 'context' or | ||
| # 'question' are missing from the payload, ensuring the client never crashes. | ||
| obs = DIPGObservation( | ||
| context=actual_obs_data.get("context", ""), | ||
| question=actual_obs_data.get("question", ""), | ||
| ) | ||
|
|
||
| # Assemble and return the final, structured StepResult. | ||
| return StepResult( | ||
| observation=obs, | ||
| reward=payload.get("reward"), | ||
| done=payload.get("done", False), | ||
| ) | ||
|
|
||
|
|
||
| def _parse_state(self, payload: dict) -> DIPGState: | ||
| """ | ||
| Parses the JSON payload from the server's `/state` endpoint into a `DIPGState` object. | ||
|
|
||
| Args: | ||
| payload: The raw dictionary parsed from the server's JSON response. | ||
|
|
||
| Returns: | ||
| A structured `DIPGState` object. | ||
| """ | ||
| return DIPGState(**payload) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # src/envs/dipg_safety_env/models.py | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from core.env_server import Action, Observation, State | ||
|
|
||
| @dataclass | ||
| class DIPGAction(Action): | ||
| """The action taken by the agent, which is its generated response.""" | ||
| llm_response: str | ||
|
|
||
| @dataclass | ||
| class DIPGObservation(Observation): | ||
| """The observation given to the agent: a context and a question.""" | ||
| context: str | ||
| question: str | ||
|
|
||
| @dataclass | ||
| class DIPGState(State): | ||
| """The internal state of the environment for tracking the current challenge.""" | ||
| current_context: str = "" | ||
| current_question: str = "" | ||
| # This will hold the ground-truth 'analysis' and 'final' answer | ||
| # for scoring purposes. | ||
| expected_answer: dict = field(default_factory=dict) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # src/envs/dipg_safety_env/server/Dockerfile | ||
| ARG BASE_IMAGE=openenv-base:latest | ||
| FROM ${BASE_IMAGE} | ||
|
|
||
| # Copy environment code | ||
| COPY src/core/ /app/src/core/ | ||
| COPY src/envs/dipg_safety_env/ /app/src/envs/dipg_safety_env/ | ||
|
|
||
| # ===> ADD THIS LINE <=== | ||
| # Copy your dataset so the environment can read it. | ||
| # Make sure the path is correct relative to the Docker build context. | ||
| COPY harmonic_reasoner_dataset_structured.jsonl /app/harmonic_reasoner_dataset_structured.jsonl | ||
|
|
||
| # Health check & CMD (same as example) | ||
| HEALTHCHECK CMD curl -f http://localhost:8000/health || exit 1 | ||
| CMD ["uvicorn", "envs.dipg_safety_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| from core.env_server import create_fastapi_app | ||
| from ..models import DIPGAction, DIPGObservation | ||
| from .dipg_environment import DIPGEnvironment | ||
|
|
||
| env = DIPGEnvironment() | ||
| app = create_fastapi_app(env, DIPGAction, DIPGObservation) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move this file to be under the envs/<YOUR_ENV>/server
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I gotta be a party pooper here but one area we need to enforce is to actually NOT host datasets in the repo. Doing so requires that we check licenses and whatnot which is frankly too much work for us to sign up for.
However, if the dataset is hosted somewhere else and we simply have a script here that downloads it, things are different since we are not redistributing it.
All of this to say to replace this with something like HF Datasets downloading this file. HF Datasets also has a well-established process to deal with licenses of datasets, so that we can be sure that we can indeed redistribute it.
Unfortunately we will not be able to merge this PR until this happens...