about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAlan Pearce2024-05-24 18:31:56 +0200
committerAlan Pearce2024-05-24 18:31:56 +0200
commite9eed3ddc4229db707cccb30beddde15044eff16 (patch)
treec586eae45a4aa99fd1a971c2bd29ad2e74d14975
parent2c1491de56d0c3e2f4cb0b0c1e33035510f72fc5 (diff)
downloadsearchix-e9eed3ddc4229db707cccb30beddde15044eff16.tar.lz
searchix-e9eed3ddc4229db707cccb30beddde15044eff16.tar.zst
searchix-e9eed3ddc4229db707cccb30beddde15044eff16.zip
refactor: split server cmd and module
It should now be possible to run the server from inside another go
application by importing the main module and running its Start() function
-rw-r--r--cmd/searchix-web/main.go82
-rw-r--r--internal/server/server.go37
-rw-r--r--justfile6
-rw-r--r--nix/modules/default.nix2
-rw-r--r--nix/package.nix3
-rw-r--r--searchix.go281
6 files changed, 249 insertions, 162 deletions
diff --git a/cmd/searchix-web/main.go b/cmd/searchix-web/main.go
new file mode 100644
index 0000000..91ecc7a
--- /dev/null
+++ b/cmd/searchix-web/main.go
@@ -0,0 +1,82 @@
+package main
+
+import (
+	"context"
+	"flag"
+	"fmt"
+	"log"
+	"log/slog"
+	"os"
+	"os/signal"
+
+	"searchix"
+	"searchix/internal/config"
+)
+
+var buildVersion string
+
+var (
+	configFile         = flag.String("config", "config.toml", "config `file` to use")
+	printDefaultConfig = flag.Bool(
+		"print-default-config",
+		false,
+		"print default configuration and exit",
+	)
+	liveReload = flag.Bool("live", false, "whether to enable live reloading (development)")
+	replace    = flag.Bool("replace", false, "replace existing index and exit")
+	update     = flag.Bool("update", false, "update index and exit")
+	version    = flag.Bool("version", false, "print version information")
+)
+
+func main() {
+	flag.Parse()
+	if *version {
+		fmt.Fprintf(os.Stderr, "searchix %s", buildVersion)
+		if buildVersion != config.CommitSHA && buildVersion != config.ShortSHA {
+			fmt.Fprintf(os.Stderr, " %s", config.CommitSHA)
+		}
+		_, err := fmt.Fprint(os.Stderr, "\n")
+		if err != nil {
+			panic("can't write to standard error?!")
+		}
+		os.Exit(0)
+	}
+	if *printDefaultConfig {
+		_, err := fmt.Print(config.GetDefaultConfig())
+		if err != nil {
+			panic("can't write to standard output?!")
+		}
+		os.Exit(0)
+	}
+
+	cfg, err := config.GetConfig(*configFile)
+	if err != nil {
+		// only use log functions after the config file has been read successfully
+		log.Fatalf("Failed to parse config file: %v", err)
+	}
+	s, err := searchix.New(cfg)
+	if err != nil {
+		log.Fatalf("Failed to initialise searchix: %v", err)
+	}
+
+	err = s.SetupIndex(*replace, *update)
+	if err != nil {
+		log.Fatalf("Failed to setup index: %v", err)
+	}
+
+	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
+	defer cancel()
+
+	go func() {
+		err = s.Start(ctx, *liveReload)
+		if err != nil {
+			// Error starting or closing listener:
+			log.Fatalf("error: %v", err)
+		}
+	}()
+
+	<-ctx.Done()
+	slog.Debug("calling stop")
+	s.Stop()
+	slog.Debug("done")
+}
diff --git a/internal/server/server.go b/internal/server/server.go
index 6c4b732..262e9a7 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -2,7 +2,6 @@ package server
 
 import (
 	"context"
-	"log"
 	"log/slog"
 	"net"
 	"net/http"
@@ -15,7 +14,9 @@ import (
 )
 
 type Server struct {
-	*http.Server
+	cfg      *config.Config
+	server   *http.Server
+	listener net.Listener
 }
 
 func New(conf *config.Config, index *index.ReadIndex, liveReload bool) (*Server, error) {
@@ -23,11 +24,10 @@ func New(conf *config.Config, index *index.ReadIndex, liveReload bool) (*Server,
 	if err != nil {
 		return nil, err
 	}
-	listenAddress := net.JoinHostPort(conf.Web.ListenAddress, strconv.Itoa(conf.Web.Port))
 
 	return &Server{
-		&http.Server{
-			Addr:              listenAddress,
+		cfg: conf,
+		server: &http.Server{
 			Handler:           mux,
 			ReadHeaderTimeout: 20 * time.Second,
 		},
@@ -35,7 +35,27 @@ func New(conf *config.Config, index *index.ReadIndex, liveReload bool) (*Server,
 }
 
 func (s *Server) Start() error {
-	if err := s.ListenAndServe(); err != http.ErrServerClosed {
+	listenAddress := net.JoinHostPort(s.cfg.Web.ListenAddress, strconv.Itoa(s.cfg.Web.Port))
+	l, err := net.Listen("tcp", listenAddress)
+	if err != nil {
+		return errors.WithMessagef(
+			err,
+			"could not listen on address %s and port %d",
+			s.cfg.Web.ListenAddress,
+			s.cfg.Web.Port,
+		)
+	}
+	s.listener = l
+
+	if s.cfg.Web.Environment == "development" {
+		slog.Info(
+			"server listening on",
+			"url",
+			s.cfg.Web.BaseURL.String(),
+		)
+	}
+
+	if err := s.server.Serve(l); err != nil && err != http.ErrServerClosed {
 		return errors.WithMessage(err, "could not start server")
 	}
 
@@ -51,12 +71,13 @@ func (s *Server) Stop() chan struct{} {
 		slog.Debug("shutting down server")
 		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 		defer cancel()
-		err := s.Server.Shutdown(ctx)
+		err := s.server.Shutdown(ctx)
 		slog.Debug("server shut down")
 		if err != nil {
 			// Error from closing listeners, or context timeout:
-			log.Printf("HTTP server Shutdown: %v", err)
+			slog.Error("error shutting down server", "error", err)
 		}
+		s.listener.Close()
 		close(idleConnsClosed)
 	}()
 
diff --git a/justfile b/justfile
index 3e63590..f5c6682 100644
--- a/justfile
+++ b/justfile
@@ -21,10 +21,10 @@ precommit:
 	nix-build -A pre-commit-check
 
 dev:
-	watchexec --no-vcs-ignore --filter "**/*.go" --filter config.toml -r wgo run -exit ./ --live --config config.toml
+	watchexec --no-vcs-ignore --filter "**/*.go" --filter config.toml -r wgo run -exit ./cmd/searchix-web --live --config config.toml
 
 reindex:
-	wgo run --exit . --config config.toml --replace
+	wgo run --exit ./cmd/searchix-web --config config.toml --replace
 
 update:
-	wgo run --exit . --config config.toml --update
+	wgo run --exit ./cmd/searchix-web --config config.toml --update
diff --git a/nix/modules/default.nix b/nix/modules/default.nix
index 509a372..acfa7aa 100644
--- a/nix/modules/default.nix
+++ b/nix/modules/default.nix
@@ -194,7 +194,7 @@ in
       wantedBy = [ "multi-user.target" ];
       path = with pkgs; [ nix ];
       serviceConfig = defaultServiceConfig // {
-        ExecStart = "${package}/bin/searchix --config ${(settingsFormat.generate "searchix-config.toml" cfg.settings)}";
+        ExecStart = "${package}/bin/searchix-web --config ${(settingsFormat.generate "searchix-config.toml" cfg.settings)}";
       } // lib.optionalAttrs (cfg.settings.web.port < 1024) {
         AmbientCapabilities = [ "CAP_NET_BIND_SERVICE" ];
         CapabilityBoundingSet = [ "CAP_NET_BIND_SERVICE" ];
diff --git a/nix/package.nix b/nix/package.nix
index 104e09f..3426235 100644
--- a/nix/package.nix
+++ b/nix/package.nix
@@ -31,10 +31,13 @@ buildGoApplication {
         ../searchix.go
         ../internal
         ../frontend
+        ../cmd
       ])
       (maybeMissing ../frontend/static/base.css);
   };
 
+  subPackages = [ "cmd/searchix-web" ];
+
   patchPhase = ''
     rm -f frontend/static/base.css
     cp ${css} frontend/static/base.css
diff --git a/searchix.go b/searchix.go
index cf1a429..c10eb6a 100644
--- a/searchix.go
+++ b/searchix.go
@@ -1,12 +1,9 @@
-package main
+package searchix
 
 import (
-	"flag"
-	"fmt"
+	"context"
 	"log"
 	"log/slog"
-	"os"
-	"os/signal"
 	"slices"
 	"sync"
 	"time"
@@ -18,21 +15,7 @@ import (
 
 	"github.com/getsentry/sentry-go"
 	"github.com/pelletier/go-toml/v2"
-)
-
-var buildVersion string
-
-var (
-	configFile         = flag.String("config", "config.toml", "config `file` to use")
-	printDefaultConfig = flag.Bool(
-		"print-default-config",
-		false,
-		"print default configuration and exit",
-	)
-	liveReload = flag.Bool("live", false, "whether to enable live reloading (development)")
-	replace    = flag.Bool("replace", false, "replace existing index and exit")
-	update     = flag.Bool("update", false, "update index and exit")
-	version    = flag.Bool("version", false, "print version information")
+	"github.com/pkg/errors"
 )
 
 func nextOccurrenceOfLocalTime(t toml.LocalTime) time.Time {
@@ -55,84 +38,43 @@ func nextOccurrenceOfLocalTime(t toml.LocalTime) time.Time {
 	return nextRun
 }
 
-func main() {
-	flag.Parse()
-	if *version {
-		fmt.Fprintf(os.Stderr, "searchix %s", buildVersion)
-		if buildVersion != config.CommitSHA && buildVersion != config.ShortSHA {
-			fmt.Fprintf(os.Stderr, " %s", config.CommitSHA)
-		}
-		_, err := fmt.Fprint(os.Stderr, "\n")
-		if err != nil {
-			panic("can't write to standard error?!")
-		}
-		os.Exit(0)
-	}
-	if *printDefaultConfig {
-		_, err := fmt.Print(config.GetDefaultConfig())
-		if err != nil {
-			panic("can't write to standard output?!")
-		}
-		os.Exit(0)
-	}
-
-	cfg, err := config.GetConfig(*configFile)
-	if err != nil {
-		// only use log functions after the config file has been read successfully
-		fmt.Fprintf(os.Stderr, "error parsing configuration file: %v", err)
-		os.Exit(1)
-	}
-	slog.SetLogLoggerLevel(cfg.LogLevel)
-	if cfg.Web.Environment == "production" {
-		log.SetFlags(0)
-	} else {
-		log.SetFlags(log.LstdFlags)
-	}
-
-	err = sentry.Init(sentry.ClientOptions{
-		EnableTracing:    true,
-		TracesSampleRate: 1.0,
-		Dsn:              cfg.Web.SentryDSN,
-		Environment:      cfg.Web.Environment,
-	})
-	if err != nil {
-		slog.Warn("could not initialise sentry", "error", err)
-	}
-
+func (s *Server) SetupIndex(update bool, replace bool) error {
 	var i uint
-	cfgEnabledSources := make([]string, len(cfg.Importer.Sources))
-	for key := range cfg.Importer.Sources {
+	cfgEnabledSources := make([]string, len(s.cfg.Importer.Sources))
+	for key := range s.cfg.Importer.Sources {
 		cfgEnabledSources[i] = key
 		i++
 	}
 	slices.Sort(cfgEnabledSources)
 
-	read, write, exists, err := index.OpenOrCreate(cfg.DataPath, *replace)
+	read, write, exists, err := index.OpenOrCreate(s.cfg.DataPath, replace)
 	if err != nil {
-		log.Fatalf("Failed to open or create index: %v", err)
+		return errors.Wrap(err, "Failed to open or create index")
 	}
+	s.readIndex = read
+	s.writeIndex = write
 
-	if !exists || *replace || *update {
+	if !exists || replace || update {
 		slog.Info(
 			"Starting build job",
 			"new",
 			!exists,
 			"replace",
-			*replace,
+			replace,
 			"update",
-			*update,
+			update,
 		)
-		err = importer.Start(cfg, write, *replace, nil)
+		err = importer.Start(s.cfg, write, replace, nil)
 		if err != nil {
-			log.Fatalf("Failed to build index: %v", err)
+			return errors.Wrap(err, "Failed to build index")
 		}
-		if *replace || *update {
-			return
+		if replace || update {
+			return nil
 		}
 	} else {
 		indexedSources, err := read.GetEnabledSources()
 		if err != nil {
-			log.Fatalln("failed to get enabled sources from index")
+			return errors.Wrap(err, "Failed to get enabled sources from index")
 		}
 		slices.Sort(indexedSources)
 		if !slices.Equal(cfgEnabledSources, indexedSources) {
@@ -144,9 +86,9 @@ func main() {
 			})
 			if len(newSources) > 0 {
 				slog.Info("adding new sources", "sources", newSources)
-				err := importer.Start(cfg, write, false, &newSources)
+				err := importer.Start(s.cfg, write, false, &newSources)
 				if err != nil {
-					log.Fatalf("failed to update index with new sources: %v", err)
+					return errors.Wrap(err, "Failed to update index with new sources")
 				}
 			}
 			if len(retiredSources) > 0 {
@@ -154,94 +96,133 @@ func main() {
 				for _, s := range retiredSources {
 					err := write.DeleteBySource(s)
 					if err != nil {
-						log.Fatalf("failed to remove retired source %s from index: %v", s, err)
+						return errors.Wrapf(err, "Failed to remove retired source %s", s)
 					}
 				}
 			}
 		}
 	}
 
-	c := make(chan os.Signal, 2)
-	signal.Notify(c, os.Interrupt)
-	sv, err := server.New(cfg, read, *liveReload)
+	return nil
+}
+
+type Server struct {
+	sv         *server.Server
+	wg         *sync.WaitGroup
+	cfg        *config.Config
+	sentryHub  *sentry.Hub
+	readIndex  *index.ReadIndex
+	writeIndex *index.WriteIndex
+}
+
+func New(cfg *config.Config) (*Server, error) {
+	slog.SetLogLoggerLevel(cfg.LogLevel)
+	if cfg.Web.Environment == "production" {
+		log.SetFlags(0)
+	} else {
+		log.SetFlags(log.LstdFlags)
+	}
+
+	err := sentry.Init(sentry.ClientOptions{
+		EnableTracing:    true,
+		TracesSampleRate: 1.0,
+		Dsn:              cfg.Web.SentryDSN,
+		Environment:      cfg.Web.Environment,
+	})
 	if err != nil {
-		log.Fatalf("error setting up server: %v", err)
+		slog.Warn("could not initialise sentry", "error", err)
 	}
-	wg := &sync.WaitGroup{}
-	wg.Add(1)
-	go func() {
-		defer wg.Done()
-		sig := <-c
-		log.Printf("signal captured: %v", sig)
-		<-sv.Stop()
-		slog.Debug("server stopped")
-	}()
-
-	go func(localHub *sentry.Hub) {
-		const monitorSlug = "import"
-		localHub.WithScope(func(scope *sentry.Scope) {
-			scope.SetContext("monitor", sentry.Context{"slug": monitorSlug})
-			monitorConfig := &sentry.MonitorConfig{
-				Schedule: sentry.IntervalSchedule(1, sentry.MonitorScheduleUnitDay),
-				// minutes
-				MaxRuntime:    10,
-				CheckInMargin: 5,
-				Timezone:      time.Local.String(),
-			}
 
-			nextRun := nextOccurrenceOfLocalTime(cfg.Importer.UpdateAt.LocalTime)
-			for {
-				slog.Debug("scheduling next run", "next-run", nextRun)
-				<-time.After(time.Until(nextRun))
-				wg.Add(1)
-				slog.Info("updating index")
+	return &Server{
+		cfg:       cfg,
+		sentryHub: sentry.CurrentHub(),
+	}, nil
+}
 
-				eventID := localHub.CaptureCheckIn(&sentry.CheckIn{
+func (s *Server) startUpdateTimer(
+	ctx context.Context,
+	localHub *sentry.Hub,
+) {
+	const monitorSlug = "import"
+	localHub.WithScope(func(scope *sentry.Scope) {
+		var err error
+		scope.SetContext("monitor", sentry.Context{"slug": monitorSlug})
+		monitorConfig := &sentry.MonitorConfig{
+			Schedule: sentry.IntervalSchedule(1, sentry.MonitorScheduleUnitDay),
+			// minutes
+			MaxRuntime:    10,
+			CheckInMargin: 5,
+			Timezone:      time.Local.String(),
+		}
+
+		s.wg.Add(1)
+		nextRun := nextOccurrenceOfLocalTime(s.cfg.Importer.UpdateAt.LocalTime)
+		for {
+			slog.Debug("scheduling next run", "next-run", nextRun)
+			select {
+			case <-ctx.Done():
+				slog.Debug("stopping scheduler")
+				s.wg.Done()
+
+				return
+			case <-time.After(time.Until(nextRun)):
+			}
+			s.wg.Add(1)
+			slog.Info("updating index")
+
+			eventID := localHub.CaptureCheckIn(&sentry.CheckIn{
+				MonitorSlug: monitorSlug,
+				Status:      sentry.CheckInStatusInProgress,
+			}, monitorConfig)
+
+			err = importer.Start(s.cfg, s.writeIndex, false, nil)
+			s.wg.Done()
+			if err != nil {
+				slog.Warn("error updating index", "error", err)
+
+				localHub.CaptureException(err)
+				localHub.CaptureCheckIn(&sentry.CheckIn{
+					ID:          *eventID,
 					MonitorSlug: monitorSlug,
-					Status:      sentry.CheckInStatusInProgress,
+					Status:      sentry.CheckInStatusError,
 				}, monitorConfig)
+			} else {
+				slog.Info("update complete")
 
-				err = importer.Start(cfg, write, false, nil)
-				wg.Done()
-				if err != nil {
-					slog.Warn("error updating index", "error", err)
-
-					localHub.CaptureException(err)
-					localHub.CaptureCheckIn(&sentry.CheckIn{
-						ID:          *eventID,
-						MonitorSlug: monitorSlug,
-						Status:      sentry.CheckInStatusError,
-					}, monitorConfig)
-				} else {
-					slog.Info("update complete")
-
-					localHub.CaptureCheckIn(&sentry.CheckIn{
-						ID:          *eventID,
-						MonitorSlug: monitorSlug,
-						Status:      sentry.CheckInStatusOK,
-					}, monitorConfig)
-				}
-				nextRun = nextRun.AddDate(0, 0, 1)
+				localHub.CaptureCheckIn(&sentry.CheckIn{
+					ID:          *eventID,
+					MonitorSlug: monitorSlug,
+					Status:      sentry.CheckInStatusOK,
+				}, monitorConfig)
 			}
-		})
-	}(sentry.CurrentHub().Clone())
-
-	sErr := make(chan error)
-	wg.Add(1)
-	go func() {
-		defer wg.Done()
-		sErr <- sv.Start()
-	}()
-
-	if cfg.Web.Environment == "development" {
-		log.Printf("server listening on %s", cfg.Web.BaseURL.String())
+			nextRun = nextRun.AddDate(0, 0, 1)
+		}
+	})
+}
+
+func (s *Server) Start(ctx context.Context, liveReload bool) error {
+	var err error
+	s.sv, err = server.New(s.cfg, s.readIndex, liveReload)
+	if err != nil {
+		return errors.Wrap(err, "error setting up server")
 	}
 
-	err = <-sErr
+	s.wg = &sync.WaitGroup{}
+	go s.startUpdateTimer(ctx, sentry.CurrentHub().Clone())
+
+	s.wg.Add(1)
+	err = s.sv.Start()
 	if err != nil {
-		// Error starting or closing listener:
-		log.Fatalf("error: %v", err)
+		s.wg.Done()
+
+		return errors.Wrap(err, "error starting server")
 	}
-	sentry.Flush(2 * time.Second)
-	wg.Wait()
+
+	return nil
+}
+
+func (s *Server) Stop() {
+	<-s.sv.Stop()
+	defer s.wg.Done()
+	s.sentryHub.Flush(2 * time.Second)
 }