Skip to content

Commit e24bd02

Browse files
artagelbnfinet
authored andcommitted
Add support for grabbing claims from the return of getuserinfo functions for all providers.
1 parent aa2a4a2 commit e24bd02

File tree

1 file changed

+60
-11
lines changed

1 file changed

+60
-11
lines changed

handlers/handlers.go

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ const (
4444
base64Bytes = 32
4545
)
4646

47+
// Temporary struct storing custom claims until JWT creation.
48+
var CustomClaims map[string]interface{}
49+
4750
var (
4851
// Templates
4952
indexTemplate = template.Must(template.ParseFiles("./templates/index.tmpl"))
@@ -398,7 +401,9 @@ func CallbackHandler(w http.ResponseWriter, r *http.Request) {
398401
}
399402

400403
user := structs.User{}
401-
if err := getUserInfo(r, &user); err != nil {
404+
customClaims := CustomClaims
405+
406+
if err := getUserInfo(r, &user, customClaims); err != nil {
402407
log.Error(err)
403408
http.Error(w, err.Error(), http.StatusBadRequest)
404409
return
@@ -438,7 +443,7 @@ func CallbackHandler(w http.ResponseWriter, r *http.Request) {
438443

439444
// TODO: put all getUserInfo logic into its own pkg
440445

441-
func getUserInfo(r *http.Request, user *structs.User) error {
446+
func getUserInfo(r *http.Request, user *structs.User, customClaims map[string]interface{}) error {
442447

443448
// indieauth sends the "me" setting in json back to the callback, so just pluck it from the callback
444449
if cfg.GenOAuth.Provider == cfg.Providers.IndieAuth {
@@ -459,20 +464,24 @@ func getUserInfo(r *http.Request, user *structs.User) error {
459464
} else if cfg.GenOAuth.Provider == cfg.Providers.GitHub {
460465
return getUserInfoFromGitHub(client, user, providerToken)
461466
} else if cfg.GenOAuth.Provider == cfg.Providers.OIDC {
462-
return getUserInfoFromOpenID(client, user, providerToken)
467+
return getUserInfoFromOpenID(client, user, customClaims, providerToken)
463468
}
464469
log.Error("we don't know how to look up the user info")
465470
return nil
466471
}
467472

468-
func getUserInfoFromOpenID(client *http.Client, user *structs.User, ptoken *oauth2.Token) error {
473+
func getUserInfoFromOpenID(client *http.Client, user *structs.User, customClaims map[string]interface{}, ptoken *oauth2.Token) error {
469474
userinfo, err := client.Get(cfg.GenOAuth.UserInfoURL)
470475
if err != nil {
471476
return err
472477
}
473478
defer userinfo.Body.Close()
474479
data, _ := ioutil.ReadAll(userinfo.Body)
475480
log.Infof("OpenID userinfo body: ", string(data))
481+
if err = mapClaims(data, customClaims); err != nil {
482+
log.Error(err)
483+
return err
484+
}
476485
if err = json.Unmarshal(data, user); err != nil {
477486
log.Error(err)
478487
return err
@@ -481,14 +490,18 @@ func getUserInfoFromOpenID(client *http.Client, user *structs.User, ptoken *oaut
481490
return nil
482491
}
483492

484-
func getUserInfoFromGoogle(client *http.Client, user *structs.User) error {
493+
func getUserInfoFromGoogle(client *http.Client, user *structs.User, customClaims map[string]interface{}) error {
485494
userinfo, err := client.Get(cfg.GenOAuth.UserInfoURL)
486495
if err != nil {
487496
return err
488497
}
489498
defer userinfo.Body.Close()
490499
data, _ := ioutil.ReadAll(userinfo.Body)
491500
log.Infof("google userinfo body: ", string(data))
501+
if err = mapClaims(data, customClaims); err != nil {
502+
log.Error(err)
503+
return err
504+
}
492505
if err = json.Unmarshal(data, user); err != nil {
493506
log.Error(err)
494507
return err
@@ -500,7 +513,7 @@ func getUserInfoFromGoogle(client *http.Client, user *structs.User) error {
500513

501514
// github
502515
// https://developer.github.com/apps/building-integrations/setting-up-and-registering-oauth-apps/about-authorization-options-for-oauth-apps/
503-
func getUserInfoFromGitHub(client *http.Client, user *structs.User, ptoken *oauth2.Token) error {
516+
func getUserInfoFromGitHub(client *http.Client, user *structs.User, ptoken *oauth2.Token, customClaims map[string]interface{}) error {
504517

505518
log.Errorf("ptoken.AccessToken: %s", ptoken.AccessToken)
506519
userinfo, err := client.Get(cfg.GenOAuth.UserInfoURL + ptoken.AccessToken)
@@ -511,6 +524,10 @@ func getUserInfoFromGitHub(client *http.Client, user *structs.User, ptoken *oaut
511524
defer userinfo.Body.Close()
512525
data, _ := ioutil.ReadAll(userinfo.Body)
513526
log.Infof("github userinfo body: ", string(data))
527+
if err = mapClaims(data, customClaims); err != nil {
528+
log.Error(err)
529+
return err
530+
}
514531
ghUser := structs.GitHubUser{}
515532
if err = json.Unmarshal(data, &ghUser); err != nil {
516533
log.Error(err)
@@ -533,7 +550,7 @@ func getUserInfoFromGitHub(client *http.Client, user *structs.User, ptoken *oaut
533550
return nil
534551
}
535552

536-
func getUserInfoFromIndieAuth(r *http.Request, user *structs.User) error {
553+
func getUserInfoFromIndieAuth(r *http.Request, user *structs.User, customClaims map[string]interface{}) error {
537554

538555
code := r.URL.Query().Get("code")
539556
log.Errorf("ptoken.AccessToken: %s", code)
@@ -579,6 +596,10 @@ func getUserInfoFromIndieAuth(r *http.Request, user *structs.User) error {
579596
defer userinfo.Body.Close()
580597
data, _ := ioutil.ReadAll(userinfo.Body)
581598
log.Infof("indieauth userinfo body: ", string(data))
599+
if err = mapClaims(data, customClaims); err != nil {
600+
log.Error(err)
601+
return err
602+
}
582603
iaUser := structs.IndieAuthUser{}
583604
if err = json.Unmarshal(data, &iaUser); err != nil {
584605
log.Error(err)
@@ -598,7 +619,7 @@ type adfsTokenRes struct {
598619
}
599620

600621
// More info: https://docs.microsoft.com/en-us/windows-server/identity/ad-fs/overview/ad-fs-scenarios-for-developers#supported-scenarios
601-
func getUserInfoFromADFS(r *http.Request, user *structs.User) error {
622+
func getUserInfoFromADFS(r *http.Request, user *structs.User, customClaims map[string]interface{}) error {
602623
code := r.URL.Query().Get("code")
603624
log.Errorf("code: %s", code)
604625

@@ -648,11 +669,15 @@ func getUserInfoFromADFS(r *http.Request, user *structs.User) error {
648669

649670
adfsUser := structs.ADFSUser{}
650671
json.Unmarshal([]byte(idToken), &adfsUser)
651-
log.Infof("adfs adfsUser: ", adfsUser)
652-
672+
log.Infof("adfs adfsUser: %+v", adfsUser)
673+
if err = mapClaims([]byte(idToken), user); err != nil {
674+
log.Error(err)
675+
return err
676+
}
653677
adfsUser.PrepareUserData()
654678
user.Username = adfsUser.Username
655-
log.Debug(user)
679+
user.Email = adfsUser.Email
680+
log.Debugf("User Obj: %+v", user)
656681
return nil
657682
}
658683

@@ -686,3 +711,27 @@ func ok200(w http.ResponseWriter, r *http.Request) {
686711
log.Error(err)
687712
}
688713
}
714+
715+
func mapClaims(claims []byte, customClaims map[string]interface{}) error {
716+
// Create a struct that contains the claims that we want to store from the config.
717+
var f interface{}
718+
err := json.Unmarshal(claims, &f)
719+
if err != nil {
720+
log.Error("Error unmarshaling claims")
721+
return err
722+
}
723+
m := f.(map[string]interface{})
724+
for k, _ := range m {
725+
var found = false
726+
for _, e := range cfg.Cfg.Headers.Claims {
727+
if k == e {
728+
found = true
729+
}
730+
}
731+
if found == false {
732+
delete(m, k)
733+
}
734+
}
735+
customClaims = m
736+
return nil
737+
}

0 commit comments

Comments
 (0)