server: implement authentication
This commit is contained in:
@@ -20,16 +20,30 @@ func Start() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r.Route("/whoami", func(r chi.Router) {
|
r.Route("/whoami", func(r chi.Router) {
|
||||||
|
r.Use(SessionAuthMiddleware)
|
||||||
|
r.Use(LoginCtx)
|
||||||
r.Get("/", Whoami)
|
r.Get("/", Whoami)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Route("/users", func(r chi.Router) {
|
r.Route("/users", func(r chi.Router) {
|
||||||
|
r.Use(SessionAuthMiddleware)
|
||||||
|
|
||||||
r.Get("/", ListUsers)
|
r.Get("/", ListUsers)
|
||||||
r.Route("/{userID}", func(r chi.Router) {
|
r.Route("/{userID}", func(r chi.Router) {
|
||||||
r.Get("/", GetUser)
|
r.Get("/", GetUser)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.Route("/login", func(r chi.Router) {
|
||||||
|
r.Post("/", Login)
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Route("/logout", func(r chi.Router) {
|
||||||
|
r.Use(SessionAuthMiddleware)
|
||||||
|
|
||||||
|
r.Post("/", Logout)
|
||||||
|
})
|
||||||
|
|
||||||
r.Route("/register", func(r chi.Router) {
|
r.Route("/register", func(r chi.Router) {
|
||||||
r.Post("/", NewUser)
|
r.Post("/", NewUser)
|
||||||
})
|
})
|
||||||
|
|||||||
+201
-1
@@ -1,6 +1,206 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import "golang.org/x/crypto/bcrypt"
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var jwtSecret = []byte(os.Getenv("JWT_SECRET"))
|
||||||
|
|
||||||
|
func hashToken(token string) string {
|
||||||
|
hash := sha256.Sum256([]byte(token))
|
||||||
|
return hex.EncodeToString(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func Login(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := r.ParseMultipartForm(64 << 10)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
username := r.FormValue("username")
|
||||||
|
password := r.FormValue("password")
|
||||||
|
if username == "" || password == "" {
|
||||||
|
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := dbGetUserByName(username)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validatePassword(user.Password, password); err != nil {
|
||||||
|
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionToken := CreateSession(user.ID)
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "session_token",
|
||||||
|
Value: sessionToken,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
slog.Info("auth: login successful", "userid", user.ID, "username", user.Name)
|
||||||
|
w.Write([]byte("Login successful"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Logout(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cookie, err := r.Cookie("session_token")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "No session cookie found. You are already logged out", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionToken := cookie.Value
|
||||||
|
userID, valid := ValidateSession(sessionToken)
|
||||||
|
if !valid {
|
||||||
|
http.Error(w, "Session cookie could not be validated. You are already logged out", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := dbGetUser(userID.String())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Session cookie validated but user could not be found", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
DeleteSession(sessionToken)
|
||||||
|
|
||||||
|
cookie.Expires = time.Now()
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
|
||||||
|
slog.Debug("auth: logout successful", "user ID", user.ID, "username", user.Name)
|
||||||
|
w.Write([]byte(fmt.Sprintf("%v has been logged out", user.Name)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateSession(sessionToken string) (uuid.UUID, bool) {
|
||||||
|
token, err := jwt.Parse(sessionToken, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
return jwtSecret, nil
|
||||||
|
})
|
||||||
|
if err != nil || !token.Valid {
|
||||||
|
slog.Debug("auth: session token invalid, rejecting")
|
||||||
|
return uuid.Nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
slog.Debug("auth: could not map claims from JWT")
|
||||||
|
return uuid.Nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDStr, ok := claims["userid"].(string)
|
||||||
|
if !ok {
|
||||||
|
slog.Debug("auth: userID claim is not a string")
|
||||||
|
return uuid.Nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, err := uuid.Parse(userIDStr)
|
||||||
|
if err != nil {
|
||||||
|
slog.Debug("auth: failed to parse userID as uuid", "error", err)
|
||||||
|
return uuid.Nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedToken := hashToken(sessionToken)
|
||||||
|
|
||||||
|
session, err := dbGetSession(hashedToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Debug("auth: failed to retrieve session from db", "error", err)
|
||||||
|
return uuid.Nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("auth: session validated", "userID", session.UserID)
|
||||||
|
return userID, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteSession(sessionToken string) bool {
|
||||||
|
hashedToken := hashToken(sessionToken)
|
||||||
|
|
||||||
|
err := dbDeleteSession(hashedToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("auth: failed to delete session", "error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("auth: session deleted", "token", hashedToken)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const userIDKey contextKey = "userID"
|
||||||
|
|
||||||
|
func SessionAuthMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cookie, err := r.Cookie("session_token")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionToken := cookie.Value
|
||||||
|
userID, valid := ValidateSession(sessionToken)
|
||||||
|
if !valid {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add username to request context
|
||||||
|
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
Token string
|
||||||
|
UserID uuid.UUID
|
||||||
|
Expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateSession(userID uuid.UUID) string {
|
||||||
|
expiry := time.Now().Add(7 * 24 * time.Hour)
|
||||||
|
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"userid": userID.String(),
|
||||||
|
"exp": expiry.Unix(), // 7 day token
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
tokenString, err := token.SignedString(jwtSecret)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("auth: failed to create JWT", "error", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedToken := hashToken(tokenString)
|
||||||
|
session := Session{
|
||||||
|
Token: hashedToken,
|
||||||
|
UserID: userID,
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
dbAddSession(&session)
|
||||||
|
|
||||||
|
slog.Debug("auth: new session created", "userid", session.UserID)
|
||||||
|
return tokenString
|
||||||
|
}
|
||||||
|
|
||||||
func hashPassword(password string) (string, error) {
|
func hashPassword(password string) (string, error) {
|
||||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password),
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password),
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ErrUserNotFound = errors.New("db: user not found")
|
var ErrUserNotFound = errors.New("db: user not found")
|
||||||
|
var ErrSessionNotFound = errors.New("db: session not found")
|
||||||
|
|
||||||
func dbGetUser(id string) (*User, error) {
|
func dbGetUser(id string) (*User, error) {
|
||||||
query := `SELECT id, name, password FROM users WHERE id = $1`
|
query := `SELECT id, name, password FROM users WHERE id = $1`
|
||||||
@@ -86,3 +87,47 @@ func dbAddUser(user *User) error {
|
|||||||
slog.Debug("db: user added", "userid", user.ID, "username", user.Name)
|
slog.Debug("db: user added", "userid", user.ID, "username", user.Name)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dbAddSession(session *Session) error {
|
||||||
|
query := `INSERT INTO sessions (jwttoken, userid, expiry) VALUES ($1, $2, $3)`
|
||||||
|
_, err := db.Pool.Exec(context.Background(), query, session.Token, session.UserID, session.Expiry)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("db: failed to add session", "error", err)
|
||||||
|
return fmt.Errorf("failed to add session")
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("db: session added", "userid", session.UserID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbGetSession(jwtToken string) (*Session, error) {
|
||||||
|
query := `SELECT jwttoken, userid, expiry FROM sessions WHERE jwttoken = $1`
|
||||||
|
|
||||||
|
var session Session
|
||||||
|
err := db.Pool.QueryRow(context.Background(), query, jwtToken).Scan(&session.Token, &session.UserID, &session.Expiry)
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
slog.Debug("db: session not found")
|
||||||
|
return nil, ErrSessionNotFound
|
||||||
|
} else if err != nil {
|
||||||
|
slog.Error("db: failed to query session", "error", err)
|
||||||
|
return nil, fmt.Errorf("failed to query session")
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("db: session found", "userid", session.UserID)
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbDeleteSession(jwtToken string) error {
|
||||||
|
query := `DELETE FROM sessions WHERE jwttoken = $1`
|
||||||
|
tag, err := db.Pool.Exec(context.Background(), query, jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("db: failed to delete session", "error", err)
|
||||||
|
return fmt.Errorf("failed to delete session")
|
||||||
|
}
|
||||||
|
if tag.RowsAffected() == 0 {
|
||||||
|
return ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("db: session deleted")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -18,6 +19,33 @@ func Whoami(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Write([]byte("anonymous"))
|
w.Write([]byte("anonymous"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Debug("user: returning username", "userid", user.ID, "username", user.Name)
|
||||||
|
w.Write([]byte(user.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoginCtx(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
slog.Debug("user: entering LoginCtx middleware")
|
||||||
|
userID, ok := r.Context().Value(userIDKey).(uuid.UUID)
|
||||||
|
if !ok || userID == uuid.Nil {
|
||||||
|
slog.Debug("user: no user ID provided, assuming anonymous user")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("user: fetching user by ID", "user ID", userID)
|
||||||
|
user, err := dbGetUser(userID.String())
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("user: failed to fetch user by ID", "user ID", userID, "error", err)
|
||||||
|
render.Render(w, r, ErrNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("user: successfully fetched user", "user ID", user.ID, "username", user.Name)
|
||||||
|
ctx := context.WithValue(r.Context(), userKey{}, user)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListUsers(w http.ResponseWriter, r *http.Request) {
|
func ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/ajg/form v1.5.1 // indirect
|
github.com/ajg/form v1.5.1 // indirect
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
|
||||||
github.com/golang/snappy v0.0.3 // indirect
|
github.com/golang/snappy v0.0.3 // indirect
|
||||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
|
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
|
|||||||
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
|
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
|
||||||
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
||||||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
|||||||
+1
-1
@@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var REQUIRED_ENVS = [...]string{
|
var REQUIRED_ENVS = [...]string{
|
||||||
"DATABASE_URL",
|
"DATABASE_URL", "JWT_SECRET",
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkEnvVars(keys []string) (bool, []string) {
|
func checkEnvVars(keys []string) (bool, []string) {
|
||||||
|
|||||||
Reference in New Issue
Block a user