From 93dcb98817b95f8dba9548f4b7c21a5beba05b68 Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Wed, 26 Jun 2024 03:58:55 -0700 Subject: [PATCH] Fix: fix CLI arg <-> config <-> env parity --- go.mod | 1 + go.sum | 2 + internal/config/command_line.go | 73 ++++++++++++--- internal/config/defaults.go | 4 + internal/config/models.go | 4 +- internal/config/setup.go | 155 ++++++++++++++++++++++++++++---- internal/http/router.go | 6 +- 7 files changed, 214 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index 97fdf8b..f7f0b67 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( git.tcp.direct/kayos/common v0.9.7 github.com/fasthttp/router v1.5.1 github.com/knadh/koanf/parsers/toml v0.1.0 + github.com/knadh/koanf/providers/basicflag v1.0.0 github.com/knadh/koanf/providers/env v0.1.0 github.com/knadh/koanf/v2 v2.1.1 github.com/rs/zerolog v1.33.0 diff --git a/go.sum b/go.sum index 4ab95a0..dd8dbf7 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/knadh/koanf/maps v0.1.1 h1:G5TjmUh2D7G2YWf5SQQqSiHRJEjaicvU0KpypqB3NI github.com/knadh/koanf/maps v0.1.1/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI= github.com/knadh/koanf/parsers/toml v0.1.0 h1:S2hLqS4TgWZYj4/7mI5m1CQQcWurxUz6ODgOub/6LCI= github.com/knadh/koanf/parsers/toml v0.1.0/go.mod h1:yUprhq6eo3GbyVXFFMdbfZSo928ksS+uo0FFqNMnO18= +github.com/knadh/koanf/providers/basicflag v1.0.0 h1:qB0es/9fYsLuYnrKazxNCuWtkv3JFX1lI1druUsDDvY= +github.com/knadh/koanf/providers/basicflag v1.0.0/go.mod h1:n0NlnaxXUCER/WIzRroT9q3Np+FiZ9pSjrC6A/OozI8= github.com/knadh/koanf/providers/env v0.1.0 h1:LqKteXqfOWyx5Ab9VfGHmjY9BvRXi+clwyZozgVRiKg= github.com/knadh/koanf/providers/env v0.1.0/go.mod h1:RE8K9GbACJkeEnkl8L/Qcj8p4ZyPXZIQ191HJi44ZaQ= github.com/knadh/koanf/v2 v2.1.1 h1:/R8eXqasSTsmDCsAyYj+81Wteg8AqrV9CP6gvsTsOmM= diff --git a/internal/config/command_line.go b/internal/config/command_line.go index c794ef6..c0f083f 100644 --- a/internal/config/command_line.go +++ b/internal/config/command_line.go @@ -4,38 +4,88 @@ import ( "flag" "io" "os" + "slices" "strings" "github.com/yunginnanet/HellPot/internal/extra" "github.com/yunginnanet/HellPot/internal/version" ) -var CLIFlags = flag.NewFlagSet("config", flag.ExitOnError) +var CLIFlags = flag.NewFlagSet("config", flag.ContinueOnError) + +var ( + sliceDefs = make(map[string][]string) + slicePtrs = make(map[string]*string) +) + +func addCLIFlags() { + parse := func(k string, v interface{}, nestedName string) { + switch casted := v.(type) { + case bool: + CLIFlags.Bool(nestedName, casted, "set "+k) + case string: + CLIFlags.String(nestedName, casted, "set "+k) + case int: + CLIFlags.Int(nestedName, casted, "set "+k) + case float64: + CLIFlags.Float64(nestedName, casted, "set "+k) + case []string: + sliceDefs[nestedName] = casted + joined := strings.Join(sliceDefs[nestedName], ",") + slicePtrs[nestedName] = CLIFlags.String(nestedName, joined, "set "+k) + } + } + + for key, val := range Defaults.val { + if _, ok := val.(map[string]interface{}); !ok { + parse(key, val, key) + continue + } + nested, ok := val.(map[string]interface{}) + if !ok { + // linter was confused by the above check + panic("unreachable, if you see this you have entered a real life HellPot") + } + for k, v := range nested { + nestedName := key + "." + k + parse(k, v, nestedName) + } + } +} + +var replacer = map[string][]string{ + "-h": {"-help"}, + "-v": {"-version"}, + "-c": {"-config"}, + "-g": {"-bespoke.enable_grimoire", "true", "-bespoke.grimoire_file"}, +} func InitCLI() { newArgs := make([]string, 0) for _, arg := range os.Args { + if repl, ok := replacer[arg]; ok { + newArgs = append(newArgs, repl...) + continue + } // check for unit test flags if !strings.HasPrefix(arg, "-test.") { newArgs = append(newArgs, arg) } } - CLIFlags.Bool("logger-debug", false, "force debug logging") - CLIFlags.Bool("logger-trace", false, "force trace logging") - CLIFlags.Bool("logger-nocolor", false, "force no color logging") - CLIFlags.String("bespoke-grimoire", "", "specify a custom file used for text generation") + newArgs = slices.Compact(newArgs) + CLIFlags.Bool("banner", false, "show banner and version then exit") CLIFlags.Bool("genconfig", false, "write default config to stdout then exit") - CLIFlags.Bool("h", false, "show this help and exit") CLIFlags.Bool("help", false, "show this help and exit") - CLIFlags.String("c", "", "specify config file") CLIFlags.String("config", "", "specify config file") CLIFlags.String("version", "", "show version and exit") - CLIFlags.String("v", "", "show version and exit") + + addCLIFlags() + if err := CLIFlags.Parse(newArgs[1:]); err != nil { println(err.Error()) - // flag.ExitOnError will call os.Exit(2) + os.Exit(2) } if os.Getenv("HELLPOT_CONFIG") != "" { if err := CLIFlags.Set("config", os.Getenv("HELLPOT_CONFIG")); err != nil { @@ -45,11 +95,11 @@ func InitCLI() { panic(err) } } - if CLIFlags.Lookup("h").Value.String() == "true" || CLIFlags.Lookup("help").Value.String() == "true" { + if CLIFlags.Lookup("help").Value.String() == "true" { CLIFlags.Usage() os.Exit(0) } - if CLIFlags.Lookup("version").Value.String() == "true" || CLIFlags.Lookup("v").Value.String() == "true" { + if CLIFlags.Lookup("version").Value.String() == "true" { _, _ = os.Stdout.WriteString("HellPot version: " + version.Version + "\n") os.Exit(0) } @@ -66,4 +116,5 @@ func InitCLI() { extra.Banner() os.Exit(0) } + } diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 06d8e05..509d886 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -83,4 +83,8 @@ var defOpts = map[string]interface{}{ "deception": map[string]interface{}{ "server_name": "nginx", }, + "bespoke": map[string]interface{}{ + "grimoire_file": "", + "enable_grimoire": false, + }, } diff --git a/internal/config/models.go b/internal/config/models.go index c6b9dec..c3d9a32 100644 --- a/internal/config/models.go +++ b/internal/config/models.go @@ -85,6 +85,6 @@ type DevilsPlaythings struct { // Customization represents the configuration for the customizations. type Customization struct { - CustomHeffalump bool `koanf:"custom_heffalump"` - Grimoire string `koanf:"grimoire"` + CustomHeffalump bool `koanf:"enable_grimoire"` + Grimoire string `koanf:"grimoire_file"` } diff --git a/internal/config/setup.go b/internal/config/setup.go index 885da52..bdbf6a7 100644 --- a/internal/config/setup.go +++ b/internal/config/setup.go @@ -3,9 +3,11 @@ package config import ( "fmt" "io" + "slices" "strings" "github.com/knadh/koanf/parsers/toml" + flags "github.com/knadh/koanf/providers/basicflag" "github.com/knadh/koanf/providers/env" "github.com/knadh/koanf/v2" ) @@ -26,6 +28,134 @@ func (r *readerProvider) Read() (map[string]interface{}, error) { return toml.Parser().Unmarshal(b) //nolint:wrapcheck } +func normalizeMap(m map[string]interface{}) map[string]interface{} { + for k, v := range m { + ogk := k + k = strings.ToLower(k) + + var sslice []string + var sliceOK bool + + if sslice, sliceOK = v.([]string); !sliceOK { + goto justLower + } + for i, s := range sslice { + sslice[i] = strings.ToLower(s) + } + slices.Sort(sslice) + m[k] = sslice + justLower: + if k != ogk { + delete(m, ogk) + } + } + return m +} + +func (p *Parameters) merge(ogk *koanf.Koanf, newk *koanf.Koanf, friendlyName string) error { + if ogk == nil { + panic("original koanf is nil") + } + if newk == nil { + return nil + } + dirty := false + + newKeys := normalizeMap(newk.All()) + + if len(newk.All()) == 0 || len(newKeys) == 0 { + return nil + } + + for k, v := range newKeys { + if !ogk.Exists(k) { + if err := ogk.Set(k, v); err != nil { + panic(fmt.Sprintf("failed to set key %s: %v", k, err)) + } + dirty = true + continue + } + + ogv := ogk.Get(k) + if ogv == nil { + if err := ogk.Set(k, v); err != nil { + panic(fmt.Sprintf("failed to set key %s: %v", k, err)) + } + dirty = true + continue + } + + if _, hasDefault := Defaults.val[k]; !hasDefault { + continue + } + + if ogv == Defaults.val[k] && v != ogv { + if err := ogk.Set(k, v); err != nil { + panic(fmt.Sprintf("failed to set key %s: %v", k, err)) + } + dirty = true + } + } + + if !dirty { + return nil + } + + println("found configuration overrides in " + friendlyName) + + if err := ogk.Merge(newk); err != nil { + return fmt.Errorf("failed to merge env config: %w", err) + } + + return nil +} + +func (p *Parameters) LoadEnv(k *koanf.Koanf) error { + envK := koanf.New(".") + + envErr := envK.Load(env.Provider("HELLPOT_", ".", func(s string) string { + s = strings.TrimPrefix(s, "HELLPOT_") + s = strings.ToLower(s) + s = strings.ReplaceAll(s, "__", " ") + s = strings.ReplaceAll(s, "_", ".") + s = strings.ReplaceAll(s, " ", "_") + return s + }), nil) + + if envErr != nil { + return fmt.Errorf("failed to load env: %w", envErr) + } + + if err := p.merge(k, envK, "environment variables"); err != nil { + return err + } + + return nil +} + +func parseCLISlice(key string, value string) (string, interface{}) { + if _, ok := slicePtrs[key]; !ok { + return key, value + } + split := strings.Split(value, ",") + slices.Sort(split) + return key, split +} + +func (p *Parameters) LoadFlags(k *koanf.Koanf) error { + flagsK := koanf.New(".") + + if err := flagsK.Load(flags.ProviderWithValue(CLIFlags, ".", parseCLISlice), nil); err != nil { + return fmt.Errorf("failed to load flags: %w", err) + } + + if err := p.merge(k, flagsK, "cli arguments"); err != nil { + return err + } + + return nil +} + func Setup(source io.Reader) (*Parameters, error) { k := koanf.New(".") @@ -39,23 +169,6 @@ func Setup(source io.Reader) (*Parameters, error) { } } - envK := koanf.New(".") - - envErr := envK.Load(env.Provider("HELLPOT_", ".", func(s string) string { - s = strings.TrimPrefix(s, "HELLPOT_") - s = strings.ToLower(s) - s = strings.ReplaceAll(s, "__", " ") - s = strings.ReplaceAll(s, "_", ".") - s = strings.ReplaceAll(s, " ", "_") - return s - }), nil) - - if envErr == nil && envK != nil && len(envK.All()) > 0 { - if err := k.Merge(envK); err != nil { - return nil, fmt.Errorf("failed to merge env config: %w", err) - } - } - p := &Parameters{ source: k, } @@ -64,6 +177,14 @@ func Setup(source io.Reader) (*Parameters, error) { p.UsingDefaults = true } + if err := p.LoadFlags(k); err != nil { + return nil, err + } + + if err := p.LoadEnv(k); err != nil { + return nil, err + } + if err := k.Unmarshal("", p); err != nil { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } diff --git a/internal/http/router.go b/internal/http/router.go index d2e06c8..e63167b 100644 --- a/internal/http/router.go +++ b/internal/http/router.go @@ -119,7 +119,7 @@ func getSrv(r *router.Router) fasthttp.Server { } } -func setupHeffalump(config *config.Parameters) error { +func SetupHeffalump(config *config.Parameters) error { switch config.Bespoke.CustomHeffalump { case true: content, err := os.ReadFile(config.Bespoke.Grimoire) @@ -151,6 +151,10 @@ func Serve(config *config.Parameters) error { log = config.GetLogger() runningConfig = config + if err := SetupHeffalump(config); err != nil { + return fmt.Errorf("failed to setup heffalump: %w", err) + } + l := config.HTTP.Bind + ":" + strconv.Itoa(int(config.HTTP.Port)) r := router.New()