2525from api .pcm_globals import set_auth_cookies_in_context , logger , auth_cookies
2626from api .security .csrf .constants import CSRF_COOKIE_NAME
2727from api .security .csrf .csrf import csrf_needed
28- from api .utils import disable_auth
28+ from api .utils import disable_auth , read_and_delete_ssm_output_from_cloudwatch
2929from api .validation import validated
3030from api .validation .schemas import PCProxyArgs , PCProxyBody
3131
3232USER_POOL_ID = os .getenv ("USER_POOL_ID" )
3333AUTH_PATH = os .getenv ("AUTH_PATH" )
3434API_BASE_URL = os .getenv ("API_BASE_URL" )
35- API_VERSION = os .getenv ("API_VERSION" , "3.1.0" )
35+ API_VERSION = sorted (set (os .getenv ("API_VERSION" , "3.1.0" ).strip ().split ("," )), key = lambda x : [- int (n ) for n in x .split ('.' )])
36+ # Default version must be highest version so that it can be used for read operations due to backwards compatibility
37+ DEFAULT_API_VERSION = API_VERSION [0 ]
3638API_USER_ROLE = os .getenv ("API_USER_ROLE" )
3739OIDC_PROVIDER = os .getenv ("OIDC_PROVIDER" )
3840CLIENT_ID = os .getenv ("CLIENT_ID" )
3941CLIENT_SECRET = os .getenv ("CLIENT_SECRET" )
4042SECRET_ID = os .getenv ("SECRET_ID" )
41- SITE_URL = os .getenv ("SITE_URL" , API_BASE_URL )
4243SCOPES_LIST = os .getenv ("SCOPES_LIST" )
4344REGION = os .getenv ("AWS_DEFAULT_REGION" )
4445TOKEN_URL = os .getenv ("TOKEN_URL" , f"{ AUTH_PATH } /oauth2/token" )
4748JWKS_URL = os .getenv ("JWKS_URL" )
4849AUDIENCE = os .getenv ("AUDIENCE" )
4950USER_ROLES_CLAIM = os .getenv ("USER_ROLES_CLAIM" , "cognito:groups" )
51+ SSM_LOG_GROUP_NAME = os .getenv ("SSM_LOG_GROUP_NAME" )
52+ ARG_VERSION = "version"
5053
5154try :
5255 if (not USER_POOL_ID or USER_POOL_ID == "" ) and SECRET_ID :
6265 JWKS_URL = os .getenv ("JWKS_URL" ,
6366 f"https://cognito-idp.{ REGION } .amazonaws.com/{ USER_POOL_ID } /" ".well-known/jwks.json" )
6467
68+ def create_url_map (url_list ):
69+ url_map = {}
70+ if url_list :
71+ for url in url_list .split ("," ):
72+ if url :
73+ pair = url .split ("=" )
74+ url_map [pair [0 ]] = pair [1 ]
75+ return url_map
76+
77+ API_BASE_URL_MAPPING = create_url_map (API_BASE_URL )
78+ SITE_URL = os .getenv ("SITE_URL" , API_BASE_URL_MAPPING .get (DEFAULT_API_VERSION ))
79+
80+
6581
6682def jwt_decode (token , audience = None , access_token = None ):
6783 return jwt .decode (
@@ -164,7 +180,7 @@ def authenticate(groups):
164180
165181 if (not groups ):
166182 return abort (403 )
167-
183+
168184 jwt_roles = set (decoded .get (USER_ROLES_CLAIM , []))
169185 groups_granted = groups .intersection (jwt_roles )
170186 if len (groups_granted ) == 0 :
@@ -190,7 +206,7 @@ def get_scopes_list():
190206
191207def get_redirect_uri ():
192208 return f"{ SITE_URL } /login"
193-
209+
194210# Local Endpoints
195211
196212
@@ -232,9 +248,9 @@ def ec2_action():
232248def get_cluster_config_text (cluster_name , region = None ):
233249 url = f"/v3/clusters/{ cluster_name } "
234250 if region :
235- info_resp = sigv4_request ("GET" , API_BASE_URL , url , params = {"region" : region })
251+ info_resp = sigv4_request ("GET" , get_base_url ( request ) , url , params = {"region" : region })
236252 else :
237- info_resp = sigv4_request ("GET" , API_BASE_URL , url )
253+ info_resp = sigv4_request ("GET" , get_base_url ( request ) , url )
238254 if info_resp .status_code != 200 :
239255 abort (info_resp .status_code )
240256
@@ -264,10 +280,16 @@ def ssm_command(region, instance_id, user, run_command):
264280 DocumentName = "AWS-RunShellScript" ,
265281 Comment = f"Run ssm command." ,
266282 Parameters = {"commands" : [command ]},
283+ CloudWatchOutputConfig = {
284+ 'CloudWatchLogGroupName' : SSM_LOG_GROUP_NAME ,
285+ 'CloudWatchOutputEnabled' : True
286+ },
267287 )
268288
269289 command_id = ssm_resp ["Command" ]["CommandId" ]
270290
291+ logger .info (f"Submitted SSM command { command_id } " )
292+
271293 # Wait for command to complete
272294 time .sleep (0.75 )
273295 while time .time () - start < 60 :
@@ -282,7 +304,13 @@ def ssm_command(region, instance_id, user, run_command):
282304 if status ["Status" ] != "Success" :
283305 raise Exception (status ["StandardErrorContent" ])
284306
285- output = status ["StandardOutputContent" ]
307+ output = read_and_delete_ssm_output_from_cloudwatch (
308+ region = region ,
309+ log_group_name = SSM_LOG_GROUP_NAME ,
310+ command_id = command_id ,
311+ instance_id = instance_id ,
312+ )
313+
286314 return output
287315
288316
@@ -352,7 +380,7 @@ def sacct():
352380 user ,
353381 f"sacct { sacct_args } --json "
354382 + "| jq -c .jobs[0:120]\\ |\\ map\\ ({name,user,partition,state,job_id,exit_code\\ }\\ )" ,
355- )
383+ )
356384 if type (accounting ) is tuple :
357385 return accounting
358386 else :
@@ -471,7 +499,7 @@ def get_dcv_session():
471499
472500
473501def get_custom_image_config ():
474- image_info = sigv4_request ("GET" , API_BASE_URL , f"/v3/images/custom/{ request .args .get ('image_id' )} " ).json ()
502+ image_info = sigv4_request ("GET" , get_base_url ( request ) , f"/v3/images/custom/{ request .args .get ('image_id' )} " ).json ()
475503 configuration = requests .get (image_info ["imageConfiguration" ]["url" ])
476504 return configuration .text
477505
@@ -553,13 +581,7 @@ def get_instance_types():
553581 ec2 = boto3 .client ("ec2" , config = config )
554582 else :
555583 ec2 = boto3 .client ("ec2" )
556- filters = [
557- {"Name" : "current-generation" , "Values" : ["true" ]},
558- {"Name" : "instance-type" ,
559- "Values" : [
560- "c5*" , "c6*" , "c7*" , "g4*" , "g5*" , "g6*" , "hpc*" , "p3*" , "p4*" , "p5*" , "t2*" , "t3*" , "m6*" , "m7*" , "r*"
561- ]},
562- ]
584+ filters = [{"Name" : "current-generation" , "Values" : ["true" ]}]
563585 instance_paginator = ec2 .get_paginator ("describe_instance_types" )
564586 instances_paginator = instance_paginator .paginate (Filters = filters )
565587 instance_types = []
@@ -583,9 +605,9 @@ def _get_identity_from_token(decoded, claims):
583605 identity ["username" ] = decoded ["username" ]
584606
585607 for claim in claims :
586- if claim in decoded :
587- identity ["attributes" ][claim ] = decoded [claim ]
588-
608+ if claim in decoded :
609+ identity ["attributes" ][claim ] = decoded [claim ]
610+
589611 return identity
590612
591613def get_identity ():
@@ -722,14 +744,20 @@ def _get_params(_request):
722744 params .pop ("path" )
723745 return params
724746
747+ def get_base_url (request ):
748+ version = request .args .get (ARG_VERSION )
749+ if version and str (version ) in API_VERSION :
750+ return API_BASE_URL_MAPPING [str (version )]
751+ return API_BASE_URL_MAPPING [DEFAULT_API_VERSION ]
752+
725753
726754pc = Blueprint ('pc' , __name__ )
727755
728756@pc .get ('/' , strict_slashes = False )
729757@authenticated ({'admin' })
730758@validated (params = PCProxyArgs )
731759def pc_proxy_get ():
732- response = sigv4_request (request .method , API_BASE_URL , request .args .get ("path" ), _get_params (request ))
760+ response = sigv4_request (request .method , get_base_url ( request ) , request .args .get ("path" ), _get_params (request ))
733761 return response .json (), response .status_code
734762
735763@pc .route ('/' , methods = ['POST' ,'PUT' ,'PATCH' ,'DELETE' ], strict_slashes = False )
@@ -743,5 +771,5 @@ def pc_proxy():
743771 except :
744772 pass
745773
746- response = sigv4_request (request .method , API_BASE_URL , request .args .get ("path" ), _get_params (request ), body = body )
774+ response = sigv4_request (request .method , get_base_url ( request ) , request .args .get ("path" ), _get_params (request ), body = body )
747775 return response .json (), response .status_code
0 commit comments