about summary refs log tree commit diff stats
path: root/internal/server/tls.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/tls.go')
-rw-r--r--internal/server/tls.go83
1 files changed, 58 insertions, 25 deletions
diff --git a/internal/server/tls.go b/internal/server/tls.go
index f7f8c07..9481b6a 100644
--- a/internal/server/tls.go
+++ b/internal/server/tls.go
@@ -2,18 +2,18 @@ package server
 
 import (
 	"context"
-	"crypto/tls"
 	"crypto/x509"
 	"net"
 	"net/http"
 	"strconv"
-	"website/internal/log"
+
+	"go.alanpearce.eu/x/listenfd"
 
 	"github.com/ardanlabs/conf/v3"
 	"github.com/caddyserver/caddy/v2"
 	"github.com/caddyserver/certmagic"
 	certmagic_redis "github.com/pberkel/caddy-storage-redis"
-	"github.com/pkg/errors"
+	"gitlab.com/tozd/go/errors"
 )
 
 type redisConfig struct {
@@ -25,11 +25,17 @@ type redisConfig struct {
 }
 
 func (s *Server) serveTLS() (err error) {
+	log := s.log.Named("tls")
+
+	// setting cfg.Logger is too late somehow
+	certmagic.Default.Logger = log.GetLogger().Named("certmagic")
 	cfg := certmagic.NewDefault()
 	cfg.DefaultServerName = s.config.Domains[0]
 
+	issuer := &certmagic.DefaultACME
 	certmagic.DefaultACME.Agreed = true
 	certmagic.DefaultACME.Email = s.config.Email
+	certmagic.DefaultACME.Logger = certmagic.Default.Logger
 
 	if s.runtimeConfig.Development {
 		ca := s.runtimeConfig.ACMECA
@@ -50,7 +56,7 @@ func (s *Server) serveTLS() (err error) {
 		// caddy's ACME server (step-ca) doesn't specify an OCSP server
 		cfg.OCSP.DisableStapling = true
 
-		cfg.Issuers[0] = certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
+		issuer = certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
 			CA:                      s.runtimeConfig.ACMECA,
 			TrustedRoots:            cp,
 			DisableTLSALPNChallenge: true,
@@ -58,6 +64,7 @@ func (s *Server) serveTLS() (err error) {
 			AltHTTPPort:             s.runtimeConfig.Port,
 			AltTLSALPNPort:          s.runtimeConfig.TLSPort,
 		})
+		cfg.Issuers[0] = issuer
 	} else {
 		rc := &redisConfig{}
 		_, err = conf.Parse("REDIS", rc)
@@ -81,12 +88,54 @@ func (s *Server) serveTLS() (err error) {
 		}
 	}
 
+	ln, err := listenfd.GetListener(
+		1,
+		net.JoinHostPort(s.runtimeConfig.ListenAddress, strconv.Itoa(s.runtimeConfig.Port)),
+		log.Named("listenfd"),
+	)
+	if err != nil {
+		return errors.Wrap(err, "could not bind plain socket")
+	}
+
+	go func(ln net.Listener, srv *http.Server) {
+		httpMux := http.NewServeMux()
+		httpMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			if certmagic.LooksLikeHTTPChallenge(r) && issuer.HandleHTTPChallenge(w, r) {
+				return
+			}
+			url := r.URL
+			url.Scheme = "https"
+			port := s.config.BaseURL.Port()
+			if port == "" {
+				url.Host = r.Host
+			} else {
+				host, _, err := net.SplitHostPort(r.Host)
+				if err != nil {
+					log.Warn("error splitting host and port", "error", err)
+					host = r.Host
+				}
+				url.Host = net.JoinHostPort(host, s.config.BaseURL.Port())
+			}
+			http.Redirect(w, r, url.String(), http.StatusMovedPermanently)
+		})
+		srv.Handler = httpMux
+
+		if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
+			log.Error("error in http handler", "error", err)
+		}
+	}(ln, &http.Server{
+		ReadHeaderTimeout: s.ReadHeaderTimeout,
+		ReadTimeout:       s.ReadTimeout,
+		WriteTimeout:      s.WriteTimeout,
+		IdleTimeout:       s.IdleTimeout,
+	})
+
 	log.Debug(
 		"starting certmagic",
 		"http_port",
-		certmagic.HTTPPort,
+		s.runtimeConfig.Port,
 		"https_port",
-		certmagic.HTTPSPort,
+		s.runtimeConfig.TLSPort,
 	)
 	err = cfg.ManageSync(context.TODO(), s.config.Domains)
 	if err != nil {
@@ -95,31 +144,15 @@ func (s *Server) serveTLS() (err error) {
 	tlsConfig := cfg.TLSConfig()
 	tlsConfig.NextProtos = append([]string{"h2", "http/1.1"}, tlsConfig.NextProtos...)
 
-	sln, err := tls.Listen(
-		"tcp",
+	sln, err := listenfd.GetListenerTLS(
+		0,
 		net.JoinHostPort(s.runtimeConfig.ListenAddress, strconv.Itoa(s.runtimeConfig.TLSPort)),
 		tlsConfig,
+		log.Named("listenfd"),
 	)
 	if err != nil {
 		return errors.Wrap(err, "could not bind tls socket")
 	}
 
-	ln, err := net.Listen(
-		"tcp",
-		net.JoinHostPort(s.runtimeConfig.ListenAddress, strconv.Itoa(s.runtimeConfig.Port)),
-	)
-	if err != nil {
-		return errors.Wrap(err, "could not bind plain socket")
-	}
-
-	go func(ln net.Listener) {
-		redirecter := http.NewServeMux()
-		redirecter.HandleFunc("/", s.redirectHandler)
-		err := http.Serve(ln, redirecter)
-		if err != nil && !errors.Is(err, http.ErrServerClosed) {
-			log.Error("error in http handler", "error", err)
-		}
-	}(ln)
-
 	return s.Serve(sln)
 }