diff --git a/internal/identify/studio.go b/internal/identify/studio.go index d05967bc4..51bcaf2ee 100644 --- a/internal/identify/studio.go +++ b/internal/identify/studio.go @@ -46,17 +46,17 @@ func createMissingStudio(ctx context.Context, endpoint string, w models.StudioRe return nil, err } - studioPartial := s.Parent.ToPartial(s.Parent.StoredID, endpoint, nil, existingStashIDs) + studioPartial := s.Parent.ToPartial(*s.Parent.StoredID, endpoint, nil, existingStashIDs) parentImage, err := s.Parent.GetImage(ctx, nil) if err != nil { return nil, err } - if err := studio.ValidateModify(ctx, *studioPartial, w); err != nil { + if err := studio.ValidateModify(ctx, studioPartial, w); err != nil { return nil, err } - _, err = w.UpdatePartial(ctx, *studioPartial) + _, err = w.UpdatePartial(ctx, studioPartial) if err != nil { return nil, err } diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go index 298b58e27..8bb399601 100644 --- a/internal/manager/task_stash_box_tag.go +++ b/internal/manager/task_stash_box_tag.go @@ -311,13 +311,13 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode return err } - partial := s.ToPartial(s.StoredID, t.box.Endpoint, excluded, existingStashIDs) + partial := s.ToPartial(*s.StoredID, t.box.Endpoint, excluded, existingStashIDs) - if err := studio.ValidateModify(ctx, *partial, qb); err != nil { + if err := studio.ValidateModify(ctx, partial, qb); err != nil { return err } - if _, err := qb.UpdatePartial(ctx, *partial); err != nil { + if _, err := qb.UpdatePartial(ctx, partial); err != nil { return err } @@ -435,13 +435,13 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent * return err } - partial := parent.ToPartial(parent.StoredID, t.box.Endpoint, excluded, existingStashIDs) + partial := parent.ToPartial(*parent.StoredID, t.box.Endpoint, excluded, existingStashIDs) - if err := studio.ValidateModify(ctx, *partial, qb); err != nil { + if err := studio.ValidateModify(ctx, partial, qb); err != nil { return err } - if _, err := qb.UpdatePartial(ctx, *partial); err != nil { + if _, err := qb.UpdatePartial(ctx, partial); err != nil { return err } diff --git a/pkg/models/model_scraped_item.go b/pkg/models/model_scraped_item.go index 206f1109b..84c69d7e4 100644 --- a/pkg/models/model_scraped_item.go +++ b/pkg/models/model_scraped_item.go @@ -62,9 +62,9 @@ func (s *ScrapedStudio) GetImage(ctx context.Context, excluded map[string]bool) return nil, nil } -func (s *ScrapedStudio) ToPartial(id *string, endpoint string, excluded map[string]bool, existingStashIDs []StashID) *StudioPartial { +func (s *ScrapedStudio) ToPartial(id string, endpoint string, excluded map[string]bool, existingStashIDs []StashID) StudioPartial { ret := NewStudioPartial() - ret.ID, _ = strconv.Atoi(*id) + ret.ID, _ = strconv.Atoi(id) if s.Name != "" && !excluded["name"] { ret.Name = NewOptionalString(s.Name) @@ -82,8 +82,6 @@ func (s *ScrapedStudio) ToPartial(id *string, endpoint string, excluded map[stri ret.ParentID = NewOptionalInt(parentID) } } - } else { - ret.ParentID = NewOptionalIntPtr(nil) } if s.RemoteSiteID != nil && endpoint != "" { @@ -97,7 +95,7 @@ func (s *ScrapedStudio) ToPartial(id *string, endpoint string, excluded map[stri }) } - return &ret + return ret } // A performer from a scraping operation... diff --git a/pkg/models/model_scraped_item_test.go b/pkg/models/model_scraped_item_test.go index 50657188d..87ce2ad57 100644 --- a/pkg/models/model_scraped_item_test.go +++ b/pkg/models/model_scraped_item_test.go @@ -247,3 +247,123 @@ func Test_scrapedToPerformerInput(t *testing.T) { }) } } + +func TestScrapedStudio_ToPartial(t *testing.T) { + var ( + id = 1000 + idStr = strconv.Itoa(id) + storedID = "storedID" + parentStoredID = 2000 + parentStoredIDStr = strconv.Itoa(parentStoredID) + name = "name" + url = "url" + remoteSiteID = "remoteSiteID" + endpoint = "endpoint" + image = "image" + images = []string{image} + + existingEndpoint = "existingEndpoint" + existingStashID = StashID{"existingStashID", existingEndpoint} + existingStashIDs = []StashID{existingStashID} + ) + + fullStudio := ScrapedStudio{ + StoredID: &storedID, + Name: name, + URL: &url, + Parent: &ScrapedStudio{ + StoredID: &parentStoredIDStr, + }, + Image: &image, + Images: images, + RemoteSiteID: &remoteSiteID, + } + + type args struct { + id string + endpoint string + excluded map[string]bool + existingStashIDs []StashID + } + + stdArgs := args{ + id: idStr, + endpoint: endpoint, + excluded: map[string]bool{}, + existingStashIDs: existingStashIDs, + } + + excludeAll := map[string]bool{ + "name": true, + "url": true, + "parent": true, + } + + tests := []struct { + name string + o ScrapedStudio + args args + want StudioPartial + }{ + { + "full no exclusions", + fullStudio, + stdArgs, + StudioPartial{ + ID: id, + Name: NewOptionalString(name), + URL: NewOptionalString(url), + ParentID: NewOptionalInt(parentStoredID), + StashIDs: &UpdateStashIDs{ + StashIDs: append(existingStashIDs, StashID{ + Endpoint: endpoint, + StashID: remoteSiteID, + }), + Mode: RelationshipUpdateModeSet, + }, + }, + }, + { + "exclude all", + fullStudio, + args{ + id: idStr, + excluded: excludeAll, + }, + StudioPartial{ + ID: id, + }, + }, + { + "overwrite stash id", + fullStudio, + args{ + id: idStr, + excluded: excludeAll, + endpoint: existingEndpoint, + existingStashIDs: existingStashIDs, + }, + StudioPartial{ + ID: id, + StashIDs: &UpdateStashIDs{ + StashIDs: []StashID{{ + Endpoint: existingEndpoint, + StashID: remoteSiteID, + }}, + Mode: RelationshipUpdateModeSet, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := tt.o + got := s.ToPartial(tt.args.id, tt.args.endpoint, tt.args.excluded, tt.args.existingStashIDs) + + // unset updatedAt - we don't need to compare it + got.UpdatedAt = OptionalTime{} + + assert.Equal(t, tt.want, got) + }) + } +}