From b0bcb3460b45f1909586c2434aa1d58f01b88298 Mon Sep 17 00:00:00 2001 From: William P Date: Sun, 24 May 2026 16:48:52 +0000 Subject: [PATCH] server: implement local file management --- server/api/api.go | 10 +++ server/api/db.go | 27 +++++++ server/api/file.go | 156 +++++++++++++++++++++++++++++++++++++++++ server/api/response.go | 12 ++++ server/main.go | 2 +- 5 files changed, 206 insertions(+), 1 deletion(-) create mode 100644 server/api/file.go diff --git a/server/api/api.go b/server/api/api.go index bd050fb..3b4e926 100644 --- a/server/api/api.go +++ b/server/api/api.go @@ -13,6 +13,8 @@ func Start() { db.InitPostgres(ctx) defer db.ClosePostgres() + Store = initFileStore() + r := chi.NewRouter() r.Get("/", func(w http.ResponseWriter, r *http.Request) { @@ -45,6 +47,14 @@ func Start() { }) }) + r.Route("/files", func(r chi.Router) { + r.Use(SessionAuthMiddleware) + + r.Route("/{fileID}", func(r chi.Router) { + r.Get("/", ServeFile) + }) + }) + r.Route("/login", func(r chi.Router) { r.Post("/", Login) }) diff --git a/server/api/db.go b/server/api/db.go index a040bbb..ae7dc05 100644 --- a/server/api/db.go +++ b/server/api/db.go @@ -13,6 +13,7 @@ import ( var ErrUserNotFound = errors.New("db: user not found") var ErrSessionNotFound = errors.New("db: session not found") var ErrChannelNotFound = errors.New("db: channel not found") +var ErrFileNotFound = errors.New("db: file not found") func dbGetUser(id string) (*User, error) { query := `SELECT id, name, password FROM users WHERE id = $1` @@ -206,3 +207,29 @@ func dbDeleteChannel(id string) error { slog.Debug("db: channel deleted") return nil } + +func dbAddFile(file *File) error { + query := `INSERT INTO files (id, name, created, backend, path) VALUES ($1, $2, $3, $4, $5)` + _, err := db.Pool.Exec(context.Background(), query, file.ID, file.Name, file.Created, file.Backend, file.Path) + if err != nil { + slog.Error("db: failed to add file", "error", err, "fileid", file.ID) + return fmt.Errorf("failed to add file") + } + slog.Debug("db: file added", "fileid", file.ID, "filename", file.Name) + return nil +} + +func dbGetFile(id string) (*File, error) { + query := `SELECT id, name, created, backend, path FROM files WHERE id = $1` + var file File + err := db.Pool.QueryRow(context.Background(), query, id).Scan(&file.ID, &file.Name, &file.Created, &file.Backend, &file.Path) + if errors.Is(err, pgx.ErrNoRows) { + slog.Debug("db: file not found", "fileid", id) + return nil, ErrFileNotFound + } else if err != nil { + slog.Error("db: failed to query file", "error", err) + return nil, fmt.Errorf("failed to query file") + } + slog.Debug("db: file found", "fileid", file.ID, "filename", file.Name) + return &file, nil +} diff --git a/server/api/file.go b/server/api/file.go new file mode 100644 index 0000000..97b18fb --- /dev/null +++ b/server/api/file.go @@ -0,0 +1,156 @@ +package api + +import ( + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + "github.com/google/uuid" +) + +func initFileStore() FileStore { + val, ok := os.LookupEnv("FILE_BACKEND") + if !ok { + slog.Error("FILE_BACKEND environment variable not set") + os.Exit(1) + } + switch FileBackend(val) { + case FileBackendLocal: + localFilePath, ok := os.LookupEnv("LOCAL_FILEPATH") + if !ok { + slog.Error("LOCAL_FILEPATH environment variable not set") + os.Exit(1) + } + return &LocalFileStore{BaseDir: localFilePath} + } + slog.Error("unsupported FILE_BACKEND", "value", val) + os.Exit(1) + return nil +} + +type File struct { + ID uuid.UUID + Name string + Created time.Time + Backend FileBackend + Path string +} + +type FileBackend string + +const ( + FileBackendLocal FileBackend = "local" + FileBackendS3 FileBackend = "s3" +) + +var Store FileStore + +type FileStore interface { + Save(name string, r io.Reader) (*File, error) + URL(file *File) (string, error) +} + +type LocalFileStore struct { + BaseDir string +} + +func (s *LocalFileStore) Save(name string, r io.Reader) (*File, error) { + id := uuid.New() + path := filepath.Join(s.BaseDir, id.String()) + + f, err := os.Create(path) + if err != nil { + return nil, fmt.Errorf("file(local): failed to create file: %w", err) + } + defer f.Close() + + if _, err := io.Copy(f, r); err != nil { + os.Remove(path) + return nil, fmt.Errorf("file(local): failed to write file: %w", err) + } + + return &File{ + ID: id, + Name: name, + Created: time.Now(), + Backend: FileBackendLocal, + Path: path, + }, nil +} + +func (s *LocalFileStore) URL(file *File) (string, error) { + return "/files/" + file.ID.String(), nil +} + +func ServeFile(w http.ResponseWriter, r *http.Request) { + slog.Debug("file: entering ServeFile handler") + + fileID := chi.URLParam(r, "fileID") + parsed, err := uuid.Parse(fileID) + if err != nil { + render.Render(w, r, ErrInvalidRequest(err)) + return + } + + file, err := dbGetFile(parsed.String()) + if err != nil { + if errors.Is(err, ErrFileNotFound) { + render.Render(w, r, ErrNotFound) + } else { + slog.Error("file: failed to fetch file", "fileid", parsed.String(), "error", err) + render.Render(w, r, ErrInternal(err)) + } + return + } + + f, err := os.Open(file.Path) + if err != nil { + slog.Error("file: failed to open file", "fileid", file.ID, "error", err) + render.Render(w, r, ErrInternal(err)) + return + } + defer f.Close() + + http.ServeContent(w, r, file.Name, file.Created, f) +} + +// UploadFile is a temporary handler for testing file uploads. +/* +func UploadFile(w http.ResponseWriter, r *http.Request) { + slog.Debug("file: entering UploadFile handler") + + if err := r.ParseMultipartForm(32 << 20); err != nil { + render.Render(w, r, ErrInvalidRequest(err)) + return + } + + f, header, err := r.FormFile("file") + if err != nil { + render.Render(w, r, ErrInvalidRequest(err)) + return + } + defer f.Close() + + file, err := Store.Save(header.Filename, f) + if err != nil { + slog.Error("file: failed to save file", "error", err) + render.Render(w, r, ErrInternal(err)) + return + } + + if err := dbAddFile(file); err != nil { + render.Render(w, r, ErrInternal(err)) + return + } + + slog.Debug("file: uploaded file", "fileid", file.ID, "filename", file.Name) + render.Render(w, r, NewFilePayloadResponse(file)) +} +*/ diff --git a/server/api/response.go b/server/api/response.go index 59890d6..bb7c8de 100644 --- a/server/api/response.go +++ b/server/api/response.go @@ -37,3 +37,15 @@ func NewChannelListResponse(channels []*Channel) []render.Renderer { func (c *ChannelPayload) Render(w http.ResponseWriter, r *http.Request) error { return nil } + +type FilePayload struct { + *File +} + +func NewFilePayloadResponse(file *File) *FilePayload { + return &FilePayload{File: file} +} + +func (f *FilePayload) Render(w http.ResponseWriter, r *http.Request) error { + return nil +} diff --git a/server/main.go b/server/main.go index 8383cac..cca09a4 100644 --- a/server/main.go +++ b/server/main.go @@ -9,7 +9,7 @@ import ( ) var REQUIRED_ENVS = [...]string{ - "DATABASE_URL", "JWT_SECRET", + "DATABASE_URL", "JWT_SECRET", "FILE_BACKEND", } func checkEnvVars(keys []string) (bool, []string) {