From 5cf28cf8af8200bf33a4ac60a5b233ef07e5eed4 Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Sun, 14 Jan 2024 12:52:16 +1100 Subject: [PATCH] Fix studio name uniqueness validation (#4454) --- internal/api/resolver_mutation_studio.go | 8 +- internal/manager/task_stash_box_tag.go | 4 + pkg/studio/{update.go => validate.go} | 29 ++++++- pkg/studio/validate_test.go | 104 +++++++++++++++++++++++ 4 files changed, 137 insertions(+), 8 deletions(-) rename pkg/studio/{update.go => validate.go} (80%) create mode 100644 pkg/studio/validate_test.go diff --git a/internal/api/resolver_mutation_studio.go b/internal/api/resolver_mutation_studio.go index 21f75224a..c41efe9ff 100644 --- a/internal/api/resolver_mutation_studio.go +++ b/internal/api/resolver_mutation_studio.go @@ -61,16 +61,10 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio if err := r.withTxn(ctx, func(ctx context.Context) error { qb := r.repository.Studio - if err := studio.EnsureStudioNameUnique(ctx, 0, newStudio.Name, qb); err != nil { + if err := studio.ValidateCreate(ctx, newStudio, qb); err != nil { return err } - if len(input.Aliases) > 0 { - if err := studio.EnsureAliasesUnique(ctx, 0, input.Aliases, qb); err != nil { - return err - } - } - err = qb.Create(ctx, &newStudio) if err != nil { return err diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go index 9d5d637c4..298b58e27 100644 --- a/internal/manager/task_stash_box_tag.go +++ b/internal/manager/task_stash_box_tag.go @@ -355,6 +355,10 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode err = r.WithTxn(ctx, func(ctx context.Context) error { qb := r.Studio + if err := studio.ValidateCreate(ctx, *newStudio, qb); err != nil { + return err + } + if err := qb.Create(ctx, newStudio); err != nil { return err } diff --git a/pkg/studio/update.go b/pkg/studio/validate.go similarity index 80% rename from pkg/studio/update.go rename to pkg/studio/validate.go index 3125e674e..8a8676351 100644 --- a/pkg/studio/update.go +++ b/pkg/studio/validate.go @@ -9,6 +9,7 @@ import ( ) var ( + ErrNameMissing = errors.New("studio name must not be blank") ErrStudioOwnAncestor = errors.New("studio cannot be an ancestor of itself") ) @@ -70,6 +71,32 @@ func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb model return nil } +func ValidateCreate(ctx context.Context, studio models.Studio, qb models.StudioQueryer) error { + if err := validateName(ctx, 0, studio.Name, qb); err != nil { + return err + } + + if studio.Aliases.Loaded() && len(studio.Aliases.List()) > 0 { + if err := EnsureAliasesUnique(ctx, 0, studio.Aliases.List(), qb); err != nil { + return err + } + } + + return nil +} + +func validateName(ctx context.Context, studioID int, name string, qb models.StudioQueryer) error { + if name == "" { + return ErrNameMissing + } + + if err := EnsureStudioNameUnique(ctx, studioID, name, qb); err != nil { + return err + } + + return nil +} + type ValidateModifyReader interface { models.StudioGetter models.StudioQueryer @@ -110,7 +137,7 @@ func ValidateModify(ctx context.Context, s models.StudioPartial, qb ValidateModi } if s.Name.Set && s.Name.Value != existing.Name { - if err := EnsureStudioNameUnique(ctx, 0, s.Name.Value, qb); err != nil { + if err := validateName(ctx, s.ID, s.Name.Value, qb); err != nil { return err } } diff --git a/pkg/studio/validate_test.go b/pkg/studio/validate_test.go new file mode 100644 index 000000000..6562dc5ca --- /dev/null +++ b/pkg/studio/validate_test.go @@ -0,0 +1,104 @@ +package studio + +import ( + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func nameFilter(n string) *models.StudioFilterType { + return &models.StudioFilterType{ + Name: &models.StringCriterionInput{ + Value: n, + Modifier: models.CriterionModifierEquals, + }, + } +} + +func TestValidateName(t *testing.T) { + db := mocks.NewDatabase() + + const ( + name1 = "name 1" + newName = "new name" + ) + + existing1 := models.Studio{ + ID: 1, + Name: name1, + } + + pp := 1 + findFilter := &models.FindFilterType{ + PerPage: &pp, + } + + db.Studio.On("Query", testCtx, nameFilter(name1), findFilter).Return([]*models.Studio{&existing1}, 1, nil) + db.Studio.On("Query", testCtx, mock.Anything, findFilter).Return(nil, 0, nil) + + tests := []struct { + tName string + name string + want error + }{ + {"missing name", "", ErrNameMissing}, + {"new name", newName, nil}, + {"existing name", name1, &NameExistsError{name1}}, + } + + for _, tt := range tests { + t.Run(tt.tName, func(t *testing.T) { + got := validateName(testCtx, 0, tt.name, db.Studio) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestValidateUpdateName(t *testing.T) { + db := mocks.NewDatabase() + + const ( + name1 = "name 1" + name2 = "name 2" + newName = "new name" + ) + + existing1 := models.Studio{ + ID: 1, + Name: name1, + } + existing2 := models.Studio{ + ID: 2, + Name: name2, + } + + pp := 1 + findFilter := &models.FindFilterType{ + PerPage: &pp, + } + + db.Studio.On("Query", testCtx, nameFilter(name1), findFilter).Return([]*models.Studio{&existing1}, 1, nil) + db.Studio.On("Query", testCtx, nameFilter(name2), findFilter).Return([]*models.Studio{&existing2}, 2, nil) + db.Studio.On("Query", testCtx, mock.Anything, findFilter).Return(nil, 0, nil) + + tests := []struct { + tName string + studio models.Studio + name string + want error + }{ + {"missing name", existing1, "", ErrNameMissing}, + {"same name", existing2, name2, nil}, + {"new name", existing1, newName, nil}, + } + + for _, tt := range tests { + t.Run(tt.tName, func(t *testing.T) { + got := validateName(testCtx, tt.studio.ID, tt.name, db.Studio) + assert.Equal(t, tt.want, got) + }) + } +}