Fix: fix CLI arg <-> config <-> env parity

This commit is contained in:
kayos@tcp.direct 2024-06-26 03:58:55 -07:00
parent 23c65b94d7
commit 93dcb98817
No known key found for this signature in database
GPG Key ID: 4B841471B4BEE979
7 changed files with 214 additions and 31 deletions

1
go.mod
View File

@ -6,6 +6,7 @@ require (
git.tcp.direct/kayos/common v0.9.7
github.com/fasthttp/router v1.5.1
github.com/knadh/koanf/parsers/toml v0.1.0
github.com/knadh/koanf/providers/basicflag v1.0.0
github.com/knadh/koanf/providers/env v0.1.0
github.com/knadh/koanf/v2 v2.1.1
github.com/rs/zerolog v1.33.0

2
go.sum
View File

@ -15,6 +15,8 @@ github.com/knadh/koanf/maps v0.1.1 h1:G5TjmUh2D7G2YWf5SQQqSiHRJEjaicvU0KpypqB3NI
github.com/knadh/koanf/maps v0.1.1/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI=
github.com/knadh/koanf/parsers/toml v0.1.0 h1:S2hLqS4TgWZYj4/7mI5m1CQQcWurxUz6ODgOub/6LCI=
github.com/knadh/koanf/parsers/toml v0.1.0/go.mod h1:yUprhq6eo3GbyVXFFMdbfZSo928ksS+uo0FFqNMnO18=
github.com/knadh/koanf/providers/basicflag v1.0.0 h1:qB0es/9fYsLuYnrKazxNCuWtkv3JFX1lI1druUsDDvY=
github.com/knadh/koanf/providers/basicflag v1.0.0/go.mod h1:n0NlnaxXUCER/WIzRroT9q3Np+FiZ9pSjrC6A/OozI8=
github.com/knadh/koanf/providers/env v0.1.0 h1:LqKteXqfOWyx5Ab9VfGHmjY9BvRXi+clwyZozgVRiKg=
github.com/knadh/koanf/providers/env v0.1.0/go.mod h1:RE8K9GbACJkeEnkl8L/Qcj8p4ZyPXZIQ191HJi44ZaQ=
github.com/knadh/koanf/v2 v2.1.1 h1:/R8eXqasSTsmDCsAyYj+81Wteg8AqrV9CP6gvsTsOmM=

View File

@ -4,38 +4,88 @@ import (
"flag"
"io"
"os"
"slices"
"strings"
"github.com/yunginnanet/HellPot/internal/extra"
"github.com/yunginnanet/HellPot/internal/version"
)
var CLIFlags = flag.NewFlagSet("config", flag.ExitOnError)
var CLIFlags = flag.NewFlagSet("config", flag.ContinueOnError)
var (
sliceDefs = make(map[string][]string)
slicePtrs = make(map[string]*string)
)
func addCLIFlags() {
parse := func(k string, v interface{}, nestedName string) {
switch casted := v.(type) {
case bool:
CLIFlags.Bool(nestedName, casted, "set "+k)
case string:
CLIFlags.String(nestedName, casted, "set "+k)
case int:
CLIFlags.Int(nestedName, casted, "set "+k)
case float64:
CLIFlags.Float64(nestedName, casted, "set "+k)
case []string:
sliceDefs[nestedName] = casted
joined := strings.Join(sliceDefs[nestedName], ",")
slicePtrs[nestedName] = CLIFlags.String(nestedName, joined, "set "+k)
}
}
for key, val := range Defaults.val {
if _, ok := val.(map[string]interface{}); !ok {
parse(key, val, key)
continue
}
nested, ok := val.(map[string]interface{})
if !ok {
// linter was confused by the above check
panic("unreachable, if you see this you have entered a real life HellPot")
}
for k, v := range nested {
nestedName := key + "." + k
parse(k, v, nestedName)
}
}
}
var replacer = map[string][]string{
"-h": {"-help"},
"-v": {"-version"},
"-c": {"-config"},
"-g": {"-bespoke.enable_grimoire", "true", "-bespoke.grimoire_file"},
}
func InitCLI() {
newArgs := make([]string, 0)
for _, arg := range os.Args {
if repl, ok := replacer[arg]; ok {
newArgs = append(newArgs, repl...)
continue
}
// check for unit test flags
if !strings.HasPrefix(arg, "-test.") {
newArgs = append(newArgs, arg)
}
}
CLIFlags.Bool("logger-debug", false, "force debug logging")
CLIFlags.Bool("logger-trace", false, "force trace logging")
CLIFlags.Bool("logger-nocolor", false, "force no color logging")
CLIFlags.String("bespoke-grimoire", "", "specify a custom file used for text generation")
newArgs = slices.Compact(newArgs)
CLIFlags.Bool("banner", false, "show banner and version then exit")
CLIFlags.Bool("genconfig", false, "write default config to stdout then exit")
CLIFlags.Bool("h", false, "show this help and exit")
CLIFlags.Bool("help", false, "show this help and exit")
CLIFlags.String("c", "", "specify config file")
CLIFlags.String("config", "", "specify config file")
CLIFlags.String("version", "", "show version and exit")
CLIFlags.String("v", "", "show version and exit")
addCLIFlags()
if err := CLIFlags.Parse(newArgs[1:]); err != nil {
println(err.Error())
// flag.ExitOnError will call os.Exit(2)
os.Exit(2)
}
if os.Getenv("HELLPOT_CONFIG") != "" {
if err := CLIFlags.Set("config", os.Getenv("HELLPOT_CONFIG")); err != nil {
@ -45,11 +95,11 @@ func InitCLI() {
panic(err)
}
}
if CLIFlags.Lookup("h").Value.String() == "true" || CLIFlags.Lookup("help").Value.String() == "true" {
if CLIFlags.Lookup("help").Value.String() == "true" {
CLIFlags.Usage()
os.Exit(0)
}
if CLIFlags.Lookup("version").Value.String() == "true" || CLIFlags.Lookup("v").Value.String() == "true" {
if CLIFlags.Lookup("version").Value.String() == "true" {
_, _ = os.Stdout.WriteString("HellPot version: " + version.Version + "\n")
os.Exit(0)
}
@ -66,4 +116,5 @@ func InitCLI() {
extra.Banner()
os.Exit(0)
}
}

View File

@ -83,4 +83,8 @@ var defOpts = map[string]interface{}{
"deception": map[string]interface{}{
"server_name": "nginx",
},
"bespoke": map[string]interface{}{
"grimoire_file": "",
"enable_grimoire": false,
},
}

View File

@ -85,6 +85,6 @@ type DevilsPlaythings struct {
// Customization represents the configuration for the customizations.
type Customization struct {
CustomHeffalump bool `koanf:"custom_heffalump"`
Grimoire string `koanf:"grimoire"`
CustomHeffalump bool `koanf:"enable_grimoire"`
Grimoire string `koanf:"grimoire_file"`
}

View File

@ -3,9 +3,11 @@ package config
import (
"fmt"
"io"
"slices"
"strings"
"github.com/knadh/koanf/parsers/toml"
flags "github.com/knadh/koanf/providers/basicflag"
"github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/v2"
)
@ -26,6 +28,134 @@ func (r *readerProvider) Read() (map[string]interface{}, error) {
return toml.Parser().Unmarshal(b) //nolint:wrapcheck
}
func normalizeMap(m map[string]interface{}) map[string]interface{} {
for k, v := range m {
ogk := k
k = strings.ToLower(k)
var sslice []string
var sliceOK bool
if sslice, sliceOK = v.([]string); !sliceOK {
goto justLower
}
for i, s := range sslice {
sslice[i] = strings.ToLower(s)
}
slices.Sort(sslice)
m[k] = sslice
justLower:
if k != ogk {
delete(m, ogk)
}
}
return m
}
func (p *Parameters) merge(ogk *koanf.Koanf, newk *koanf.Koanf, friendlyName string) error {
if ogk == nil {
panic("original koanf is nil")
}
if newk == nil {
return nil
}
dirty := false
newKeys := normalizeMap(newk.All())
if len(newk.All()) == 0 || len(newKeys) == 0 {
return nil
}
for k, v := range newKeys {
if !ogk.Exists(k) {
if err := ogk.Set(k, v); err != nil {
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
}
dirty = true
continue
}
ogv := ogk.Get(k)
if ogv == nil {
if err := ogk.Set(k, v); err != nil {
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
}
dirty = true
continue
}
if _, hasDefault := Defaults.val[k]; !hasDefault {
continue
}
if ogv == Defaults.val[k] && v != ogv {
if err := ogk.Set(k, v); err != nil {
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
}
dirty = true
}
}
if !dirty {
return nil
}
println("found configuration overrides in " + friendlyName)
if err := ogk.Merge(newk); err != nil {
return fmt.Errorf("failed to merge env config: %w", err)
}
return nil
}
func (p *Parameters) LoadEnv(k *koanf.Koanf) error {
envK := koanf.New(".")
envErr := envK.Load(env.Provider("HELLPOT_", ".", func(s string) string {
s = strings.TrimPrefix(s, "HELLPOT_")
s = strings.ToLower(s)
s = strings.ReplaceAll(s, "__", " ")
s = strings.ReplaceAll(s, "_", ".")
s = strings.ReplaceAll(s, " ", "_")
return s
}), nil)
if envErr != nil {
return fmt.Errorf("failed to load env: %w", envErr)
}
if err := p.merge(k, envK, "environment variables"); err != nil {
return err
}
return nil
}
func parseCLISlice(key string, value string) (string, interface{}) {
if _, ok := slicePtrs[key]; !ok {
return key, value
}
split := strings.Split(value, ",")
slices.Sort(split)
return key, split
}
func (p *Parameters) LoadFlags(k *koanf.Koanf) error {
flagsK := koanf.New(".")
if err := flagsK.Load(flags.ProviderWithValue(CLIFlags, ".", parseCLISlice), nil); err != nil {
return fmt.Errorf("failed to load flags: %w", err)
}
if err := p.merge(k, flagsK, "cli arguments"); err != nil {
return err
}
return nil
}
func Setup(source io.Reader) (*Parameters, error) {
k := koanf.New(".")
@ -39,23 +169,6 @@ func Setup(source io.Reader) (*Parameters, error) {
}
}
envK := koanf.New(".")
envErr := envK.Load(env.Provider("HELLPOT_", ".", func(s string) string {
s = strings.TrimPrefix(s, "HELLPOT_")
s = strings.ToLower(s)
s = strings.ReplaceAll(s, "__", " ")
s = strings.ReplaceAll(s, "_", ".")
s = strings.ReplaceAll(s, " ", "_")
return s
}), nil)
if envErr == nil && envK != nil && len(envK.All()) > 0 {
if err := k.Merge(envK); err != nil {
return nil, fmt.Errorf("failed to merge env config: %w", err)
}
}
p := &Parameters{
source: k,
}
@ -64,6 +177,14 @@ func Setup(source io.Reader) (*Parameters, error) {
p.UsingDefaults = true
}
if err := p.LoadFlags(k); err != nil {
return nil, err
}
if err := p.LoadEnv(k); err != nil {
return nil, err
}
if err := k.Unmarshal("", p); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}

View File

@ -119,7 +119,7 @@ func getSrv(r *router.Router) fasthttp.Server {
}
}
func setupHeffalump(config *config.Parameters) error {
func SetupHeffalump(config *config.Parameters) error {
switch config.Bespoke.CustomHeffalump {
case true:
content, err := os.ReadFile(config.Bespoke.Grimoire)
@ -151,6 +151,10 @@ func Serve(config *config.Parameters) error {
log = config.GetLogger()
runningConfig = config
if err := SetupHeffalump(config); err != nil {
return fmt.Errorf("failed to setup heffalump: %w", err)
}
l := config.HTTP.Bind + ":" + strconv.Itoa(int(config.HTTP.Port))
r := router.New()