about summary refs log tree commit diff stats
path: root/internal/server
diff options
context:
space:
mode:
authorAlan Pearce2024-06-27 11:04:06 +0200
committerAlan Pearce2024-06-27 11:07:17 +0200
commitdb122cd2fd0c7210acafc3752dcffe926370cc28 (patch)
treeff49c00b98d95277da0a5f2e6697190c844f3e12 /internal/server
parent765a227bbf42983a9edb3eaac6e48df7a43f2808 (diff)
downloadwebsite-db122cd2fd0c7210acafc3752dcffe926370cc28.tar.lz
website-db122cd2fd0c7210acafc3752dcffe926370cc28.tar.zst
website-db122cd2fd0c7210acafc3752dcffe926370cc28.zip
avoid redirect chains (http -> https, host1 -> host2)
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/server.go39
-rw-r--r--internal/server/tls.go64
2 files changed, 80 insertions, 23 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index 0f7701a..1512632 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -47,8 +47,9 @@ type Config struct {
 
 type Server struct {
 	*http.Server
-	runtimeConfig *Config
-	config        *cfg.Config
+	redirectHandler func(http.ResponseWriter, *http.Request)
+	runtimeConfig   *Config
+	config          *cfg.Config
 }
 
 func applyDevModeOverrides(config *cfg.Config, runtimeConfig *Config) {
@@ -59,9 +60,13 @@ func applyDevModeOverrides(config *cfg.Config, runtimeConfig *Config) {
 	} else {
 		config.Domains = []string{runtimeConfig.ListenAddress}
 	}
+	scheme := "http"
+	if runtimeConfig.TLS {
+		scheme = "https"
+	}
 	config.BaseURL = cfg.URL{
 		URL: &url.URL{
-			Scheme: "http",
+			Scheme: scheme,
 			Host:   runtimeConfig.ListenAddress,
 		},
 	}
@@ -76,9 +81,22 @@ func updateCSPHashes(config *cfg.Config, r *builder.Result) {
 
 func serverHeaderHandler(wrappedHandler http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		log.Debug(
+			"headers",
+			"proto",
+			r.Header.Get("X-Forwarded-Proto"),
+			"host",
+			r.Header.Get("X-Forwarded-Host"),
+			"scheme",
+			r.URL.Scheme,
+			"secure",
+			r.TLS != nil,
+		)
+		log.Debug("host", "request", r.Host, "header", r.Header.Get("Host"))
 		if r.ProtoMajor >= 2 && r.Header.Get("Host") != "" {
 			// net/http does this for HTTP/1.1, but not h2c
 			// TODO: check with HTTP/2.0 (i.e. with TLS)
+			log.Debug("host", "request", r.Host, "header", r.Header.Get("Host"))
 			r.Host = r.Header.Get("Host")
 			r.Header.Del("Host")
 		}
@@ -174,12 +192,14 @@ func New(runtimeConfig *Config) (*Server, error) {
 		return nil, errors.Wrap(err, "could not create website mux")
 	}
 
+	redirectHandler := func(w http.ResponseWriter, r *http.Request) {
+		path, _ := website.CanonicalisePath(r.URL.Path)
+		newURL := config.BaseURL.JoinPath(path)
+		http.Redirect(w, r, newURL.String(), 301)
+	}
 	if runtimeConfig.Redirect {
 		loggingMux.Handle(config.BaseURL.Hostname()+"/", mux)
-		loggingMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
-			newURL := config.BaseURL.JoinPath(r.URL.String())
-			http.Redirect(w, r, newURL.String(), 301)
-		})
+		loggingMux.HandleFunc("/", redirectHandler)
 	} else {
 		loggingMux.Handle("/", mux)
 	}
@@ -205,8 +225,9 @@ func New(runtimeConfig *Config) (*Server, error) {
 				},
 			), 0),
 		},
-		config:        config,
-		runtimeConfig: runtimeConfig,
+		redirectHandler: redirectHandler,
+		config:          config,
+		runtimeConfig:   runtimeConfig,
 	}, nil
 }
 
diff --git a/internal/server/tls.go b/internal/server/tls.go
index f6bc320..ce2e69d 100644
--- a/internal/server/tls.go
+++ b/internal/server/tls.go
@@ -2,7 +2,11 @@ package server
 
 import (
 	"context"
+	"crypto/tls"
 	"crypto/x509"
+	"net"
+	"net/http"
+	"strconv"
 	"website/internal/log"
 
 	"github.com/ardanlabs/conf/v3"
@@ -21,6 +25,12 @@ type redisConfig struct {
 }
 
 func (s *Server) serveTLS() (err error) {
+	cfg := certmagic.NewDefault()
+	cfg.DefaultServerName = s.config.Domains[0]
+
+	certmagic.DefaultACME.Agreed = true
+	certmagic.DefaultACME.Email = s.config.Email
+
 	if s.runtimeConfig.Development {
 		ca := s.runtimeConfig.ACMECA
 		if ca == "" {
@@ -33,20 +43,20 @@ func (s *Server) serveTLS() (err error) {
 			cp = x509.NewCertPool()
 		}
 
-		cacert := s.runtimeConfig.ACMECACert
-		if cacert != "" {
+		if cacert := s.runtimeConfig.ACMECACert; cacert != "" {
 			cp.AppendCertsFromPEM([]byte(cacert))
 		}
 
-		cfg := certmagic.NewDefault()
-		issuer := certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
+		// caddy's ACME server (step-ca) doesn't specify an OCSP server
+		cfg.OCSP.DisableStapling = true
+
+		cfg.Issuers[0] = certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
 			CA:                      s.runtimeConfig.ACMECA,
 			TrustedRoots:            cp,
 			DisableTLSALPNChallenge: true,
 			AltHTTPPort:             s.runtimeConfig.Port,
+			AltTLSALPNPort:          s.runtimeConfig.TLSPort,
 		})
-
-		certmagic.DefaultACME = *issuer
 	} else {
 		rc := &redisConfig{}
 		_, err = conf.Parse("REDIS", rc)
@@ -61,7 +71,7 @@ func (s *Server) serveTLS() (err error) {
 		rs.EncryptionKey = rc.EncryptionKey
 		rs.KeyPrefix = rc.KeyPrefix
 
-		certmagic.Default.Storage = rs
+		cfg.Storage = rs
 		err = rs.Provision(caddy.Context{
 			Context: context.Background(),
 		})
@@ -70,12 +80,6 @@ func (s *Server) serveTLS() (err error) {
 		}
 	}
 
-	certmagic.DefaultACME.Agreed = true
-	certmagic.DefaultACME.Email = s.config.Email
-	certmagic.Default.DefaultServerName = s.config.Domains[0]
-	certmagic.HTTPPort = s.runtimeConfig.Port
-	certmagic.HTTPSPort = s.runtimeConfig.TLSPort
-
 	log.Debug(
 		"starting certmagic",
 		"http_port",
@@ -83,6 +87,38 @@ func (s *Server) serveTLS() (err error) {
 		"https_port",
 		certmagic.HTTPSPort,
 	)
+	err = cfg.ManageSync(context.TODO(), s.config.Domains)
+	if err != nil {
+		return errors.Wrap(err, "could not enable TLS")
+	}
+	tlsConfig := cfg.TLSConfig()
+	tlsConfig.NextProtos = append([]string{"h2", "http/1.1"}, tlsConfig.NextProtos...)
+
+	sln, err := tls.Listen(
+		"tcp",
+		net.JoinHostPort(s.runtimeConfig.ListenAddress, strconv.Itoa(s.runtimeConfig.TLSPort)),
+		tlsConfig,
+	)
+	if err != nil {
+		return errors.Wrap(err, "could not bind tls socket")
+	}
+
+	ln, err := net.Listen(
+		"tcp",
+		net.JoinHostPort(s.runtimeConfig.ListenAddress, strconv.Itoa(s.runtimeConfig.Port)),
+	)
+	if err != nil {
+		return errors.Wrap(err, "could not bind plain socket")
+	}
+
+	go func(ln net.Listener) {
+		redirecter := http.NewServeMux()
+		redirecter.HandleFunc("/", s.redirectHandler)
+		err := http.Serve(ln, redirecter)
+		if err != nil && !errors.Is(err, http.ErrServerClosed) {
+			log.Error("error in http handler", "error", err)
+		}
+	}(ln)
 
-	return certmagic.HTTPS(s.config.Domains, s.Server.Handler)
+	return s.Serve(sln)
 }