From 1177382c0bbbd682fd63d32d8e3e36138f4221e3 Mon Sep 17 00:00:00 2001 From: Antoine Popineau <antoine.popineau@checkmarble.com> Date: Fri, 17 Jan 2025 11:35:13 +0100 Subject: [PATCH] Set up plumbing for org-level configuration. --- models/organization.go | 17 ++++++++++ repositories/dbmodels/db_organization.go | 2 ++ repositories/opensanctions_repository.go | 34 +++++++++++++++---- .../evaluate_scenario/evaluate_scenario.go | 6 ++-- usecases/sanction_check_usecase.go | 22 +++++++++--- usecases/usecases_with_creds.go | 6 ++-- 6 files changed, 71 insertions(+), 16 deletions(-) diff --git a/models/organization.go b/models/organization.go index f010f4ca9..274420f8c 100644 --- a/models/organization.go +++ b/models/organization.go @@ -20,6 +20,23 @@ type Organization struct { // to a separate DB. // TODO: clean this up when it's no longuer used. UseMarbleDbSchemaAsDefault bool + + OpenSanctionsConfig OrganizationOpenSanctionsConfig +} + +// TODO: Add other organization-level configuration options +type OrganizationOpenSanctionsConfig struct { + Datasets []string + MatchThreshold int + MatchLimit int +} + +func DefaultOrganizationOpenSanctionsConfig() OrganizationOpenSanctionsConfig { + return OrganizationOpenSanctionsConfig{ + Datasets: []string{}, + MatchThreshold: 70, + MatchLimit: 20, + } } type CreateOrganizationInput struct { diff --git a/repositories/dbmodels/db_organization.go b/repositories/dbmodels/db_organization.go index 54446e533..fa2508626 100644 --- a/repositories/dbmodels/db_organization.go +++ b/repositories/dbmodels/db_organization.go @@ -25,5 +25,7 @@ func AdaptOrganization(db DBOrganizationResult) (models.Organization, error) { TransferCheckScenarioId: db.TransferCheckScenarioId, UseMarbleDbSchemaAsDefault: db.UseMarbleDbSchemaAsDefault, DefaultScenarioTimezone: db.DefaultScenarioTimezone, + // TODO: Actually get it from the database + OpenSanctionsConfig: models.DefaultOrganizationOpenSanctionsConfig(), }, nil } diff --git a/repositories/opensanctions_repository.go b/repositories/opensanctions_repository.go index e8698da8d..46e22e944 100644 --- a/repositories/opensanctions_repository.go +++ b/repositories/opensanctions_repository.go @@ -29,10 +29,11 @@ type openSanctionsRequestQuery struct { Properties models.OpenSanctionCheckFilter `json:"properties"` } -func (repo OpenSanctionsRepository) Search(ctx context.Context, cfg models.SanctionCheckConfig, +func (repo OpenSanctionsRepository) Search(ctx context.Context, + orgCfg models.OrganizationOpenSanctionsConfig, cfg models.SanctionCheckConfig, query models.OpenSanctionsQuery, ) (models.SanctionCheckResult, error) { - req, err := repo.searchRequest(ctx, query) + req, err := repo.searchRequest(ctx, orgCfg, query) if err != nil { return models.SanctionCheckResult{}, err } @@ -61,7 +62,9 @@ func (repo OpenSanctionsRepository) Search(ctx context.Context, cfg models.Sanct return httpmodels.AdaptOpenSanctionsResult(matches) } -func (repo OpenSanctionsRepository) searchRequest(ctx context.Context, query models.OpenSanctionsQuery) (*http.Request, error) { +func (repo OpenSanctionsRepository) searchRequest(ctx context.Context, + orgCfg models.OrganizationOpenSanctionsConfig, query models.OpenSanctionsQuery, +) (*http.Request, error) { q := openSanctionsRequest{ Queries: make(map[string]openSanctionsRequestQuery, len(query.Queries)), } @@ -81,10 +84,7 @@ func (repo OpenSanctionsRepository) searchRequest(ctx context.Context, query mod requestUrl := fmt.Sprintf("%s/match/sanctions", repo.opensanctions.Host()) - if len(repo.opensanctions.ApiKey()) > 0 { - qs := url.Values{} - qs.Set("api_key", repo.opensanctions.ApiKey()) - + if qs := repo.buildQueryString(orgCfg); len(qs) > 0 { requestUrl = fmt.Sprintf("%s?%s", requestUrl, qs.Encode()) } @@ -92,3 +92,23 @@ func (repo OpenSanctionsRepository) searchRequest(ctx context.Context, query mod return req, err } + +func (repo OpenSanctionsRepository) buildQueryString(orgCfg models.OrganizationOpenSanctionsConfig) url.Values { + qs := url.Values{} + + if len(repo.opensanctions.ApiKey()) > 0 { + qs.Set("api_key", repo.opensanctions.ApiKey()) + } + + if len(orgCfg.Datasets) > 0 { + qs["include_dataset"] = orgCfg.Datasets + } + if orgCfg.MatchLimit > 0 { + qs.Set("limit", fmt.Sprintf("%d", orgCfg.MatchLimit)) + } + if orgCfg.MatchThreshold > 0 { + qs.Set("threshold", fmt.Sprintf("%.1f", float64(orgCfg.MatchThreshold)/100)) + } + + return qs +} diff --git a/usecases/evaluate_scenario/evaluate_scenario.go b/usecases/evaluate_scenario/evaluate_scenario.go index a91a4097c..8be7b5156 100644 --- a/usecases/evaluate_scenario/evaluate_scenario.go +++ b/usecases/evaluate_scenario/evaluate_scenario.go @@ -33,7 +33,8 @@ type ScenarioEvaluationParameters struct { } type EvalSanctionCheckUsecase interface { - Execute(context.Context, models.SanctionCheckConfig, models.OpenSanctionsQuery) (models.SanctionCheckResult, error) + Execute(context.Context, string, models.SanctionCheckConfig, + models.OpenSanctionsQuery) (models.SanctionCheckResult, error) } type SnoozesForDecisionReader interface { @@ -133,7 +134,8 @@ func processScenarioIteration(ctx context.Context, params ScenarioEvaluationPara "name": []string{"obama"}, }} - result, err := repositories.EvalSanctionCheckUsecase.Execute(ctx, *iteration.SanctionCheckConfig, query) + result, err := repositories.EvalSanctionCheckUsecase.Execute(ctx, + params.Scenario.OrganizationId, *iteration.SanctionCheckConfig, query) if err != nil { return models.ScenarioExecution{}, errors.Wrap(err, "could not perform sanction check") } diff --git a/usecases/sanction_check_usecase.go b/usecases/sanction_check_usecase.go index a79a4412a..037672169 100644 --- a/usecases/sanction_check_usecase.go +++ b/usecases/sanction_check_usecase.go @@ -4,10 +4,14 @@ import ( "context" "github.com/checkmarble/marble-backend/models" + "github.com/checkmarble/marble-backend/repositories" + "github.com/checkmarble/marble-backend/usecases/executor_factory" + "github.com/pkg/errors" ) type SanctionCheckProvider interface { - Search(context.Context, models.SanctionCheckConfig, models.OpenSanctionsQuery) (models.SanctionCheckResult, error) + Search(context.Context, models.OrganizationOpenSanctionsConfig, models.SanctionCheckConfig, + models.OpenSanctionsQuery) (models.SanctionCheckResult, error) } type SanctionCheckRepository interface { @@ -15,14 +19,22 @@ type SanctionCheckRepository interface { } type SanctionCheckUsecase struct { - openSanctionsProvider SanctionCheckProvider - repository SanctionCheckRepository + organizationRepository repositories.OrganizationRepository + openSanctionsProvider SanctionCheckProvider + repository SanctionCheckRepository + executorFactory executor_factory.ExecutorFactory } -func (uc SanctionCheckUsecase) Execute(ctx context.Context, cfg models.SanctionCheckConfig, +func (uc SanctionCheckUsecase) Execute(ctx context.Context, orgId string, cfg models.SanctionCheckConfig, query models.OpenSanctionsQuery, ) (models.SanctionCheckResult, error) { - matches, err := uc.openSanctionsProvider.Search(ctx, cfg, query) + org, err := uc.organizationRepository.GetOrganizationById(ctx, + uc.executorFactory.NewExecutor(), orgId) + if err != nil { + return models.SanctionCheckResult{}, errors.Wrap(err, "could not retrieve organization") + } + + matches, err := uc.openSanctionsProvider.Search(ctx, org.OpenSanctionsConfig, cfg, query) if err != nil { return models.SanctionCheckResult{}, err } diff --git a/usecases/usecases_with_creds.go b/usecases/usecases_with_creds.go index 46515f5e9..11bfc1bad 100644 --- a/usecases/usecases_with_creds.go +++ b/usecases/usecases_with_creds.go @@ -119,8 +119,10 @@ func (usecases *UsecasesWithCreds) NewDecisionUsecase() DecisionUsecase { func (usecases *UsecasesWithCreds) NewSanctionCheckUsecase() SanctionCheckUsecase { return SanctionCheckUsecase{ - openSanctionsProvider: usecases.Repositories.OpenSanctionsRepository, - repository: &usecases.Repositories.MarbleDbRepository, + organizationRepository: usecases.Repositories.OrganizationRepository, + openSanctionsProvider: usecases.Repositories.OpenSanctionsRepository, + repository: &usecases.Repositories.MarbleDbRepository, + executorFactory: usecases.NewExecutorFactory(), } }