diff --git a/api/api.go b/api/api.go index 19ffb0e..ed6a680 100644 --- a/api/api.go +++ b/api/api.go @@ -36,6 +36,12 @@ func Start() { panic("oh no") }) + r.Route("/whoami", func(r chi.Router) { + r.Use(SessionAuthMiddleware) + r.Use(LoginCtx) + r.Get("/", Whoami) + }) + r.Route("/messages", func(r chi.Router) { r.Use(SessionAuthMiddleware) // Protect with authentication diff --git a/api/user.go b/api/user.go index 47f57e2..63c0086 100644 --- a/api/user.go +++ b/api/user.go @@ -30,6 +30,38 @@ func UserCtx(next http.Handler) http.Handler { }) } +func Whoami(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(userKey{}).(*User) + if !ok { + w.Write([]byte("undefined")) + return + } else { + w.Write([]byte(user.Name)) + return + } +} + +func LoginCtx(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var user *User + var err error + + if username := r.Context().Value(usernameKey).(string); username != "" { + user, err = dbGetUserByName(username) + } else { + render.Render(w, r, ErrNotFound) + return + } + if err != nil { + render.Render(w, r, ErrNotFound) + return + } + + ctx := context.WithValue(r.Context(), userKey{}, user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func GetUser(w http.ResponseWriter, r *http.Request) { user, ok := r.Context().Value(userKey{}).(*User) if !ok || user == nil {