From 1177382c0bbbd682fd63d32d8e3e36138f4221e3 Mon Sep 17 00:00:00 2001
From: Antoine Popineau <antoine.popineau@checkmarble.com>
Date: Fri, 17 Jan 2025 11:35:13 +0100
Subject: [PATCH] Set up plumbing for org-level configuration.

---
 models/organization.go                        | 17 ++++++++++
 repositories/dbmodels/db_organization.go      |  2 ++
 repositories/opensanctions_repository.go      | 34 +++++++++++++++----
 .../evaluate_scenario/evaluate_scenario.go    |  6 ++--
 usecases/sanction_check_usecase.go            | 22 +++++++++---
 usecases/usecases_with_creds.go               |  6 ++--
 6 files changed, 71 insertions(+), 16 deletions(-)

diff --git a/models/organization.go b/models/organization.go
index f010f4ca9..274420f8c 100644
--- a/models/organization.go
+++ b/models/organization.go
@@ -20,6 +20,23 @@ type Organization struct {
 	// to a separate DB.
 	// TODO: clean this up when it's no longuer used.
 	UseMarbleDbSchemaAsDefault bool
+
+	OpenSanctionsConfig OrganizationOpenSanctionsConfig
+}
+
+// TODO: Add other organization-level configuration options
+type OrganizationOpenSanctionsConfig struct {
+	Datasets       []string
+	MatchThreshold int
+	MatchLimit     int
+}
+
+func DefaultOrganizationOpenSanctionsConfig() OrganizationOpenSanctionsConfig {
+	return OrganizationOpenSanctionsConfig{
+		Datasets:       []string{},
+		MatchThreshold: 70,
+		MatchLimit:     20,
+	}
 }
 
 type CreateOrganizationInput struct {
diff --git a/repositories/dbmodels/db_organization.go b/repositories/dbmodels/db_organization.go
index 54446e533..fa2508626 100644
--- a/repositories/dbmodels/db_organization.go
+++ b/repositories/dbmodels/db_organization.go
@@ -25,5 +25,7 @@ func AdaptOrganization(db DBOrganizationResult) (models.Organization, error) {
 		TransferCheckScenarioId:    db.TransferCheckScenarioId,
 		UseMarbleDbSchemaAsDefault: db.UseMarbleDbSchemaAsDefault,
 		DefaultScenarioTimezone:    db.DefaultScenarioTimezone,
+		// TODO: Actually get it from the database
+		OpenSanctionsConfig: models.DefaultOrganizationOpenSanctionsConfig(),
 	}, nil
 }
diff --git a/repositories/opensanctions_repository.go b/repositories/opensanctions_repository.go
index e8698da8d..46e22e944 100644
--- a/repositories/opensanctions_repository.go
+++ b/repositories/opensanctions_repository.go
@@ -29,10 +29,11 @@ type openSanctionsRequestQuery struct {
 	Properties models.OpenSanctionCheckFilter `json:"properties"`
 }
 
-func (repo OpenSanctionsRepository) Search(ctx context.Context, cfg models.SanctionCheckConfig,
+func (repo OpenSanctionsRepository) Search(ctx context.Context,
+	orgCfg models.OrganizationOpenSanctionsConfig, cfg models.SanctionCheckConfig,
 	query models.OpenSanctionsQuery,
 ) (models.SanctionCheckResult, error) {
-	req, err := repo.searchRequest(ctx, query)
+	req, err := repo.searchRequest(ctx, orgCfg, query)
 	if err != nil {
 		return models.SanctionCheckResult{}, err
 	}
@@ -61,7 +62,9 @@ func (repo OpenSanctionsRepository) Search(ctx context.Context, cfg models.Sanct
 	return httpmodels.AdaptOpenSanctionsResult(matches)
 }
 
-func (repo OpenSanctionsRepository) searchRequest(ctx context.Context, query models.OpenSanctionsQuery) (*http.Request, error) {
+func (repo OpenSanctionsRepository) searchRequest(ctx context.Context,
+	orgCfg models.OrganizationOpenSanctionsConfig, query models.OpenSanctionsQuery,
+) (*http.Request, error) {
 	q := openSanctionsRequest{
 		Queries: make(map[string]openSanctionsRequestQuery, len(query.Queries)),
 	}
@@ -81,10 +84,7 @@ func (repo OpenSanctionsRepository) searchRequest(ctx context.Context, query mod
 
 	requestUrl := fmt.Sprintf("%s/match/sanctions", repo.opensanctions.Host())
 
-	if len(repo.opensanctions.ApiKey()) > 0 {
-		qs := url.Values{}
-		qs.Set("api_key", repo.opensanctions.ApiKey())
-
+	if qs := repo.buildQueryString(orgCfg); len(qs) > 0 {
 		requestUrl = fmt.Sprintf("%s?%s", requestUrl, qs.Encode())
 	}
 
@@ -92,3 +92,23 @@ func (repo OpenSanctionsRepository) searchRequest(ctx context.Context, query mod
 
 	return req, err
 }
+
+func (repo OpenSanctionsRepository) buildQueryString(orgCfg models.OrganizationOpenSanctionsConfig) url.Values {
+	qs := url.Values{}
+
+	if len(repo.opensanctions.ApiKey()) > 0 {
+		qs.Set("api_key", repo.opensanctions.ApiKey())
+	}
+
+	if len(orgCfg.Datasets) > 0 {
+		qs["include_dataset"] = orgCfg.Datasets
+	}
+	if orgCfg.MatchLimit > 0 {
+		qs.Set("limit", fmt.Sprintf("%d", orgCfg.MatchLimit))
+	}
+	if orgCfg.MatchThreshold > 0 {
+		qs.Set("threshold", fmt.Sprintf("%.1f", float64(orgCfg.MatchThreshold)/100))
+	}
+
+	return qs
+}
diff --git a/usecases/evaluate_scenario/evaluate_scenario.go b/usecases/evaluate_scenario/evaluate_scenario.go
index a91a4097c..8be7b5156 100644
--- a/usecases/evaluate_scenario/evaluate_scenario.go
+++ b/usecases/evaluate_scenario/evaluate_scenario.go
@@ -33,7 +33,8 @@ type ScenarioEvaluationParameters struct {
 }
 
 type EvalSanctionCheckUsecase interface {
-	Execute(context.Context, models.SanctionCheckConfig, models.OpenSanctionsQuery) (models.SanctionCheckResult, error)
+	Execute(context.Context, string, models.SanctionCheckConfig,
+		models.OpenSanctionsQuery) (models.SanctionCheckResult, error)
 }
 
 type SnoozesForDecisionReader interface {
@@ -133,7 +134,8 @@ func processScenarioIteration(ctx context.Context, params ScenarioEvaluationPara
 			"name": []string{"obama"},
 		}}
 
-		result, err := repositories.EvalSanctionCheckUsecase.Execute(ctx, *iteration.SanctionCheckConfig, query)
+		result, err := repositories.EvalSanctionCheckUsecase.Execute(ctx,
+			params.Scenario.OrganizationId, *iteration.SanctionCheckConfig, query)
 		if err != nil {
 			return models.ScenarioExecution{}, errors.Wrap(err, "could not perform sanction check")
 		}
diff --git a/usecases/sanction_check_usecase.go b/usecases/sanction_check_usecase.go
index a79a4412a..037672169 100644
--- a/usecases/sanction_check_usecase.go
+++ b/usecases/sanction_check_usecase.go
@@ -4,10 +4,14 @@ import (
 	"context"
 
 	"github.com/checkmarble/marble-backend/models"
+	"github.com/checkmarble/marble-backend/repositories"
+	"github.com/checkmarble/marble-backend/usecases/executor_factory"
+	"github.com/pkg/errors"
 )
 
 type SanctionCheckProvider interface {
-	Search(context.Context, models.SanctionCheckConfig, models.OpenSanctionsQuery) (models.SanctionCheckResult, error)
+	Search(context.Context, models.OrganizationOpenSanctionsConfig, models.SanctionCheckConfig,
+		models.OpenSanctionsQuery) (models.SanctionCheckResult, error)
 }
 
 type SanctionCheckRepository interface {
@@ -15,14 +19,22 @@ type SanctionCheckRepository interface {
 }
 
 type SanctionCheckUsecase struct {
-	openSanctionsProvider SanctionCheckProvider
-	repository            SanctionCheckRepository
+	organizationRepository repositories.OrganizationRepository
+	openSanctionsProvider  SanctionCheckProvider
+	repository             SanctionCheckRepository
+	executorFactory        executor_factory.ExecutorFactory
 }
 
-func (uc SanctionCheckUsecase) Execute(ctx context.Context, cfg models.SanctionCheckConfig,
+func (uc SanctionCheckUsecase) Execute(ctx context.Context, orgId string, cfg models.SanctionCheckConfig,
 	query models.OpenSanctionsQuery,
 ) (models.SanctionCheckResult, error) {
-	matches, err := uc.openSanctionsProvider.Search(ctx, cfg, query)
+	org, err := uc.organizationRepository.GetOrganizationById(ctx,
+		uc.executorFactory.NewExecutor(), orgId)
+	if err != nil {
+		return models.SanctionCheckResult{}, errors.Wrap(err, "could not retrieve organization")
+	}
+
+	matches, err := uc.openSanctionsProvider.Search(ctx, org.OpenSanctionsConfig, cfg, query)
 	if err != nil {
 		return models.SanctionCheckResult{}, err
 	}
diff --git a/usecases/usecases_with_creds.go b/usecases/usecases_with_creds.go
index 46515f5e9..11bfc1bad 100644
--- a/usecases/usecases_with_creds.go
+++ b/usecases/usecases_with_creds.go
@@ -119,8 +119,10 @@ func (usecases *UsecasesWithCreds) NewDecisionUsecase() DecisionUsecase {
 
 func (usecases *UsecasesWithCreds) NewSanctionCheckUsecase() SanctionCheckUsecase {
 	return SanctionCheckUsecase{
-		openSanctionsProvider: usecases.Repositories.OpenSanctionsRepository,
-		repository:            &usecases.Repositories.MarbleDbRepository,
+		organizationRepository: usecases.Repositories.OrganizationRepository,
+		openSanctionsProvider:  usecases.Repositories.OpenSanctionsRepository,
+		repository:             &usecases.Repositories.MarbleDbRepository,
+		executorFactory:        usecases.NewExecutorFactory(),
 	}
 }