diff --git a/.github/python-dev.yml b/.github/python-dev.yml deleted file mode 100644 index 706c254..0000000 --- a/.github/python-dev.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: Python Dev - -on: - push: - branches: - - main - pull_request: - workflow_dispatch: - -jobs: - build: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.8' - - - name: Install dependencies - run: | - python -m pip install --upgrade pip setuptools wheel - pip install -r requirements.txt - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - pip install pytest-cov - pip list - - name: Build package - run: | - python setup.py build_ext -q -j2 - python -m pip install -e . - - name: Build Version - run: | - python -c "import shap_demo" - - name: Test with pytest - run: | - pytest -v --cov=src --cov-report=xml - continue-on-error: true - - name: Publish test results - uses: actions/upload-artifact@master - with: - name: Test results - path: coverage.xml - if: failure() - - name: Report Coverage - run: | - coverage report -m - diff --git a/src/offline_rl_ope/components/ImportanceSampler.py b/src/offline_rl_ope/components/ImportanceSampler.py index ff41af0..1032018 100644 --- a/src/offline_rl_ope/components/ImportanceSampler.py +++ b/src/offline_rl_ope/components/ImportanceSampler.py @@ -137,7 +137,23 @@ def update( states=states, actions=actions, eval_policy=eval_policy) check_array_dim(_is_weights,2) check_array_dim(_weight_msk,2) - assert len(eval_policy.policy_actions) == len(actions) + len_act = len(actions) + len_eval_act = len(eval_policy.policy_actions) + _msg = f""" + Actions have length: {len_act}. + Evalutaion policy predicted actions have length: {len_eval_act} + """ + try: + assert len_act == len_eval_act, _msg + except Exception as e: + if len_eval_act == 0: + logger.info( + """ + No actions assoicated with evaluation policy. + Has collect_act been set to true in the Policy class? + """ + ) + raise e self.policy_actions = eval_policy.policy_actions self.is_weights = _is_weights self.weight_msk = _weight_msk diff --git a/src/offline_rl_ope/components/Policy.py b/src/offline_rl_ope/components/Policy.py index 466988f..d2e2db7 100644 --- a/src/offline_rl_ope/components/Policy.py +++ b/src/offline_rl_ope/components/Policy.py @@ -72,7 +72,7 @@ def __call__( self, state:StateTensor, action:ActionTensor - )->ActionTensor: + )->Float[torch.Tensor, "traj_length 1"]: """Defines the probability of the given actions under the given states according to the policy defined by policy_func