Skip to content

Commit

Permalink
Merge pull request #4682 from broadinstitute/dag-triggger-ui
Browse files Browse the repository at this point in the history
Dag triggger UI
  • Loading branch information
hanars authored Mar 3, 2025
2 parents f7a6130 + 907240d commit b26bdb4
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from seqr.models import Project, Family, VariantTag, VariantTagType, Sample
from seqr.utils.search.utils import backend_specific_call
from seqr.views.utils.airflow_utils import trigger_airflow_delete_families, DagRunningException
from seqr.views.utils.airflow_utils import trigger_airflow_dag, DELETE_FAMILIES_DAG_NAME, DagRunningException

import logging
logger = logging.getLogger(__name__)
Expand All @@ -29,7 +29,7 @@ def _trigger_delete_families_dags(from_project, updated_family_dataset_types):

for dataset_type, family_guids in sorted(updated_families_by_dataset_type.items()):
try:
trigger_airflow_delete_families(dataset_type, family_guids, from_project)
trigger_airflow_dag(DELETE_FAMILIES_DAG_NAME, from_project, dataset_type, family_guids=sorted(family_guids))
logger.info(f'Successfully triggered DELETE_FAMILIES DAG for {len(family_guids)} {dataset_type} families')
except Exception as e:
logger_call = logger.warning if isinstance(e, DagRunningException) else logger.error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def assert_airflow_delete_families_calls(self):
call_count_per_dag = 5
for i, dataset_type in enumerate(['MITO', 'SNV_INDEL', 'SV']):
offset = i * call_count_per_dag
self._assert_airflow_calls(self._get_dag_variables(dataset_type), call_count_per_dag, offset)
self.assert_airflow_calls(self._get_dag_variables(dataset_type), call_count_per_dag, offset)

def _assert_update_check_airflow_calls(self, call_count, offset, update_check_path):
variables_update_check_path = f'{self.MOCK_AIRFLOW_URL}/api/v1/variables/{self.DAG_NAME}'
Expand Down
3 changes: 2 additions & 1 deletion seqr/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@

from seqr.views.apis.data_manager_api import elasticsearch_status, upload_qc_pipeline_output, delete_index, \
update_rna_seq, load_rna_seq_sample_data, proxy_to_kibana, load_phenotype_prioritization_data, \
validate_callset, get_loaded_projects, load_data, loading_vcfs, proxy_to_luigi
validate_callset, get_loaded_projects, load_data, loading_vcfs, trigger_dag, proxy_to_luigi
from seqr.views.apis.report_api import \
anvil_export, \
family_metadata, \
Expand Down Expand Up @@ -338,6 +338,7 @@
'data_management/loaded_projects/(?P<genome_version>[^/]+)/(?P<sample_type>[^/]+)/(?P<dataset_type>[^/]+)': get_loaded_projects,
'data_management/load_data': load_data,
'data_management/add_igv': receive_bulk_igv_table_handler,
'data_management/trigger_dag/(?P<dag_id>[^/]+)': trigger_dag,

'summary_data/saved_variants/(?P<tag>[^/]+)': saved_variants_page,
'summary_data/hpo/(?P<hpo_id>[^/]+)': hpo_summary_data,
Expand Down
2 changes: 1 addition & 1 deletion seqr/utils/search/add_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def format_loading_pipeline_variables(
projects: list[Project], genome_version: str, dataset_type: str, sample_type: str = None, **kwargs
):
variables = {
'projects_to_run': sorted([p.guid for p in projects]),
'projects_to_run': sorted([p.guid for p in projects]) if projects else None,
'dataset_type': _dag_dataset_type(sample_type, dataset_type),
'reference_genome': GENOME_VERSION_LOOKUP[genome_version],
**kwargs
Expand Down
25 changes: 23 additions & 2 deletions seqr/views/apis/data_manager_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from seqr.utils.middleware import ErrorsWarningsException
from seqr.utils.vcf_utils import validate_vcf_and_get_samples, get_vcf_list

from seqr.views.utils.airflow_utils import trigger_airflow_data_loading
from seqr.views.utils.airflow_utils import trigger_airflow_data_loading, trigger_airflow_dag, is_airflow_enabled
from seqr.views.utils.airtable_utils import AirtableSession, LOADABLE_PDO_STATUSES, AVAILABLE_PDO_STATUS
from seqr.views.utils.dataset_utils import load_rna_seq, load_phenotype_prioritization_data_file, RNA_DATA_TYPE_CONFIGS, \
post_process_rna_data, convert_django_meta_to_http_headers
from seqr.views.utils.file_utils import parse_file, get_temp_file_path, load_uploaded_file, persist_temp_file
from seqr.views.utils.json_utils import create_json_response
from seqr.views.utils.json_utils import create_json_response, _to_snake_case
from seqr.views.utils.json_to_orm_utils import update_model_from_json
from seqr.views.utils.pedigree_info_utils import get_validated_related_individuals, JsonConstants
from seqr.views.utils.permissions_utils import data_manager_required, pm_or_data_manager_required, get_internal_projects
Expand Down Expand Up @@ -618,6 +618,27 @@ def _get_valid_search_individuals(project, airtable_samples, vcf_samples, datase
return [i['id'] for i in search_individuals_by_id.values()] + loaded_individual_ids


@data_manager_required
def trigger_dag(request, dag_id):
if not is_airflow_enabled():
raise PermissionDenied()
request_json = json.loads(request.body)
project_guid = request_json.pop('project', None)
family_guid = request_json.pop('family', None)
kwargs = {_to_snake_case(k): v for k, v in request_json.items()}
project = None
if project_guid:
project = Project.objects.get(guid=project_guid)
elif family_guid:
project = Project.objects.get(family__guid=family_guid)
kwargs['family_guids'] = [family_guid]
try:
dag_variables = trigger_airflow_dag(dag_id, project, **kwargs)
except Exception as e:
return create_json_response({'error': str(e)}, status=400)
return create_json_response({'info': [f'Triggered DAG {dag_id} with variables: {json.dumps(dag_variables)}']})


# Hop-by-hop HTTP response headers shouldn't be forwarded.
# More info at: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
EXCLUDE_HTTP_RESPONSE_HEADERS = {
Expand Down
91 changes: 90 additions & 1 deletion seqr/views/apis/data_manager_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from seqr.utils.communication_utils import _set_bulk_notification_stream
from seqr.views.apis.data_manager_api import elasticsearch_status, upload_qc_pipeline_output, delete_index, \
update_rna_seq, load_rna_seq_sample_data, load_phenotype_prioritization_data, validate_callset, loading_vcfs, \
get_loaded_projects, load_data
get_loaded_projects, trigger_dag, load_data
from seqr.views.utils.orm_to_json_utils import _get_json_for_models
from seqr.views.utils.test_utils import AuthenticationTestCase, AirflowTestCase, AirtableTest
from seqr.utils.search.elasticsearch.es_utils_tests import urllib3_responses
Expand Down Expand Up @@ -1659,6 +1659,54 @@ def _test_no_affected_family(self, url, body):
})
Individual.objects.filter(guid='I000009_na20874').update(affected='A')

@responses.activate
def test_trigger_dag(self):
self.check_data_manager_login(reverse(trigger_dag, args=['some_dag']))

self._test_trigger_single_dag(
'DELETE_PROJECTS',
{'project': PROJECT_GUID, 'datasetType': 'SNV_INDEL'},
{
'projects_to_run': [PROJECT_GUID],
'dataset_type': 'SNV_INDEL',
'reference_genome': 'GRCh37',
}
)
self._test_trigger_single_dag(
'DELETE_FAMILIES',
{'family': 'F000012_12', 'datasetType': 'MITO'},
{
'projects_to_run': ['R0003_test'],
'dataset_type': 'MITO',
'reference_genome': 'GRCh37',
'family_guids': ['F000012_12'],
}
)

body = {'genomeVersion': '38', 'datasetType': 'SV'}
url = self._test_trigger_single_dag('UPDATE_REFERENCE_DATASETS',body,{
'projects_to_run': None,
'dataset_type': 'SV',
'reference_genome': 'GRCh38',
})

self._test_dag_trigger_errors(url, body)

def _test_trigger_single_dag(self, dag_id, body, dag_variables):
responses.calls.reset()
self.set_up_one_dag(dag_id, variables=dag_variables)

url = reverse(trigger_dag, args=[dag_id])
response = self.client.post(url, content_type='application/json', data=json.dumps(body))
self._assert_expected_dag_trigger(response, dag_id, dag_variables)
return url

def _assert_expected_dag_trigger(self, response, dag_id, variables):
self.assertEqual(response.status_code, 403)

def _test_dag_trigger_errors(self, url, body):
pass


class LocalDataManagerAPITest(AuthenticationTestCase, DataManagerAPITest):
fixtures = ['users', '1kg_project', 'reference_data']
Expand Down Expand Up @@ -1781,6 +1829,9 @@ def _assert_write_pedigree_error(self, response):
def _test_validate_dataset_type(self, url):
pass

def set_up_one_dag(self, *args, **kwargs):
pass


@mock.patch('seqr.views.utils.permissions_utils.PM_USER_GROUP', 'project-managers')
class AnvilDataManagerAPITest(AirflowTestCase, DataManagerAPITest):
Expand Down Expand Up @@ -2038,3 +2089,41 @@ def _test_validate_dataset_type(self, url):
self.assertListEqual(response.json()['errors'], [f'Data file or path {self.CALLSET_DIR}/mito_callset.mt is not found.'])
self._set_file_not_found()

def _add_update_check_dag_responses(self, variables=None, **kwargs):
if not variables:
super()._add_update_check_dag_responses(**kwargs)
return

# get variables
responses.add(responses.GET, f'{self.MOCK_AIRFLOW_URL}/api/v1/variables/{self.DAG_NAME}', json={
'key': self.DAG_NAME,
'value': '{}'
})
# get variables again if the response of the previous request didn't include the updated variables
responses.add(responses.GET, f'{self.MOCK_AIRFLOW_URL}/api/v1/variables/{self.DAG_NAME}', json={
'key': self.DAG_NAME,
'value': json.dumps(variables)
})

def _assert_update_check_airflow_calls(self, call_count, offset, update_check_path):
if self.DAG_NAME != 'LOADING_PIPELINE':
update_check_path = f'{self.MOCK_AIRFLOW_URL}/api/v1/variables/{self.DAG_NAME}'
super()._assert_update_check_airflow_calls(call_count, offset, update_check_path)

def set_up_one_dag(self, dag_id=None, **kwargs):
if dag_id:
self._dag_url = self._dag_url.replace(self.DAG_NAME, dag_id)
self.DAG_NAME = dag_id
super().set_up_one_dag(**kwargs)

def _assert_expected_dag_trigger(self, response, dag_id, variables):
self.assertEqual(response.status_code, 200)
self.assertDictEqual(response.json(), {'info': [f'Triggered DAG {dag_id} with variables: {json.dumps(variables)}']})

self.assert_airflow_calls(variables, 5)

def _test_dag_trigger_errors(self, url, body):
self.set_dag_trigger_error_response()
response = self.client.post(url, content_type='application/json', data=json.dumps(body))
self.assertEqual(response.status_code, 400)
self.assertDictEqual(response.json(), {'error': 'UPDATE_REFERENCE_DATASETS DAG is running and cannot be triggered again.'})
23 changes: 13 additions & 10 deletions seqr/views/utils/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class DagRunningException(Exception):
pass


def is_airflow_enabled():
return bool(AIRFLOW_WEBSERVER_URL)


def trigger_airflow_data_loading(*args, user: User, success_message: str, success_slack_channel: str,
error_message: str, is_internal: bool = False, **kwargs):
success = True
Expand All @@ -47,19 +51,18 @@ def trigger_airflow_data_loading(*args, user: User, success_message: str, succes
return success


def trigger_airflow_delete_families(
dataset_type: str, family_guids: list[str], from_project: Project,
):
def trigger_airflow_dag(dag_id: str, project: Project, dataset_type: str, genome_version: str = None, **kwargs):
variables = format_loading_pipeline_variables(
[from_project],
from_project.genome_version,
[project] if project else [],
project.genome_version if project else genome_version,
dataset_type,
family_guids=sorted(family_guids)
**kwargs
)
_check_dag_running_state(DELETE_FAMILIES_DAG_NAME)
_update_variables(variables, DELETE_FAMILIES_DAG_NAME)
_wait_for_dag_variable_update(variables, DELETE_FAMILIES_DAG_NAME)
_trigger_dag(DELETE_FAMILIES_DAG_NAME)
_check_dag_running_state(dag_id)
_update_variables(variables, dag_id)
_wait_for_dag_variable_update(variables, dag_id)
_trigger_dag(dag_id)
return variables


def _send_load_data_slack_msg(messages: list[str], channel: str, dag: dict):
Expand Down
4 changes: 2 additions & 2 deletions seqr/views/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,14 +673,14 @@ def assert_airflow_loading_calls(self, trigger_error=False, additional_tasks_che
if dag_variable_overrides.get('skip_validation'):
dag_variables['skip_validation'] = True
dag_variables['sample_source'] = dag_variable_overrides['sample_source']
self._assert_airflow_calls(dag_variables, call_count, offset=offset)
self.assert_airflow_calls(dag_variables, call_count, offset=offset)

def _assert_call_counts(self, call_count):
self.mock_airflow_logger.info.assert_not_called()
self.assertEqual(len(responses.calls), call_count + self.ADDITIONAL_REQUEST_COUNT)
self.assertEqual(self.mock_authorized_session.call_count, call_count)

def _assert_airflow_calls(self, dag_variables, call_count, offset=0):
def assert_airflow_calls(self, dag_variables, call_count, offset=0):
self._assert_dag_running_state_calls(offset)

if call_count < 2:
Expand Down
19 changes: 14 additions & 5 deletions ui/pages/DataManagement/DataManagement.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import RnaSeq from './components/RnaSeq'
import SampleQc from './components/SampleQc'
import Users from './components/Users'
import PhenotypePrioritization from './components/PhenotypePrioritization'
import TRIGGER_DAG_PAGES from './components/TriggerDagPages'

const IFRAME_STYLE = { position: 'fixed', left: '0', top: '95px' }

Expand Down Expand Up @@ -45,16 +46,24 @@ const ES_DATA_MANAGEMENT_PAGES = [
...DATA_MANAGEMENT_PAGES,
]

const HAIL_SEARCH_DATA_MANAGEMENT_PAGES = [
const LOCAL_HAIL_SEARCH_DATA_MANAGEMENT_PAGES = [
...DATA_MANAGEMENT_PAGES,
{ path: 'pipeline_status', component: () => <IframePage title="Loading UI" src="/luigi_ui/static/visualiser/index.html" /> },
]

const dataManagementPages = (isDataManager, elasticsearchEnabled) => {
if (!isDataManager) {
const AIRFLOW_HAIL_SEARCH_DATA_MANAGEMENT_PAGES = [
...DATA_MANAGEMENT_PAGES,
...TRIGGER_DAG_PAGES,
]

const dataManagementPages = (user, elasticsearchEnabled) => {
if (!user.isDataManager) {
return PM_DATA_MANAGEMENT_PAGES
}
return elasticsearchEnabled ? ES_DATA_MANAGEMENT_PAGES : HAIL_SEARCH_DATA_MANAGEMENT_PAGES
if (elasticsearchEnabled) {
return ES_DATA_MANAGEMENT_PAGES
}
return user.isAnvil ? AIRFLOW_HAIL_SEARCH_DATA_MANAGEMENT_PAGES : LOCAL_HAIL_SEARCH_DATA_MANAGEMENT_PAGES
}

const DataManagement = ({ match, user, pages }) => (
Expand All @@ -78,7 +87,7 @@ export const mapStateToProps = (state) => {
const user = getUser(state)
return {
user,
pages: dataManagementPages(user.isDataManager, getElasticsearchEnabled(state)),
pages: dataManagementPages(user, getElasticsearchEnabled(state)),
}
}

Expand Down
82 changes: 82 additions & 0 deletions ui/pages/DataManagement/components/TriggerDagPages.jsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import React from 'react'
import PropTypes from 'prop-types'

import { validators } from 'shared/components/form/FormHelpers'
import { Select } from 'shared/components/form/Inputs'
import { AwesomeBarFormInput } from 'shared/components/page/AwesomeBar'
import SubmitFormPage from 'shared/components/page/SubmitFormPage'
import {
DATASET_TYPE_SNV_INDEL_CALLS,
DATASET_TYPE_SV_CALLS,
DATASET_TYPE_MITO_CALLS,
GENOME_VERSION_FIELD,
} from 'shared/utils/constants'

const DATASET_TYPE_FIELD = {
name: 'datasetType',
label: 'Dataset Type',
component: Select,
options: [
DATASET_TYPE_SNV_INDEL_CALLS, DATASET_TYPE_MITO_CALLS, DATASET_TYPE_SV_CALLS,
].map(value => ({ value, name: value })),
validate: validators.required,
}
const PROJECT_FIELDS = [
{
name: 'project',
label: 'Project',
control: AwesomeBarFormInput,
categories: ['projects'],
fluid: true,
placeholder: 'Search for a project',
validate: validators.required,
},
DATASET_TYPE_FIELD,
]
const FAMILY_FIELDS = [
{
name: 'family',
label: 'Family',
control: AwesomeBarFormInput,
categories: ['families'],
fluid: true,
placeholder: 'Search for a family',
validate: validators.required,
},
DATASET_TYPE_FIELD,
]
const REFERENCE_DATASET_FIELDS = [
{ ...GENOME_VERSION_FIELD, validate: validators.required },
DATASET_TYPE_FIELD,
]

const TriggerDagForm = ({ dagName, fields }) => (
<SubmitFormPage
header={`Trigger ${dagName} DAG`}
url={`/api/data_management/trigger_dag/${dagName}`}
fields={fields}
/>
)

TriggerDagForm.propTypes = {
dagName: PropTypes.string,
fields: PropTypes.arrayOf(PropTypes.object),
}

const TriggerDeleteProjectsDag = () => (
<TriggerDagForm dagName="DELETE_PROJECTS" fields={PROJECT_FIELDS} />
)

const TriggerDeleteFamiliesDag = () => (
<TriggerDagForm dagName="DELETE_FAMILIES" fields={FAMILY_FIELDS} />
)

const TriggerUpdateReferenceDatasetDag = () => (
<TriggerDagForm dagName="UPDATE_REFERENCE_DATASETS" fields={REFERENCE_DATASET_FIELDS} />
)

export default [
{ path: 'delete_search_projects', component: TriggerDeleteProjectsDag },
{ path: 'delete_search_families', component: TriggerDeleteFamiliesDag },
{ path: 'update_search_reference_data', component: TriggerUpdateReferenceDatasetDag },
]
Loading

0 comments on commit b26bdb4

Please sign in to comment.