diff --git a/server/api/api.go b/server/api/api.go index e61aef9..f83f565 100644 --- a/server/api/api.go +++ b/server/api/api.go @@ -20,16 +20,30 @@ func Start() { }) r.Route("/whoami", func(r chi.Router) { + r.Use(SessionAuthMiddleware) + r.Use(LoginCtx) r.Get("/", Whoami) }) r.Route("/users", func(r chi.Router) { + r.Use(SessionAuthMiddleware) + r.Get("/", ListUsers) r.Route("/{userID}", func(r chi.Router) { r.Get("/", GetUser) }) }) + r.Route("/login", func(r chi.Router) { + r.Post("/", Login) + }) + + r.Route("/logout", func(r chi.Router) { + r.Use(SessionAuthMiddleware) + + r.Post("/", Logout) + }) + r.Route("/register", func(r chi.Router) { r.Post("/", NewUser) }) diff --git a/server/api/auth.go b/server/api/auth.go index bf10da3..c20d531 100644 --- a/server/api/auth.go +++ b/server/api/auth.go @@ -1,6 +1,206 @@ package api -import "golang.org/x/crypto/bcrypt" +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "log/slog" + "net/http" + "os" + "time" + + "github.com/golang-jwt/jwt" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" +) + +var jwtSecret = []byte(os.Getenv("JWT_SECRET")) + +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) +} + +func Login(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(64 << 10) + if err != nil { + http.Error(w, "Unable to parse form", http.StatusBadRequest) + return + } + + username := r.FormValue("username") + password := r.FormValue("password") + if username == "" || password == "" { + http.Error(w, "Username and password cannot be empty", http.StatusBadRequest) + return + } + + user, err := dbGetUserByName(username) + if err != nil { + http.Error(w, "Invalid username or password", http.StatusUnauthorized) + return + } + + if err := validatePassword(user.Password, password); err != nil { + http.Error(w, "Invalid username or password", http.StatusUnauthorized) + return + } + + sessionToken := CreateSession(user.ID) + http.SetCookie(w, &http.Cookie{ + Name: "session_token", + Value: sessionToken, + Path: "/", + HttpOnly: true, + Secure: false, + }) + + slog.Info("auth: login successful", "userid", user.ID, "username", user.Name) + w.Write([]byte("Login successful")) +} + +func Logout(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("session_token") + if err != nil { + http.Error(w, "No session cookie found. You are already logged out", http.StatusBadRequest) + return + } + + sessionToken := cookie.Value + userID, valid := ValidateSession(sessionToken) + if !valid { + http.Error(w, "Session cookie could not be validated. You are already logged out", http.StatusBadRequest) + return + } + + user, err := dbGetUser(userID.String()) + if err != nil { + http.Error(w, "Session cookie validated but user could not be found", http.StatusInternalServerError) + return + } + + DeleteSession(sessionToken) + + cookie.Expires = time.Now() + http.SetCookie(w, cookie) + + slog.Debug("auth: logout successful", "user ID", user.ID, "username", user.Name) + w.Write([]byte(fmt.Sprintf("%v has been logged out", user.Name))) +} + +func ValidateSession(sessionToken string) (uuid.UUID, bool) { + token, err := jwt.Parse(sessionToken, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return jwtSecret, nil + }) + if err != nil || !token.Valid { + slog.Debug("auth: session token invalid, rejecting") + return uuid.Nil, false + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + slog.Debug("auth: could not map claims from JWT") + return uuid.Nil, false + } + + userIDStr, ok := claims["userid"].(string) + if !ok { + slog.Debug("auth: userID claim is not a string") + return uuid.Nil, false + } + + userID, err := uuid.Parse(userIDStr) + if err != nil { + slog.Debug("auth: failed to parse userID as uuid", "error", err) + return uuid.Nil, false + } + + hashedToken := hashToken(sessionToken) + + session, err := dbGetSession(hashedToken) + if err != nil { + slog.Debug("auth: failed to retrieve session from db", "error", err) + return uuid.Nil, false + } + + slog.Debug("auth: session validated", "userID", session.UserID) + return userID, true +} + +func DeleteSession(sessionToken string) bool { + hashedToken := hashToken(sessionToken) + + err := dbDeleteSession(hashedToken) + if err != nil { + slog.Error("auth: failed to delete session", "error", err) + return false + } + + slog.Debug("auth: session deleted", "token", hashedToken) + return true +} + +type contextKey string + +const userIDKey contextKey = "userID" + +func SessionAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("session_token") + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + sessionToken := cookie.Value + userID, valid := ValidateSession(sessionToken) + if !valid { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Add username to request context + ctx := context.WithValue(r.Context(), userIDKey, userID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +type Session struct { + Token string + UserID uuid.UUID + Expiry time.Time +} + +func CreateSession(userID uuid.UUID) string { + expiry := time.Now().Add(7 * 24 * time.Hour) + + claims := jwt.MapClaims{ + "userid": userID.String(), + "exp": expiry.Unix(), // 7 day token + "iat": time.Now().Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(jwtSecret) + if err != nil { + slog.Error("auth: failed to create JWT", "error", err) + return "" + } + + hashedToken := hashToken(tokenString) + session := Session{ + Token: hashedToken, + UserID: userID, + Expiry: expiry, + } + dbAddSession(&session) + + slog.Debug("auth: new session created", "userid", session.UserID) + return tokenString +} func hashPassword(password string) (string, error) { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), diff --git a/server/api/db.go b/server/api/db.go index 11bf2e2..a483cc4 100644 --- a/server/api/db.go +++ b/server/api/db.go @@ -11,6 +11,7 @@ import ( ) var ErrUserNotFound = errors.New("db: user not found") +var ErrSessionNotFound = errors.New("db: session not found") func dbGetUser(id string) (*User, error) { query := `SELECT id, name, password FROM users WHERE id = $1` @@ -86,3 +87,47 @@ func dbAddUser(user *User) error { slog.Debug("db: user added", "userid", user.ID, "username", user.Name) return nil } + +func dbAddSession(session *Session) error { + query := `INSERT INTO sessions (jwttoken, userid, expiry) VALUES ($1, $2, $3)` + _, err := db.Pool.Exec(context.Background(), query, session.Token, session.UserID, session.Expiry) + if err != nil { + slog.Error("db: failed to add session", "error", err) + return fmt.Errorf("failed to add session") + } + + slog.Debug("db: session added", "userid", session.UserID) + return nil +} + +func dbGetSession(jwtToken string) (*Session, error) { + query := `SELECT jwttoken, userid, expiry FROM sessions WHERE jwttoken = $1` + + var session Session + err := db.Pool.QueryRow(context.Background(), query, jwtToken).Scan(&session.Token, &session.UserID, &session.Expiry) + if errors.Is(err, pgx.ErrNoRows) { + slog.Debug("db: session not found") + return nil, ErrSessionNotFound + } else if err != nil { + slog.Error("db: failed to query session", "error", err) + return nil, fmt.Errorf("failed to query session") + } + + slog.Debug("db: session found", "userid", session.UserID) + return &session, nil +} + +func dbDeleteSession(jwtToken string) error { + query := `DELETE FROM sessions WHERE jwttoken = $1` + tag, err := db.Pool.Exec(context.Background(), query, jwtToken) + if err != nil { + slog.Error("db: failed to delete session", "error", err) + return fmt.Errorf("failed to delete session") + } + if tag.RowsAffected() == 0 { + return ErrSessionNotFound + } + + slog.Debug("db: session deleted") + return nil +} diff --git a/server/api/user.go b/server/api/user.go index 0bed69a..6531e9b 100644 --- a/server/api/user.go +++ b/server/api/user.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "log/slog" "net/http" @@ -18,6 +19,33 @@ func Whoami(w http.ResponseWriter, r *http.Request) { w.Write([]byte("anonymous")) return } + + slog.Debug("user: returning username", "userid", user.ID, "username", user.Name) + w.Write([]byte(user.Name)) +} + +func LoginCtx(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + slog.Debug("user: entering LoginCtx middleware") + userID, ok := r.Context().Value(userIDKey).(uuid.UUID) + if !ok || userID == uuid.Nil { + slog.Debug("user: no user ID provided, assuming anonymous user") + next.ServeHTTP(w, r) + return + } + + slog.Debug("user: fetching user by ID", "user ID", userID) + user, err := dbGetUser(userID.String()) + if err != nil { + slog.Error("user: failed to fetch user by ID", "user ID", userID, "error", err) + render.Render(w, r, ErrNotFound) + return + } + + slog.Debug("user: successfully fetched user", "user ID", user.ID, "username", user.Name) + ctx := context.WithValue(r.Context(), userKey{}, user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) } func ListUsers(w http.ResponseWriter, r *http.Request) { diff --git a/server/go.mod b/server/go.mod index 7dfb16f..af69483 100644 --- a/server/go.mod +++ b/server/go.mod @@ -14,6 +14,7 @@ require ( require ( github.com/ajg/form v1.5.1 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/snappy v0.0.3 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/server/go.sum b/server/go.sum index e97f764..b8b1f9a 100644 --- a/server/go.sum +++ b/server/go.sum @@ -14,6 +14,8 @@ github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus= github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/server/main.go b/server/main.go index 801f78d..8383cac 100644 --- a/server/main.go +++ b/server/main.go @@ -9,7 +9,7 @@ import ( ) var REQUIRED_ENVS = [...]string{ - "DATABASE_URL", + "DATABASE_URL", "JWT_SECRET", } func checkEnvVars(keys []string) (bool, []string) {