about summary refs log tree commit diff stats
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/config/config.go3
-rw-r--r--internal/listenfd/listenfd.go67
-rw-r--r--internal/server/server.go80
-rw-r--r--internal/server/tcp.go2
-rw-r--r--internal/server/tls.go75
-rw-r--r--internal/vcs/repository.go34
6 files changed, 136 insertions, 125 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
index 47d5de8..7ccad85 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -42,7 +42,8 @@ type Config struct {
 	OriginalDomain   string `toml:"original_domain"`
 	GoatCounter      URL    `toml:"goatcounter"`
 	Domains          []string
-	OIDCHost         URL `toml:"oidc_host"`
+	WildcardDomain   string `toml:"wildcard_domain"`
+	OIDCHost         URL    `toml:"oidc_host"`
 	Taxonomies       []Taxonomy
 	CSP              *CSP `toml:"content-security-policy"`
 	Extra            struct {
diff --git a/internal/listenfd/listenfd.go b/internal/listenfd/listenfd.go
deleted file mode 100644
index 5287898..0000000
--- a/internal/listenfd/listenfd.go
+++ /dev/null
@@ -1,67 +0,0 @@
-package listenfd
-
-import (
-	"crypto/tls"
-	"net"
-	"os"
-	"strconv"
-
-	"go.alanpearce.eu/x/log"
-
-	"gitlab.com/tozd/go/errors"
-)
-
-const fdStart = 3
-
-func GetListener(i uint64, addr string, log *log.Logger) (l net.Listener, err error) {
-	l, err = getFDSocket(i)
-	if err != nil {
-		log.Warn("could not create listener from listenfd", "error", err)
-	}
-
-	log.Debug("listener from listenfd?", "passed", l != nil)
-	if l == nil {
-		l, err = net.Listen("tcp", addr)
-		if err != nil {
-			return nil, errors.Wrap(err, "could not create listener")
-		}
-	}
-
-	return
-}
-
-func GetListenerTLS(
-	i uint64,
-	addr string,
-	config *tls.Config,
-	log *log.Logger,
-) (l net.Listener, err error) {
-	l, err = GetListener(i, addr, log)
-	if err != nil {
-		return nil, err
-	}
-
-	return tls.NewListener(l, config), nil
-}
-
-func getFDSocket(i uint64) (net.Listener, error) {
-	lfds, present := os.LookupEnv("LISTEN_FDS")
-	if !present {
-		return nil, nil
-	}
-
-	fds, err := strconv.ParseUint(lfds, 10, 32)
-	if err != nil {
-		return nil, errors.Wrap(err, "could not parse LISTEN_FDS")
-	}
-	if i >= fds {
-		return nil, errors.Errorf("only %d fds available, requested index %d", fds, i)
-	}
-
-	l, err := net.FileListener(os.NewFile(uintptr(i+fdStart), ""))
-	if err != nil {
-		return nil, errors.Wrap(err, "could not create listener")
-	}
-
-	return l, nil
-}
diff --git a/internal/server/server.go b/internal/server/server.go
index 8523bc9..269ed9e 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -8,6 +8,7 @@ import (
 	"net/url"
 	"os"
 	"path/filepath"
+	"regexp"
 	"slices"
 	"strconv"
 	"strings"
@@ -46,10 +47,9 @@ type Config struct {
 
 type Server struct {
 	*http.Server
-	redirectServer *http.Server
-	runtimeConfig  *Config
-	config         *cfg.Config
-	log            *log.Logger
+	runtimeConfig *Config
+	config        *cfg.Config
+	log           *log.Logger
 }
 
 func applyDevModeOverrides(config *cfg.Config, runtimeConfig *Config) {
@@ -75,7 +75,6 @@ func applyDevModeOverrides(config *cfg.Config, runtimeConfig *Config) {
 }
 
 func updateCSPHashes(config *cfg.Config, r *builder.Result) {
-	clear(config.CSP.StyleSrc)
 	for i, h := range r.Hashes {
 		config.CSP.StyleSrc[i] = fmt.Sprintf("'%s'", h)
 	}
@@ -110,20 +109,24 @@ func New(runtimeConfig *Config, log *log.Logger) (*Server, error) {
 		if err != nil {
 			return nil, err
 		}
-		_, err = vcs.CloneOrUpdate(vcsConfig, log.Named("vcs"))
-		if err != nil {
-			return nil, err
-		}
-		err = os.Chdir(runtimeConfig.Root)
-		if err != nil {
-			return nil, err
-		}
+		if vcsConfig.LocalPath != "" {
+			_, err = vcs.CloneOrUpdate(vcsConfig, log.Named("vcs"))
+			if err != nil {
+				return nil, err
+			}
+			err = os.Chdir(runtimeConfig.Root)
+			if err != nil {
+				return nil, err
+			}
 
-		builderConfig.Source = vcsConfig.LocalPath
+			builderConfig.Source = vcsConfig.LocalPath
 
-		publicDir := filepath.Join(runtimeConfig.Root, "public")
-		builderConfig.Destination = publicDir
-		runtimeConfig.Root = publicDir
+			publicDir := filepath.Join(runtimeConfig.Root, "public")
+			builderConfig.Destination = publicDir
+			runtimeConfig.Root = publicDir
+		} else {
+			log.Warn("in production mode without VCS configuration")
+		}
 	}
 
 	config, err := cfg.GetConfig(builderConfig.Source, log.Named("config"))
@@ -180,24 +183,36 @@ func New(runtimeConfig *Config, log *log.Logger) (*Server, error) {
 		return nil, errors.Wrap(err, "could not create website mux")
 	}
 
-	rMux := http.NewServeMux()
-	rMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
-		path, _ := website.CanonicalisePath(r.URL.Path)
-		newURL := config.BaseURL.JoinPath(path)
-		http.Redirect(w, r, newURL.String(), 301)
-	})
 	if runtimeConfig.Redirect {
+		re := regexp.MustCompile(
+			"^(.*)\\." + strings.ReplaceAll(config.WildcardDomain, ".", `\.`) + "$",
+		)
+		replace := "${1}." + config.Domains[0]
 		loggingMux.Handle(config.BaseURL.Hostname()+"/", mux)
-		loggingMux.Handle("/", rMux)
+		loggingMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			if slices.Contains(config.Domains, r.Host) {
+				path, _ := website.CanonicalisePath(r.URL.Path)
+				newURL := config.BaseURL.JoinPath(path)
+				http.Redirect(w, r, newURL.String(), http.StatusMovedPermanently)
+			} else {
+				url := config.BaseURL
+				url.Host = re.ReplaceAllString(r.Host, replace)
+				http.Redirect(w, r, url.String(), http.StatusTemporaryRedirect)
+			}
+		})
 	} else {
 		loggingMux.Handle("/", mux)
 	}
 
-	top.Handle("/",
-		serverHeaderHandler(
-			wrapHandlerWithLogging(loggingMux, log),
-		),
-	)
+	if runtimeConfig.Development {
+		top.Handle("/",
+			serverHeaderHandler(
+				wrapHandlerWithLogging(loggingMux, log),
+			),
+		)
+	} else {
+		top.Handle("/", serverHeaderHandler(loggingMux))
+	}
 
 	top.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
 		w.WriteHeader(http.StatusNoContent)
@@ -211,13 +226,6 @@ func New(runtimeConfig *Config, log *log.Logger) (*Server, error) {
 			IdleTimeout:       10 * time.Minute,
 			Handler:           top,
 		},
-		redirectServer: &http.Server{
-			ReadHeaderTimeout: 10 * time.Second,
-			ReadTimeout:       1 * time.Minute,
-			WriteTimeout:      2 * time.Minute,
-			IdleTimeout:       10 * time.Minute,
-			Handler:           rMux,
-		},
 		log:           log,
 		config:        config,
 		runtimeConfig: runtimeConfig,
diff --git a/internal/server/tcp.go b/internal/server/tcp.go
index 12fdeb2..1627854 100644
--- a/internal/server/tcp.go
+++ b/internal/server/tcp.go
@@ -1,7 +1,7 @@
 package server
 
 import (
-	"go.alanpearce.eu/website/internal/listenfd"
+	"go.alanpearce.eu/x/listenfd"
 )
 
 func (s *Server) serveTCP() error {
diff --git a/internal/server/tls.go b/internal/server/tls.go
index cd2bfb8..4d52b8d 100644
--- a/internal/server/tls.go
+++ b/internal/server/tls.go
@@ -7,11 +7,12 @@ import (
 	"net/http"
 	"strconv"
 
-	"go.alanpearce.eu/website/internal/listenfd"
+	"go.alanpearce.eu/x/listenfd"
 
 	"github.com/ardanlabs/conf/v3"
 	"github.com/caddyserver/caddy/v2"
 	"github.com/caddyserver/certmagic"
+	"github.com/libdns/acmedns"
 	certmagic_redis "github.com/pberkel/caddy-storage-redis"
 	"gitlab.com/tozd/go/errors"
 )
@@ -24,8 +25,14 @@ type redisConfig struct {
 	KeyPrefix     string `conf:"default:certmagic"`
 }
 
+type acmeConfig struct {
+	Username  string `conf:"required"`
+	Password  string `conf:"required"`
+	Subdomain string `conf:"required"`
+	ServerURL string `conf:"env:SERVER_URL,default:https://acme.alanpearce.eu"`
+}
+
 func (s *Server) serveTLS() (err error) {
-	var issuer *certmagic.ACMEIssuer
 	log := s.log.Named("tls")
 
 	// setting cfg.Logger is too late somehow
@@ -33,9 +40,7 @@ func (s *Server) serveTLS() (err error) {
 	cfg := certmagic.NewDefault()
 	cfg.DefaultServerName = s.config.Domains[0]
 
-	issuer = &certmagic.DefaultACME
-	certmagic.DefaultACME.Agreed = true
-	certmagic.DefaultACME.Email = s.config.Email
+	var issuer *certmagic.ACMEIssuer
 
 	if s.runtimeConfig.Development {
 		ca := s.runtimeConfig.ACMECA
@@ -63,8 +68,8 @@ func (s *Server) serveTLS() (err error) {
 			ListenHost:              s.runtimeConfig.ListenAddress,
 			AltHTTPPort:             s.runtimeConfig.Port,
 			AltTLSALPNPort:          s.runtimeConfig.TLSPort,
+			Logger:                  certmagic.Default.Logger,
 		})
-		cfg.Issuers[0] = issuer
 	} else {
 		rc := &redisConfig{}
 		_, err = conf.Parse("REDIS", rc)
@@ -72,6 +77,27 @@ func (s *Server) serveTLS() (err error) {
 			return errors.Wrap(err, "could not parse redis config")
 		}
 
+		acme := &acmedns.Provider{}
+		_, err = conf.Parse("ACME", acme)
+		if err != nil {
+			return errors.Wrap(err, "could not parse ACME config")
+		}
+
+		issuer = certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
+			CA:     certmagic.LetsEncryptProductionCA,
+			Email:  s.config.Email,
+			Agreed: true,
+			Logger: certmagic.Default.Logger,
+			DNS01Solver: &certmagic.DNS01Solver{
+				DNSManager: certmagic.DNSManager{
+					DNSProvider: acme,
+					Logger:      certmagic.Default.Logger,
+				},
+			},
+		})
+
+		log.Info("acme", "username", acme.Username, "subdomain", acme.Subdomain, "server_url", acme.ServerURL)
+
 		rs := certmagic_redis.New()
 		rs.Address = []string{rc.Address}
 		rs.Username = rc.Username
@@ -87,6 +113,7 @@ func (s *Server) serveTLS() (err error) {
 			return errors.Wrap(err, "could not provision redis storage")
 		}
 	}
+	cfg.Issuers[0] = issuer
 
 	ln, err := listenfd.GetListener(
 		1,
@@ -97,12 +124,38 @@ func (s *Server) serveTLS() (err error) {
 		return errors.Wrap(err, "could not bind plain socket")
 	}
 
-	go func(ln net.Listener) {
-		s.redirectServer.Handler = issuer.HTTPChallengeHandler(s.redirectServer.Handler)
-		if err := s.redirectServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
+	go func(ln net.Listener, srv *http.Server) {
+		httpMux := http.NewServeMux()
+		httpMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+			if certmagic.LooksLikeHTTPChallenge(r) && issuer.HandleHTTPChallenge(w, r) {
+				return
+			}
+			url := r.URL
+			url.Scheme = "https"
+			port := s.config.BaseURL.Port()
+			if port == "" {
+				url.Host = r.Host
+			} else {
+				host, _, err := net.SplitHostPort(r.Host)
+				if err != nil {
+					log.Warn("error splitting host and port", "error", err)
+					host = r.Host
+				}
+				url.Host = net.JoinHostPort(host, s.config.BaseURL.Port())
+			}
+			http.Redirect(w, r, url.String(), http.StatusMovedPermanently)
+		})
+		srv.Handler = httpMux
+
+		if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
 			log.Error("error in http handler", "error", err)
 		}
-	}(ln)
+	}(ln, &http.Server{
+		ReadHeaderTimeout: s.ReadHeaderTimeout,
+		ReadTimeout:       s.ReadTimeout,
+		WriteTimeout:      s.WriteTimeout,
+		IdleTimeout:       s.IdleTimeout,
+	})
 
 	log.Debug(
 		"starting certmagic",
@@ -111,7 +164,7 @@ func (s *Server) serveTLS() (err error) {
 		"https_port",
 		s.runtimeConfig.TLSPort,
 	)
-	err = cfg.ManageSync(context.TODO(), s.config.Domains)
+	err = cfg.ManageAsync(context.TODO(), s.config.Domains)
 	if err != nil {
 		return errors.Wrap(err, "could not enable TLS")
 	}
diff --git a/internal/vcs/repository.go b/internal/vcs/repository.go
index e034ea4..5950e53 100644
--- a/internal/vcs/repository.go
+++ b/internal/vcs/repository.go
@@ -7,6 +7,7 @@ import (
 	"go.alanpearce.eu/x/log"
 
 	"github.com/go-git/go-git/v5"
+	"github.com/go-git/go-git/v5/plumbing"
 	"gitlab.com/tozd/go/errors"
 )
 
@@ -61,13 +62,8 @@ func (r *Repository) Update() (bool, error) {
 	}
 
 	r.log.Info("updating from", "rev", head.Hash().String())
-	wt, err := r.repo.Worktree()
-	if err != nil {
-		return false, err
-	}
-	err = wt.Pull(&git.PullOptions{
-		SingleBranch: true,
-		Force:        true,
+	err = r.repo.Fetch(&git.FetchOptions{
+		Prune: true,
 	})
 	if err != nil {
 		if errors.Is(err, git.NoErrAlreadyUpToDate) {
@@ -79,11 +75,31 @@ func (r *Repository) Update() (bool, error) {
 		return false, err
 	}
 
-	head, err = r.repo.Head()
+	rem, err := r.repo.Remote("origin")
 	if err != nil {
 		return false, err
 	}
-	r.log.Info("updated to", "rev", head.Hash().String())
+	refs, err := rem.List(&git.ListOptions{
+		Timeout: 5,
+	})
+
+	var hash plumbing.Hash
+	for _, ref := range refs {
+		if ref.Name() == plumbing.Main {
+			hash = ref.Hash()
+		}
+	}
+
+	wt, err := r.repo.Worktree()
+	if err != nil {
+		return false, err
+	}
+	wt.Checkout(&git.CheckoutOptions{
+		Hash:  hash,
+		Force: true,
+	})
+
+	r.log.Info("updated to", "rev", hash)
 
 	return true, r.Clean(wt)
 }