Skip to content

Commit

Permalink
Refactor app launching process
Browse files Browse the repository at this point in the history
  • Loading branch information
theskyinflames committed May 23, 2024
1 parent a6eb9d7 commit 4a37b66
Show file tree
Hide file tree
Showing 13 changed files with 150 additions and 2,231 deletions.
19 changes: 7 additions & 12 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@

lint:
golangci-lint run ./...
@golangci-lint run ./...

test-coverage:
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
@go test -coverprofile=coverage.out ./...
@go tool cover -func=coverage.out

run-oauth2:
docker run -p 8080:8080 --name keycloak -e KEYCLOAK_ADMIN=admin -e KEYCLOAK_ADMIN_PASSWORD=admin quay.io/keycloak/keycloak:latest start-dev
run:
@cd script && ./run.sh

shutdown-oauth2:
docker stop keycloak
docker rm keycloak
@docker stop keycloak
@docker rm keycloak

create-realm:
cd ./script && ./create-realm.sh

run-api:
./script/run-api.sh

28 changes: 9 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,13 @@ This is a simple example of protecting REST endpoints using an identity manager

## How to run it

### 1. Start the `keycloak` container in its own terminal
1. Start the identity manager and the API:

```sh
make run-oauth2
make run
```

### 2. Open a new terminal session and run following steps.

1. Create the realm, the client and the user

```sh
make create-realm
```

2. Run the API

```sh
make run-api
```

### 3. Open the browser and go to [protected endpoin](http://localhost:9000/protected)
2. Open the browser and go to [protected endpoint](http://localhost:9000/protected)

When you do it, you'll be redirected to the log in page of the identy manager. Then log in and you'll be redirected to protected endpoint with a valid authentication JWT token.

Expand All @@ -37,7 +23,7 @@ User credentials to log in:

**That's it !!! :-D**

# What authentication middleware does
## What authentication middleware does

The `http.AuthMiddleware` function in the provided code is a middleware function in Go that is used to check if a user is authenticated before allowing them to access certain routes or resources. Let's break down how it works:

Expand All @@ -61,7 +47,7 @@ The `http.AuthMiddleware` function in the provided code is a middleware function

In summary, the `AuthMiddleware` function is a middleware that checks if a user is authenticated by verifying the JWT token stored in a cookie. It ensures that only authenticated users with valid tokens and the required roles can access protected routes or resources.

# Auth diagram
# Authentication process diagram

```mermaid
sequenceDiagram
Expand Down Expand Up @@ -94,6 +80,10 @@ sequenceDiagram
end
```

## Pending

- Adding logout handler

## Used stack

- Go 1.22.3
Expand Down
7 changes: 0 additions & 7 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@ import (
httpx "theskyinflames/oauth2-example/pkg/http"
)

const (
dfclientID = "test-client"
dfclientSecret = "EPgv2q0H2fjG1VlHfrVkk5sVQPxLVzOW"
dfauthURL = "http://localhost:8080/realms/test-realm/protocol/openid-connect/auth"
dftokenURL = "http://localhost:8080/realms/test-realm/protocol/openid-connect/token"
)

func main() {
// Parse the needed parameters to set the OAuth2 configuration from environment variables
clientID := os.Getenv("CLIENT_ID")
Expand Down
Binary file added pkg/http/__debug_bin1715452224
Binary file not shown.
28 changes: 20 additions & 8 deletions pkg/http/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,33 @@ import (
// OAuthConfig is the OAuth2 configuration
func OAuthConfig(clientID, ClientSecret, authURL, tokenURL string) *oauth2.Config {
return &oauth2.Config{
ClientID: "test-client",
ClientSecret: "EPgv2q0H2fjG1VlHfrVkk5sVQPxLVzOW",
ClientID: clientID,
ClientSecret: ClientSecret,
// RedirectURL: "http://localhost:9000/callback",
// Scopes: []string{"openid", "profile", "email"},
Endpoint: oauth2.Endpoint{
AuthURL: "http://localhost:8080/realms/test-realm/protocol/openid-connect/auth",
TokenURL: "http://localhost:8080/realms/test-realm/protocol/openid-connect/token",
AuthURL: authURL,
TokenURL: tokenURL,
},
}
}

const (
authCookieName = "my-auth-cookie" // Name of the cookie to store the token

UserCtxKey = "user" // Key to store the user in the context
)

// User represents a user
type User struct {
Email Email
Roles []Role
}

func (u User) String() string {
return fmt.Sprintf("Email: %s, Roles: %#v", u.Email, u.Roles)
}

// AuthMiddleware is a middleware to check if the user is authenticated
func AuthMiddleware(rsaPublicKey []*rsa.PublicKey) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
Expand All @@ -52,14 +64,14 @@ func AuthMiddleware(rsaPublicKey []*rsa.PublicKey) echo.MiddlewareFunc {
}

// parse the token received in the auth cookie
_, roles, err := parseJWT(token, rsaPublicKey)
_, email, roles, err := parseJWT(token, rsaPublicKey)
if err != nil {
c.Logger().Errorf("Failed to parse token: %v", err)
return c.Redirect(http.StatusTemporaryRedirect, "/login")
}

// TODO: Check if the user has the required roles to access the resource
c.Set("roles", roles)
c.Set(UserCtxKey, User{Email: email, Roles: roles})

// Return the next handler
return next(c)
Expand Down Expand Up @@ -97,10 +109,10 @@ func CallbackHandler(f OAuthConfigExchangeFunc) echo.HandlerFunc {
}

// Extract the roles from the token
roles := extractRoles(accessToken)
email, roles := extractRoles(accessToken)

// Create a new token
newToken := NewCustomToken(accessToken, roles)
newToken := NewCustomToken(accessToken, email, roles)

fmt.Printf("New token: %v\n", newToken.Raw)

Expand Down
16 changes: 14 additions & 2 deletions pkg/http/handlers.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package http

import (
"fmt"
"net/http"
"strings"

"github.com/labstack/echo/v4"
)
Expand All @@ -14,8 +14,20 @@ func ProtectedHandler(c echo.Context) error {
for _, cookie := range cookies {
if cookie.Name == authCookieName {
msg = cookie.Value
break
}
}

return c.String(http.StatusOK, fmt.Sprintf("Protected endpoint: %s", msg))
user := c.Get(UserCtxKey).(User)

sb := strings.Builder{}
sb.WriteString("<h1>Protected endpoint</h1>")
sb.WriteString("<h2>User info</h2>")
sb.WriteString("<p>")
sb.WriteString(user.String())
sb.WriteString("</p>")
sb.WriteString("<h2>Token</h2>")
sb.WriteString("<p>JWT token: " + msg + "</p>")

return c.HTML(http.StatusOK, sb.String())
}
16 changes: 12 additions & 4 deletions pkg/http/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,23 @@ import (

func TestProtectedHandler(t *testing.T) {
e := echo.New()

req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
res := httptest.NewRecorder()

// Create a new context
c := e.NewContext(req, res)
c.Set(httpx.UserCtxKey, httpx.User{
Email: httpx.Email("email"),
Roles: []httpx.Role{
httpx.Role("admin"),
},
})

// Call the handler function
err := httpx.ProtectedHandler(c)
assert.NoError(t, err)

// Check the response
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Protected endpoint: ", rec.Body.String())
assert.Equal(t, http.StatusOK, res.Code)
}
37 changes: 25 additions & 12 deletions pkg/http/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,33 @@ import (
"golang.org/x/oauth2"
)

// Role represents the role of a user
type Role string
type (
// Role represents the role of a user
Role string

// Email represents the email of a user
Email string
)

func (e Email) String() string {
return string(e)
}

const rolesClaimKey = "realm_access"

// extractRoles extracts the roles from the JWT token and returns them as a list of strings
func extractRoles(token *jwt.Token) []Role {
func extractRoles(token *jwt.Token) (Email, []Role) {
// Extract the claims from the token
tokenRoles := token.Claims.(jwt.MapClaims)[rolesClaimKey].(map[string]interface{})["roles"].([]interface{})
roles := make([]Role, len(tokenRoles))
for i, v := range tokenRoles {
roles[i] = Role(fmt.Sprintf("%s", v))
}
return roles

// Extract the email from the token
email := Email(token.Claims.(jwt.MapClaims)["email"].(string))

return email, roles
}

// KeyFunc returns a jwt.Keyfunc that can be used to parse the JWT token
Expand All @@ -37,7 +50,7 @@ var KeyFunc = func(pk *rsa.PublicKey) jwt.Keyfunc {
}

// parseJWT parses the JWT token and returns the token if it is valid along with the roles of the user
func parseJWT(receivedToken string, rsaPublicKey []*rsa.PublicKey) (*jwt.Token, []Role, error) {
func parseJWT(receivedToken string, rsaPublicKey []*rsa.PublicKey) (*jwt.Token, Email, []Role, error) {
// Parse the token
var (
token *jwt.Token
Expand All @@ -50,16 +63,16 @@ func parseJWT(receivedToken string, rsaPublicKey []*rsa.PublicKey) (*jwt.Token,
}
}
if err != nil {
return nil, nil, err
return nil, "", nil, err
}

if !token.Valid {
return nil, nil, fmt.Errorf("token is invalid")
return nil, "", nil, fmt.Errorf("token is invalid")
}

roles := extractRoles(token)
email, roles := extractRoles(token)

return token, roles, nil
return token, email, roles, nil
}

var _ jwt.Claims = &CustomToken{}
Expand All @@ -74,15 +87,15 @@ type CustomToken struct {
}

// NewCustomToken creates a new *jwt.Token from the given JWT token
func NewCustomToken(token *jwt.Token, roles []Role) *jwt.Token {
func NewCustomToken(token *jwt.Token, email Email, roles []Role) *jwt.Token {
expiresAt := time.Unix(int64(token.Claims.(jwt.MapClaims)["exp"].(float64)), 0)
fmt.Printf("Expires at: %v\n", expiresAt)
return &jwt.Token{
Header: token.Header,
Claims: &CustomToken{
Claims: token.Claims,
Username: token.Claims.(jwt.MapClaims)["preferred_username"].(string),
Email: token.Claims.(jwt.MapClaims)["email"].(string),
Username: email.String(),
Email: email.String(),
Roles: roles,
ExpiresAt: expiresAt,
},
Expand Down
10 changes: 8 additions & 2 deletions pkg/http/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@ func TestExtractRoles(t *testing.T) {
"realm_access": map[string]interface{}{
"roles": []interface{}{"admin", "user"},
},
"email": "[email protected]",
},
}

roles := extractRoles(token)
email, roles := extractRoles(token)

expectedRoles := []Role{"admin", "user"}
if !reflect.DeepEqual(roles, expectedRoles) {
t.Fatalf("Roles mismatch. Expected: %v, got: %v", expectedRoles, roles)
}

expectedEmail := Email("[email protected]")
if email != expectedEmail {
t.Fatalf("Email mismatch. Expected: %v, got: %v", expectedEmail, email)
}
}

func TestParseJWT(t *testing.T) {
Expand All @@ -41,7 +47,7 @@ func TestParseJWT(t *testing.T) {
}

// Parse the token
parsedToken, _, err := parseJWT(tokenString, []*rsa.PublicKey{pubKey})
parsedToken, _, _, err := parseJWT(tokenString, []*rsa.PublicKey{pubKey})
if err != nil {
t.Fatalf("Failed to parse JWT: %v", err)
}
Expand Down
Loading

0 comments on commit 4a37b66

Please sign in to comment.