Compare commits

...

2 Commits

Author SHA1 Message Date
williamp bd6c0bf211 server: implement authentication 2026-05-17 03:22:03 +00:00
williamp 8568b147bb server: implement ISEs w/o public error outputs 2026-05-17 01:09:52 +00:00
8 changed files with 305 additions and 7 deletions
+14
View File
@@ -20,16 +20,30 @@ func Start() {
})
r.Route("/whoami", func(r chi.Router) {
r.Use(SessionAuthMiddleware)
r.Use(LoginCtx)
r.Get("/", Whoami)
})
r.Route("/users", func(r chi.Router) {
r.Use(SessionAuthMiddleware)
r.Get("/", ListUsers)
r.Route("/{userID}", func(r chi.Router) {
r.Get("/", GetUser)
})
})
r.Route("/login", func(r chi.Router) {
r.Post("/", Login)
})
r.Route("/logout", func(r chi.Router) {
r.Use(SessionAuthMiddleware)
r.Post("/", Logout)
})
r.Route("/register", func(r chi.Router) {
r.Post("/", NewUser)
})
+201 -1
View File
@@ -1,6 +1,206 @@
package api
import "golang.org/x/crypto/bcrypt"
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"log/slog"
"net/http"
"os"
"time"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
var jwtSecret = []byte(os.Getenv("JWT_SECRET"))
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
func Login(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(64 << 10)
if err != nil {
http.Error(w, "Unable to parse form", http.StatusBadRequest)
return
}
username := r.FormValue("username")
password := r.FormValue("password")
if username == "" || password == "" {
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
return
}
user, err := dbGetUserByName(username)
if err != nil {
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
return
}
if err := validatePassword(user.Password, password); err != nil {
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
return
}
sessionToken := CreateSession(user.ID)
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: sessionToken,
Path: "/",
HttpOnly: true,
Secure: false,
})
slog.Info("auth: login successful", "userid", user.ID, "username", user.Name)
w.Write([]byte("Login successful"))
}
func Logout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_token")
if err != nil {
http.Error(w, "No session cookie found. You are already logged out", http.StatusBadRequest)
return
}
sessionToken := cookie.Value
userID, valid := ValidateSession(sessionToken)
if !valid {
http.Error(w, "Session cookie could not be validated. You are already logged out", http.StatusBadRequest)
return
}
user, err := dbGetUser(userID.String())
if err != nil {
http.Error(w, "Session cookie validated but user could not be found", http.StatusInternalServerError)
return
}
DeleteSession(sessionToken)
cookie.Expires = time.Now()
http.SetCookie(w, cookie)
slog.Debug("auth: logout successful", "user ID", user.ID, "username", user.Name)
w.Write([]byte(fmt.Sprintf("%v has been logged out", user.Name)))
}
func ValidateSession(sessionToken string) (uuid.UUID, bool) {
token, err := jwt.Parse(sessionToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
slog.Debug("auth: session token invalid, rejecting")
return uuid.Nil, false
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
slog.Debug("auth: could not map claims from JWT")
return uuid.Nil, false
}
userIDStr, ok := claims["userid"].(string)
if !ok {
slog.Debug("auth: userID claim is not a string")
return uuid.Nil, false
}
userID, err := uuid.Parse(userIDStr)
if err != nil {
slog.Debug("auth: failed to parse userID as uuid", "error", err)
return uuid.Nil, false
}
hashedToken := hashToken(sessionToken)
session, err := dbGetSession(hashedToken)
if err != nil {
slog.Debug("auth: failed to retrieve session from db", "error", err)
return uuid.Nil, false
}
slog.Debug("auth: session validated", "userID", session.UserID)
return userID, true
}
func DeleteSession(sessionToken string) bool {
hashedToken := hashToken(sessionToken)
err := dbDeleteSession(hashedToken)
if err != nil {
slog.Error("auth: failed to delete session", "error", err)
return false
}
slog.Debug("auth: session deleted", "token", hashedToken)
return true
}
type contextKey string
const userIDKey contextKey = "userID"
func SessionAuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_token")
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
sessionToken := cookie.Value
userID, valid := ValidateSession(sessionToken)
if !valid {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
// Add username to request context
ctx := context.WithValue(r.Context(), userIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
type Session struct {
Token string
UserID uuid.UUID
Expiry time.Time
}
func CreateSession(userID uuid.UUID) string {
expiry := time.Now().Add(7 * 24 * time.Hour)
claims := jwt.MapClaims{
"userid": userID.String(),
"exp": expiry.Unix(), // 7 day token
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(jwtSecret)
if err != nil {
slog.Error("auth: failed to create JWT", "error", err)
return ""
}
hashedToken := hashToken(tokenString)
session := Session{
Token: hashedToken,
UserID: userID,
Expiry: expiry,
}
dbAddSession(&session)
slog.Debug("auth: new session created", "userid", session.UserID)
return tokenString
}
func hashPassword(password string) (string, error) {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password),
+45
View File
@@ -11,6 +11,7 @@ import (
)
var ErrUserNotFound = errors.New("db: user not found")
var ErrSessionNotFound = errors.New("db: session not found")
func dbGetUser(id string) (*User, error) {
query := `SELECT id, name, password FROM users WHERE id = $1`
@@ -86,3 +87,47 @@ func dbAddUser(user *User) error {
slog.Debug("db: user added", "userid", user.ID, "username", user.Name)
return nil
}
func dbAddSession(session *Session) error {
query := `INSERT INTO sessions (jwttoken, userid, expiry) VALUES ($1, $2, $3)`
_, err := db.Pool.Exec(context.Background(), query, session.Token, session.UserID, session.Expiry)
if err != nil {
slog.Error("db: failed to add session", "error", err)
return fmt.Errorf("failed to add session")
}
slog.Debug("db: session added", "userid", session.UserID)
return nil
}
func dbGetSession(jwtToken string) (*Session, error) {
query := `SELECT jwttoken, userid, expiry FROM sessions WHERE jwttoken = $1`
var session Session
err := db.Pool.QueryRow(context.Background(), query, jwtToken).Scan(&session.Token, &session.UserID, &session.Expiry)
if errors.Is(err, pgx.ErrNoRows) {
slog.Debug("db: session not found")
return nil, ErrSessionNotFound
} else if err != nil {
slog.Error("db: failed to query session", "error", err)
return nil, fmt.Errorf("failed to query session")
}
slog.Debug("db: session found", "userid", session.UserID)
return &session, nil
}
func dbDeleteSession(jwtToken string) error {
query := `DELETE FROM sessions WHERE jwttoken = $1`
tag, err := db.Pool.Exec(context.Background(), query, jwtToken)
if err != nil {
slog.Error("db: failed to delete session", "error", err)
return fmt.Errorf("failed to delete session")
}
if tag.RowsAffected() == 0 {
return ErrSessionNotFound
}
slog.Debug("db: session deleted")
return nil
}
+8
View File
@@ -38,4 +38,12 @@ func ErrRender(err error) render.Renderer {
}
}
func ErrInternal(err error) render.Renderer {
return &ErrResponse{
Err: err,
HTTPStatusCode: 500,
StatusText: "Internal server error.",
}
}
var ErrNotFound = &ErrResponse{HTTPStatusCode: 404, StatusText: "Resource not found."}
+33 -5
View File
@@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
"log/slog"
"net/http"
@@ -18,6 +19,33 @@ func Whoami(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("anonymous"))
return
}
slog.Debug("user: returning username", "userid", user.ID, "username", user.Name)
w.Write([]byte(user.Name))
}
func LoginCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Debug("user: entering LoginCtx middleware")
userID, ok := r.Context().Value(userIDKey).(uuid.UUID)
if !ok || userID == uuid.Nil {
slog.Debug("user: no user ID provided, assuming anonymous user")
next.ServeHTTP(w, r)
return
}
slog.Debug("user: fetching user by ID", "user ID", userID)
user, err := dbGetUser(userID.String())
if err != nil {
slog.Error("user: failed to fetch user by ID", "user ID", userID, "error", err)
render.Render(w, r, ErrNotFound)
return
}
slog.Debug("user: successfully fetched user", "user ID", user.ID, "username", user.Name)
ctx := context.WithValue(r.Context(), userKey{}, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func ListUsers(w http.ResponseWriter, r *http.Request) {
@@ -25,14 +53,14 @@ func ListUsers(w http.ResponseWriter, r *http.Request) {
dbUsers, err := dbGetAllUsers()
if err != nil {
slog.Error("user: failed to fetch users", "error", err)
render.Render(w, r, ErrRender(err))
render.Render(w, r, ErrInternal(err))
return
}
slog.Debug("user: successfully fetched users", "count", len(dbUsers))
if err := render.RenderList(w, r, NewUserListResponse(dbUsers)); err != nil {
slog.Error("user: failed to render user list response", "error", err)
render.Render(w, r, ErrRender(err))
render.Render(w, r, ErrInternal(err))
return
}
}
@@ -53,7 +81,7 @@ func GetUser(w http.ResponseWriter, r *http.Request) {
render.Render(w, r, ErrNotFound)
} else {
slog.Error("user: failed to fetch user", "userid", parsed.String(), "error", err)
render.Render(w, r, ErrRender(err))
render.Render(w, r, ErrInternal(err))
}
return
}
@@ -61,7 +89,7 @@ func GetUser(w http.ResponseWriter, r *http.Request) {
slog.Debug("user: rendering user", "userid", user.ID, "username", user.Name)
if err := render.Render(w, r, NewUserPayloadResponse(user)); err != nil {
slog.Error("user: failed to render user", "userid", parsed.String(), "error", err)
render.Render(w, r, ErrRender(err))
render.Render(w, r, ErrInternal(err))
}
}
@@ -104,7 +132,7 @@ func NewUser(w http.ResponseWriter, r *http.Request) {
err = dbAddUser(&newUser)
if err != nil {
slog.Error("user: failed to add new user", "userID", newUser.ID, "userName", newUser.Name, "error", err)
render.Render(w, r, ErrRender(err))
render.Render(w, r, ErrInternal(err))
return
}
+1
View File
@@ -14,6 +14,7 @@ require (
require (
github.com/ajg/form v1.5.1 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
+2
View File
@@ -14,6 +14,8 @@ github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+1 -1
View File
@@ -9,7 +9,7 @@ import (
)
var REQUIRED_ENVS = [...]string{
"DATABASE_URL",
"DATABASE_URL", "JWT_SECRET",
}
func checkEnvVars(keys []string) (bool, []string) {