Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 283 additions & 15 deletions src/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,86 @@ package main

import (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's decouple some of the logic from this file to other new or existing files

"context"
"encoding/json"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused import

"fmt"
"log"
"net/http"
"net/http/httputil"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"

"github.com/go-oauth2/oauth2/v4/manage"
"github.com/go-oauth2/oauth2/v4/models"
"github.com/go-oauth2/oauth2/v4/server"
"github.com/go-oauth2/oauth2/v4/store"
"github.com/redis/go-redis/v9"
"golang.org/x/crypto/bcrypt"
)

type LoginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}

type RegisterRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}

type App struct {
redisClient *redis.Client
scheduler *Scheduler
supervisor *Supervisor
httpServer *http.Server
manager *manage.Manager
srv *server.Server
wg sync.WaitGroup
}

func NewApp(redisAddr, gpuType string) *App {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of gpuType here?

client := redis.NewClient(&redis.Options{Addr: redisAddr})
scheduler := NewScheduler(redisAddr)

consumerID := fmt.Sprintf("worker_%d", os.Getpid())
supervisor := NewSupervisor(redisAddr, consumerID, gpuType)
manager := manage.NewDefaultManager()
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
manager.MustTokenStorage(store.NewMemoryTokenStore()) // TODO: move to redis?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding currently, if the API gateway crashes and restarts then users will need to get new tokens. Therefore, I would suggest moving to redis


clientStore := store.NewClientStore()
clientStore.Set("client", &models.Client{
ID: "client",
Secret: "secret", // replace this with actual secret
Domain: "http://localhost:3000", // replace with environment domain
})
manager.MapClientStorage(clientStore)

srv := server.NewDefaultServer(manager)
srv.SetAllowGetAccessRequest(true)
srv.SetClientInfoHandler(server.ClientFormHandler)

mux := http.NewServeMux()
a := &App{
redisClient: client,
scheduler: scheduler,
supervisor: supervisor,
manager: manager,
srv: srv,
httpServer: &http.Server{Addr: ":3000", Handler: mux},
}

// auth routes
mux.HandleFunc("/auth/register", a.register)
mux.HandleFunc("/auth/login", a.login)
mux.HandleFunc("/auth/refresh", a.refresh)

mux.HandleFunc("/oauth/authorize", a.authorize)
mux.HandleFunc("/oauth/token", a.token)

mux.HandleFunc("/jobs", a.enqueueJob)
mux.HandleFunc("/jobs/status", a.getJobStatus)

srv.UserAuthorizationHandler = a.UserAuthorizationHandler

return a
}

Expand All @@ -50,7 +90,7 @@ func (a *App) Start() error {
if err := a.redisClient.Ping(context.Background()).Err(); err != nil {
return fmt.Errorf("redis ping failed: %w", err)
}

// Launch HTTP server
a.wg.Add(1)
go func() {
Expand All @@ -72,8 +112,6 @@ func (a *App) Shutdown(ctx context.Context) error {
// Wait for ListenAndServe goroutine to finish
a.wg.Wait()

a.supervisor.Stop()

if err := a.scheduler.Close(); err != nil {
log.Printf("error closing scheduler: %v", err)
}
Expand Down Expand Up @@ -106,18 +144,248 @@ func main() {
log.Println("all services stopped cleanly")
}

func (a *App) register(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
a.jsonResponse(w, http.StatusMethodNotAllowed, APIResponse{
Success: false,
Error: "Method not allowed",
})
return
}

var req RegisterRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
a.jsonResponse(w, http.StatusBadRequest, APIResponse{
Success: false,
Error: "Invalid JSON",
})
return
}

if req.Email == "" || req.Password == "" {
a.jsonResponse(w, http.StatusBadRequest, APIResponse{
Success: false,
Error: "Email and password required",
})
return
}

// Check if user exists
if _, err := a.getUserByEmail(req.Email); err == nil {
a.jsonResponse(w, http.StatusConflict, APIResponse{
Success: false,
Error: "Email already exists",
})
return
}

user, err := a.createUser(req.Email, req.Password)
if err != nil {
log.Printf("Failed to create user: %v", err)
a.jsonResponse(w, http.StatusInternalServerError, APIResponse{
Success: false,
Error: "Failed to create user",
})
return
}

a.jsonResponse(w, http.StatusCreated, APIResponse{
Success: true,
Data: map[string]string{
"user_id": user.ID,
"email": user.Email,
},
})
}

func (a *App) authorize(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
sessionID := r.Header.Get("Authorization")
if strings.HasPrefix(sessionID, "Session ") {
sessionID = strings.TrimPrefix(sessionID, "Session ")
}
if cookie, err := r.Cookie("session"); err == nil {
sessionID = cookie.Value
}

log.Println(sessionID)

if sessionID == "" {
log.Println("a")
a.redirectToLogin(w, r)
return
}
if _, err := a.getSession(sessionID); err != nil {
log.Println("b")
a.redirectToLogin(w, r)
return
}

err := a.srv.HandleAuthorizeRequest(w, r)
if err != nil {
log.Printf("Authorize error: %v", err)
http.Error(w, err.Error(), http.StatusBadRequest)
}
return
}

http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}

func (a *App) token(w http.ResponseWriter, r *http.Request) {
err := a.srv.HandleTokenRequest(w, r)
if err != nil {
log.Printf("Token error: %v", err)
http.Error(w, err.Error(), http.StatusBadRequest)
}
}

func (a *App) login(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
val, err := a.redisClient.Get(ctx, "some:key").Result()
if err != nil || err != redis.Nil {
http.Error(w, "redis error", http.StatusInternalServerError)
if r.Method == "GET" {
if sessionID := a.getSessionFromRequest(r); sessionID != "" {
if _, err := a.getSession(sessionID); err == nil {
// user is already logged in, redirect them
redirectURL := r.URL.Query().Get("redirect")
if redirectURL == "" {
redirectURL = "/" // default
}

http.Redirect(w, r, redirectURL, http.StatusFound)
return
}
}


a.showLoginPage(w, r)
return
}
fmt.Fprintf(w, "login page; redis says: %q\n", val)

if r.Method != "POST" {
a.jsonResponse(w, http.StatusMethodNotAllowed, APIResponse{
Success: false,
Error: "Method not allowed",
})
return
}

// Check if this is a form submission or api
contentType := r.Header.Get("Content-Type")
isFormData := strings.Contains(contentType, "application/x-www-form-urlencoded") || contentType == ""

var email, password string
var err error

if isFormData {
email = r.FormValue("email")
password = r.FormValue("password")
if email == "" || password == "" {
a.showLoginPage(w, r)
return
}
} else {
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
a.jsonResponse(w, http.StatusBadRequest, APIResponse{
Success: false,
Error: "Invalid JSON",
})
return
}
email = req.Email
password = req.Password
}

user, err := a.getUserByEmail(email)
if err != nil {
if isFormData {
a.showLoginPage(w, r)
return
}
a.jsonResponse(w, http.StatusUnauthorized, APIResponse{
Success: false,
Error: "Invalid Email / Password",
})
return
}

err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
if err != nil {
if isFormData {
a.showLoginPage(w, r)
return
}
a.jsonResponse(w, http.StatusUnauthorized, APIResponse{
Success: false,
Error: "Invalid Email / Password",
})
return
}

sessionID, err := a.CreateSession(user.ID)
if err != nil {
if isFormData {
a.showLoginPage(w, r)
return
}
a.jsonResponse(w, http.StatusInternalServerError, APIResponse{
Success: false,
Error: "Error Creating Session",
})
return
}

if isFormData {
http.SetCookie(w, &http.Cookie{
Name: "session",
Path: "/",
Value: sessionID,
HttpOnly: true,
Secure: false, // TODO: change in prod
SameSite: http.SameSiteLaxMode,
})

redirectURL := r.FormValue("redirect")
log.Println(redirectURL)
if redirectURL == "" {
redirectURL = "/" // Default redirect
}
http.Redirect(w, r, redirectURL, http.StatusFound)
return
}

a.jsonResponse(w, http.StatusAccepted, APIResponse{
Success: true,
Data: map[string]any{
"session": sessionID,
},
})
}

func (a *App) refresh(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, world!\n")
// TODO: move this to a file?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move it to a file

func (a *App) showLoginPage(w http.ResponseWriter, r *http.Request) {
redirectURL := r.URL.Query().Get("redirect")

html := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<body>
<form method="POST" action="/auth/login">
<input type="hidden" name="redirect" value="%s">

<div>
<label>Email:</label>
<input type="email" name="email" required>
</div>
<div>
<label>Password:</label>
<input type="password" name="password" required>
</div>
<button type="submit">Login</button>
</form>
</body>
</html>`, redirectURL)
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, html)
}

func (a *App) enqueueJob(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading