about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--internal/server/server.go26
-rw-r--r--internal/server/tls.go28
2 files changed, 32 insertions, 22 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index ba5effe..b174c0c 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -46,10 +46,9 @@ type Config struct {
 
 type Server struct {
 	*http.Server
-	redirectServer *http.Server
-	runtimeConfig  *Config
-	config         *cfg.Config
-	log            *log.Logger
+	runtimeConfig *Config
+	config        *cfg.Config
+	log           *log.Logger
 }
 
 func applyDevModeOverrides(config *cfg.Config, runtimeConfig *Config) {
@@ -184,15 +183,13 @@ func New(runtimeConfig *Config, log *log.Logger) (*Server, error) {
 		return nil, errors.Wrap(err, "could not create website mux")
 	}
 
-	rMux := http.NewServeMux()
-	rMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
-		path, _ := website.CanonicalisePath(r.URL.Path)
-		newURL := config.BaseURL.JoinPath(path)
-		http.Redirect(w, r, newURL.String(), 301)
-	})
 	if runtimeConfig.Redirect {
 		loggingMux.Handle(config.BaseURL.Hostname()+"/", mux)
-		loggingMux.Handle("/", rMux)
+		loggingMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			path, _ := website.CanonicalisePath(r.URL.Path)
+			newURL := config.BaseURL.JoinPath(path)
+			http.Redirect(w, r, newURL.String(), http.StatusMovedPermanently)
+		})
 	} else {
 		loggingMux.Handle("/", mux)
 	}
@@ -215,13 +212,6 @@ func New(runtimeConfig *Config, log *log.Logger) (*Server, error) {
 			IdleTimeout:       10 * time.Minute,
 			Handler:           top,
 		},
-		redirectServer: &http.Server{
-			ReadHeaderTimeout: 10 * time.Second,
-			ReadTimeout:       1 * time.Minute,
-			WriteTimeout:      2 * time.Minute,
-			IdleTimeout:       10 * time.Minute,
-			Handler:           rMux,
-		},
 		log:           log,
 		config:        config,
 		runtimeConfig: runtimeConfig,
diff --git a/internal/server/tls.go b/internal/server/tls.go
index 6b64a79..ebf76a2 100644
--- a/internal/server/tls.go
+++ b/internal/server/tls.go
@@ -97,12 +97,32 @@ func (s *Server) serveTLS() (err error) {
 		return errors.Wrap(err, "could not bind plain socket")
 	}
 
-	go func(ln net.Listener) {
-		s.redirectServer.Handler = issuer.HTTPChallengeHandler(s.redirectServer.Handler)
-		if err := s.redirectServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
+	go func(ln net.Listener, srv *http.Server) {
+		httpMux := http.NewServeMux()
+		httpMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			if certmagic.LooksLikeHTTPChallenge(r) && issuer.HandleHTTPChallenge(w, r) {
+				return
+			}
+			url := r.URL
+			url.Scheme = "https"
+			host, _, err := net.SplitHostPort(r.Host)
+			if err != nil {
+				log.Warn("error splitting host and port", "error", err)
+				host = s.config.BaseURL.Hostname()
+			}
+			url.Host = net.JoinHostPort(host, s.config.BaseURL.Port())
+			http.Redirect(w, r, url.String(), http.StatusMovedPermanently)
+		})
+
+		if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
 			log.Error("error in http handler", "error", err)
 		}
-	}(ln)
+	}(ln, &http.Server{
+		ReadHeaderTimeout: s.ReadHeaderTimeout,
+		ReadTimeout:       s.ReadTimeout,
+		WriteTimeout:      s.WriteTimeout,
+		IdleTimeout:       s.IdleTimeout,
+	})
 
 	log.Debug(
 		"starting certmagic",