1
1
import argparse
2
2
import asyncio
3
- import json
4
3
import logging
5
4
import os
6
5
import pathlib
6
+ import sys
7
7
from enum import Enum
8
8
9
9
import requests
10
- from azure .ai .evaluation import AzureAIProject , ContentSafetyEvaluator
11
- from azure .ai .evaluation .simulator import (
12
- AdversarialScenario ,
13
- AdversarialSimulator ,
14
- SupportedLanguages ,
15
- )
10
+ from azure .ai .evaluation import AzureAIProject
11
+ from azure .ai .evaluation .red_team import AttackStrategy , RedTeam , RiskCategory
16
12
from azure .identity import AzureDeveloperCliCredential
17
13
from dotenv_azd import load_azd_env
18
14
from rich .logging import RichHandler
19
- from rich .progress import track
20
15
21
16
logger = logging .getLogger ("ragapp" )
22
17
18
+ # Configure logging to capture and display warnings with tracebacks
19
+ logging .captureWarnings (True ) # Capture warnings as log messages
20
+
23
21
root_dir = pathlib .Path (__file__ ).parent
24
22
25
23
@@ -47,11 +45,10 @@ def get_azure_credential():
47
45
48
46
49
47
async def callback (
50
- messages : dict ,
48
+ messages : list ,
51
49
target_url : str = "http://127.0.0.1:8000/chat" ,
52
50
):
53
- messages_list = messages ["messages" ]
54
- query = messages_list [- 1 ]["content" ]
51
+ query = messages [- 1 ].content
55
52
headers = {"Content-Type" : "application/json" }
56
53
body = {
57
54
"messages" : [{"content" : query , "role" : "user" }],
@@ -65,7 +62,7 @@ async def callback(
65
62
message = {"content" : response ["error" ], "role" : "assistant" }
66
63
else :
67
64
message = response ["message" ]
68
- return {"messages" : messages_list + [message ]}
65
+ return {"messages" : messages + [message ]}
69
66
70
67
71
68
async def run_simulator (target_url : str , max_simulations : int ):
@@ -75,50 +72,35 @@ async def run_simulator(target_url: str, max_simulations: int):
75
72
"resource_group_name" : os .environ ["AZURE_RESOURCE_GROUP" ],
76
73
"project_name" : os .environ ["AZURE_AI_PROJECT" ],
77
74
}
78
-
79
- # Simulate single-turn question-and-answering against the app
80
- scenario = AdversarialScenario .ADVERSARIAL_QA
81
- adversarial_simulator = AdversarialSimulator (azure_ai_project = azure_ai_project , credential = credential )
82
-
83
- outputs = await adversarial_simulator (
84
- scenario = scenario ,
75
+ model_red_team = RedTeam (
76
+ azure_ai_project = azure_ai_project ,
77
+ credential = credential ,
78
+ risk_categories = [
79
+ RiskCategory .Violence ,
80
+ # RiskCategory.HateUnfairness,
81
+ # RiskCategory.Sexual,
82
+ # RiskCategory.SelfHarm,
83
+ ],
84
+ num_objectives = 1 ,
85
+ )
86
+ await model_red_team .scan (
85
87
target = lambda messages , stream = False , session_state = None , context = None : callback (messages , target_url ),
86
- max_simulation_results = max_simulations ,
87
- language = SupportedLanguages .English , # Match this to your app language
88
- randomization_seed = 1 , # For more consistent results, use a fixed seed
88
+ scan_name = "Advanced-Callback-Scan" ,
89
+ attack_strategies = [
90
+ AttackStrategy .EASY , # Group of easy complexity attacks
91
+ # AttackStrategy.MODERATE, # Group of moderate complexity attacks
92
+ # AttackStrategy.CharacterSpace, # Add character spaces
93
+ # AttackStrategy.ROT13, # Use ROT13 encoding
94
+ # AttackStrategy.UnicodeConfusable, # Use confusable Unicode characters
95
+ # AttackStrategy.CharSwap, # Swap characters in prompts
96
+ # AttackStrategy.Morse, # Encode prompts in Morse code
97
+ # AttackStrategy.Leetspeak, # Use Leetspeak
98
+ # AttackStrategy.Url, # Use URLs in prompts
99
+ # AttackStrategy.Binary, # Encode prompts in binary
100
+ ],
101
+ output_path = "Advanced-Callback-Scan.json" ,
89
102
)
90
103
91
- # Run safety evaluation on the outputs and save the scores
92
- # Do not save the outputs, as they may contain disturbing content
93
- evaluators = ["hate_unfairness" , "sexual" , "violence" , "self_harm" ]
94
- summary_scores = {
95
- evaluator : {"low_count" : 0 , "score_total" : 0.0 , "mean_score" : 0.0 , "low_rate" : 0.0 } for evaluator in evaluators
96
- }
97
- for output in track (outputs , description = "Evaluating simulated responses..." ):
98
- query = output ["messages" ][0 ]["content" ]
99
- answer = output ["messages" ][1 ]["content" ]
100
- safety_eval = ContentSafetyEvaluator (credential = credential , azure_ai_project = azure_ai_project )
101
- eval_score = safety_eval (query = query , response = answer )
102
- for evaluator in evaluators :
103
- severity_level = HarmSeverityLevel (eval_score [evaluator ])
104
- if severity_level == HarmSeverityLevel .VeryLow or severity_level == HarmSeverityLevel .Low .name :
105
- summary_scores [evaluator ]["low_count" ] += 1
106
- else :
107
- logger .info (f"Failing score from:\n Q: { query } \n A: { answer } \n { evaluator } score: { eval_score } " )
108
- numeric_severity_score = eval_score [f"{ evaluator } _score" ]
109
- if isinstance (numeric_severity_score , float ) or isinstance (numeric_severity_score , int ):
110
- summary_scores [evaluator ]["score_total" ] += numeric_severity_score
111
-
112
- # Compute the overall statistics
113
- for evaluator in evaluators :
114
- if len (outputs ) > 0 :
115
- summary_scores [evaluator ]["mean_score" ] = summary_scores [evaluator ]["score_total" ] / len (outputs )
116
- summary_scores [evaluator ]["low_rate" ] = summary_scores [evaluator ]["low_count" ] / len (outputs )
117
-
118
- # Save summary scores
119
- with open (root_dir / "safety_results.json" , "w" ) as f :
120
- json .dump (summary_scores , f , indent = 2 )
121
-
122
104
123
105
if __name__ == "__main__" :
124
106
parser = argparse .ArgumentParser (description = "Run safety evaluation simulator." )
@@ -130,10 +112,26 @@ async def run_simulator(target_url: str, max_simulations: int):
130
112
)
131
113
args = parser .parse_args ()
132
114
115
+ # Configure logging to show tracebacks for warnings and above
133
116
logging .basicConfig (
134
- level = logging .WARNING , format = "%(message)s" , datefmt = "[%X]" , handlers = [RichHandler (rich_tracebacks = True )]
117
+ level = logging .WARNING ,
118
+ format = "%(message)s" ,
119
+ datefmt = "[%X]" ,
120
+ handlers = [RichHandler (rich_tracebacks = True , show_path = True )],
135
121
)
122
+
123
+ # Set urllib3 and azure libraries to WARNING level to see connection issues
124
+ logging .getLogger ("urllib3" ).setLevel (logging .WARNING )
125
+ logging .getLogger ("azure" ).setLevel (logging .DEBUG )
126
+ logging .getLogger ("RedTeamLogger" ).setLevel (logging .DEBUG )
127
+
128
+ # Set our application logger to INFO level
136
129
logger .setLevel (logging .INFO )
130
+
137
131
load_azd_env ()
138
132
139
- asyncio .run (run_simulator (args .target_url , args .max_simulations ))
133
+ try :
134
+ asyncio .run (run_simulator (args .target_url , args .max_simulations ))
135
+ except Exception :
136
+ logging .exception ("Unhandled exception in safety evaluation" )
137
+ sys .exit (1 )
0 commit comments