-
Notifications
You must be signed in to change notification settings - Fork 1
Add Oauth2 Authorization Code + Login Page #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,46 +2,86 @@ package main | |
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused import |
||
"fmt" | ||
"log" | ||
"net/http" | ||
"net/http/httputil" | ||
"os" | ||
"os/signal" | ||
"strings" | ||
"sync" | ||
"syscall" | ||
"time" | ||
|
||
"github.com/go-oauth2/oauth2/v4/manage" | ||
"github.com/go-oauth2/oauth2/v4/models" | ||
"github.com/go-oauth2/oauth2/v4/server" | ||
"github.com/go-oauth2/oauth2/v4/store" | ||
"github.com/redis/go-redis/v9" | ||
"golang.org/x/crypto/bcrypt" | ||
) | ||
|
||
type LoginRequest struct { | ||
Email string `json:"email"` | ||
Password string `json:"password"` | ||
} | ||
|
||
type RegisterRequest struct { | ||
Email string `json:"email"` | ||
Password string `json:"password"` | ||
} | ||
|
||
type App struct { | ||
redisClient *redis.Client | ||
scheduler *Scheduler | ||
supervisor *Supervisor | ||
httpServer *http.Server | ||
manager *manage.Manager | ||
srv *server.Server | ||
wg sync.WaitGroup | ||
} | ||
|
||
func NewApp(redisAddr, gpuType string) *App { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the purpose of gpuType here? |
||
client := redis.NewClient(&redis.Options{Addr: redisAddr}) | ||
scheduler := NewScheduler(redisAddr) | ||
|
||
consumerID := fmt.Sprintf("worker_%d", os.Getpid()) | ||
supervisor := NewSupervisor(redisAddr, consumerID, gpuType) | ||
manager := manage.NewDefaultManager() | ||
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) | ||
manager.MustTokenStorage(store.NewMemoryTokenStore()) // TODO: move to redis? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From my understanding currently, if the API gateway crashes and restarts then users will need to get new tokens. Therefore, I would suggest moving to redis |
||
|
||
clientStore := store.NewClientStore() | ||
clientStore.Set("client", &models.Client{ | ||
ID: "client", | ||
Secret: "secret", // replace this with actual secret | ||
Domain: "http://localhost:3000", // replace with environment domain | ||
}) | ||
manager.MapClientStorage(clientStore) | ||
|
||
srv := server.NewDefaultServer(manager) | ||
srv.SetAllowGetAccessRequest(true) | ||
srv.SetClientInfoHandler(server.ClientFormHandler) | ||
|
||
mux := http.NewServeMux() | ||
a := &App{ | ||
redisClient: client, | ||
scheduler: scheduler, | ||
supervisor: supervisor, | ||
manager: manager, | ||
srv: srv, | ||
httpServer: &http.Server{Addr: ":3000", Handler: mux}, | ||
} | ||
|
||
// auth routes | ||
mux.HandleFunc("/auth/register", a.register) | ||
mux.HandleFunc("/auth/login", a.login) | ||
mux.HandleFunc("/auth/refresh", a.refresh) | ||
|
||
mux.HandleFunc("/oauth/authorize", a.authorize) | ||
mux.HandleFunc("/oauth/token", a.token) | ||
|
||
mux.HandleFunc("/jobs", a.enqueueJob) | ||
mux.HandleFunc("/jobs/status", a.getJobStatus) | ||
|
||
srv.UserAuthorizationHandler = a.UserAuthorizationHandler | ||
|
||
return a | ||
} | ||
|
||
|
@@ -50,7 +90,7 @@ func (a *App) Start() error { | |
if err := a.redisClient.Ping(context.Background()).Err(); err != nil { | ||
return fmt.Errorf("redis ping failed: %w", err) | ||
} | ||
|
||
// Launch HTTP server | ||
a.wg.Add(1) | ||
go func() { | ||
|
@@ -72,8 +112,6 @@ func (a *App) Shutdown(ctx context.Context) error { | |
// Wait for ListenAndServe goroutine to finish | ||
a.wg.Wait() | ||
|
||
a.supervisor.Stop() | ||
|
||
if err := a.scheduler.Close(); err != nil { | ||
log.Printf("error closing scheduler: %v", err) | ||
} | ||
|
@@ -106,18 +144,248 @@ func main() { | |
log.Println("all services stopped cleanly") | ||
} | ||
|
||
func (a *App) register(w http.ResponseWriter, r *http.Request) { | ||
if r.Method != "POST" { | ||
a.jsonResponse(w, http.StatusMethodNotAllowed, APIResponse{ | ||
Success: false, | ||
Error: "Method not allowed", | ||
}) | ||
return | ||
} | ||
|
||
var req RegisterRequest | ||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | ||
a.jsonResponse(w, http.StatusBadRequest, APIResponse{ | ||
Success: false, | ||
Error: "Invalid JSON", | ||
}) | ||
return | ||
} | ||
|
||
if req.Email == "" || req.Password == "" { | ||
a.jsonResponse(w, http.StatusBadRequest, APIResponse{ | ||
Success: false, | ||
Error: "Email and password required", | ||
}) | ||
return | ||
} | ||
|
||
// Check if user exists | ||
if _, err := a.getUserByEmail(req.Email); err == nil { | ||
a.jsonResponse(w, http.StatusConflict, APIResponse{ | ||
Success: false, | ||
Error: "Email already exists", | ||
}) | ||
return | ||
} | ||
|
||
user, err := a.createUser(req.Email, req.Password) | ||
if err != nil { | ||
log.Printf("Failed to create user: %v", err) | ||
a.jsonResponse(w, http.StatusInternalServerError, APIResponse{ | ||
Success: false, | ||
Error: "Failed to create user", | ||
}) | ||
return | ||
} | ||
|
||
a.jsonResponse(w, http.StatusCreated, APIResponse{ | ||
Success: true, | ||
Data: map[string]string{ | ||
"user_id": user.ID, | ||
"email": user.Email, | ||
}, | ||
}) | ||
} | ||
|
||
func (a *App) authorize(w http.ResponseWriter, r *http.Request) { | ||
if r.Method == "GET" { | ||
sessionID := r.Header.Get("Authorization") | ||
if strings.HasPrefix(sessionID, "Session ") { | ||
sessionID = strings.TrimPrefix(sessionID, "Session ") | ||
} | ||
if cookie, err := r.Cookie("session"); err == nil { | ||
sessionID = cookie.Value | ||
} | ||
|
||
log.Println(sessionID) | ||
|
||
if sessionID == "" { | ||
log.Println("a") | ||
a.redirectToLogin(w, r) | ||
return | ||
} | ||
if _, err := a.getSession(sessionID); err != nil { | ||
log.Println("b") | ||
a.redirectToLogin(w, r) | ||
return | ||
} | ||
|
||
err := a.srv.HandleAuthorizeRequest(w, r) | ||
if err != nil { | ||
log.Printf("Authorize error: %v", err) | ||
http.Error(w, err.Error(), http.StatusBadRequest) | ||
} | ||
return | ||
} | ||
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | ||
} | ||
|
||
func (a *App) token(w http.ResponseWriter, r *http.Request) { | ||
err := a.srv.HandleTokenRequest(w, r) | ||
if err != nil { | ||
log.Printf("Token error: %v", err) | ||
http.Error(w, err.Error(), http.StatusBadRequest) | ||
} | ||
} | ||
|
||
func (a *App) login(w http.ResponseWriter, r *http.Request) { | ||
ctx := r.Context() | ||
val, err := a.redisClient.Get(ctx, "some:key").Result() | ||
if err != nil || err != redis.Nil { | ||
http.Error(w, "redis error", http.StatusInternalServerError) | ||
if r.Method == "GET" { | ||
if sessionID := a.getSessionFromRequest(r); sessionID != "" { | ||
if _, err := a.getSession(sessionID); err == nil { | ||
// user is already logged in, redirect them | ||
redirectURL := r.URL.Query().Get("redirect") | ||
if redirectURL == "" { | ||
redirectURL = "/" // default | ||
} | ||
|
||
http.Redirect(w, r, redirectURL, http.StatusFound) | ||
return | ||
} | ||
} | ||
|
||
|
||
a.showLoginPage(w, r) | ||
return | ||
} | ||
fmt.Fprintf(w, "login page; redis says: %q\n", val) | ||
|
||
if r.Method != "POST" { | ||
a.jsonResponse(w, http.StatusMethodNotAllowed, APIResponse{ | ||
Success: false, | ||
Error: "Method not allowed", | ||
}) | ||
return | ||
} | ||
|
||
// Check if this is a form submission or api | ||
contentType := r.Header.Get("Content-Type") | ||
isFormData := strings.Contains(contentType, "application/x-www-form-urlencoded") || contentType == "" | ||
|
||
var email, password string | ||
var err error | ||
|
||
if isFormData { | ||
email = r.FormValue("email") | ||
password = r.FormValue("password") | ||
if email == "" || password == "" { | ||
a.showLoginPage(w, r) | ||
return | ||
} | ||
} else { | ||
var req LoginRequest | ||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | ||
a.jsonResponse(w, http.StatusBadRequest, APIResponse{ | ||
Success: false, | ||
Error: "Invalid JSON", | ||
}) | ||
return | ||
} | ||
email = req.Email | ||
password = req.Password | ||
} | ||
|
||
user, err := a.getUserByEmail(email) | ||
if err != nil { | ||
if isFormData { | ||
a.showLoginPage(w, r) | ||
return | ||
} | ||
a.jsonResponse(w, http.StatusUnauthorized, APIResponse{ | ||
Success: false, | ||
Error: "Invalid Email / Password", | ||
}) | ||
return | ||
} | ||
|
||
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) | ||
if err != nil { | ||
if isFormData { | ||
a.showLoginPage(w, r) | ||
return | ||
} | ||
a.jsonResponse(w, http.StatusUnauthorized, APIResponse{ | ||
Success: false, | ||
Error: "Invalid Email / Password", | ||
}) | ||
return | ||
} | ||
|
||
sessionID, err := a.CreateSession(user.ID) | ||
if err != nil { | ||
if isFormData { | ||
a.showLoginPage(w, r) | ||
return | ||
} | ||
a.jsonResponse(w, http.StatusInternalServerError, APIResponse{ | ||
Success: false, | ||
Error: "Error Creating Session", | ||
}) | ||
return | ||
} | ||
|
||
if isFormData { | ||
http.SetCookie(w, &http.Cookie{ | ||
Name: "session", | ||
Path: "/", | ||
Value: sessionID, | ||
HttpOnly: true, | ||
Secure: false, // TODO: change in prod | ||
SameSite: http.SameSiteLaxMode, | ||
}) | ||
|
||
redirectURL := r.FormValue("redirect") | ||
log.Println(redirectURL) | ||
if redirectURL == "" { | ||
redirectURL = "/" // Default redirect | ||
} | ||
http.Redirect(w, r, redirectURL, http.StatusFound) | ||
return | ||
} | ||
|
||
a.jsonResponse(w, http.StatusAccepted, APIResponse{ | ||
Success: true, | ||
Data: map[string]any{ | ||
"session": sessionID, | ||
}, | ||
}) | ||
} | ||
|
||
func (a *App) refresh(w http.ResponseWriter, r *http.Request) { | ||
fmt.Fprintf(w, "Hello, world!\n") | ||
// TODO: move this to a file? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move it to a file |
||
func (a *App) showLoginPage(w http.ResponseWriter, r *http.Request) { | ||
redirectURL := r.URL.Query().Get("redirect") | ||
|
||
html := fmt.Sprintf(` | ||
<!DOCTYPE html> | ||
<html> | ||
<body> | ||
<form method="POST" action="/auth/login"> | ||
<input type="hidden" name="redirect" value="%s"> | ||
|
||
<div> | ||
<label>Email:</label> | ||
<input type="email" name="email" required> | ||
</div> | ||
<div> | ||
<label>Password:</label> | ||
<input type="password" name="password" required> | ||
</div> | ||
<button type="submit">Login</button> | ||
</form> | ||
</body> | ||
</html>`, redirectURL) | ||
w.Header().Set("Content-Type", "text/html") | ||
fmt.Fprint(w, html) | ||
} | ||
|
||
func (a *App) enqueueJob(w http.ResponseWriter, r *http.Request) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's decouple some of the logic from this file to other new or existing files