diff --git a/api/api.go b/api/api.go index d4623eb..19ffb0e 100644 --- a/api/api.go +++ b/api/api.go @@ -37,6 +37,8 @@ func Start() { }) r.Route("/messages", func(r chi.Router) { + r.Use(SessionAuthMiddleware) // Protect with authentication + r.Get("/", ListMessages) r.Route("/{messageID}", func(r chi.Router) { r.Use(MessageCtx) // Load message @@ -48,6 +50,8 @@ func Start() { }) r.Route("/users", func(r chi.Router) { + r.Use(SessionAuthMiddleware) // Protect with authentication + r.Get("/", ListUsers) r.Route("/{userID}", func(r chi.Router) { r.Use(UserCtx) // Load user diff --git a/api/user.go b/api/user.go index ca151c0..c7fe832 100644 --- a/api/user.go +++ b/api/user.go @@ -75,6 +75,10 @@ func NewUser(w http.ResponseWriter, r *http.Request) { } hashedPassword, err := hashPassword(password) + if err != nil { + http.Error(w, "Unable to hash password", http.StatusInternalServerError) + } + newUser := User{ ID: newUserID(), Name: newUserName, @@ -116,9 +120,57 @@ func Login(w http.ResponseWriter, r *http.Request) { return } + sessionToken := CreateSession(username) + + http.SetCookie(w, &http.Cookie{ + Name: "session_token", + Value: sessionToken, + Path: "/", + HttpOnly: true, + Secure: false, + }) + w.Write([]byte("Login successful")) } +var sessionStore = make(map[string]string) + +func CreateSession(username string) string { + sessionToken := uuid.New().String() + sessionStore[sessionToken] = username + return sessionToken +} + +func ValidateSession(sessionToken string) (string, bool) { + username, exists := sessionStore[sessionToken] + return username, exists +} + +type contextKey string + +const usernameKey contextKey = "username" + +func SessionAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("session_token") + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + sessionToken := cookie.Value + username, valid := ValidateSession(sessionToken) + if !valid { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Add username to request context + ctx := context.WithValue(r.Context(), usernameKey, username) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func (u *UserPayload) Render(w http.ResponseWriter, r *http.Request) error { return nil }