diff options
Diffstat (limited to 'internal/server')
-rw-r--r-- | internal/server/filemap.go | 77 | ||||
-rw-r--r-- | internal/server/logging.go | 55 | ||||
-rw-r--r-- | internal/server/server.go | 221 |
3 files changed, 353 insertions, 0 deletions
diff --git a/internal/server/filemap.go b/internal/server/filemap.go new file mode 100644 index 0000000..466db49 --- /dev/null +++ b/internal/server/filemap.go @@ -0,0 +1,77 @@ +package server + +import ( + "fmt" + "hash/fnv" + "io" + "io/fs" + "log" + "log/slog" + "os" + "path/filepath" + "strings" + + "github.com/pkg/errors" +) + +type File struct { + filename string + etag string +} + +var files = map[string]File{} + +func hashFile(filename string) (string, error) { + f, err := os.Open(filename) + if err != nil { + return "", err + } + defer f.Close() + hash := fnv.New64a() + if _, err := io.Copy(hash, f); err != nil { + return "", err + } + return fmt.Sprintf(`W/"%x"`, hash.Sum(nil)), nil +} + +func registerFile(urlpath string, filepath string) error { + if files[urlpath] != (File{}) { + log.Printf("registerFile called with duplicate file, urlPath: %s", urlpath) + return nil + } + hash, err := hashFile(filepath) + if err != nil { + return err + } + files[urlpath] = File{ + filename: filepath, + etag: hash, + } + return nil +} + +func registerContentFiles(root string) error { + err := filepath.WalkDir(root, func(filePath string, f fs.DirEntry, err error) error { + if err != nil { + return errors.WithMessagef(err, "failed to access path %s", filePath) + } + relPath, err := filepath.Rel(root, filePath) + if err != nil { + return errors.WithMessagef(err, "failed to make path relative, path: %s", filePath) + } + urlPath, _ := strings.CutSuffix(relPath, "index.html") + if !f.IsDir() { + slog.Debug("registering file", "urlpath", "/"+urlPath) + return registerFile("/"+urlPath, filePath) + } + return nil + }) + if err != nil { + return err + } + return nil +} + +func GetFile(urlPath string) File { + return files[urlPath] +} diff --git a/internal/server/logging.go b/internal/server/logging.go new file mode 100644 index 0000000..135f06e --- /dev/null +++ b/internal/server/logging.go @@ -0,0 +1,55 @@ +package server + +import ( + "fmt" + "io" + "net/http" +) + +type loggingResponseWriter struct { + http.ResponseWriter + statusCode int +} + +func (lrw *loggingResponseWriter) WriteHeader(code int) { + lrw.statusCode = code + // avoids warning: superfluous response.WriteHeader call + if lrw.statusCode != http.StatusOK { + lrw.ResponseWriter.WriteHeader(code) + } +} + +func NewLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { + return &loggingResponseWriter{w, http.StatusOK} +} + +type wrappedHandlerOptions struct { + defaultHostname string + logger io.Writer +} + +func wrapHandlerWithLogging(wrappedHandler http.Handler, opts wrappedHandlerOptions) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + scheme := r.Header.Get("X-Forwarded-Proto") + if scheme == "" { + scheme = "http" + } + host := r.Header.Get("Host") + if host == "" { + host = opts.defaultHostname + } + lw := NewLoggingResponseWriter(w) + wrappedHandler.ServeHTTP(lw, r) + statusCode := lw.statusCode + fmt.Fprintf( + opts.logger, + "%s %s %d %s %s %s\n", + scheme, + r.Method, + statusCode, + host, + r.URL.Path, + lw.Header().Get("Location"), + ) + }) +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..2e9796d --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,221 @@ +package server + +import ( + "context" + "fmt" + "io" + "log" + "log/slog" + "mime" + "net" + "net/http" + "os" + "path" + "slices" + "strings" + "time" + + cfg "website/internal/config" + + "github.com/getsentry/sentry-go" + sentryhttp "github.com/getsentry/sentry-go/http" + "github.com/pkg/errors" + "github.com/shengyanli1982/law" +) + +var config *cfg.Config + +type Config struct { + Production bool `conf:"default:false"` + InDevServer bool `conf:"default:false"` + Root string `conf:"default:website"` + ListenAddress string `conf:"default:localhost"` + Port string `conf:"default:3000,short:p"` + BaseURL cfg.URL `conf:"default:http://localhost:3000,short:b"` +} + +type HTTPError struct { + Error error + Message string + Code int +} + +type Server struct { + *http.Server +} + +func canonicalisePath(path string) (cPath string, differs bool) { + cPath = path + if strings.HasSuffix(path, "/index.html") { + cPath, differs = strings.CutSuffix(path, "index.html") + } else if !strings.HasSuffix(path, "/") && files[path+"/"] != (File{}) { + cPath, differs = path+"/", true + } + return cPath, differs +} + +func serveFile(w http.ResponseWriter, r *http.Request) *HTTPError { + urlPath, shouldRedirect := canonicalisePath(r.URL.Path) + if shouldRedirect { + http.Redirect(w, r, urlPath, 302) + return nil + } + file := GetFile(urlPath) + if file == (File{}) { + return &HTTPError{ + Message: "File not found", + Code: http.StatusNotFound, + } + } + w.Header().Add("ETag", file.etag) + w.Header().Add("Vary", "Accept-Encoding") + w.Header().Add("Content-Security-Policy", config.CSP.String()) + for k, v := range config.Extra.Headers { + w.Header().Add(k, v) + } + + http.ServeFile(w, r, files[urlPath].filename) + return nil +} + +type webHandler func(http.ResponseWriter, *http.Request) *HTTPError + +func (fn webHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer func() { + if fail := recover(); fail != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error("runtime panic!", "error", fail) + } + }() + w.Header().Set("Server", fmt.Sprintf("website (%s)", ShortSHA)) + if err := fn(w, r); err != nil { + if strings.Contains(r.Header.Get("Accept"), "text/html") { + w.WriteHeader(err.Code) + notFoundPage := "website/private/404.html" + http.ServeFile(w, r, notFoundPage) + } else { + http.Error(w, err.Message, err.Code) + } + } +} + +var newMIMEs = map[string]string{ + ".xsl": "text/xsl", +} + +func fixupMIMETypes() { + for ext, newType := range newMIMEs { + if err := mime.AddExtensionType(ext, newType); err != nil { + slog.Error("could not update mime type", "ext", ext, "mime", newType) + } + } +} + +func applyDevModeOverrides(config *cfg.Config) { + config.CSP.ScriptSrc = slices.Insert(config.CSP.ScriptSrc, 0, "'unsafe-inline'") + config.CSP.ConnectSrc = slices.Insert(config.CSP.ConnectSrc, 0, "'self'") +} + +func New(runtimeConfig *Config) (*Server, error) { + fixupMIMETypes() + + var err error + config, err = cfg.GetConfig() + if err != nil { + return nil, errors.WithMessage(err, "error parsing configuration file") + } + if runtimeConfig.InDevServer { + applyDevModeOverrides(config) + } + + prefix := path.Join(runtimeConfig.Root, "public") + slog.Debug("registering content files", "prefix", prefix) + err = registerContentFiles(prefix) + if err != nil { + return nil, errors.WithMessagef(err, "registering content files") + } + + env := "development" + if runtimeConfig.Production { + env = "production" + } + err = sentry.Init(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + Dsn: os.Getenv("SENTRY_DSN"), + Release: CommitSHA, + Environment: env, + }) + if err != nil { + return nil, errors.WithMessage(err, "could not set up sentry") + } + defer sentry.Flush(2 * time.Second) + sentryHandler := sentryhttp.New(sentryhttp.Options{ + Repanic: true, + }) + + top := http.NewServeMux() + mux := http.NewServeMux() + slog.Debug("binding main handler to", "host", runtimeConfig.BaseURL.Hostname()+"/") + mux.Handle(runtimeConfig.BaseURL.Hostname()+"/", webHandler(serveFile)) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + newURL := runtimeConfig.BaseURL.String() + r.URL.String() + http.Redirect(w, r, newURL, 301) + }) + + var logWriter io.Writer + if runtimeConfig.Production { + logWriter = law.NewWriteAsyncer(os.Stdout, nil) + } else { + logWriter = os.Stdout + } + top.Handle("/", + sentryHandler.Handle( + wrapHandlerWithLogging(mux, wrappedHandlerOptions{ + defaultHostname: runtimeConfig.BaseURL.Hostname(), + logger: logWriter, + }), + ), + ) + // no logging, no sentry + top.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + listenAddress := net.JoinHostPort(runtimeConfig.ListenAddress, runtimeConfig.Port) + return &Server{ + &http.Server{ + Addr: listenAddress, + Handler: top, + }, + }, nil +} + +func (s *Server) Start() error { + if err := s.ListenAndServe(); err != http.ErrServerClosed { + return err + } + return nil +} + +func (s *Server) Stop() chan struct{} { + slog.Debug("stop called") + + idleConnsClosed := make(chan struct{}) + + go func() { + slog.Debug("shutting down server") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err := s.Server.Shutdown(ctx) + slog.Debug("server shut down") + if err != nil { + // Error from closing listeners, or context timeout: + log.Printf("HTTP server Shutdown: %v", err) + } + close(idleConnsClosed) + }() + + return idleConnsClosed +} |