Fix: fix CLI arg <-> config <-> env parity
This commit is contained in:
parent
23c65b94d7
commit
93dcb98817
1
go.mod
1
go.mod
|
@ -6,6 +6,7 @@ require (
|
|||
git.tcp.direct/kayos/common v0.9.7
|
||||
github.com/fasthttp/router v1.5.1
|
||||
github.com/knadh/koanf/parsers/toml v0.1.0
|
||||
github.com/knadh/koanf/providers/basicflag v1.0.0
|
||||
github.com/knadh/koanf/providers/env v0.1.0
|
||||
github.com/knadh/koanf/v2 v2.1.1
|
||||
github.com/rs/zerolog v1.33.0
|
||||
|
|
2
go.sum
2
go.sum
|
@ -15,6 +15,8 @@ github.com/knadh/koanf/maps v0.1.1 h1:G5TjmUh2D7G2YWf5SQQqSiHRJEjaicvU0KpypqB3NI
|
|||
github.com/knadh/koanf/maps v0.1.1/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI=
|
||||
github.com/knadh/koanf/parsers/toml v0.1.0 h1:S2hLqS4TgWZYj4/7mI5m1CQQcWurxUz6ODgOub/6LCI=
|
||||
github.com/knadh/koanf/parsers/toml v0.1.0/go.mod h1:yUprhq6eo3GbyVXFFMdbfZSo928ksS+uo0FFqNMnO18=
|
||||
github.com/knadh/koanf/providers/basicflag v1.0.0 h1:qB0es/9fYsLuYnrKazxNCuWtkv3JFX1lI1druUsDDvY=
|
||||
github.com/knadh/koanf/providers/basicflag v1.0.0/go.mod h1:n0NlnaxXUCER/WIzRroT9q3Np+FiZ9pSjrC6A/OozI8=
|
||||
github.com/knadh/koanf/providers/env v0.1.0 h1:LqKteXqfOWyx5Ab9VfGHmjY9BvRXi+clwyZozgVRiKg=
|
||||
github.com/knadh/koanf/providers/env v0.1.0/go.mod h1:RE8K9GbACJkeEnkl8L/Qcj8p4ZyPXZIQ191HJi44ZaQ=
|
||||
github.com/knadh/koanf/v2 v2.1.1 h1:/R8eXqasSTsmDCsAyYj+81Wteg8AqrV9CP6gvsTsOmM=
|
||||
|
|
|
@ -4,38 +4,88 @@ import (
|
|||
"flag"
|
||||
"io"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/yunginnanet/HellPot/internal/extra"
|
||||
"github.com/yunginnanet/HellPot/internal/version"
|
||||
)
|
||||
|
||||
var CLIFlags = flag.NewFlagSet("config", flag.ExitOnError)
|
||||
var CLIFlags = flag.NewFlagSet("config", flag.ContinueOnError)
|
||||
|
||||
var (
|
||||
sliceDefs = make(map[string][]string)
|
||||
slicePtrs = make(map[string]*string)
|
||||
)
|
||||
|
||||
func addCLIFlags() {
|
||||
parse := func(k string, v interface{}, nestedName string) {
|
||||
switch casted := v.(type) {
|
||||
case bool:
|
||||
CLIFlags.Bool(nestedName, casted, "set "+k)
|
||||
case string:
|
||||
CLIFlags.String(nestedName, casted, "set "+k)
|
||||
case int:
|
||||
CLIFlags.Int(nestedName, casted, "set "+k)
|
||||
case float64:
|
||||
CLIFlags.Float64(nestedName, casted, "set "+k)
|
||||
case []string:
|
||||
sliceDefs[nestedName] = casted
|
||||
joined := strings.Join(sliceDefs[nestedName], ",")
|
||||
slicePtrs[nestedName] = CLIFlags.String(nestedName, joined, "set "+k)
|
||||
}
|
||||
}
|
||||
|
||||
for key, val := range Defaults.val {
|
||||
if _, ok := val.(map[string]interface{}); !ok {
|
||||
parse(key, val, key)
|
||||
continue
|
||||
}
|
||||
nested, ok := val.(map[string]interface{})
|
||||
if !ok {
|
||||
// linter was confused by the above check
|
||||
panic("unreachable, if you see this you have entered a real life HellPot")
|
||||
}
|
||||
for k, v := range nested {
|
||||
nestedName := key + "." + k
|
||||
parse(k, v, nestedName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var replacer = map[string][]string{
|
||||
"-h": {"-help"},
|
||||
"-v": {"-version"},
|
||||
"-c": {"-config"},
|
||||
"-g": {"-bespoke.enable_grimoire", "true", "-bespoke.grimoire_file"},
|
||||
}
|
||||
|
||||
func InitCLI() {
|
||||
newArgs := make([]string, 0)
|
||||
for _, arg := range os.Args {
|
||||
if repl, ok := replacer[arg]; ok {
|
||||
newArgs = append(newArgs, repl...)
|
||||
continue
|
||||
}
|
||||
// check for unit test flags
|
||||
if !strings.HasPrefix(arg, "-test.") {
|
||||
newArgs = append(newArgs, arg)
|
||||
}
|
||||
}
|
||||
|
||||
CLIFlags.Bool("logger-debug", false, "force debug logging")
|
||||
CLIFlags.Bool("logger-trace", false, "force trace logging")
|
||||
CLIFlags.Bool("logger-nocolor", false, "force no color logging")
|
||||
CLIFlags.String("bespoke-grimoire", "", "specify a custom file used for text generation")
|
||||
newArgs = slices.Compact(newArgs)
|
||||
|
||||
CLIFlags.Bool("banner", false, "show banner and version then exit")
|
||||
CLIFlags.Bool("genconfig", false, "write default config to stdout then exit")
|
||||
CLIFlags.Bool("h", false, "show this help and exit")
|
||||
CLIFlags.Bool("help", false, "show this help and exit")
|
||||
CLIFlags.String("c", "", "specify config file")
|
||||
CLIFlags.String("config", "", "specify config file")
|
||||
CLIFlags.String("version", "", "show version and exit")
|
||||
CLIFlags.String("v", "", "show version and exit")
|
||||
|
||||
addCLIFlags()
|
||||
|
||||
if err := CLIFlags.Parse(newArgs[1:]); err != nil {
|
||||
println(err.Error())
|
||||
// flag.ExitOnError will call os.Exit(2)
|
||||
os.Exit(2)
|
||||
}
|
||||
if os.Getenv("HELLPOT_CONFIG") != "" {
|
||||
if err := CLIFlags.Set("config", os.Getenv("HELLPOT_CONFIG")); err != nil {
|
||||
|
@ -45,11 +95,11 @@ func InitCLI() {
|
|||
panic(err)
|
||||
}
|
||||
}
|
||||
if CLIFlags.Lookup("h").Value.String() == "true" || CLIFlags.Lookup("help").Value.String() == "true" {
|
||||
if CLIFlags.Lookup("help").Value.String() == "true" {
|
||||
CLIFlags.Usage()
|
||||
os.Exit(0)
|
||||
}
|
||||
if CLIFlags.Lookup("version").Value.String() == "true" || CLIFlags.Lookup("v").Value.String() == "true" {
|
||||
if CLIFlags.Lookup("version").Value.String() == "true" {
|
||||
_, _ = os.Stdout.WriteString("HellPot version: " + version.Version + "\n")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
@ -66,4 +116,5 @@ func InitCLI() {
|
|||
extra.Banner()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -83,4 +83,8 @@ var defOpts = map[string]interface{}{
|
|||
"deception": map[string]interface{}{
|
||||
"server_name": "nginx",
|
||||
},
|
||||
"bespoke": map[string]interface{}{
|
||||
"grimoire_file": "",
|
||||
"enable_grimoire": false,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -85,6 +85,6 @@ type DevilsPlaythings struct {
|
|||
|
||||
// Customization represents the configuration for the customizations.
|
||||
type Customization struct {
|
||||
CustomHeffalump bool `koanf:"custom_heffalump"`
|
||||
Grimoire string `koanf:"grimoire"`
|
||||
CustomHeffalump bool `koanf:"enable_grimoire"`
|
||||
Grimoire string `koanf:"grimoire_file"`
|
||||
}
|
||||
|
|
|
@ -3,9 +3,11 @@ package config
|
|||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/knadh/koanf/parsers/toml"
|
||||
flags "github.com/knadh/koanf/providers/basicflag"
|
||||
"github.com/knadh/koanf/providers/env"
|
||||
"github.com/knadh/koanf/v2"
|
||||
)
|
||||
|
@ -26,6 +28,134 @@ func (r *readerProvider) Read() (map[string]interface{}, error) {
|
|||
return toml.Parser().Unmarshal(b) //nolint:wrapcheck
|
||||
}
|
||||
|
||||
func normalizeMap(m map[string]interface{}) map[string]interface{} {
|
||||
for k, v := range m {
|
||||
ogk := k
|
||||
k = strings.ToLower(k)
|
||||
|
||||
var sslice []string
|
||||
var sliceOK bool
|
||||
|
||||
if sslice, sliceOK = v.([]string); !sliceOK {
|
||||
goto justLower
|
||||
}
|
||||
for i, s := range sslice {
|
||||
sslice[i] = strings.ToLower(s)
|
||||
}
|
||||
slices.Sort(sslice)
|
||||
m[k] = sslice
|
||||
justLower:
|
||||
if k != ogk {
|
||||
delete(m, ogk)
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (p *Parameters) merge(ogk *koanf.Koanf, newk *koanf.Koanf, friendlyName string) error {
|
||||
if ogk == nil {
|
||||
panic("original koanf is nil")
|
||||
}
|
||||
if newk == nil {
|
||||
return nil
|
||||
}
|
||||
dirty := false
|
||||
|
||||
newKeys := normalizeMap(newk.All())
|
||||
|
||||
if len(newk.All()) == 0 || len(newKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for k, v := range newKeys {
|
||||
if !ogk.Exists(k) {
|
||||
if err := ogk.Set(k, v); err != nil {
|
||||
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
|
||||
}
|
||||
dirty = true
|
||||
continue
|
||||
}
|
||||
|
||||
ogv := ogk.Get(k)
|
||||
if ogv == nil {
|
||||
if err := ogk.Set(k, v); err != nil {
|
||||
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
|
||||
}
|
||||
dirty = true
|
||||
continue
|
||||
}
|
||||
|
||||
if _, hasDefault := Defaults.val[k]; !hasDefault {
|
||||
continue
|
||||
}
|
||||
|
||||
if ogv == Defaults.val[k] && v != ogv {
|
||||
if err := ogk.Set(k, v); err != nil {
|
||||
panic(fmt.Sprintf("failed to set key %s: %v", k, err))
|
||||
}
|
||||
dirty = true
|
||||
}
|
||||
}
|
||||
|
||||
if !dirty {
|
||||
return nil
|
||||
}
|
||||
|
||||
println("found configuration overrides in " + friendlyName)
|
||||
|
||||
if err := ogk.Merge(newk); err != nil {
|
||||
return fmt.Errorf("failed to merge env config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Parameters) LoadEnv(k *koanf.Koanf) error {
|
||||
envK := koanf.New(".")
|
||||
|
||||
envErr := envK.Load(env.Provider("HELLPOT_", ".", func(s string) string {
|
||||
s = strings.TrimPrefix(s, "HELLPOT_")
|
||||
s = strings.ToLower(s)
|
||||
s = strings.ReplaceAll(s, "__", " ")
|
||||
s = strings.ReplaceAll(s, "_", ".")
|
||||
s = strings.ReplaceAll(s, " ", "_")
|
||||
return s
|
||||
}), nil)
|
||||
|
||||
if envErr != nil {
|
||||
return fmt.Errorf("failed to load env: %w", envErr)
|
||||
}
|
||||
|
||||
if err := p.merge(k, envK, "environment variables"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseCLISlice(key string, value string) (string, interface{}) {
|
||||
if _, ok := slicePtrs[key]; !ok {
|
||||
return key, value
|
||||
}
|
||||
split := strings.Split(value, ",")
|
||||
slices.Sort(split)
|
||||
return key, split
|
||||
}
|
||||
|
||||
func (p *Parameters) LoadFlags(k *koanf.Koanf) error {
|
||||
flagsK := koanf.New(".")
|
||||
|
||||
if err := flagsK.Load(flags.ProviderWithValue(CLIFlags, ".", parseCLISlice), nil); err != nil {
|
||||
return fmt.Errorf("failed to load flags: %w", err)
|
||||
}
|
||||
|
||||
if err := p.merge(k, flagsK, "cli arguments"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Setup(source io.Reader) (*Parameters, error) {
|
||||
k := koanf.New(".")
|
||||
|
||||
|
@ -39,23 +169,6 @@ func Setup(source io.Reader) (*Parameters, error) {
|
|||
}
|
||||
}
|
||||
|
||||
envK := koanf.New(".")
|
||||
|
||||
envErr := envK.Load(env.Provider("HELLPOT_", ".", func(s string) string {
|
||||
s = strings.TrimPrefix(s, "HELLPOT_")
|
||||
s = strings.ToLower(s)
|
||||
s = strings.ReplaceAll(s, "__", " ")
|
||||
s = strings.ReplaceAll(s, "_", ".")
|
||||
s = strings.ReplaceAll(s, " ", "_")
|
||||
return s
|
||||
}), nil)
|
||||
|
||||
if envErr == nil && envK != nil && len(envK.All()) > 0 {
|
||||
if err := k.Merge(envK); err != nil {
|
||||
return nil, fmt.Errorf("failed to merge env config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
p := &Parameters{
|
||||
source: k,
|
||||
}
|
||||
|
@ -64,6 +177,14 @@ func Setup(source io.Reader) (*Parameters, error) {
|
|||
p.UsingDefaults = true
|
||||
}
|
||||
|
||||
if err := p.LoadFlags(k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.LoadEnv(k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := k.Unmarshal("", p); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
|
|
@ -119,7 +119,7 @@ func getSrv(r *router.Router) fasthttp.Server {
|
|||
}
|
||||
}
|
||||
|
||||
func setupHeffalump(config *config.Parameters) error {
|
||||
func SetupHeffalump(config *config.Parameters) error {
|
||||
switch config.Bespoke.CustomHeffalump {
|
||||
case true:
|
||||
content, err := os.ReadFile(config.Bespoke.Grimoire)
|
||||
|
@ -151,6 +151,10 @@ func Serve(config *config.Parameters) error {
|
|||
log = config.GetLogger()
|
||||
runningConfig = config
|
||||
|
||||
if err := SetupHeffalump(config); err != nil {
|
||||
return fmt.Errorf("failed to setup heffalump: %w", err)
|
||||
}
|
||||
|
||||
l := config.HTTP.Bind + ":" + strconv.Itoa(int(config.HTTP.Port))
|
||||
|
||||
r := router.New()
|
||||
|
|
Loading…
Reference in New Issue