Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
bd6c0bf211
|
|||
|
8568b147bb
|
@@ -20,16 +20,30 @@ func Start() {
|
||||
})
|
||||
|
||||
r.Route("/whoami", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
r.Use(LoginCtx)
|
||||
r.Get("/", Whoami)
|
||||
})
|
||||
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
|
||||
r.Get("/", ListUsers)
|
||||
r.Route("/{userID}", func(r chi.Router) {
|
||||
r.Get("/", GetUser)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/login", func(r chi.Router) {
|
||||
r.Post("/", Login)
|
||||
})
|
||||
|
||||
r.Route("/logout", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
|
||||
r.Post("/", Logout)
|
||||
})
|
||||
|
||||
r.Route("/register", func(r chi.Router) {
|
||||
r.Post("/", NewUser)
|
||||
})
|
||||
|
||||
+201
-1
@@ -1,6 +1,206 @@
|
||||
package api
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var jwtSecret = []byte(os.Getenv("JWT_SECRET"))
|
||||
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func Login(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseMultipartForm(64 << 10)
|
||||
if err != nil {
|
||||
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
if username == "" || password == "" {
|
||||
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := dbGetUserByName(username)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validatePassword(user.Password, password); err != nil {
|
||||
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := CreateSession(user.ID)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
})
|
||||
|
||||
slog.Info("auth: login successful", "userid", user.ID, "username", user.Name)
|
||||
w.Write([]byte("Login successful"))
|
||||
}
|
||||
|
||||
func Logout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Error(w, "No session cookie found. You are already logged out", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := cookie.Value
|
||||
userID, valid := ValidateSession(sessionToken)
|
||||
if !valid {
|
||||
http.Error(w, "Session cookie could not be validated. You are already logged out", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := dbGetUser(userID.String())
|
||||
if err != nil {
|
||||
http.Error(w, "Session cookie validated but user could not be found", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
DeleteSession(sessionToken)
|
||||
|
||||
cookie.Expires = time.Now()
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
slog.Debug("auth: logout successful", "user ID", user.ID, "username", user.Name)
|
||||
w.Write([]byte(fmt.Sprintf("%v has been logged out", user.Name)))
|
||||
}
|
||||
|
||||
func ValidateSession(sessionToken string) (uuid.UUID, bool) {
|
||||
token, err := jwt.Parse(sessionToken, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
slog.Debug("auth: session token invalid, rejecting")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
slog.Debug("auth: could not map claims from JWT")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
userIDStr, ok := claims["userid"].(string)
|
||||
if !ok {
|
||||
slog.Debug("auth: userID claim is not a string")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
slog.Debug("auth: failed to parse userID as uuid", "error", err)
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
hashedToken := hashToken(sessionToken)
|
||||
|
||||
session, err := dbGetSession(hashedToken)
|
||||
if err != nil {
|
||||
slog.Debug("auth: failed to retrieve session from db", "error", err)
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
slog.Debug("auth: session validated", "userID", session.UserID)
|
||||
return userID, true
|
||||
}
|
||||
|
||||
func DeleteSession(sessionToken string) bool {
|
||||
hashedToken := hashToken(sessionToken)
|
||||
|
||||
err := dbDeleteSession(hashedToken)
|
||||
if err != nil {
|
||||
slog.Error("auth: failed to delete session", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
slog.Debug("auth: session deleted", "token", hashedToken)
|
||||
return true
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const userIDKey contextKey = "userID"
|
||||
|
||||
func SessionAuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := cookie.Value
|
||||
userID, valid := ValidateSession(sessionToken)
|
||||
if !valid {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Add username to request context
|
||||
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Token string
|
||||
UserID uuid.UUID
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
func CreateSession(userID uuid.UUID) string {
|
||||
expiry := time.Now().Add(7 * 24 * time.Hour)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"userid": userID.String(),
|
||||
"exp": expiry.Unix(), // 7 day token
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
slog.Error("auth: failed to create JWT", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
hashedToken := hashToken(tokenString)
|
||||
session := Session{
|
||||
Token: hashedToken,
|
||||
UserID: userID,
|
||||
Expiry: expiry,
|
||||
}
|
||||
dbAddSession(&session)
|
||||
|
||||
slog.Debug("auth: new session created", "userid", session.UserID)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func hashPassword(password string) (string, error) {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password),
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
)
|
||||
|
||||
var ErrUserNotFound = errors.New("db: user not found")
|
||||
var ErrSessionNotFound = errors.New("db: session not found")
|
||||
|
||||
func dbGetUser(id string) (*User, error) {
|
||||
query := `SELECT id, name, password FROM users WHERE id = $1`
|
||||
@@ -86,3 +87,47 @@ func dbAddUser(user *User) error {
|
||||
slog.Debug("db: user added", "userid", user.ID, "username", user.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbAddSession(session *Session) error {
|
||||
query := `INSERT INTO sessions (jwttoken, userid, expiry) VALUES ($1, $2, $3)`
|
||||
_, err := db.Pool.Exec(context.Background(), query, session.Token, session.UserID, session.Expiry)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add session", "error", err)
|
||||
return fmt.Errorf("failed to add session")
|
||||
}
|
||||
|
||||
slog.Debug("db: session added", "userid", session.UserID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbGetSession(jwtToken string) (*Session, error) {
|
||||
query := `SELECT jwttoken, userid, expiry FROM sessions WHERE jwttoken = $1`
|
||||
|
||||
var session Session
|
||||
err := db.Pool.QueryRow(context.Background(), query, jwtToken).Scan(&session.Token, &session.UserID, &session.Expiry)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Debug("db: session not found")
|
||||
return nil, ErrSessionNotFound
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query session", "error", err)
|
||||
return nil, fmt.Errorf("failed to query session")
|
||||
}
|
||||
|
||||
slog.Debug("db: session found", "userid", session.UserID)
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func dbDeleteSession(jwtToken string) error {
|
||||
query := `DELETE FROM sessions WHERE jwttoken = $1`
|
||||
tag, err := db.Pool.Exec(context.Background(), query, jwtToken)
|
||||
if err != nil {
|
||||
slog.Error("db: failed to delete session", "error", err)
|
||||
return fmt.Errorf("failed to delete session")
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return ErrSessionNotFound
|
||||
}
|
||||
|
||||
slog.Debug("db: session deleted")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,4 +38,12 @@ func ErrRender(err error) render.Renderer {
|
||||
}
|
||||
}
|
||||
|
||||
func ErrInternal(err error) render.Renderer {
|
||||
return &ErrResponse{
|
||||
Err: err,
|
||||
HTTPStatusCode: 500,
|
||||
StatusText: "Internal server error.",
|
||||
}
|
||||
}
|
||||
|
||||
var ErrNotFound = &ErrResponse{HTTPStatusCode: 404, StatusText: "Resource not found."}
|
||||
|
||||
+33
-5
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -18,6 +19,33 @@ func Whoami(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("anonymous"))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: returning username", "userid", user.ID, "username", user.Name)
|
||||
w.Write([]byte(user.Name))
|
||||
}
|
||||
|
||||
func LoginCtx(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering LoginCtx middleware")
|
||||
userID, ok := r.Context().Value(userIDKey).(uuid.UUID)
|
||||
if !ok || userID == uuid.Nil {
|
||||
slog.Debug("user: no user ID provided, assuming anonymous user")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: fetching user by ID", "user ID", userID)
|
||||
user, err := dbGetUser(userID.String())
|
||||
if err != nil {
|
||||
slog.Error("user: failed to fetch user by ID", "user ID", userID, "error", err)
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully fetched user", "user ID", user.ID, "username", user.Name)
|
||||
ctx := context.WithValue(r.Context(), userKey{}, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -25,14 +53,14 @@ func ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
dbUsers, err := dbGetAllUsers()
|
||||
if err != nil {
|
||||
slog.Error("user: failed to fetch users", "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully fetched users", "count", len(dbUsers))
|
||||
if err := render.RenderList(w, r, NewUserListResponse(dbUsers)); err != nil {
|
||||
slog.Error("user: failed to render user list response", "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -53,7 +81,7 @@ func GetUser(w http.ResponseWriter, r *http.Request) {
|
||||
render.Render(w, r, ErrNotFound)
|
||||
} else {
|
||||
slog.Error("user: failed to fetch user", "userid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -61,7 +89,7 @@ func GetUser(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: rendering user", "userid", user.ID, "username", user.Name)
|
||||
if err := render.Render(w, r, NewUserPayloadResponse(user)); err != nil {
|
||||
slog.Error("user: failed to render user", "userid", parsed.String(), "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +132,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) {
|
||||
err = dbAddUser(&newUser)
|
||||
if err != nil {
|
||||
slog.Error("user: failed to add new user", "userID", newUser.ID, "userName", newUser.Name, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
render.Render(w, r, ErrInternal(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ require (
|
||||
|
||||
require (
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||
github.com/golang/snappy v0.0.3 // indirect
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
|
||||
@@ -14,6 +14,8 @@ github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
|
||||
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
|
||||
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
||||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
|
||||
+1
-1
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
var REQUIRED_ENVS = [...]string{
|
||||
"DATABASE_URL",
|
||||
"DATABASE_URL", "JWT_SECRET",
|
||||
}
|
||||
|
||||
func checkEnvVars(keys []string) (bool, []string) {
|
||||
|
||||
Reference in New Issue
Block a user