about summary refs log tree commit diff stats
path: root/internal/server/tls.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/tls.go')
-rw-r--r--internal/server/tls.go46
1 files changed, 28 insertions, 18 deletions
diff --git a/internal/server/tls.go b/internal/server/tls.go
index 848d97c..fa9e69a 100644
--- a/internal/server/tls.go
+++ b/internal/server/tls.go
@@ -25,12 +25,38 @@ type redisConfig struct {
 }
 
 func (s *Server) serveTLS() (err error) {
+	var issuer *certmagic.ACMEIssuer
+
 	cfg := certmagic.NewDefault()
 	cfg.DefaultServerName = s.config.Domains[0]
 
+	issuer = &certmagic.DefaultACME
 	certmagic.DefaultACME.Agreed = true
 	certmagic.DefaultACME.Email = s.config.Email
 
+	ln, err := listenfd.GetListener(
+		1,
+		net.JoinHostPort(s.runtimeConfig.ListenAddress, strconv.Itoa(s.runtimeConfig.Port)),
+	)
+	if err != nil {
+		return errors.Wrap(err, "could not bind plain socket")
+	}
+
+	go func(ln net.Listener) {
+		redirecter := http.NewServeMux()
+		redirecter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			if certmagic.LooksLikeHTTPChallenge(r) {
+				issuer.HandleHTTPChallenge(w, r)
+			} else {
+				s.redirectHandler(w, r)
+			}
+		})
+		err := http.Serve(ln, redirecter)
+		if err != nil && !errors.Is(err, http.ErrServerClosed) {
+			log.Error("error in http handler", "error", err)
+		}
+	}(ln)
+
 	if s.runtimeConfig.Development {
 		ca := s.runtimeConfig.ACMECA
 		if ca == "" {
@@ -55,7 +81,7 @@ func (s *Server) serveTLS() (err error) {
 			listenAddress = listenAddress[1 : len(listenAddress)-1]
 		}
 
-		cfg.Issuers[0] = certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
+		issuer = certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
 			CA:                      s.runtimeConfig.ACMECA,
 			TrustedRoots:            cp,
 			DisableTLSALPNChallenge: true,
@@ -63,6 +89,7 @@ func (s *Server) serveTLS() (err error) {
 			AltHTTPPort:             s.runtimeConfig.Port,
 			AltTLSALPNPort:          s.runtimeConfig.TLSPort,
 		})
+		cfg.Issuers[0] = issuer
 	} else {
 		rc := &redisConfig{}
 		_, err = conf.Parse("REDIS", rc)
@@ -109,22 +136,5 @@ func (s *Server) serveTLS() (err error) {
 		return errors.Wrap(err, "could not bind tls socket")
 	}
 
-	ln, err := listenfd.GetListener(
-		1,
-		net.JoinHostPort(s.runtimeConfig.ListenAddress, strconv.Itoa(s.runtimeConfig.Port)),
-	)
-	if err != nil {
-		return errors.Wrap(err, "could not bind plain socket")
-	}
-
-	go func(ln net.Listener) {
-		redirecter := http.NewServeMux()
-		redirecter.HandleFunc("/", s.redirectHandler)
-		err := http.Serve(ln, redirecter)
-		if err != nil && !errors.Is(err, http.ErrServerClosed) {
-			log.Error("error in http handler", "error", err)
-		}
-	}(ln)
-
 	return s.Serve(sln)
 }