about summary refs log tree commit diff stats
path: root/internal/server/tls.go
diff options
context:
space:
mode:
authorAlan Pearce2024-07-03 10:35:12 +0200
committerAlan Pearce2024-07-03 10:37:05 +0200
commitc4c67c0f3a07ebed224dfc9de4c93d10c47f149a (patch)
tree1d5b72522b18206d6b786d1ea6279eb47dd5bd9d /internal/server/tls.go
parent336ddaf703ec403661ee3d588512934019ff9b5c (diff)
downloadwebsite-c4c67c0f3a07ebed224dfc9de4c93d10c47f149a.tar.lz
website-c4c67c0f3a07ebed224dfc9de4c93d10c47f149a.tar.zst
website-c4c67c0f3a07ebed224dfc9de4c93d10c47f149a.zip
make HTTP->S redirects use same host only for HSTS
Diffstat (limited to 'internal/server/tls.go')
-rw-r--r--internal/server/tls.go28
1 files changed, 24 insertions, 4 deletions
diff --git a/internal/server/tls.go b/internal/server/tls.go
index 6b64a79..ebf76a2 100644
--- a/internal/server/tls.go
+++ b/internal/server/tls.go
@@ -97,12 +97,32 @@ func (s *Server) serveTLS() (err error) {
 		return errors.Wrap(err, "could not bind plain socket")
 	}
 
-	go func(ln net.Listener) {
-		s.redirectServer.Handler = issuer.HTTPChallengeHandler(s.redirectServer.Handler)
-		if err := s.redirectServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
+	go func(ln net.Listener, srv *http.Server) {
+		httpMux := http.NewServeMux()
+		httpMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			if certmagic.LooksLikeHTTPChallenge(r) && issuer.HandleHTTPChallenge(w, r) {
+				return
+			}
+			url := r.URL
+			url.Scheme = "https"
+			host, _, err := net.SplitHostPort(r.Host)
+			if err != nil {
+				log.Warn("error splitting host and port", "error", err)
+				host = s.config.BaseURL.Hostname()
+			}
+			url.Host = net.JoinHostPort(host, s.config.BaseURL.Port())
+			http.Redirect(w, r, url.String(), http.StatusMovedPermanently)
+		})
+
+		if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
 			log.Error("error in http handler", "error", err)
 		}
-	}(ln)
+	}(ln, &http.Server{
+		ReadHeaderTimeout: s.ReadHeaderTimeout,
+		ReadTimeout:       s.ReadTimeout,
+		WriteTimeout:      s.WriteTimeout,
+		IdleTimeout:       s.IdleTimeout,
+	})
 
 	log.Debug(
 		"starting certmagic",