diff --git a/internal/api/resolver.go b/internal/api/resolver.go index 5880f1e3c..e6289a218 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -11,6 +11,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/scraper" + "github.com/stashapp/stash/pkg/txn" ) var ( @@ -30,7 +31,9 @@ type hookExecutor interface { } type Resolver struct { - txnManager models.TransactionManager + txnManager txn.Manager + repository models.Repository + hookExecutor hookExecutor } @@ -85,17 +88,13 @@ type studioResolver struct{ *Resolver } type movieResolver struct{ *Resolver } type tagResolver struct{ *Resolver } -func (r *Resolver) withTxn(ctx context.Context, fn func(r models.Repository) error) error { - return r.txnManager.WithTxn(ctx, fn) -} - -func (r *Resolver) withReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error { - return r.txnManager.WithReadTxn(ctx, fn) +func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error { + return txn.WithTxn(ctx, r.txnManager, fn) } func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.SceneMarker().Wall(q) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.SceneMarker.Wall(ctx, q) return err }); err != nil { return nil, err @@ -104,8 +103,8 @@ func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*model } func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models.Scene, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Scene().Wall(q) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Scene.Wall(ctx, q) return err }); err != nil { return nil, err @@ -115,8 +114,8 @@ func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models } func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) (ret []*models.MarkerStringsResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.SceneMarker().GetMarkerStrings(q, sort) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.SceneMarker.GetMarkerStrings(ctx, q, sort) return err }); err != nil { return nil, err @@ -127,24 +126,25 @@ func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *stri func (r *queryResolver) Stats(ctx context.Context) (*StatsResultType, error) { var ret StatsResultType - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - scenesQB := repo.Scene() - imageQB := repo.Image() - galleryQB := repo.Gallery() - studiosQB := repo.Studio() - performersQB := repo.Performer() - moviesQB := repo.Movie() - tagsQB := repo.Tag() - scenesCount, _ := scenesQB.Count() - scenesSize, _ := scenesQB.Size() - scenesDuration, _ := scenesQB.Duration() - imageCount, _ := imageQB.Count() - imageSize, _ := imageQB.Size() - galleryCount, _ := galleryQB.Count() - performersCount, _ := performersQB.Count() - studiosCount, _ := studiosQB.Count() - moviesCount, _ := moviesQB.Count() - tagsCount, _ := tagsQB.Count() + if err := r.withTxn(ctx, func(ctx context.Context) error { + repo := r.repository + scenesQB := repo.Scene + imageQB := repo.Image + galleryQB := repo.Gallery + studiosQB := repo.Studio + performersQB := repo.Performer + moviesQB := repo.Movie + tagsQB := repo.Tag + scenesCount, _ := scenesQB.Count(ctx) + scenesSize, _ := scenesQB.Size(ctx) + scenesDuration, _ := scenesQB.Duration(ctx) + imageCount, _ := imageQB.Count(ctx) + imageSize, _ := imageQB.Size(ctx) + galleryCount, _ := galleryQB.Count(ctx) + performersCount, _ := performersQB.Count(ctx) + studiosCount, _ := studiosQB.Count(ctx) + moviesCount, _ := moviesQB.Count(ctx) + tagsCount, _ := tagsQB.Count(ctx) ret = StatsResultType{ SceneCount: scenesCount, @@ -202,15 +202,15 @@ func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([ var keys []int tags := make(map[int]*SceneMarkerTag) - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - sceneMarkers, err := repo.SceneMarker().FindBySceneID(sceneID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + sceneMarkers, err := r.repository.SceneMarker.FindBySceneID(ctx, sceneID) if err != nil { return err } - tqb := repo.Tag() + tqb := r.repository.Tag for _, sceneMarker := range sceneMarkers { - markerPrimaryTag, err := tqb.Find(sceneMarker.PrimaryTagID) + markerPrimaryTag, err := tqb.Find(ctx, sceneMarker.PrimaryTagID) if err != nil { return err } diff --git a/internal/api/resolver_model_gallery.go b/internal/api/resolver_model_gallery.go index 911cb8fe1..8e1d98dd4 100644 --- a/internal/api/resolver_model_gallery.go +++ b/internal/api/resolver_model_gallery.go @@ -24,12 +24,12 @@ func (r *galleryResolver) Title(ctx context.Context, obj *models.Gallery) (*stri } func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret []*models.Image, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error // #2376 - sort images by path // doing this via Query is really slow, so stick with FindByGalleryID - ret, err = repo.Image().FindByGalleryID(obj.ID) + ret, err = r.repository.Image.FindByGalleryID(ctx, obj.ID) if err != nil { return err } @@ -43,9 +43,9 @@ func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret } func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (ret *models.Image, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { // doing this via Query is really slow, so stick with FindByGalleryID - imgs, err := repo.Image().FindByGalleryID(obj.ID) + imgs, err := r.repository.Image.FindByGalleryID(ctx, obj.ID) if err != nil { return err } @@ -100,9 +100,9 @@ func (r *galleryResolver) Rating(ctx context.Context, obj *models.Gallery) (*int } func (r *galleryResolver) Scenes(ctx context.Context, obj *models.Gallery) (ret []*models.Scene, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - ret, err = repo.Scene().FindByGalleryID(obj.ID) + ret, err = r.repository.Scene.FindByGalleryID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -116,9 +116,9 @@ func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret return nil, nil } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) + ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64)) return err }); err != nil { return nil, err @@ -128,9 +128,9 @@ func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret } func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - ret, err = repo.Tag().FindByGalleryID(obj.ID) + ret, err = r.repository.Tag.FindByGalleryID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -140,9 +140,9 @@ func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret [] } func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (ret []*models.Performer, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - ret, err = repo.Performer().FindByGalleryID(obj.ID) + ret, err = r.repository.Performer.FindByGalleryID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -152,9 +152,9 @@ func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) ( } func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (ret int, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - ret, err = repo.Image().CountByGalleryID(obj.ID) + ret, err = r.repository.Image.CountByGalleryID(ctx, obj.ID) return err }); err != nil { return 0, err diff --git a/internal/api/resolver_model_image.go b/internal/api/resolver_model_image.go index 0f47b6c3e..8c7222b97 100644 --- a/internal/api/resolver_model_image.go +++ b/internal/api/resolver_model_image.go @@ -45,9 +45,9 @@ func (r *imageResolver) Paths(ctx context.Context, obj *models.Image) (*ImagePat } func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) (ret []*models.Gallery, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - ret, err = repo.Gallery().FindByImageID(obj.ID) + ret, err = r.repository.Gallery.FindByImageID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -61,8 +61,8 @@ func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *mod return nil, nil } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64)) return err }); err != nil { return nil, err @@ -72,8 +72,8 @@ func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *mod } func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().FindByImageID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.FindByImageID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -83,8 +83,8 @@ func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*mod } func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) (ret []*models.Performer, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Performer().FindByImageID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Performer.FindByImageID(ctx, obj.ID) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_model_movie.go b/internal/api/resolver_model_movie.go index 57e979828..9085eee82 100644 --- a/internal/api/resolver_model_movie.go +++ b/internal/api/resolver_model_movie.go @@ -56,8 +56,8 @@ func (r *movieResolver) Rating(ctx context.Context, obj *models.Movie) (*int, er func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (ret *models.Studio, err error) { if obj.StudioID.Valid { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64)) return err }); err != nil { return nil, err @@ -92,9 +92,9 @@ func (r *movieResolver) FrontImagePath(ctx context.Context, obj *models.Movie) ( func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*string, error) { // don't return any thing if there is no back image var img []byte - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - img, err = repo.Movie().GetBackImage(obj.ID) + img, err = r.repository.Movie.GetBackImage(ctx, obj.ID) if err != nil { return err } @@ -115,8 +115,8 @@ func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (* func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = repo.Scene().CountByMovieID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = r.repository.Scene.CountByMovieID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -126,9 +126,9 @@ func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret } func (r *movieResolver) Scenes(ctx context.Context, obj *models.Movie) (ret []*models.Scene, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - ret, err = repo.Scene().FindByMovieID(obj.ID) + ret, err = r.repository.Scene.FindByMovieID(ctx, obj.ID) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_model_performer.go b/internal/api/resolver_model_performer.go index 364c346ca..3a44a1cc0 100644 --- a/internal/api/resolver_model_performer.go +++ b/internal/api/resolver_model_performer.go @@ -142,8 +142,8 @@ func (r *performerResolver) ImagePath(ctx context.Context, obj *models.Performer } func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().FindByPerformerID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.FindByPerformerID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -154,8 +154,8 @@ func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (re func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = repo.Scene().CountByPerformerID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = r.repository.Scene.CountByPerformerID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -166,8 +166,8 @@ func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performe func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = image.CountByPerformerID(repo.Image(), obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = image.CountByPerformerID(ctx, r.repository.Image, obj.ID) return err }); err != nil { return nil, err @@ -178,8 +178,8 @@ func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performe func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = gallery.CountByPerformerID(repo.Gallery(), obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = gallery.CountByPerformerID(ctx, r.repository.Gallery, obj.ID) return err }); err != nil { return nil, err @@ -189,8 +189,8 @@ func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Perfor } func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (ret []*models.Scene, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Scene().FindByPerformerID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Scene.FindByPerformerID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -200,8 +200,8 @@ func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) ( } func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) (ret []*models.StashID, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Performer().GetStashIDs(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Performer.GetStashIDs(ctx, obj.ID) return err }); err != nil { return nil, err @@ -256,8 +256,8 @@ func (r *performerResolver) UpdatedAt(ctx context.Context, obj *models.Performer } func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (ret []*models.Movie, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Movie().FindByPerformerID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Movie.FindByPerformerID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -268,8 +268,8 @@ func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) ( func (r *performerResolver) MovieCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = repo.Movie().CountByPerformerID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = r.repository.Movie.CountByPerformerID(ctx, obj.ID) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_model_scene.go b/internal/api/resolver_model_scene.go index 0b81beaa5..d9c783ac8 100644 --- a/internal/api/resolver_model_scene.go +++ b/internal/api/resolver_model_scene.go @@ -116,8 +116,8 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*ScenePat } func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (ret []*models.SceneMarker, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.SceneMarker().FindBySceneID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.SceneMarker.FindBySceneID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -127,8 +127,8 @@ func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (re } func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret []*models.SceneCaption, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Scene().GetCaptions(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Scene.GetCaptions(ctx, obj.ID) return err }); err != nil { return nil, err @@ -138,8 +138,8 @@ func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret [] } func (r *sceneResolver) Galleries(ctx context.Context, obj *models.Scene) (ret []*models.Gallery, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Gallery().FindBySceneID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Gallery.FindBySceneID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -153,8 +153,8 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod return nil, nil } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64)) return err }); err != nil { return nil, err @@ -164,17 +164,17 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod } func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*SceneMovie, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - qb := repo.Scene() - mqb := repo.Movie() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene + mqb := r.repository.Movie - sceneMovies, err := qb.GetMovies(obj.ID) + sceneMovies, err := qb.GetMovies(ctx, obj.ID) if err != nil { return err } for _, sm := range sceneMovies { - movie, err := mqb.Find(sm.MovieID) + movie, err := mqb.Find(ctx, sm.MovieID) if err != nil { return err } @@ -200,8 +200,8 @@ func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*S } func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().FindBySceneID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.FindBySceneID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -211,8 +211,8 @@ func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*mod } func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret []*models.Performer, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Performer().FindBySceneID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Performer.FindBySceneID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -222,8 +222,8 @@ func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret } func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) (ret []*models.StashID, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Scene().GetStashIDs(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Scene.GetStashIDs(ctx, obj.ID) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_model_scene_marker.go b/internal/api/resolver_model_scene_marker.go index 64d418bd1..7a4d01be1 100644 --- a/internal/api/resolver_model_scene_marker.go +++ b/internal/api/resolver_model_scene_marker.go @@ -13,9 +13,9 @@ func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker panic("Invalid scene id") } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { sceneID := int(obj.SceneID.Int64) - ret, err = repo.Scene().Find(sceneID) + ret, err = r.repository.Scene.Find(ctx, sceneID) return err }); err != nil { return nil, err @@ -25,8 +25,8 @@ func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker } func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (ret *models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().Find(obj.PrimaryTagID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.Find(ctx, obj.PrimaryTagID) return err }); err != nil { return nil, err @@ -36,8 +36,8 @@ func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneM } func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().FindBySceneMarkerID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.FindBySceneMarkerID(ctx, obj.ID) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_model_studio.go b/internal/api/resolver_model_studio.go index d2b6b44c1..a7fc56442 100644 --- a/internal/api/resolver_model_studio.go +++ b/internal/api/resolver_model_studio.go @@ -29,9 +29,9 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st imagePath := urlbuilders.NewStudioURLBuilder(baseURL, obj).GetStudioImageURL() var hasImage bool - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - hasImage, err = repo.Studio().HasImage(obj.ID) + hasImage, err = r.repository.Studio.HasImage(ctx, obj.ID) return err }); err != nil { return nil, err @@ -46,8 +46,8 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st } func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret []string, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().GetAliases(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.GetAliases(ctx, obj.ID) return err }); err != nil { return nil, err @@ -58,8 +58,8 @@ func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret [ func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = repo.Scene().CountByStudioID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = r.repository.Scene.CountByStudioID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -70,8 +70,8 @@ func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (re func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = image.CountByStudioID(repo.Image(), obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = image.CountByStudioID(ctx, r.repository.Image, obj.ID) return err }); err != nil { return nil, err @@ -82,8 +82,8 @@ func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (re func (r *studioResolver) GalleryCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = gallery.CountByStudioID(repo.Gallery(), obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = gallery.CountByStudioID(ctx, r.repository.Gallery, obj.ID) return err }); err != nil { return nil, err @@ -97,8 +97,8 @@ func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) ( return nil, nil } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().Find(int(obj.ParentID.Int64)) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.Find(ctx, int(obj.ParentID.Int64)) return err }); err != nil { return nil, err @@ -108,8 +108,8 @@ func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) ( } func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (ret []*models.Studio, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().FindChildren(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.FindChildren(ctx, obj.ID) return err }); err != nil { return nil, err @@ -119,8 +119,8 @@ func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) ( } func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) (ret []*models.StashID, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().GetStashIDs(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.GetStashIDs(ctx, obj.ID) return err }); err != nil { return nil, err @@ -153,8 +153,8 @@ func (r *studioResolver) UpdatedAt(ctx context.Context, obj *models.Studio) (*ti } func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []*models.Movie, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Movie().FindByStudioID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Movie.FindByStudioID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -165,8 +165,8 @@ func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret [] func (r *studioResolver) MovieCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = repo.Movie().CountByStudioID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = r.repository.Movie.CountByStudioID(ctx, obj.ID) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_model_tag.go b/internal/api/resolver_model_tag.go index cb406e5fc..3592dd959 100644 --- a/internal/api/resolver_model_tag.go +++ b/internal/api/resolver_model_tag.go @@ -11,8 +11,8 @@ import ( ) func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().FindByChildTagID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.FindByChildTagID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -22,8 +22,8 @@ func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*mode } func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().FindByParentTagID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.FindByParentTagID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -33,8 +33,8 @@ func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*mod } func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []string, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().GetAliases(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.GetAliases(ctx, obj.ID) return err }); err != nil { return nil, err @@ -45,8 +45,8 @@ func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []strin func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { var count int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - count, err = repo.Scene().CountByTagID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + count, err = r.repository.Scene.CountByTagID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -57,8 +57,8 @@ func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { var count int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - count, err = repo.SceneMarker().CountByTagID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + count, err = r.repository.SceneMarker.CountByTagID(ctx, obj.ID) return err }); err != nil { return nil, err @@ -69,8 +69,8 @@ func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (re func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = image.CountByTagID(repo.Image(), obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = image.CountByTagID(ctx, r.repository.Image, obj.ID) return err }); err != nil { return nil, err @@ -81,8 +81,8 @@ func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { var res int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - res, err = gallery.CountByTagID(repo.Gallery(), obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + res, err = gallery.CountByTagID(ctx, r.repository.Gallery, obj.ID) return err }); err != nil { return nil, err @@ -93,8 +93,8 @@ func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *i func (r *tagResolver) PerformerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { var count int - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - count, err = repo.Performer().CountByTagID(obj.ID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + count, err = r.repository.Performer.CountByTagID(ctx, obj.ID) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go index 48b1c6b9f..484903b54 100644 --- a/internal/api/resolver_mutation_configure.go +++ b/internal/api/resolver_mutation_configure.go @@ -132,7 +132,9 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen } // validate changing VideoFileNamingAlgorithm - if err := manager.ValidateVideoFileNamingAlgorithm(r.txnManager, *input.VideoFileNamingAlgorithm); err != nil { + if err := r.withTxn(context.TODO(), func(ctx context.Context) error { + return manager.ValidateVideoFileNamingAlgorithm(ctx, r.repository.Scene, *input.VideoFileNamingAlgorithm) + }); err != nil { return makeConfigGeneralResult(), err } diff --git a/internal/api/resolver_mutation_gallery.go b/internal/api/resolver_mutation_gallery.go index 816d6aba3..04a320365 100644 --- a/internal/api/resolver_mutation_gallery.go +++ b/internal/api/resolver_mutation_gallery.go @@ -11,6 +11,7 @@ import ( "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/file" + "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" @@ -21,8 +22,8 @@ import ( ) func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Gallery().Find(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Gallery.Find(ctx, id) return err }); err != nil { return nil, err @@ -80,26 +81,26 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat // Start the transaction and save the gallery var gallery *models.Gallery - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Gallery() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Gallery var err error - gallery, err = qb.Create(newGallery) + gallery, err = qb.Create(ctx, newGallery) if err != nil { return err } // Save the performers - if err := r.updateGalleryPerformers(qb, gallery.ID, input.PerformerIds); err != nil { + if err := r.updateGalleryPerformers(ctx, qb, gallery.ID, input.PerformerIds); err != nil { return err } // Save the tags - if err := r.updateGalleryTags(qb, gallery.ID, input.TagIds); err != nil { + if err := r.updateGalleryTags(ctx, qb, gallery.ID, input.TagIds); err != nil { return err } // Save the scenes - if err := r.updateGalleryScenes(qb, gallery.ID, input.SceneIds); err != nil { + if err := r.updateGalleryScenes(ctx, qb, gallery.ID, input.SceneIds); err != nil { return err } @@ -112,28 +113,32 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat return r.getGallery(ctx, gallery.ID) } -func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error { +func (r *mutationResolver) updateGalleryPerformers(ctx context.Context, qb gallery.PerformerUpdater, galleryID int, performerIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(performerIDs) if err != nil { return err } - return qb.UpdatePerformers(galleryID, ids) + return qb.UpdatePerformers(ctx, galleryID, ids) } -func (r *mutationResolver) updateGalleryTags(qb models.GalleryReaderWriter, galleryID int, tagIDs []string) error { +func (r *mutationResolver) updateGalleryTags(ctx context.Context, qb gallery.TagUpdater, galleryID int, tagIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(tagIDs) if err != nil { return err } - return qb.UpdateTags(galleryID, ids) + return qb.UpdateTags(ctx, galleryID, ids) } -func (r *mutationResolver) updateGalleryScenes(qb models.GalleryReaderWriter, galleryID int, sceneIDs []string) error { +type GallerySceneUpdater interface { + UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error +} + +func (r *mutationResolver) updateGalleryScenes(ctx context.Context, qb GallerySceneUpdater, galleryID int, sceneIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(sceneIDs) if err != nil { return err } - return qb.UpdateScenes(galleryID, ids) + return qb.UpdateScenes(ctx, galleryID, ids) } func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (ret *models.Gallery, err error) { @@ -142,8 +147,8 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle } // Start the transaction and save the gallery - if err := r.withTxn(ctx, func(repo models.Repository) error { - ret, err = r.galleryUpdate(input, translator, repo) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.galleryUpdate(ctx, input, translator) return err }); err != nil { return nil, err @@ -158,13 +163,13 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models. inputMaps := getUpdateInputMaps(ctx) // Start the transaction and save the gallery - if err := r.withTxn(ctx, func(repo models.Repository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { for i, gallery := range input { translator := changesetTranslator{ inputMap: inputMaps[i], } - thisGallery, err := r.galleryUpdate(*gallery, translator, repo) + thisGallery, err := r.galleryUpdate(ctx, *gallery, translator) if err != nil { return err } @@ -196,8 +201,8 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models. return newRet, nil } -func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) { - qb := repo.Gallery() +func (r *mutationResolver) galleryUpdate(ctx context.Context, input models.GalleryUpdateInput, translator changesetTranslator) (*models.Gallery, error) { + qb := r.repository.Gallery // Populate gallery from the input galleryID, err := strconv.Atoi(input.ID) @@ -205,7 +210,7 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl return nil, err } - originalGallery, err := qb.Find(galleryID) + originalGallery, err := qb.Find(ctx, galleryID) if err != nil { return nil, err } @@ -244,28 +249,28 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl // gallery scene is set from the scene only - gallery, err := qb.UpdatePartial(updatedGallery) + gallery, err := qb.UpdatePartial(ctx, updatedGallery) if err != nil { return nil, err } // Save the performers if translator.hasField("performer_ids") { - if err := r.updateGalleryPerformers(qb, galleryID, input.PerformerIds); err != nil { + if err := r.updateGalleryPerformers(ctx, qb, galleryID, input.PerformerIds); err != nil { return nil, err } } // Save the tags if translator.hasField("tag_ids") { - if err := r.updateGalleryTags(qb, galleryID, input.TagIds); err != nil { + if err := r.updateGalleryTags(ctx, qb, galleryID, input.TagIds); err != nil { return nil, err } } // Save the scenes if translator.hasField("scene_ids") { - if err := r.updateGalleryScenes(qb, galleryID, input.SceneIds); err != nil { + if err := r.updateGalleryScenes(ctx, qb, galleryID, input.SceneIds); err != nil { return nil, err } } @@ -295,14 +300,14 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall ret := []*models.Gallery{} // Start the transaction and save the galleries - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Gallery() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Gallery for _, galleryIDStr := range input.Ids { galleryID, _ := strconv.Atoi(galleryIDStr) updatedGallery.ID = galleryID - gallery, err := qb.UpdatePartial(updatedGallery) + gallery, err := qb.UpdatePartial(ctx, updatedGallery) if err != nil { return err } @@ -311,36 +316,36 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall // Save the performers if translator.hasField("performer_ids") { - performerIDs, err := adjustGalleryPerformerIDs(qb, galleryID, *input.PerformerIds) + performerIDs, err := adjustGalleryPerformerIDs(ctx, qb, galleryID, *input.PerformerIds) if err != nil { return err } - if err := qb.UpdatePerformers(galleryID, performerIDs); err != nil { + if err := qb.UpdatePerformers(ctx, galleryID, performerIDs); err != nil { return err } } // Save the tags if translator.hasField("tag_ids") { - tagIDs, err := adjustGalleryTagIDs(qb, galleryID, *input.TagIds) + tagIDs, err := adjustGalleryTagIDs(ctx, qb, galleryID, *input.TagIds) if err != nil { return err } - if err := qb.UpdateTags(galleryID, tagIDs); err != nil { + if err := qb.UpdateTags(ctx, galleryID, tagIDs); err != nil { return err } } // Save the scenes if translator.hasField("scene_ids") { - sceneIDs, err := adjustGallerySceneIDs(qb, galleryID, *input.SceneIds) + sceneIDs, err := adjustGallerySceneIDs(ctx, qb, galleryID, *input.SceneIds) if err != nil { return err } - if err := qb.UpdateScenes(galleryID, sceneIDs); err != nil { + if err := qb.UpdateScenes(ctx, galleryID, sceneIDs); err != nil { return err } } @@ -367,8 +372,20 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall return newRet, nil } -func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetPerformerIDs(galleryID) +type GalleryPerformerGetter interface { + GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error) +} + +type GalleryTagGetter interface { + GetTagIDs(ctx context.Context, galleryID int) ([]int, error) +} + +type GallerySceneGetter interface { + GetSceneIDs(ctx context.Context, galleryID int) ([]int, error) +} + +func adjustGalleryPerformerIDs(ctx context.Context, qb GalleryPerformerGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetPerformerIDs(ctx, galleryID) if err != nil { return nil, err } @@ -376,8 +393,8 @@ func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkU return adjustIDs(ret, ids), nil } -func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetTagIDs(galleryID) +func adjustGalleryTagIDs(ctx context.Context, qb GalleryTagGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetTagIDs(ctx, galleryID) if err != nil { return nil, err } @@ -385,8 +402,8 @@ func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateI return adjustIDs(ret, ids), nil } -func adjustGallerySceneIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetSceneIDs(galleryID) +func adjustGallerySceneIDs(ctx context.Context, qb GallerySceneGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetSceneIDs(ctx, galleryID) if err != nil { return nil, err } @@ -410,12 +427,12 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall deleteGenerated := utils.IsTrue(input.DeleteGenerated) deleteFile := utils.IsTrue(input.DeleteFile) - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Gallery() - iqb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Gallery + iqb := r.repository.Image for _, id := range galleryIDs { - gallery, err := qb.Find(id) + gallery, err := qb.Find(ctx, id) if err != nil { return err } @@ -428,13 +445,13 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall // if this is a zip-based gallery, delete the images as well first if gallery.Zip { - imgs, err := iqb.FindByGalleryID(id) + imgs, err := iqb.FindByGalleryID(ctx, id) if err != nil { return err } for _, img := range imgs { - if err := image.Destroy(img, iqb, fileDeleter, deleteGenerated, false); err != nil { + if err := image.Destroy(ctx, img, iqb, fileDeleter, deleteGenerated, false); err != nil { return err } @@ -448,19 +465,19 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall } } else if deleteFile { // Delete image if it is only attached to this gallery - imgs, err := iqb.FindByGalleryID(id) + imgs, err := iqb.FindByGalleryID(ctx, id) if err != nil { return err } for _, img := range imgs { - imgGalleries, err := qb.FindByImageID(img.ID) + imgGalleries, err := qb.FindByImageID(ctx, img.ID) if err != nil { return err } if len(imgGalleries) == 1 { - if err := image.Destroy(img, iqb, fileDeleter, deleteGenerated, deleteFile); err != nil { + if err := image.Destroy(ctx, img, iqb, fileDeleter, deleteGenerated, deleteFile); err != nil { return err } @@ -472,7 +489,7 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall // don't do this with the file deleter } - if err := qb.Destroy(id); err != nil { + if err := qb.Destroy(ctx, id); err != nil { return err } } @@ -537,9 +554,9 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAd return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Gallery() - gallery, err := qb.Find(galleryID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Gallery + gallery, err := qb.Find(ctx, galleryID) if err != nil { return err } @@ -552,13 +569,13 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAd return errors.New("cannot modify zip gallery images") } - newIDs, err := qb.GetImageIDs(galleryID) + newIDs, err := qb.GetImageIDs(ctx, galleryID) if err != nil { return err } newIDs = intslice.IntAppendUniques(newIDs, imageIDs) - return qb.UpdateImages(galleryID, newIDs) + return qb.UpdateImages(ctx, galleryID, newIDs) }); err != nil { return false, err } @@ -577,9 +594,9 @@ func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input Galler return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Gallery() - gallery, err := qb.Find(galleryID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Gallery + gallery, err := qb.Find(ctx, galleryID) if err != nil { return err } @@ -592,13 +609,13 @@ func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input Galler return errors.New("cannot modify zip gallery images") } - newIDs, err := qb.GetImageIDs(galleryID) + newIDs, err := qb.GetImageIDs(ctx, galleryID) if err != nil { return err } newIDs = intslice.IntExclude(newIDs, imageIDs) - return qb.UpdateImages(galleryID, newIDs) + return qb.UpdateImages(ctx, galleryID, newIDs) }); err != nil { return false, err } diff --git a/internal/api/resolver_mutation_image.go b/internal/api/resolver_mutation_image.go index 2df016be2..72d98696a 100644 --- a/internal/api/resolver_mutation_image.go +++ b/internal/api/resolver_mutation_image.go @@ -16,8 +16,8 @@ import ( ) func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Image().Find(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Image.Find(ctx, id) return err }); err != nil { return nil, err @@ -32,8 +32,8 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input ImageUpdateInp } // Start the transaction and save the image - if err := r.withTxn(ctx, func(repo models.Repository) error { - ret, err = r.imageUpdate(input, translator, repo) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.imageUpdate(ctx, input, translator) return err }); err != nil { return nil, err @@ -48,13 +48,13 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat inputMaps := getUpdateInputMaps(ctx) // Start the transaction and save the image - if err := r.withTxn(ctx, func(repo models.Repository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { for i, image := range input { translator := changesetTranslator{ inputMap: inputMaps[i], } - thisImage, err := r.imageUpdate(*image, translator, repo) + thisImage, err := r.imageUpdate(ctx, *image, translator) if err != nil { return err } @@ -86,7 +86,7 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat return newRet, nil } -func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) { +func (r *mutationResolver) imageUpdate(ctx context.Context, input ImageUpdateInput, translator changesetTranslator) (*models.Image, error) { // Populate image from the input imageID, err := strconv.Atoi(input.ID) if err != nil { @@ -104,28 +104,28 @@ func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator change updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id") updatedImage.Organized = input.Organized - qb := repo.Image() - image, err := qb.Update(updatedImage) + qb := r.repository.Image + image, err := qb.Update(ctx, updatedImage) if err != nil { return nil, err } if translator.hasField("gallery_ids") { - if err := r.updateImageGalleries(qb, imageID, input.GalleryIds); err != nil { + if err := r.updateImageGalleries(ctx, imageID, input.GalleryIds); err != nil { return nil, err } } // Save the performers if translator.hasField("performer_ids") { - if err := r.updateImagePerformers(qb, imageID, input.PerformerIds); err != nil { + if err := r.updateImagePerformers(ctx, imageID, input.PerformerIds); err != nil { return nil, err } } // Save the tags if translator.hasField("tag_ids") { - if err := r.updateImageTags(qb, imageID, input.TagIds); err != nil { + if err := r.updateImageTags(ctx, imageID, input.TagIds); err != nil { return nil, err } } @@ -133,28 +133,28 @@ func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator change return image, nil } -func (r *mutationResolver) updateImageGalleries(qb models.ImageReaderWriter, imageID int, galleryIDs []string) error { +func (r *mutationResolver) updateImageGalleries(ctx context.Context, imageID int, galleryIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(galleryIDs) if err != nil { return err } - return qb.UpdateGalleries(imageID, ids) + return r.repository.Image.UpdateGalleries(ctx, imageID, ids) } -func (r *mutationResolver) updateImagePerformers(qb models.ImageReaderWriter, imageID int, performerIDs []string) error { +func (r *mutationResolver) updateImagePerformers(ctx context.Context, imageID int, performerIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(performerIDs) if err != nil { return err } - return qb.UpdatePerformers(imageID, ids) + return r.repository.Image.UpdatePerformers(ctx, imageID, ids) } -func (r *mutationResolver) updateImageTags(qb models.ImageReaderWriter, imageID int, tagsIDs []string) error { +func (r *mutationResolver) updateImageTags(ctx context.Context, imageID int, tagsIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(tagsIDs) if err != nil { return err } - return qb.UpdateTags(imageID, ids) + return r.repository.Image.UpdateTags(ctx, imageID, ids) } func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageUpdateInput) (ret []*models.Image, err error) { @@ -180,13 +180,13 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU updatedImage.Organized = input.Organized // Start the transaction and save the image marker - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Image for _, imageID := range imageIDs { updatedImage.ID = imageID - image, err := qb.Update(updatedImage) + image, err := qb.Update(ctx, updatedImage) if err != nil { return err } @@ -195,36 +195,36 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU // Save the galleries if translator.hasField("gallery_ids") { - galleryIDs, err := adjustImageGalleryIDs(qb, imageID, *input.GalleryIds) + galleryIDs, err := r.adjustImageGalleryIDs(ctx, imageID, *input.GalleryIds) if err != nil { return err } - if err := qb.UpdateGalleries(imageID, galleryIDs); err != nil { + if err := qb.UpdateGalleries(ctx, imageID, galleryIDs); err != nil { return err } } // Save the performers if translator.hasField("performer_ids") { - performerIDs, err := adjustImagePerformerIDs(qb, imageID, *input.PerformerIds) + performerIDs, err := r.adjustImagePerformerIDs(ctx, imageID, *input.PerformerIds) if err != nil { return err } - if err := qb.UpdatePerformers(imageID, performerIDs); err != nil { + if err := qb.UpdatePerformers(ctx, imageID, performerIDs); err != nil { return err } } // Save the tags if translator.hasField("tag_ids") { - tagIDs, err := adjustImageTagIDs(qb, imageID, *input.TagIds) + tagIDs, err := r.adjustImageTagIDs(ctx, imageID, *input.TagIds) if err != nil { return err } - if err := qb.UpdateTags(imageID, tagIDs); err != nil { + if err := qb.UpdateTags(ctx, imageID, tagIDs); err != nil { return err } } @@ -251,8 +251,8 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU return newRet, nil } -func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetGalleryIDs(imageID) +func (r *mutationResolver) adjustImageGalleryIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = r.repository.Image.GetGalleryIDs(ctx, imageID) if err != nil { return nil, err } @@ -260,8 +260,8 @@ func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds return adjustIDs(ret, ids), nil } -func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetPerformerIDs(imageID) +func (r *mutationResolver) adjustImagePerformerIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = r.repository.Image.GetPerformerIDs(ctx, imageID) if err != nil { return nil, err } @@ -269,8 +269,8 @@ func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateI return adjustIDs(ret, ids), nil } -func adjustImageTagIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetTagIDs(imageID) +func (r *mutationResolver) adjustImageTagIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = r.repository.Image.GetTagIDs(ctx, imageID) if err != nil { return nil, err } @@ -289,10 +289,10 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD Deleter: *file.NewDeleter(), Paths: manager.GetInstance().Paths, } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Image - i, err = qb.Find(imageID) + i, err = r.repository.Image.Find(ctx, imageID) if err != nil { return err } @@ -301,7 +301,7 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD return fmt.Errorf("image with id %d not found", imageID) } - return image.Destroy(i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)) + return image.Destroy(ctx, i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)) }); err != nil { fileDeleter.Rollback() return false, err @@ -331,12 +331,12 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image Deleter: *file.NewDeleter(), Paths: manager.GetInstance().Paths, } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Image for _, imageID := range imageIDs { - i, err := qb.Find(imageID) + i, err := qb.Find(ctx, imageID) if err != nil { return err } @@ -347,7 +347,7 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image images = append(images, i) - if err := image.Destroy(i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil { + if err := image.Destroy(ctx, i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil { return err } } @@ -379,10 +379,10 @@ func (r *mutationResolver) ImageIncrementO(ctx context.Context, id string) (ret return 0, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Image - ret, err = qb.IncrementOCounter(imageID) + ret, err = qb.IncrementOCounter(ctx, imageID) return err }); err != nil { return 0, err @@ -397,10 +397,10 @@ func (r *mutationResolver) ImageDecrementO(ctx context.Context, id string) (ret return 0, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Image - ret, err = qb.DecrementOCounter(imageID) + ret, err = qb.DecrementOCounter(ctx, imageID) return err }); err != nil { return 0, err @@ -415,10 +415,10 @@ func (r *mutationResolver) ImageResetO(ctx context.Context, id string) (ret int, return 0, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Image - ret, err = qb.ResetOCounter(imageID) + ret, err = qb.ResetOCounter(ctx, imageID) return err }); err != nil { return 0, err diff --git a/internal/api/resolver_mutation_metadata.go b/internal/api/resolver_mutation_metadata.go index 0d1794e58..ff8635536 100644 --- a/internal/api/resolver_mutation_metadata.go +++ b/internal/api/resolver_mutation_metadata.go @@ -12,7 +12,6 @@ import ( "github.com/stashapp/stash/internal/identify" "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/internal/manager/config" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/logger" ) @@ -111,6 +110,7 @@ func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatab // if download is true, then backup to temporary file and return a link download := input.Download != nil && *input.Download mgr := manager.GetInstance() + database := mgr.Database var backupPath string if download { if err := fsutil.EnsureDir(mgr.Paths.Generated.Downloads); err != nil { @@ -127,7 +127,7 @@ func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatab backupPath = database.DatabaseBackupPath() } - err := database.Backup(database.DB, backupPath) + err := database.Backup(backupPath) if err != nil { return nil, err } diff --git a/internal/api/resolver_mutation_movie.go b/internal/api/resolver_mutation_movie.go index da6dfdfe7..0a22350b6 100644 --- a/internal/api/resolver_mutation_movie.go +++ b/internal/api/resolver_mutation_movie.go @@ -15,8 +15,8 @@ import ( ) func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Movie().Find(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Movie.Find(ctx, id) return err }); err != nil { return nil, err @@ -100,16 +100,16 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp // Start the transaction and save the movie var movie *models.Movie - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Movie() - movie, err = qb.Create(newMovie) + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Movie + movie, err = qb.Create(ctx, newMovie) if err != nil { return err } // update image table if len(frontimageData) > 0 { - if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil { + if err := qb.UpdateImages(ctx, movie.ID, frontimageData, backimageData); err != nil { return err } } @@ -174,9 +174,9 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp // Start the transaction and save the movie var movie *models.Movie - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Movie() - movie, err = qb.Update(updatedMovie) + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Movie + movie, err = qb.Update(ctx, updatedMovie) if err != nil { return err } @@ -184,13 +184,13 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp // update image table if frontImageIncluded || backImageIncluded { if !frontImageIncluded { - frontimageData, err = qb.GetFrontImage(updatedMovie.ID) + frontimageData, err = qb.GetFrontImage(ctx, updatedMovie.ID) if err != nil { return err } } if !backImageIncluded { - backimageData, err = qb.GetBackImage(updatedMovie.ID) + backimageData, err = qb.GetBackImage(ctx, updatedMovie.ID) if err != nil { return err } @@ -198,7 +198,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp if len(frontimageData) == 0 && len(backimageData) == 0 { // both images are being nulled. Destroy them. - if err := qb.DestroyImages(movie.ID); err != nil { + if err := qb.DestroyImages(ctx, movie.ID); err != nil { return err } } else { @@ -208,7 +208,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp frontimageData, _ = utils.ProcessImageInput(ctx, models.DefaultMovieImage) } - if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil { + if err := qb.UpdateImages(ctx, movie.ID, frontimageData, backimageData); err != nil { return err } } @@ -245,13 +245,13 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU ret := []*models.Movie{} - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Movie() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Movie for _, movieID := range movieIDs { updatedMovie.ID = movieID - existing, err := qb.Find(movieID) + existing, err := qb.Find(ctx, movieID) if err != nil { return err } @@ -260,7 +260,7 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU return fmt.Errorf("movie with id %d not found", movieID) } - movie, err := qb.Update(updatedMovie) + movie, err := qb.Update(ctx, updatedMovie) if err != nil { return err } @@ -294,8 +294,8 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input MovieDestroyI return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - return repo.Movie().Destroy(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.repository.Movie.Destroy(ctx, id) }); err != nil { return false, err } @@ -311,10 +311,10 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string) return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Movie() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Movie for _, id := range ids { - if err := qb.Destroy(id); err != nil { + if err := qb.Destroy(ctx, id); err != nil { return err } } diff --git a/internal/api/resolver_mutation_performer.go b/internal/api/resolver_mutation_performer.go index 02c942484..a5fd19dea 100644 --- a/internal/api/resolver_mutation_performer.go +++ b/internal/api/resolver_mutation_performer.go @@ -16,8 +16,8 @@ import ( ) func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Performer().Find(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Performer.Find(ctx, id) return err }); err != nil { return nil, err @@ -129,23 +129,23 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC // Start the transaction and save the performer var performer *models.Performer - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Performer() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Performer - performer, err = qb.Create(newPerformer) + performer, err = qb.Create(ctx, newPerformer) if err != nil { return err } if len(input.TagIds) > 0 { - if err := r.updatePerformerTags(qb, performer.ID, input.TagIds); err != nil { + if err := r.updatePerformerTags(ctx, performer.ID, input.TagIds); err != nil { return err } } // update image table if len(imageData) > 0 { - if err := qb.UpdateImage(performer.ID, imageData); err != nil { + if err := qb.UpdateImage(ctx, performer.ID, imageData); err != nil { return err } } @@ -153,7 +153,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC // Save the stash_ids if input.StashIds != nil { stashIDJoins := models.StashIDsFromInput(input.StashIds) - if err := qb.UpdateStashIDs(performer.ID, stashIDJoins); err != nil { + if err := qb.UpdateStashIDs(ctx, performer.ID, stashIDJoins); err != nil { return err } } @@ -230,11 +230,11 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU // Start the transaction and save the p var p *models.Performer - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Performer() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Performer // need to get existing performer - existing, err := qb.Find(updatedPerformer.ID) + existing, err := qb.Find(ctx, updatedPerformer.ID) if err != nil { return err } @@ -249,26 +249,26 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU } } - p, err = qb.Update(updatedPerformer) + p, err = qb.Update(ctx, updatedPerformer) if err != nil { return err } // Save the tags if translator.hasField("tag_ids") { - if err := r.updatePerformerTags(qb, p.ID, input.TagIds); err != nil { + if err := r.updatePerformerTags(ctx, p.ID, input.TagIds); err != nil { return err } } // update image table if len(imageData) > 0 { - if err := qb.UpdateImage(p.ID, imageData); err != nil { + if err := qb.UpdateImage(ctx, p.ID, imageData); err != nil { return err } } else if imageIncluded { // must be unsetting - if err := qb.DestroyImage(p.ID); err != nil { + if err := qb.DestroyImage(ctx, p.ID); err != nil { return err } } @@ -276,7 +276,7 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU // Save the stash_ids if translator.hasField("stash_ids") { stashIDJoins := models.StashIDsFromInput(input.StashIds) - if err := qb.UpdateStashIDs(performerID, stashIDJoins); err != nil { + if err := qb.UpdateStashIDs(ctx, performerID, stashIDJoins); err != nil { return err } } @@ -290,12 +290,12 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU return r.getPerformer(ctx, p.ID) } -func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error { +func (r *mutationResolver) updatePerformerTags(ctx context.Context, performerID int, tagsIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(tagsIDs) if err != nil { return err } - return qb.UpdateTags(performerID, ids) + return r.repository.Performer.UpdateTags(ctx, performerID, ids) } func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPerformerUpdateInput) ([]*models.Performer, error) { @@ -348,14 +348,14 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe ret := []*models.Performer{} // Start the transaction and save the scene marker - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Performer() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Performer for _, performerID := range performerIDs { updatedPerformer.ID = performerID // need to get existing performer - existing, err := qb.Find(performerID) + existing, err := qb.Find(ctx, performerID) if err != nil { return err } @@ -368,7 +368,7 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe return err } - performer, err := qb.Update(updatedPerformer) + performer, err := qb.Update(ctx, updatedPerformer) if err != nil { return err } @@ -377,12 +377,12 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe // Save the tags if translator.hasField("tag_ids") { - tagIDs, err := adjustTagIDs(qb, performerID, *input.TagIds) + tagIDs, err := adjustTagIDs(ctx, qb, performerID, *input.TagIds) if err != nil { return err } - if err := qb.UpdateTags(performerID, tagIDs); err != nil { + if err := qb.UpdateTags(ctx, performerID, tagIDs); err != nil { return err } } @@ -415,8 +415,8 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input Performer return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - return repo.Performer().Destroy(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.repository.Performer.Destroy(ctx, id) }); err != nil { return false, err } @@ -432,10 +432,10 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [ return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Performer() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Performer for _, id := range ids { - if err := qb.Destroy(id); err != nil { + if err := qb.Destroy(ctx, id); err != nil { return err } } diff --git a/internal/api/resolver_mutation_saved_filter.go b/internal/api/resolver_mutation_saved_filter.go index bf1b4106e..a995060ea 100644 --- a/internal/api/resolver_mutation_saved_filter.go +++ b/internal/api/resolver_mutation_saved_filter.go @@ -23,17 +23,17 @@ func (r *mutationResolver) SaveFilter(ctx context.Context, input SaveFilterInput id = &idv } - if err := r.withTxn(ctx, func(repo models.Repository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { f := models.SavedFilter{ Mode: input.Mode, Name: input.Name, Filter: input.Filter, } if id == nil { - ret, err = repo.SavedFilter().Create(f) + ret, err = r.repository.SavedFilter.Create(ctx, f) } else { f.ID = *id - ret, err = repo.SavedFilter().Update(f) + ret, err = r.repository.SavedFilter.Update(ctx, f) } return err }); err != nil { @@ -48,8 +48,8 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input Destroy return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - return repo.SavedFilter().Destroy(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.repository.SavedFilter.Destroy(ctx, id) }); err != nil { return false, err } @@ -58,24 +58,24 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input Destroy } func (r *mutationResolver) SetDefaultFilter(ctx context.Context, input SetDefaultFilterInput) (bool, error) { - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.SavedFilter() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.SavedFilter if input.Filter == nil { // clearing - def, err := qb.FindDefault(input.Mode) + def, err := qb.FindDefault(ctx, input.Mode) if err != nil { return err } if def != nil { - return qb.Destroy(def.ID) + return qb.Destroy(ctx, def.ID) } return nil } - _, err := qb.SetDefault(models.SavedFilter{ + _, err := qb.SetDefault(ctx, models.SavedFilter{ Mode: input.Mode, Filter: *input.Filter, }) diff --git a/internal/api/resolver_mutation_scene.go b/internal/api/resolver_mutation_scene.go index fdaf1d64d..1a9901062 100644 --- a/internal/api/resolver_mutation_scene.go +++ b/internal/api/resolver_mutation_scene.go @@ -19,8 +19,8 @@ import ( ) func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Scene().Find(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Scene.Find(ctx, id) return err }); err != nil { return nil, err @@ -35,8 +35,8 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp } // Start the transaction and save the scene - if err := r.withTxn(ctx, func(repo models.Repository) error { - ret, err = r.sceneUpdate(ctx, input, translator, repo) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.sceneUpdate(ctx, input, translator) return err }); err != nil { return nil, err @@ -50,13 +50,13 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce inputMaps := getUpdateInputMaps(ctx) // Start the transaction and save the scene - if err := r.withTxn(ctx, func(repo models.Repository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { for i, scene := range input { translator := changesetTranslator{ inputMap: inputMaps[i], } - thisScene, err := r.sceneUpdate(ctx, *scene, translator, repo) + thisScene, err := r.sceneUpdate(ctx, *scene, translator) ret = append(ret, thisScene) if err != nil { @@ -89,7 +89,7 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce return newRet, nil } -func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) { +func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator) (*models.Scene, error) { // Populate scene from the input sceneID, err := strconv.Atoi(input.ID) if err != nil { @@ -122,43 +122,43 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp // update the cover after updating the scene } - qb := repo.Scene() - s, err := qb.Update(updatedScene) + qb := r.repository.Scene + s, err := qb.Update(ctx, updatedScene) if err != nil { return nil, err } // update cover table if len(coverImageData) > 0 { - if err := qb.UpdateCover(sceneID, coverImageData); err != nil { + if err := qb.UpdateCover(ctx, sceneID, coverImageData); err != nil { return nil, err } } // Save the performers if translator.hasField("performer_ids") { - if err := r.updateScenePerformers(qb, sceneID, input.PerformerIds); err != nil { + if err := r.updateScenePerformers(ctx, sceneID, input.PerformerIds); err != nil { return nil, err } } // Save the movies if translator.hasField("movies") { - if err := r.updateSceneMovies(qb, sceneID, input.Movies); err != nil { + if err := r.updateSceneMovies(ctx, sceneID, input.Movies); err != nil { return nil, err } } // Save the tags if translator.hasField("tag_ids") { - if err := r.updateSceneTags(qb, sceneID, input.TagIds); err != nil { + if err := r.updateSceneTags(ctx, sceneID, input.TagIds); err != nil { return nil, err } } // Save the galleries if translator.hasField("gallery_ids") { - if err := r.updateSceneGalleries(qb, sceneID, input.GalleryIds); err != nil { + if err := r.updateSceneGalleries(ctx, sceneID, input.GalleryIds); err != nil { return nil, err } } @@ -166,7 +166,7 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp // Save the stash_ids if translator.hasField("stash_ids") { stashIDJoins := models.StashIDsFromInput(input.StashIds) - if err := qb.UpdateStashIDs(sceneID, stashIDJoins); err != nil { + if err := qb.UpdateStashIDs(ctx, sceneID, stashIDJoins); err != nil { return nil, err } } @@ -182,15 +182,15 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp return s, nil } -func (r *mutationResolver) updateScenePerformers(qb models.SceneReaderWriter, sceneID int, performerIDs []string) error { +func (r *mutationResolver) updateScenePerformers(ctx context.Context, sceneID int, performerIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(performerIDs) if err != nil { return err } - return qb.UpdatePerformers(sceneID, ids) + return r.repository.Scene.UpdatePerformers(ctx, sceneID, ids) } -func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneID int, movies []*models.SceneMovieInput) error { +func (r *mutationResolver) updateSceneMovies(ctx context.Context, sceneID int, movies []*models.SceneMovieInput) error { var movieJoins []models.MoviesScenes for _, movie := range movies { @@ -213,23 +213,23 @@ func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneI movieJoins = append(movieJoins, movieJoin) } - return qb.UpdateMovies(sceneID, movieJoins) + return r.repository.Scene.UpdateMovies(ctx, sceneID, movieJoins) } -func (r *mutationResolver) updateSceneTags(qb models.SceneReaderWriter, sceneID int, tagsIDs []string) error { +func (r *mutationResolver) updateSceneTags(ctx context.Context, sceneID int, tagsIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(tagsIDs) if err != nil { return err } - return qb.UpdateTags(sceneID, ids) + return r.repository.Scene.UpdateTags(ctx, sceneID, ids) } -func (r *mutationResolver) updateSceneGalleries(qb models.SceneReaderWriter, sceneID int, galleryIDs []string) error { +func (r *mutationResolver) updateSceneGalleries(ctx context.Context, sceneID int, galleryIDs []string) error { ids, err := stringslice.StringSliceToIntSlice(galleryIDs) if err != nil { return err } - return qb.UpdateGalleries(sceneID, ids) + return r.repository.Scene.UpdateGalleries(ctx, sceneID, ids) } func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneUpdateInput) ([]*models.Scene, error) { @@ -260,13 +260,13 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU ret := []*models.Scene{} // Start the transaction and save the scene marker - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene for _, sceneID := range sceneIDs { updatedScene.ID = sceneID - scene, err := qb.Update(updatedScene) + scene, err := qb.Update(ctx, updatedScene) if err != nil { return err } @@ -275,48 +275,48 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU // Save the performers if translator.hasField("performer_ids") { - performerIDs, err := adjustScenePerformerIDs(qb, sceneID, *input.PerformerIds) + performerIDs, err := r.adjustScenePerformerIDs(ctx, sceneID, *input.PerformerIds) if err != nil { return err } - if err := qb.UpdatePerformers(sceneID, performerIDs); err != nil { + if err := qb.UpdatePerformers(ctx, sceneID, performerIDs); err != nil { return err } } // Save the tags if translator.hasField("tag_ids") { - tagIDs, err := adjustTagIDs(qb, sceneID, *input.TagIds) + tagIDs, err := adjustTagIDs(ctx, qb, sceneID, *input.TagIds) if err != nil { return err } - if err := qb.UpdateTags(sceneID, tagIDs); err != nil { + if err := qb.UpdateTags(ctx, sceneID, tagIDs); err != nil { return err } } // Save the galleries if translator.hasField("gallery_ids") { - galleryIDs, err := adjustSceneGalleryIDs(qb, sceneID, *input.GalleryIds) + galleryIDs, err := r.adjustSceneGalleryIDs(ctx, sceneID, *input.GalleryIds) if err != nil { return err } - if err := qb.UpdateGalleries(sceneID, galleryIDs); err != nil { + if err := qb.UpdateGalleries(ctx, sceneID, galleryIDs); err != nil { return err } } // Save the movies if translator.hasField("movie_ids") { - movies, err := adjustSceneMovieIDs(qb, sceneID, *input.MovieIds) + movies, err := r.adjustSceneMovieIDs(ctx, sceneID, *input.MovieIds) if err != nil { return err } - if err := qb.UpdateMovies(sceneID, movies); err != nil { + if err := qb.UpdateMovies(ctx, sceneID, movies); err != nil { return err } } @@ -380,8 +380,8 @@ func adjustIDs(existingIDs []int, updateIDs BulkUpdateIds) []int { return existingIDs } -func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetPerformerIDs(sceneID) +func (r *mutationResolver) adjustScenePerformerIDs(ctx context.Context, sceneID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = r.repository.Scene.GetPerformerIDs(ctx, sceneID) if err != nil { return nil, err } @@ -390,11 +390,11 @@ func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateI } type tagIDsGetter interface { - GetTagIDs(id int) ([]int, error) + GetTagIDs(ctx context.Context, id int) ([]int, error) } -func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetTagIDs(sceneID) +func adjustTagIDs(ctx context.Context, qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetTagIDs(ctx, sceneID) if err != nil { return nil, err } @@ -402,8 +402,8 @@ func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, e return adjustIDs(ret, ids), nil } -func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) { - ret, err = qb.GetGalleryIDs(sceneID) +func (r *mutationResolver) adjustSceneGalleryIDs(ctx context.Context, sceneID int, ids BulkUpdateIds) (ret []int, err error) { + ret, err = r.repository.Scene.GetGalleryIDs(ctx, sceneID) if err != nil { return nil, err } @@ -411,8 +411,8 @@ func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds return adjustIDs(ret, ids), nil } -func adjustSceneMovieIDs(qb models.SceneReader, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) { - existingMovies, err := qb.GetMovies(sceneID) +func (r *mutationResolver) adjustSceneMovieIDs(ctx context.Context, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) { + existingMovies, err := r.repository.Scene.GetMovies(ctx, sceneID) if err != nil { return nil, err } @@ -471,10 +471,10 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD deleteGenerated := utils.IsTrue(input.DeleteGenerated) deleteFile := utils.IsTrue(input.DeleteFile) - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene var err error - s, err = qb.Find(sceneID) + s, err = qb.Find(ctx, sceneID) if err != nil { return err } @@ -486,7 +486,7 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD // kill any running encoders manager.KillRunningStreams(s, fileNamingAlgo) - return scene.Destroy(s, repo, fileDeleter, deleteGenerated, deleteFile) + return scene.Destroy(ctx, s, r.repository.Scene, r.repository.SceneMarker, fileDeleter, deleteGenerated, deleteFile) }); err != nil { fileDeleter.Rollback() return false, err @@ -519,13 +519,13 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene deleteGenerated := utils.IsTrue(input.DeleteGenerated) deleteFile := utils.IsTrue(input.DeleteFile) - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene for _, id := range input.Ids { sceneID, _ := strconv.Atoi(id) - s, err := qb.Find(sceneID) + s, err := qb.Find(ctx, sceneID) if err != nil { return err } @@ -536,7 +536,7 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene // kill any running encoders manager.KillRunningStreams(s, fileNamingAlgo) - if err := scene.Destroy(s, repo, fileDeleter, deleteGenerated, deleteFile); err != nil { + if err := scene.Destroy(ctx, s, r.repository.Scene, r.repository.SceneMarker, fileDeleter, deleteGenerated, deleteFile); err != nil { return err } } @@ -564,8 +564,8 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene } func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.SceneMarker().Find(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.SceneMarker.Find(ctx, id) return err }); err != nil { return nil, err @@ -666,11 +666,11 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b Paths: manager.GetInstance().Paths, } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.SceneMarker() - sqb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.SceneMarker + sqb := r.repository.Scene - marker, err := qb.Find(markerID) + marker, err := qb.Find(ctx, markerID) if err != nil { return err @@ -680,12 +680,12 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b return fmt.Errorf("scene marker with id %d not found", markerID) } - s, err := sqb.Find(int(marker.SceneID.Int64)) + s, err := sqb.Find(ctx, int(marker.SceneID.Int64)) if err != nil { return err } - return scene.DestroyMarker(s, marker, qb, fileDeleter) + return scene.DestroyMarker(ctx, s, marker, qb, fileDeleter) }); err != nil { fileDeleter.Rollback() return false, err @@ -713,26 +713,26 @@ func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, cha } // Start the transaction and save the scene marker - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.SceneMarker() - sqb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.SceneMarker + sqb := r.repository.Scene var err error switch changeType { case create: - sceneMarker, err = qb.Create(changedMarker) + sceneMarker, err = qb.Create(ctx, changedMarker) case update: // check to see if timestamp was changed - existingMarker, err = qb.Find(changedMarker.ID) + existingMarker, err = qb.Find(ctx, changedMarker.ID) if err != nil { return err } - sceneMarker, err = qb.Update(changedMarker) + sceneMarker, err = qb.Update(ctx, changedMarker) if err != nil { return err } - s, err = sqb.Find(int(existingMarker.SceneID.Int64)) + s, err = sqb.Find(ctx, int(existingMarker.SceneID.Int64)) } if err != nil { return err @@ -749,7 +749,7 @@ func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, cha // Save the marker tags // If this tag is the primary tag, then let's not add it. tagIDs = intslice.IntExclude(tagIDs, []int{changedMarker.PrimaryTagID}) - return qb.UpdateTags(sceneMarker.ID, tagIDs) + return qb.UpdateTags(ctx, sceneMarker.ID, tagIDs) }); err != nil { fileDeleter.Rollback() return nil, err @@ -766,10 +766,10 @@ func (r *mutationResolver) SceneIncrementO(ctx context.Context, id string) (ret return 0, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene - ret, err = qb.IncrementOCounter(sceneID) + ret, err = qb.IncrementOCounter(ctx, sceneID) return err }); err != nil { return 0, err @@ -784,10 +784,10 @@ func (r *mutationResolver) SceneDecrementO(ctx context.Context, id string) (ret return 0, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene - ret, err = qb.DecrementOCounter(sceneID) + ret, err = qb.DecrementOCounter(ctx, sceneID) return err }); err != nil { return 0, err @@ -802,10 +802,10 @@ func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int, return 0, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene - ret, err = qb.ResetOCounter(sceneID) + ret, err = qb.ResetOCounter(ctx, sceneID) return err }); err != nil { return 0, err diff --git a/internal/api/resolver_mutation_stash_box.go b/internal/api/resolver_mutation_stash_box.go index 3d9163317..95e300f9d 100644 --- a/internal/api/resolver_mutation_stash_box.go +++ b/internal/api/resolver_mutation_stash_box.go @@ -7,10 +7,18 @@ import ( "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/internal/manager/config" - "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scraper/stashbox" ) +func (r *Resolver) stashboxRepository() stashbox.Repository { + return stashbox.Repository{ + Scene: r.repository.Scene, + Performer: r.repository.Performer, + Tag: r.repository.Tag, + Studio: r.repository.Studio, + } +} + func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) { boxes := config.GetInstance().GetStashBoxes() @@ -18,7 +26,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint) } @@ -35,7 +43,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) id, err := strconv.Atoi(input.ID) if err != nil { @@ -43,9 +51,9 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S } var res *string - err = r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - qb := repo.Scene() - scene, err := qb.Find(id) + err = r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene + scene, err := qb.Find(ctx, id) if err != nil { return err } @@ -65,7 +73,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) id, err := strconv.Atoi(input.ID) if err != nil { @@ -73,9 +81,9 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp } var res *string - err = r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - qb := repo.Performer() - performer, err := qb.Find(id) + err = r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Performer + performer, err := qb.Find(ctx, id) if err != nil { return err } diff --git a/internal/api/resolver_mutation_studio.go b/internal/api/resolver_mutation_studio.go index 8fcfd3b53..fde747e3e 100644 --- a/internal/api/resolver_mutation_studio.go +++ b/internal/api/resolver_mutation_studio.go @@ -17,8 +17,8 @@ import ( ) func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().Find(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.Find(ctx, id) return err }); err != nil { return nil, err @@ -72,18 +72,18 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI // Start the transaction and save the studio var s *models.Studio - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Studio() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Studio var err error - s, err = qb.Create(newStudio) + s, err = qb.Create(ctx, newStudio) if err != nil { return err } // update image table if len(imageData) > 0 { - if err := qb.UpdateImage(s.ID, imageData); err != nil { + if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil { return err } } @@ -91,17 +91,17 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI // Save the stash_ids if input.StashIds != nil { stashIDJoins := models.StashIDsFromInput(input.StashIds) - if err := qb.UpdateStashIDs(s.ID, stashIDJoins); err != nil { + if err := qb.UpdateStashIDs(ctx, s.ID, stashIDJoins); err != nil { return err } } if len(input.Aliases) > 0 { - if err := studio.EnsureAliasesUnique(s.ID, input.Aliases, qb); err != nil { + if err := studio.EnsureAliasesUnique(ctx, s.ID, input.Aliases, qb); err != nil { return err } - if err := qb.UpdateAliases(s.ID, input.Aliases); err != nil { + if err := qb.UpdateAliases(ctx, s.ID, input.Aliases); err != nil { return err } } @@ -155,27 +155,27 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateI // Start the transaction and save the studio var s *models.Studio - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Studio() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Studio - if err := manager.ValidateModifyStudio(updatedStudio, qb); err != nil { + if err := manager.ValidateModifyStudio(ctx, updatedStudio, qb); err != nil { return err } var err error - s, err = qb.Update(updatedStudio) + s, err = qb.Update(ctx, updatedStudio) if err != nil { return err } // update image table if len(imageData) > 0 { - if err := qb.UpdateImage(s.ID, imageData); err != nil { + if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil { return err } } else if imageIncluded { // must be unsetting - if err := qb.DestroyImage(s.ID); err != nil { + if err := qb.DestroyImage(ctx, s.ID); err != nil { return err } } @@ -183,17 +183,17 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateI // Save the stash_ids if translator.hasField("stash_ids") { stashIDJoins := models.StashIDsFromInput(input.StashIds) - if err := qb.UpdateStashIDs(studioID, stashIDJoins); err != nil { + if err := qb.UpdateStashIDs(ctx, studioID, stashIDJoins); err != nil { return err } } if translator.hasField("aliases") { - if err := studio.EnsureAliasesUnique(studioID, input.Aliases, qb); err != nil { + if err := studio.EnsureAliasesUnique(ctx, studioID, input.Aliases, qb); err != nil { return err } - if err := qb.UpdateAliases(studioID, input.Aliases); err != nil { + if err := qb.UpdateAliases(ctx, studioID, input.Aliases); err != nil { return err } } @@ -213,8 +213,8 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestro return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - return repo.Studio().Destroy(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.repository.Studio.Destroy(ctx, id) }); err != nil { return false, err } @@ -230,10 +230,10 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Studio() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Studio for _, id := range ids { - if err := qb.Destroy(id); err != nil { + if err := qb.Destroy(ctx, id); err != nil { return err } } diff --git a/internal/api/resolver_mutation_tag.go b/internal/api/resolver_mutation_tag.go index cff553d72..f5befeba7 100644 --- a/internal/api/resolver_mutation_tag.go +++ b/internal/api/resolver_mutation_tag.go @@ -15,8 +15,8 @@ import ( ) func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().Find(id) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.Find(ctx, id) return err }); err != nil { return nil, err @@ -68,44 +68,44 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) // Start the transaction and save the tag var t *models.Tag - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Tag() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Tag // ensure name is unique - if err := tag.EnsureTagNameUnique(0, newTag.Name, qb); err != nil { + if err := tag.EnsureTagNameUnique(ctx, 0, newTag.Name, qb); err != nil { return err } - t, err = qb.Create(newTag) + t, err = qb.Create(ctx, newTag) if err != nil { return err } // update image table if len(imageData) > 0 { - if err := qb.UpdateImage(t.ID, imageData); err != nil { + if err := qb.UpdateImage(ctx, t.ID, imageData); err != nil { return err } } if len(input.Aliases) > 0 { - if err := tag.EnsureAliasesUnique(t.ID, input.Aliases, qb); err != nil { + if err := tag.EnsureAliasesUnique(ctx, t.ID, input.Aliases, qb); err != nil { return err } - if err := qb.UpdateAliases(t.ID, input.Aliases); err != nil { + if err := qb.UpdateAliases(ctx, t.ID, input.Aliases); err != nil { return err } } if len(parentIDs) > 0 { - if err := qb.UpdateParentTags(t.ID, parentIDs); err != nil { + if err := qb.UpdateParentTags(ctx, t.ID, parentIDs); err != nil { return err } } if len(childIDs) > 0 { - if err := qb.UpdateChildTags(t.ID, childIDs); err != nil { + if err := qb.UpdateChildTags(ctx, t.ID, childIDs); err != nil { return err } } @@ -113,7 +113,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) // FIXME: This should be called before any changes are made, but // requires a rewrite of ValidateHierarchy. if len(parentIDs) > 0 || len(childIDs) > 0 { - if err := tag.ValidateHierarchy(t, parentIDs, childIDs, qb); err != nil { + if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil { return err } } @@ -168,11 +168,11 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) // Start the transaction and save the tag var t *models.Tag - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Tag() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Tag // ensure name is unique - t, err = qb.Find(tagID) + t, err = qb.Find(ctx, tagID) if err != nil { return err } @@ -188,48 +188,48 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) } if input.Name != nil && t.Name != *input.Name { - if err := tag.EnsureTagNameUnique(tagID, *input.Name, qb); err != nil { + if err := tag.EnsureTagNameUnique(ctx, tagID, *input.Name, qb); err != nil { return err } updatedTag.Name = input.Name } - t, err = qb.Update(updatedTag) + t, err = qb.Update(ctx, updatedTag) if err != nil { return err } // update image table if len(imageData) > 0 { - if err := qb.UpdateImage(tagID, imageData); err != nil { + if err := qb.UpdateImage(ctx, tagID, imageData); err != nil { return err } } else if imageIncluded { // must be unsetting - if err := qb.DestroyImage(tagID); err != nil { + if err := qb.DestroyImage(ctx, tagID); err != nil { return err } } if translator.hasField("aliases") { - if err := tag.EnsureAliasesUnique(tagID, input.Aliases, qb); err != nil { + if err := tag.EnsureAliasesUnique(ctx, tagID, input.Aliases, qb); err != nil { return err } - if err := qb.UpdateAliases(tagID, input.Aliases); err != nil { + if err := qb.UpdateAliases(ctx, tagID, input.Aliases); err != nil { return err } } if parentIDs != nil { - if err := qb.UpdateParentTags(tagID, parentIDs); err != nil { + if err := qb.UpdateParentTags(ctx, tagID, parentIDs); err != nil { return err } } if childIDs != nil { - if err := qb.UpdateChildTags(tagID, childIDs); err != nil { + if err := qb.UpdateChildTags(ctx, tagID, childIDs); err != nil { return err } } @@ -237,7 +237,7 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) // FIXME: This should be called before any changes are made, but // requires a rewrite of ValidateHierarchy. if parentIDs != nil || childIDs != nil { - if err := tag.ValidateHierarchy(t, parentIDs, childIDs, qb); err != nil { + if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil { logger.Errorf("Error saving tag: %s", err) return err } @@ -258,8 +258,8 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input TagDestroyInput return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - return repo.Tag().Destroy(tagID) + if err := r.withTxn(ctx, func(ctx context.Context) error { + return r.repository.Tag.Destroy(ctx, tagID) }); err != nil { return false, err } @@ -275,10 +275,10 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo return false, err } - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Tag() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Tag for _, id := range ids { - if err := qb.Destroy(id); err != nil { + if err := qb.Destroy(ctx, id); err != nil { return err } } @@ -311,11 +311,11 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput) } var t *models.Tag - if err := r.withTxn(ctx, func(repo models.Repository) error { - qb := repo.Tag() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Tag var err error - t, err = qb.Find(destination) + t, err = qb.Find(ctx, destination) if err != nil { return err } @@ -324,25 +324,25 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput) return fmt.Errorf("Tag with ID %d not found", destination) } - parents, children, err := tag.MergeHierarchy(destination, source, qb) + parents, children, err := tag.MergeHierarchy(ctx, destination, source, qb) if err != nil { return err } - if err = qb.Merge(source, destination); err != nil { + if err = qb.Merge(ctx, source, destination); err != nil { return err } - err = qb.UpdateParentTags(destination, parents) + err = qb.UpdateParentTags(ctx, destination, parents) if err != nil { return err } - err = qb.UpdateChildTags(destination, children) + err = qb.UpdateChildTags(ctx, destination, children) if err != nil { return err } - err = tag.ValidateHierarchy(t, parents, children, qb) + err = tag.ValidateHierarchy(ctx, t, parents, children, qb) if err != nil { logger.Errorf("Error merging tag: %s", err) return err diff --git a/internal/api/resolver_mutation_tag_test.go b/internal/api/resolver_mutation_tag_test.go index c8d43f5f8..91b87794d 100644 --- a/internal/api/resolver_mutation_tag_test.go +++ b/internal/api/resolver_mutation_tag_test.go @@ -16,17 +16,23 @@ import ( // TODO - move this into a common area func newResolver() *Resolver { return &Resolver{ - txnManager: mocks.NewTransactionManager(), + txnManager: &mocks.TxnManager{}, + repository: mocks.NewTxnRepository(), hookExecutor: &mockHookExecutor{}, } } -const tagName = "tagName" -const errTagName = "errTagName" +const ( + tagName = "tagName" + errTagName = "errTagName" -const existingTagID = 1 -const existingTagName = "existingTagName" -const newTagID = 2 + existingTagID = 1 + existingTagName = "existingTagName" + + newTagID = 2 +) + +var testCtx = context.Background() type mockHookExecutor struct{} @@ -36,7 +42,7 @@ func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType func TestTagCreate(t *testing.T) { r := newResolver() - tagRW := r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter) + tagRW := r.repository.Tag.(*mocks.TagReaderWriter) pp := 1 findFilter := &models.FindFilterType{ @@ -61,25 +67,25 @@ func TestTagCreate(t *testing.T) { } } - tagRW.On("Query", tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{ + tagRW.On("Query", testCtx, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, 1, nil).Once() - tagRW.On("Query", tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once() - tagRW.On("Query", tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once() + tagRW.On("Query", testCtx, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once() + tagRW.On("Query", testCtx, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once() expectedErr := errors.New("TagCreate error") - tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, expectedErr) + tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr) - _, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{ + _, err := r.Mutation().TagCreate(testCtx, TagCreateInput{ Name: existingTagName, }) assert.NotNil(t, err) - _, err = r.Mutation().TagCreate(context.TODO(), TagCreateInput{ + _, err = r.Mutation().TagCreate(testCtx, TagCreateInput{ Name: errTagName, }) @@ -87,18 +93,18 @@ func TestTagCreate(t *testing.T) { tagRW.AssertExpectations(t) r = newResolver() - tagRW = r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter) + tagRW = r.repository.Tag.(*mocks.TagReaderWriter) - tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() - tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() + tagRW.On("Query", testCtx, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() + tagRW.On("Query", testCtx, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() newTag := &models.Tag{ ID: newTagID, Name: tagName, } - tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil) - tagRW.On("Find", newTagID).Return(newTag, nil) + tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(newTag, nil) + tagRW.On("Find", testCtx, newTagID).Return(newTag, nil) - tag, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{ + tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{ Name: tagName, }) diff --git a/internal/api/resolver_query_configuration.go b/internal/api/resolver_query_configuration.go index 39c463260..f3469de97 100644 --- a/internal/api/resolver_query_configuration.go +++ b/internal/api/resolver_query_configuration.go @@ -222,7 +222,7 @@ func makeConfigUIResult() map[string]interface{} { } func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) { - client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager) + client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager, r.stashboxRepository()) user, err := client.GetUser(ctx) valid := user != nil && user.Me != nil diff --git a/internal/api/resolver_query_find_gallery.go b/internal/api/resolver_query_find_gallery.go index bd6d07c0f..ee12471d1 100644 --- a/internal/api/resolver_query_find_gallery.go +++ b/internal/api/resolver_query_find_gallery.go @@ -13,8 +13,8 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models return nil, err } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Gallery().Find(idInt) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Gallery.Find(ctx, idInt) return err }); err != nil { return nil, err @@ -24,8 +24,8 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models } func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *FindGalleriesResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - galleries, total, err := repo.Gallery().Query(galleryFilter, filter) + if err := r.withTxn(ctx, func(ctx context.Context) error { + galleries, total, err := r.repository.Gallery.Query(ctx, galleryFilter, filter) if err != nil { return err } diff --git a/internal/api/resolver_query_find_image.go b/internal/api/resolver_query_find_image.go index 0f01b42a1..f1269dce8 100644 --- a/internal/api/resolver_query_find_image.go +++ b/internal/api/resolver_query_find_image.go @@ -12,8 +12,8 @@ import ( func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *string) (*models.Image, error) { var image *models.Image - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - qb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Image var err error if id != nil { @@ -22,12 +22,12 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str return err } - image, err = qb.Find(idInt) + image, err = qb.Find(ctx, idInt) if err != nil { return err } } else if checksum != nil { - image, err = qb.FindByChecksum(*checksum) + image, err = qb.FindByChecksum(ctx, *checksum) } return err @@ -39,12 +39,12 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str } func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *FindImagesResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - qb := repo.Image() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Image fields := graphql.CollectAllFields(ctx) - result, err := qb.Query(models.ImageQueryOptions{ + result, err := qb.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: filter, Count: stringslice.StrInclude(fields, "count"), @@ -57,7 +57,7 @@ func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.Imag return err } - images, err := result.Resolve() + images, err := result.Resolve(ctx) if err != nil { return err } diff --git a/internal/api/resolver_query_find_movie.go b/internal/api/resolver_query_find_movie.go index 16a0bbe3b..7505c7f36 100644 --- a/internal/api/resolver_query_find_movie.go +++ b/internal/api/resolver_query_find_movie.go @@ -13,8 +13,8 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M return nil, err } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Movie().Find(idInt) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Movie.Find(ctx, idInt) return err }); err != nil { return nil, err @@ -24,8 +24,8 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M } func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *FindMoviesResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - movies, total, err := repo.Movie().Query(movieFilter, filter) + if err := r.withTxn(ctx, func(ctx context.Context) error { + movies, total, err := r.repository.Movie.Query(ctx, movieFilter, filter) if err != nil { return err } @@ -44,8 +44,8 @@ func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.Movi } func (r *queryResolver) AllMovies(ctx context.Context) (ret []*models.Movie, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Movie().All() + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Movie.All(ctx) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_query_find_performer.go b/internal/api/resolver_query_find_performer.go index 5d7bb2716..4314b0f69 100644 --- a/internal/api/resolver_query_find_performer.go +++ b/internal/api/resolver_query_find_performer.go @@ -13,8 +13,8 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode return nil, err } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Performer().Find(idInt) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Performer.Find(ctx, idInt) return err }); err != nil { return nil, err @@ -24,8 +24,8 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode } func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *FindPerformersResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - performers, total, err := repo.Performer().Query(performerFilter, filter) + if err := r.withTxn(ctx, func(ctx context.Context) error { + performers, total, err := r.repository.Performer.Query(ctx, performerFilter, filter) if err != nil { return err } @@ -43,8 +43,8 @@ func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *mod } func (r *queryResolver) AllPerformers(ctx context.Context) (ret []*models.Performer, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Performer().All() + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Performer.All(ctx) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_query_find_saved_filter.go b/internal/api/resolver_query_find_saved_filter.go index a28ef2f59..7b934f581 100644 --- a/internal/api/resolver_query_find_saved_filter.go +++ b/internal/api/resolver_query_find_saved_filter.go @@ -13,8 +13,8 @@ func (r *queryResolver) FindSavedFilter(ctx context.Context, id string) (ret *mo return nil, err } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.SavedFilter().Find(idInt) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.SavedFilter.Find(ctx, idInt) return err }); err != nil { return nil, err @@ -23,11 +23,11 @@ func (r *queryResolver) FindSavedFilter(ctx context.Context, id string) (ret *mo } func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.FilterMode) (ret []*models.SavedFilter, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { if mode != nil { - ret, err = repo.SavedFilter().FindByMode(*mode) + ret, err = r.repository.SavedFilter.FindByMode(ctx, *mode) } else { - ret, err = repo.SavedFilter().All() + ret, err = r.repository.SavedFilter.All(ctx) } return err }); err != nil { @@ -37,8 +37,8 @@ func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.Filte } func (r *queryResolver) FindDefaultFilter(ctx context.Context, mode models.FilterMode) (ret *models.SavedFilter, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.SavedFilter().FindDefault(mode) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.SavedFilter.FindDefault(ctx, mode) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_query_find_scene.go b/internal/api/resolver_query_find_scene.go index 486ca744a..823e86503 100644 --- a/internal/api/resolver_query_find_scene.go +++ b/internal/api/resolver_query_find_scene.go @@ -12,20 +12,20 @@ import ( func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) { var scene *models.Scene - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - qb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene var err error if id != nil { idInt, err := strconv.Atoi(*id) if err != nil { return err } - scene, err = qb.Find(idInt) + scene, err = qb.Find(ctx, idInt) if err != nil { return err } } else if checksum != nil { - scene, err = qb.FindByChecksum(*checksum) + scene, err = qb.FindByChecksum(ctx, *checksum) } return err @@ -39,18 +39,18 @@ func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *str func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInput) (*models.Scene, error) { var scene *models.Scene - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - qb := repo.Scene() + if err := r.withTxn(ctx, func(ctx context.Context) error { + qb := r.repository.Scene var err error if input.Checksum != nil { - scene, err = qb.FindByChecksum(*input.Checksum) + scene, err = qb.FindByChecksum(ctx, *input.Checksum) if err != nil { return err } } if scene == nil && input.Oshash != nil { - scene, err = qb.FindByOSHash(*input.Oshash) + scene, err = qb.FindByOSHash(ctx, *input.Oshash) if err != nil { return err } @@ -65,7 +65,7 @@ func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInpu } func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIDs []int, filter *models.FindFilterType) (ret *FindScenesResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var scenes []*models.Scene var err error @@ -73,7 +73,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen result := &models.SceneQueryResult{} if len(sceneIDs) > 0 { - scenes, err = repo.Scene().FindMany(sceneIDs) + scenes, err = r.repository.Scene.FindMany(ctx, sceneIDs) if err == nil { result.Count = len(scenes) for _, s := range scenes { @@ -83,7 +83,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen } } } else { - result, err = repo.Scene().Query(models.SceneQueryOptions{ + result, err = r.repository.Scene.Query(ctx, models.SceneQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: filter, Count: stringslice.StrInclude(fields, "count"), @@ -93,7 +93,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen TotalSize: stringslice.StrInclude(fields, "filesize"), }) if err == nil { - scenes, err = result.Resolve() + scenes, err = result.Resolve(ctx) } } @@ -117,7 +117,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen } func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *FindScenesResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { sceneFilter := &models.SceneFilterType{} @@ -138,7 +138,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model fields := graphql.CollectAllFields(ctx) - result, err := repo.Scene().Query(models.SceneQueryOptions{ + result, err := r.repository.Scene.Query(ctx, models.SceneQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: queryFilter, Count: stringslice.StrInclude(fields, "count"), @@ -151,7 +151,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model return err } - scenes, err := result.Resolve() + scenes, err := result.Resolve(ctx) if err != nil { return err } @@ -174,8 +174,14 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config manager.SceneParserInput) (ret *SceneParserResultType, err error) { parser := manager.NewSceneFilenameParser(filter, config) - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - result, count, err := parser.Parse(repo) + if err := r.withTxn(ctx, func(ctx context.Context) error { + result, count, err := parser.Parse(ctx, manager.SceneFilenameParserRepository{ + Scene: r.repository.Scene, + Performer: r.repository.Performer, + Studio: r.repository.Studio, + Movie: r.repository.Movie, + Tag: r.repository.Tag, + }) if err != nil { return err @@ -199,8 +205,8 @@ func (r *queryResolver) FindDuplicateScenes(ctx context.Context, distance *int) if distance != nil { dist = *distance } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Scene().FindDuplicates(dist) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Scene.FindDuplicates(ctx, dist) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_query_find_scene_marker.go b/internal/api/resolver_query_find_scene_marker.go index 7b0f1ee12..03b9e261a 100644 --- a/internal/api/resolver_query_find_scene_marker.go +++ b/internal/api/resolver_query_find_scene_marker.go @@ -7,8 +7,8 @@ import ( ) func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *FindSceneMarkersResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - sceneMarkers, total, err := repo.SceneMarker().Query(sceneMarkerFilter, filter) + if err := r.withTxn(ctx, func(ctx context.Context) error { + sceneMarkers, total, err := r.repository.SceneMarker.Query(ctx, sceneMarkerFilter, filter) if err != nil { return err } diff --git a/internal/api/resolver_query_find_studio.go b/internal/api/resolver_query_find_studio.go index 24591e053..0bd17b9ad 100644 --- a/internal/api/resolver_query_find_studio.go +++ b/internal/api/resolver_query_find_studio.go @@ -13,9 +13,9 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models. return nil, err } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - ret, err = repo.Studio().Find(idInt) + ret, err = r.repository.Studio.Find(ctx, idInt) return err }); err != nil { return nil, err @@ -25,8 +25,8 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models. } func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *FindStudiosResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - studios, total, err := repo.Studio().Query(studioFilter, filter) + if err := r.withTxn(ctx, func(ctx context.Context) error { + studios, total, err := r.repository.Studio.Query(ctx, studioFilter, filter) if err != nil { return err } @@ -45,8 +45,8 @@ func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.St } func (r *queryResolver) AllStudios(ctx context.Context) (ret []*models.Studio, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Studio().All() + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Studio.All(ctx) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_query_find_tag.go b/internal/api/resolver_query_find_tag.go index 21aff9f4b..77bd57f98 100644 --- a/internal/api/resolver_query_find_tag.go +++ b/internal/api/resolver_query_find_tag.go @@ -13,8 +13,8 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag return nil, err } - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().Find(idInt) + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.Find(ctx, idInt) return err }); err != nil { return nil, err @@ -24,8 +24,8 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag } func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *FindTagsResultType, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - tags, total, err := repo.Tag().Query(tagFilter, filter) + if err := r.withTxn(ctx, func(ctx context.Context) error { + tags, total, err := r.repository.Tag.Query(ctx, tagFilter, filter) if err != nil { return err } @@ -44,8 +44,8 @@ func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilte } func (r *queryResolver) AllTags(ctx context.Context) (ret []*models.Tag, err error) { - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { - ret, err = repo.Tag().All() + if err := r.withTxn(ctx, func(ctx context.Context) error { + ret, err = r.repository.Tag.All(ctx) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_query_scene.go b/internal/api/resolver_query_scene.go index 7ff77d2a8..c1ba0edca 100644 --- a/internal/api/resolver_query_scene.go +++ b/internal/api/resolver_query_scene.go @@ -14,10 +14,10 @@ import ( func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manager.SceneStreamEndpoint, error) { // find the scene var scene *models.Scene - if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + if err := r.withTxn(ctx, func(ctx context.Context) error { idInt, _ := strconv.Atoi(*id) var err error - scene, err = repo.Scene().Find(idInt) + scene, err = r.repository.Scene.Find(ctx, idInt) return err }); err != nil { return nil, err diff --git a/internal/api/resolver_query_scraper.go b/internal/api/resolver_query_scraper.go index 92c5200ce..85f47ee2c 100644 --- a/internal/api/resolver_query_scraper.go +++ b/internal/api/resolver_query_scraper.go @@ -234,7 +234,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) { return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index) } - return stashbox.NewClient(*boxes[index], r.txnManager), nil + return stashbox.NewClient(*boxes[index], r.txnManager, r.stashboxRepository()), nil } func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) { diff --git a/internal/api/routes_image.go b/internal/api/routes_image.go index 8ba2e50d5..d66ccf7cc 100644 --- a/internal/api/routes_image.go +++ b/internal/api/routes_image.go @@ -13,17 +13,24 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" ) +type ImageFinder interface { + Find(ctx context.Context, id int) (*models.Image, error) + FindByChecksum(ctx context.Context, checksum string) (*models.Image, error) +} + type imageRoutes struct { - txnManager models.TransactionManager + txnManager txn.Manager + imageFinder ImageFinder } func (rs imageRoutes) Routes() chi.Router { r := chi.NewRouter() r.Route("/{imageId}", func(r chi.Router) { - r.Use(ImageCtx) + r.Use(rs.ImageCtx) r.Get("/image", rs.Image) r.Get("/thumbnail", rs.Thumbnail) @@ -85,18 +92,18 @@ func (rs imageRoutes) Image(w http.ResponseWriter, r *http.Request) { // endregion -func ImageCtx(next http.Handler) http.Handler { +func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { imageIdentifierQueryParam := chi.URLParam(r, "imageId") imageID, _ := strconv.Atoi(imageIdentifierQueryParam) var image *models.Image - readTxnErr := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { - qb := repo.Image() + readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + qb := rs.imageFinder if imageID == 0 { - image, _ = qb.FindByChecksum(imageIdentifierQueryParam) + image, _ = qb.FindByChecksum(ctx, imageIdentifierQueryParam) } else { - image, _ = qb.Find(imageID) + image, _ = qb.Find(ctx, imageID) } return nil diff --git a/internal/api/routes_movie.go b/internal/api/routes_movie.go index 439b1e4d3..8fbccdb53 100644 --- a/internal/api/routes_movie.go +++ b/internal/api/routes_movie.go @@ -6,21 +6,28 @@ import ( "strconv" "github.com/go-chi/chi" - "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) +type MovieFinder interface { + GetFrontImage(ctx context.Context, movieID int) ([]byte, error) + GetBackImage(ctx context.Context, movieID int) ([]byte, error) + Find(ctx context.Context, id int) (*models.Movie, error) +} + type movieRoutes struct { - txnManager models.TransactionManager + txnManager txn.Manager + movieFinder MovieFinder } func (rs movieRoutes) Routes() chi.Router { r := chi.NewRouter() r.Route("/{movieId}", func(r chi.Router) { - r.Use(MovieCtx) + r.Use(rs.MovieCtx) r.Get("/frontimage", rs.FrontImage) r.Get("/backimage", rs.BackImage) }) @@ -33,8 +40,8 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) { defaultParam := r.URL.Query().Get("default") var image []byte if defaultParam != "true" { - err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { - image, _ = repo.Movie().GetFrontImage(movie.ID) + err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + image, _ = rs.movieFinder.GetFrontImage(ctx, movie.ID) return nil }) if err != nil { @@ -56,8 +63,8 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) { defaultParam := r.URL.Query().Get("default") var image []byte if defaultParam != "true" { - err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { - image, _ = repo.Movie().GetBackImage(movie.ID) + err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + image, _ = rs.movieFinder.GetBackImage(ctx, movie.ID) return nil }) if err != nil { @@ -74,7 +81,7 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) { } } -func MovieCtx(next http.Handler) http.Handler { +func (rs movieRoutes) MovieCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { movieID, err := strconv.Atoi(chi.URLParam(r, "movieId")) if err != nil { @@ -83,9 +90,9 @@ func MovieCtx(next http.Handler) http.Handler { } var movie *models.Movie - if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - movie, err = repo.Movie().Find(movieID) + movie, err = rs.movieFinder.Find(ctx, movieID) return err }); err != nil { http.Error(w, http.StatusText(404), 404) diff --git a/internal/api/routes_performer.go b/internal/api/routes_performer.go index e5c0bb862..15ad3c743 100644 --- a/internal/api/routes_performer.go +++ b/internal/api/routes_performer.go @@ -6,22 +6,28 @@ import ( "strconv" "github.com/go-chi/chi" - "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) +type PerformerFinder interface { + Find(ctx context.Context, id int) (*models.Performer, error) + GetImage(ctx context.Context, performerID int) ([]byte, error) +} + type performerRoutes struct { - txnManager models.TransactionManager + txnManager txn.Manager + performerFinder PerformerFinder } func (rs performerRoutes) Routes() chi.Router { r := chi.NewRouter() r.Route("/{performerId}", func(r chi.Router) { - r.Use(PerformerCtx) + r.Use(rs.PerformerCtx) r.Get("/image", rs.Image) }) @@ -34,8 +40,8 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) { var image []byte if defaultParam != "true" { - readTxnErr := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { - image, _ = repo.Performer().GetImage(performer.ID) + readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + image, _ = rs.performerFinder.GetImage(ctx, performer.ID) return nil }) if readTxnErr != nil { @@ -52,7 +58,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) { } } -func PerformerCtx(next http.Handler) http.Handler { +func (rs performerRoutes) PerformerCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { performerID, err := strconv.Atoi(chi.URLParam(r, "performerId")) if err != nil { @@ -61,9 +67,9 @@ func PerformerCtx(next http.Handler) http.Handler { } var performer *models.Performer - if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - performer, err = repo.Performer().Find(performerID) + performer, err = rs.performerFinder.Find(ctx, performerID) return err }); err != nil { http.Error(w, http.StatusText(404), 404) diff --git a/internal/api/routes_scene.go b/internal/api/routes_scene.go index 3612da72d..069e90876 100644 --- a/internal/api/routes_scene.go +++ b/internal/api/routes_scene.go @@ -15,18 +15,36 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) +type SceneFinder interface { + manager.SceneCoverGetter + + scene.IDFinder + FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error) + FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error) + GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error) +} + +type SceneMarkerFinder interface { + Find(ctx context.Context, id int) (*models.SceneMarker, error) + FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error) +} + type sceneRoutes struct { - txnManager models.TransactionManager + txnManager txn.Manager + sceneFinder SceneFinder + sceneMarkerFinder SceneMarkerFinder + tagFinder scene.MarkerTagFinder } func (rs sceneRoutes) Routes() chi.Router { r := chi.NewRouter() r.Route("/{sceneId}", func(r chi.Router) { - r.Use(SceneCtx) + r.Use(rs.SceneCtx) // streaming endpoints r.Get("/stream", rs.StreamDirect) @@ -48,8 +66,8 @@ func (rs sceneRoutes) Routes() chi.Router { r.Get("/scene_marker/{sceneMarkerId}/preview", rs.SceneMarkerPreview) r.Get("/scene_marker/{sceneMarkerId}/screenshot", rs.SceneMarkerScreenshot) }) - r.With(SceneCtx).Get("/{sceneId}_thumbs.vtt", rs.VttThumbs) - r.With(SceneCtx).Get("/{sceneId}_sprite.jpg", rs.VttSprite) + r.With(rs.SceneCtx).Get("/{sceneId}_thumbs.vtt", rs.VttThumbs) + r.With(rs.SceneCtx).Get("/{sceneId}_sprite.jpg", rs.VttSprite) return r } @@ -60,7 +78,8 @@ func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) { scene := r.Context().Value(sceneKey).(*models.Scene) ss := manager.SceneServer{ - TXNManager: rs.txnManager, + TxnManager: rs.txnManager, + SceneCoverGetter: rs.sceneFinder, } ss.StreamSceneDirect(scene, w, r) } @@ -190,7 +209,8 @@ func (rs sceneRoutes) Screenshot(w http.ResponseWriter, r *http.Request) { scene := r.Context().Value(sceneKey).(*models.Scene) ss := manager.SceneServer{ - TXNManager: rs.txnManager, + TxnManager: rs.txnManager, + SceneCoverGetter: rs.sceneFinder, } ss.ServeScreenshot(scene, w, r) } @@ -221,16 +241,16 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce } var ret string - if err := rs.txnManager.WithReadTxn(ctx, func(repo models.ReaderRepository) error { - qb := repo.Tag() - primaryTag, err := qb.Find(marker.PrimaryTagID) + if err := txn.WithTxn(ctx, rs.txnManager, func(ctx context.Context) error { + qb := rs.tagFinder + primaryTag, err := qb.Find(ctx, marker.PrimaryTagID) if err != nil { return err } ret = primaryTag.Name - tags, err := qb.FindBySceneMarkerID(marker.ID) + tags, err := qb.FindBySceneMarkerID(ctx, marker.ID) if err != nil { return err } @@ -250,9 +270,9 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) { scene := r.Context().Value(sceneKey).(*models.Scene) var sceneMarkers []*models.SceneMarker - if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - sceneMarkers, err = repo.SceneMarker().FindBySceneID(scene.ID) + sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID) return err }); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -289,9 +309,9 @@ func (rs sceneRoutes) InteractiveHeatmap(w http.ResponseWriter, r *http.Request) func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang string, ext string) { s := r.Context().Value(sceneKey).(*models.Scene) - if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - captions, err := repo.Scene().GetCaptions(s.ID) + captions, err := rs.sceneFinder.GetCaptions(ctx, s.ID) for _, caption := range captions { if lang == caption.LanguageCode && ext == caption.CaptionType { sub, err := scene.ReadSubs(caption.Path(s.Path)) @@ -344,9 +364,9 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request) scene := r.Context().Value(sceneKey).(*models.Scene) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) var sceneMarker *models.SceneMarker - if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID) + sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) return err }); err != nil { logger.Warnf("Error when getting scene marker for stream: %s", err.Error()) @@ -367,9 +387,9 @@ func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request) scene := r.Context().Value(sceneKey).(*models.Scene) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) var sceneMarker *models.SceneMarker - if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID) + sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) return err }); err != nil { logger.Warnf("Error when getting scene marker for stream: %s", err.Error()) @@ -400,9 +420,9 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque scene := r.Context().Value(sceneKey).(*models.Scene) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) var sceneMarker *models.SceneMarker - if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID) + sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) return err }); err != nil { logger.Warnf("Error when getting scene marker for stream: %s", err.Error()) @@ -431,23 +451,23 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque // endregion -func SceneCtx(next http.Handler) http.Handler { +func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sceneIdentifierQueryParam := chi.URLParam(r, "sceneId") sceneID, _ := strconv.Atoi(sceneIdentifierQueryParam) var scene *models.Scene - readTxnErr := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { - qb := repo.Scene() + readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + qb := rs.sceneFinder if sceneID == 0 { // determine checksum/os by the length of the query param if len(sceneIdentifierQueryParam) == 32 { - scene, _ = qb.FindByChecksum(sceneIdentifierQueryParam) + scene, _ = qb.FindByChecksum(ctx, sceneIdentifierQueryParam) } else { - scene, _ = qb.FindByOSHash(sceneIdentifierQueryParam) + scene, _ = qb.FindByOSHash(ctx, sceneIdentifierQueryParam) } } else { - scene, _ = qb.Find(sceneID) + scene, _ = qb.Find(ctx, sceneID) } return nil diff --git a/internal/api/routes_studio.go b/internal/api/routes_studio.go index 18f78b30c..e26499f04 100644 --- a/internal/api/routes_studio.go +++ b/internal/api/routes_studio.go @@ -8,21 +8,28 @@ import ( "syscall" "github.com/go-chi/chi" - "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/studio" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) +type StudioFinder interface { + studio.Finder + GetImage(ctx context.Context, studioID int) ([]byte, error) +} + type studioRoutes struct { - txnManager models.TransactionManager + txnManager txn.Manager + studioFinder StudioFinder } func (rs studioRoutes) Routes() chi.Router { r := chi.NewRouter() r.Route("/{studioId}", func(r chi.Router) { - r.Use(StudioCtx) + r.Use(rs.StudioCtx) r.Get("/image", rs.Image) }) @@ -35,8 +42,8 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) { var image []byte if defaultParam != "true" { - err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { - image, _ = repo.Studio().GetImage(studio.ID) + err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + image, _ = rs.studioFinder.GetImage(ctx, studio.ID) return nil }) if err != nil { @@ -58,7 +65,7 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) { } } -func StudioCtx(next http.Handler) http.Handler { +func (rs studioRoutes) StudioCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { studioID, err := strconv.Atoi(chi.URLParam(r, "studioId")) if err != nil { @@ -67,9 +74,9 @@ func StudioCtx(next http.Handler) http.Handler { } var studio *models.Studio - if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - studio, err = repo.Studio().Find(studioID) + studio, err = rs.studioFinder.Find(ctx, studioID) return err }); err != nil { http.Error(w, http.StatusText(404), 404) diff --git a/internal/api/routes_tag.go b/internal/api/routes_tag.go index 8ffdc62c9..69c573cb2 100644 --- a/internal/api/routes_tag.go +++ b/internal/api/routes_tag.go @@ -6,21 +6,28 @@ import ( "strconv" "github.com/go-chi/chi" - "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/tag" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) +type TagFinder interface { + tag.Finder + GetImage(ctx context.Context, tagID int) ([]byte, error) +} + type tagRoutes struct { - txnManager models.TransactionManager + txnManager txn.Manager + tagFinder TagFinder } func (rs tagRoutes) Routes() chi.Router { r := chi.NewRouter() r.Route("/{tagId}", func(r chi.Router) { - r.Use(TagCtx) + r.Use(rs.TagCtx) r.Get("/image", rs.Image) }) @@ -33,8 +40,8 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) { var image []byte if defaultParam != "true" { - err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { - image, _ = repo.Tag().GetImage(tag.ID) + err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { + image, _ = rs.tagFinder.GetImage(ctx, tag.ID) return nil }) if err != nil { @@ -51,7 +58,7 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) { } } -func TagCtx(next http.Handler) http.Handler { +func (rs tagRoutes) TagCtx(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tagID, err := strconv.Atoi(chi.URLParam(r, "tagId")) if err != nil { @@ -60,9 +67,9 @@ func TagCtx(next http.Handler) http.Handler { } var tag *models.Tag - if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { var err error - tag, err = repo.Tag().Find(tagID) + tag, err = rs.tagFinder.Find(ctx, tagID) return err }); err != nil { http.Error(w, http.StatusText(404), 404) diff --git a/internal/api/server.go b/internal/api/server.go index 95ce7f775..48bf87764 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -73,10 +73,11 @@ func Start() error { return errors.New(message) } - txnManager := manager.GetInstance().TxnManager + txnManager := manager.GetInstance().Repository pluginCache := manager.GetInstance().PluginCache resolver := &Resolver{ txnManager: txnManager, + repository: txnManager, hookExecutor: pluginCache, } @@ -118,22 +119,30 @@ func Start() error { r.Get(loginEndPoint, getLoginHandler(loginUIBox)) r.Mount("/performer", performerRoutes{ - txnManager: txnManager, + txnManager: txnManager, + performerFinder: txnManager.Performer, }.Routes()) r.Mount("/scene", sceneRoutes{ - txnManager: txnManager, + txnManager: txnManager, + sceneFinder: txnManager.Scene, + sceneMarkerFinder: txnManager.SceneMarker, + tagFinder: txnManager.Tag, }.Routes()) r.Mount("/image", imageRoutes{ - txnManager: txnManager, + txnManager: txnManager, + imageFinder: txnManager.Image, }.Routes()) r.Mount("/studio", studioRoutes{ - txnManager: txnManager, + txnManager: txnManager, + studioFinder: txnManager.Studio, }.Routes()) r.Mount("/movie", movieRoutes{ - txnManager: txnManager, + txnManager: txnManager, + movieFinder: txnManager.Movie, }.Routes()) r.Mount("/tag", tagRoutes{ txnManager: txnManager, + tagFinder: txnManager.Tag, }.Routes()) r.Mount("/downloads", downloadsRoutes{}.Routes()) diff --git a/internal/autotag/gallery.go b/internal/autotag/gallery.go index 3bdfd3c15..7f90c7e76 100644 --- a/internal/autotag/gallery.go +++ b/internal/autotag/gallery.go @@ -1,6 +1,8 @@ package autotag import ( + "context" + "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" @@ -21,18 +23,18 @@ func getGalleryFileTagger(s *models.Gallery, cache *match.Cache) tagger { } // GalleryPerformers tags the provided gallery with performers whose name matches the gallery's path. -func GalleryPerformers(s *models.Gallery, rw models.GalleryReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error { +func GalleryPerformers(ctx context.Context, s *models.Gallery, rw gallery.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error { t := getGalleryFileTagger(s, cache) - return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) { - return gallery.AddPerformer(rw, subjectID, otherID) + return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) { + return gallery.AddPerformer(ctx, rw, subjectID, otherID) }) } // GalleryStudios tags the provided gallery with the first studio whose name matches the gallery's path. // // Gallerys will not be tagged if studio is already set. -func GalleryStudios(s *models.Gallery, rw models.GalleryReaderWriter, studioReader models.StudioReader, cache *match.Cache) error { +func GalleryStudios(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error { if s.StudioID.Valid { // don't modify return nil @@ -40,16 +42,16 @@ func GalleryStudios(s *models.Gallery, rw models.GalleryReaderWriter, studioRead t := getGalleryFileTagger(s, cache) - return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) { - return addGalleryStudio(rw, subjectID, otherID) + return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) { + return addGalleryStudio(ctx, rw, subjectID, otherID) }) } // GalleryTags tags the provided gallery with tags whose name matches the gallery's path. -func GalleryTags(s *models.Gallery, rw models.GalleryReaderWriter, tagReader models.TagReader, cache *match.Cache) error { +func GalleryTags(ctx context.Context, s *models.Gallery, rw gallery.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error { t := getGalleryFileTagger(s, cache) - return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) { - return gallery.AddTag(rw, subjectID, otherID) + return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) { + return gallery.AddTag(ctx, rw, subjectID, otherID) }) } diff --git a/internal/autotag/gallery_test.go b/internal/autotag/gallery_test.go index 6d744400a..a50dc8ac4 100644 --- a/internal/autotag/gallery_test.go +++ b/internal/autotag/gallery_test.go @@ -1,6 +1,7 @@ package autotag import ( + "context" "testing" "github.com/stashapp/stash/pkg/models" @@ -11,6 +12,8 @@ import ( const galleryExt = "zip" +var testCtx = context.Background() + func TestGalleryPerformers(t *testing.T) { t.Parallel() @@ -37,19 +40,19 @@ func TestGalleryPerformers(t *testing.T) { mockPerformerReader := &mocks.PerformerReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{} - mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() + mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() if test.Matches { - mockGalleryReader.On("GetPerformerIDs", galleryID).Return(nil, nil).Once() - mockGalleryReader.On("UpdatePerformers", galleryID, []int{performerID}).Return(nil).Once() + mockGalleryReader.On("GetPerformerIDs", testCtx, galleryID).Return(nil, nil).Once() + mockGalleryReader.On("UpdatePerformers", testCtx, galleryID, []int{performerID}).Return(nil).Once() } gallery := models.Gallery{ ID: galleryID, Path: models.NullString(test.Path), } - err := GalleryPerformers(&gallery, mockGalleryReader, mockPerformerReader, nil) + err := GalleryPerformers(testCtx, &gallery, mockGalleryReader, mockPerformerReader, nil) assert.Nil(err) mockPerformerReader.AssertExpectations(t) @@ -81,9 +84,9 @@ func TestGalleryStudios(t *testing.T) { doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) { if test.Matches { - mockGalleryReader.On("Find", galleryID).Return(&models.Gallery{}, nil).Once() + mockGalleryReader.On("Find", testCtx, galleryID).Return(&models.Gallery{}, nil).Once() expectedStudioID := models.NullInt64(studioID) - mockGalleryReader.On("UpdatePartial", models.GalleryPartial{ + mockGalleryReader.On("UpdatePartial", testCtx, models.GalleryPartial{ ID: galleryID, StudioID: &expectedStudioID, }).Return(nil, nil).Once() @@ -93,7 +96,7 @@ func TestGalleryStudios(t *testing.T) { ID: galleryID, Path: models.NullString(test.Path), } - err := GalleryStudios(&gallery, mockGalleryReader, mockStudioReader, nil) + err := GalleryStudios(testCtx, &gallery, mockGalleryReader, mockStudioReader, nil) assert.Nil(err) mockStudioReader.AssertExpectations(t) @@ -104,9 +107,9 @@ func TestGalleryStudios(t *testing.T) { mockStudioReader := &mocks.StudioReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{} - mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() + mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() doTest(mockStudioReader, mockGalleryReader, test) } @@ -119,12 +122,12 @@ func TestGalleryStudios(t *testing.T) { mockStudioReader := &mocks.StudioReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{} - mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", studioID).Return([]string{ + mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ studioName, }, nil).Once() - mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() doTest(mockStudioReader, mockGalleryReader, test) } @@ -154,15 +157,15 @@ func TestGalleryTags(t *testing.T) { doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) { if test.Matches { - mockGalleryReader.On("GetTagIDs", galleryID).Return(nil, nil).Once() - mockGalleryReader.On("UpdateTags", galleryID, []int{tagID}).Return(nil).Once() + mockGalleryReader.On("GetTagIDs", testCtx, galleryID).Return(nil, nil).Once() + mockGalleryReader.On("UpdateTags", testCtx, galleryID, []int{tagID}).Return(nil).Once() } gallery := models.Gallery{ ID: galleryID, Path: models.NullString(test.Path), } - err := GalleryTags(&gallery, mockGalleryReader, mockTagReader, nil) + err := GalleryTags(testCtx, &gallery, mockGalleryReader, mockTagReader, nil) assert.Nil(err) mockTagReader.AssertExpectations(t) @@ -173,9 +176,9 @@ func TestGalleryTags(t *testing.T) { mockTagReader := &mocks.TagReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{} - mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() + mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() doTest(mockTagReader, mockGalleryReader, test) } @@ -187,12 +190,12 @@ func TestGalleryTags(t *testing.T) { mockTagReader := &mocks.TagReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{} - mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", tagID).Return([]string{ + mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ tagName, }, nil).Once() - mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once() + mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() doTest(mockTagReader, mockGalleryReader, test) } diff --git a/internal/autotag/image.go b/internal/autotag/image.go index 516f30181..17d0d1816 100644 --- a/internal/autotag/image.go +++ b/internal/autotag/image.go @@ -1,6 +1,8 @@ package autotag import ( + "context" + "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" @@ -17,18 +19,18 @@ func getImageFileTagger(s *models.Image, cache *match.Cache) tagger { } // ImagePerformers tags the provided image with performers whose name matches the image's path. -func ImagePerformers(s *models.Image, rw models.ImageReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error { +func ImagePerformers(ctx context.Context, s *models.Image, rw image.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error { t := getImageFileTagger(s, cache) - return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) { - return image.AddPerformer(rw, subjectID, otherID) + return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) { + return image.AddPerformer(ctx, rw, subjectID, otherID) }) } // ImageStudios tags the provided image with the first studio whose name matches the image's path. // // Images will not be tagged if studio is already set. -func ImageStudios(s *models.Image, rw models.ImageReaderWriter, studioReader models.StudioReader, cache *match.Cache) error { +func ImageStudios(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error { if s.StudioID.Valid { // don't modify return nil @@ -36,16 +38,16 @@ func ImageStudios(s *models.Image, rw models.ImageReaderWriter, studioReader mod t := getImageFileTagger(s, cache) - return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) { - return addImageStudio(rw, subjectID, otherID) + return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) { + return addImageStudio(ctx, rw, subjectID, otherID) }) } // ImageTags tags the provided image with tags whose name matches the image's path. -func ImageTags(s *models.Image, rw models.ImageReaderWriter, tagReader models.TagReader, cache *match.Cache) error { +func ImageTags(ctx context.Context, s *models.Image, rw image.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error { t := getImageFileTagger(s, cache) - return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) { - return image.AddTag(rw, subjectID, otherID) + return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) { + return image.AddTag(ctx, rw, subjectID, otherID) }) } diff --git a/internal/autotag/image_test.go b/internal/autotag/image_test.go index 130ce51af..67eedb689 100644 --- a/internal/autotag/image_test.go +++ b/internal/autotag/image_test.go @@ -37,19 +37,19 @@ func TestImagePerformers(t *testing.T) { mockPerformerReader := &mocks.PerformerReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{} - mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() + mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() if test.Matches { - mockImageReader.On("GetPerformerIDs", imageID).Return(nil, nil).Once() - mockImageReader.On("UpdatePerformers", imageID, []int{performerID}).Return(nil).Once() + mockImageReader.On("GetPerformerIDs", testCtx, imageID).Return(nil, nil).Once() + mockImageReader.On("UpdatePerformers", testCtx, imageID, []int{performerID}).Return(nil).Once() } image := models.Image{ ID: imageID, Path: test.Path, } - err := ImagePerformers(&image, mockImageReader, mockPerformerReader, nil) + err := ImagePerformers(testCtx, &image, mockImageReader, mockPerformerReader, nil) assert.Nil(err) mockPerformerReader.AssertExpectations(t) @@ -81,9 +81,9 @@ func TestImageStudios(t *testing.T) { doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) { if test.Matches { - mockImageReader.On("Find", imageID).Return(&models.Image{}, nil).Once() + mockImageReader.On("Find", testCtx, imageID).Return(&models.Image{}, nil).Once() expectedStudioID := models.NullInt64(studioID) - mockImageReader.On("Update", models.ImagePartial{ + mockImageReader.On("Update", testCtx, models.ImagePartial{ ID: imageID, StudioID: &expectedStudioID, }).Return(nil, nil).Once() @@ -93,7 +93,7 @@ func TestImageStudios(t *testing.T) { ID: imageID, Path: test.Path, } - err := ImageStudios(&image, mockImageReader, mockStudioReader, nil) + err := ImageStudios(testCtx, &image, mockImageReader, mockStudioReader, nil) assert.Nil(err) mockStudioReader.AssertExpectations(t) @@ -104,9 +104,9 @@ func TestImageStudios(t *testing.T) { mockStudioReader := &mocks.StudioReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{} - mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() + mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() doTest(mockStudioReader, mockImageReader, test) } @@ -119,12 +119,12 @@ func TestImageStudios(t *testing.T) { mockStudioReader := &mocks.StudioReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{} - mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", studioID).Return([]string{ + mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ studioName, }, nil).Once() - mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() doTest(mockStudioReader, mockImageReader, test) } @@ -154,15 +154,15 @@ func TestImageTags(t *testing.T) { doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) { if test.Matches { - mockImageReader.On("GetTagIDs", imageID).Return(nil, nil).Once() - mockImageReader.On("UpdateTags", imageID, []int{tagID}).Return(nil).Once() + mockImageReader.On("GetTagIDs", testCtx, imageID).Return(nil, nil).Once() + mockImageReader.On("UpdateTags", testCtx, imageID, []int{tagID}).Return(nil).Once() } image := models.Image{ ID: imageID, Path: test.Path, } - err := ImageTags(&image, mockImageReader, mockTagReader, nil) + err := ImageTags(testCtx, &image, mockImageReader, mockTagReader, nil) assert.Nil(err) mockTagReader.AssertExpectations(t) @@ -173,9 +173,9 @@ func TestImageTags(t *testing.T) { mockTagReader := &mocks.TagReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{} - mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() + mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() doTest(mockTagReader, mockImageReader, test) } @@ -188,12 +188,12 @@ func TestImageTags(t *testing.T) { mockTagReader := &mocks.TagReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{} - mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", tagID).Return([]string{ + mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ tagName, }, nil).Once() - mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once() + mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() doTest(mockTagReader, mockImageReader, test) } diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index 9ca176d4b..5465d20c8 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -10,10 +10,10 @@ import ( "os" "testing" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/sqlite" + "github.com/stashapp/stash/pkg/txn" _ "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/golang-migrate/migrate/v4/source/file" @@ -28,8 +28,11 @@ const existingStudioGalleryName = testName + ".dontChangeStudio.mp4" var existingStudioID int +var db *sqlite.Database +var r models.Repository + func testTeardown(databaseFile string) { - err := database.DB.Close() + err := db.Close() if err != nil { panic(err) @@ -50,10 +53,13 @@ func runTests(m *testing.M) int { f.Close() databaseFile := f.Name() - if err := database.Initialize(databaseFile); err != nil { + db = &sqlite.Database{} + if err := db.Open(databaseFile); err != nil { panic(fmt.Sprintf("Could not initialize database: %s", err.Error())) } + r = db.TxnRepository() + // defer close and delete the database defer testTeardown(databaseFile) @@ -71,7 +77,7 @@ func TestMain(m *testing.M) { os.Exit(ret) } -func createPerformer(pqb models.PerformerWriter) error { +func createPerformer(ctx context.Context, pqb models.PerformerWriter) error { // create the performer performer := models.Performer{ Checksum: testName, @@ -79,7 +85,7 @@ func createPerformer(pqb models.PerformerWriter) error { Favorite: sql.NullBool{Valid: true, Bool: false}, } - _, err := pqb.Create(performer) + _, err := pqb.Create(ctx, performer) if err != nil { return err } @@ -87,23 +93,23 @@ func createPerformer(pqb models.PerformerWriter) error { return nil } -func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) { +func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) { // create the studio studio := models.Studio{ Checksum: name, Name: sql.NullString{Valid: true, String: name}, } - return qb.Create(studio) + return qb.Create(ctx, studio) } -func createTag(qb models.TagWriter) error { +func createTag(ctx context.Context, qb models.TagWriter) error { // create the studio tag := models.Tag{ Name: testName, } - _, err := qb.Create(tag) + _, err := qb.Create(ctx, tag) if err != nil { return err } @@ -111,18 +117,18 @@ func createTag(qb models.TagWriter) error { return nil } -func createScenes(sqb models.SceneReaderWriter) error { +func createScenes(ctx context.Context, sqb models.SceneReaderWriter) error { // create the scenes scenePatterns, falseScenePatterns := generateTestPaths(testName, sceneExt) for _, fn := range scenePatterns { - err := createScene(sqb, makeScene(fn, true)) + err := createScene(ctx, sqb, makeScene(fn, true)) if err != nil { return err } } for _, fn := range falseScenePatterns { - err := createScene(sqb, makeScene(fn, false)) + err := createScene(ctx, sqb, makeScene(fn, false)) if err != nil { return err } @@ -132,7 +138,7 @@ func createScenes(sqb models.SceneReaderWriter) error { for _, fn := range scenePatterns { s := makeScene("organized"+fn, false) s.Organized = true - err := createScene(sqb, s) + err := createScene(ctx, sqb, s) if err != nil { return err } @@ -141,7 +147,7 @@ func createScenes(sqb models.SceneReaderWriter) error { // create scene with existing studio io studioScene := makeScene(existingStudioSceneName, true) studioScene.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)} - err := createScene(sqb, studioScene) + err := createScene(ctx, sqb, studioScene) if err != nil { return err } @@ -163,8 +169,8 @@ func makeScene(name string, expectedResult bool) *models.Scene { return scene } -func createScene(sqb models.SceneWriter, scene *models.Scene) error { - _, err := sqb.Create(*scene) +func createScene(ctx context.Context, sqb models.SceneWriter, scene *models.Scene) error { + _, err := sqb.Create(ctx, *scene) if err != nil { return fmt.Errorf("Failed to create scene with name '%s': %s", scene.Path, err.Error()) @@ -173,18 +179,18 @@ func createScene(sqb models.SceneWriter, scene *models.Scene) error { return nil } -func createImages(sqb models.ImageReaderWriter) error { +func createImages(ctx context.Context, sqb models.ImageReaderWriter) error { // create the images imagePatterns, falseImagePatterns := generateTestPaths(testName, imageExt) for _, fn := range imagePatterns { - err := createImage(sqb, makeImage(fn, true)) + err := createImage(ctx, sqb, makeImage(fn, true)) if err != nil { return err } } for _, fn := range falseImagePatterns { - err := createImage(sqb, makeImage(fn, false)) + err := createImage(ctx, sqb, makeImage(fn, false)) if err != nil { return err } @@ -194,7 +200,7 @@ func createImages(sqb models.ImageReaderWriter) error { for _, fn := range imagePatterns { s := makeImage("organized"+fn, false) s.Organized = true - err := createImage(sqb, s) + err := createImage(ctx, sqb, s) if err != nil { return err } @@ -203,7 +209,7 @@ func createImages(sqb models.ImageReaderWriter) error { // create image with existing studio io studioImage := makeImage(existingStudioImageName, true) studioImage.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)} - err := createImage(sqb, studioImage) + err := createImage(ctx, sqb, studioImage) if err != nil { return err } @@ -225,8 +231,8 @@ func makeImage(name string, expectedResult bool) *models.Image { return image } -func createImage(sqb models.ImageWriter, image *models.Image) error { - _, err := sqb.Create(*image) +func createImage(ctx context.Context, sqb models.ImageWriter, image *models.Image) error { + _, err := sqb.Create(ctx, *image) if err != nil { return fmt.Errorf("Failed to create image with name '%s': %s", image.Path, err.Error()) @@ -235,18 +241,18 @@ func createImage(sqb models.ImageWriter, image *models.Image) error { return nil } -func createGalleries(sqb models.GalleryReaderWriter) error { +func createGalleries(ctx context.Context, sqb models.GalleryReaderWriter) error { // create the galleries galleryPatterns, falseGalleryPatterns := generateTestPaths(testName, galleryExt) for _, fn := range galleryPatterns { - err := createGallery(sqb, makeGallery(fn, true)) + err := createGallery(ctx, sqb, makeGallery(fn, true)) if err != nil { return err } } for _, fn := range falseGalleryPatterns { - err := createGallery(sqb, makeGallery(fn, false)) + err := createGallery(ctx, sqb, makeGallery(fn, false)) if err != nil { return err } @@ -256,7 +262,7 @@ func createGalleries(sqb models.GalleryReaderWriter) error { for _, fn := range galleryPatterns { s := makeGallery("organized"+fn, false) s.Organized = true - err := createGallery(sqb, s) + err := createGallery(ctx, sqb, s) if err != nil { return err } @@ -265,7 +271,7 @@ func createGalleries(sqb models.GalleryReaderWriter) error { // create gallery with existing studio io studioGallery := makeGallery(existingStudioGalleryName, true) studioGallery.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)} - err := createGallery(sqb, studioGallery) + err := createGallery(ctx, sqb, studioGallery) if err != nil { return err } @@ -287,8 +293,8 @@ func makeGallery(name string, expectedResult bool) *models.Gallery { return gallery } -func createGallery(sqb models.GalleryWriter, gallery *models.Gallery) error { - _, err := sqb.Create(*gallery) +func createGallery(ctx context.Context, sqb models.GalleryWriter, gallery *models.Gallery) error { + _, err := sqb.Create(ctx, *gallery) if err != nil { return fmt.Errorf("Failed to create gallery with name '%s': %s", gallery.Path.String, err.Error()) @@ -297,47 +303,46 @@ func createGallery(sqb models.GalleryWriter, gallery *models.Gallery) error { return nil } -func withTxn(f func(r models.Repository) error) error { - t := sqlite.NewTransactionManager() - return t.WithTxn(context.TODO(), f) +func withTxn(f func(ctx context.Context) error) error { + return txn.WithTxn(context.TODO(), db, f) } func populateDB() error { - if err := withTxn(func(r models.Repository) error { - err := createPerformer(r.Performer()) + if err := withTxn(func(ctx context.Context) error { + err := createPerformer(ctx, r.Performer) if err != nil { return err } - _, err = createStudio(r.Studio(), testName) + _, err = createStudio(ctx, r.Studio, testName) if err != nil { return err } // create existing studio - existingStudio, err := createStudio(r.Studio(), existingStudioName) + existingStudio, err := createStudio(ctx, r.Studio, existingStudioName) if err != nil { return err } existingStudioID = existingStudio.ID - err = createTag(r.Tag()) + err = createTag(ctx, r.Tag) if err != nil { return err } - err = createScenes(r.Scene()) + err = createScenes(ctx, r.Scene) if err != nil { return err } - err = createImages(r.Image()) + err = createImages(ctx, r.Image) if err != nil { return err } - err = createGalleries(r.Gallery()) + err = createGalleries(ctx, r.Gallery) if err != nil { return err } @@ -352,9 +357,9 @@ func populateDB() error { func TestParsePerformerScenes(t *testing.T) { var performers []*models.Performer - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - performers, err = r.Performer().All() + performers, err = r.Performer.All(ctx) return err }); err != nil { t.Errorf("Error getting performer: %s", err) @@ -362,24 +367,24 @@ func TestParsePerformerScenes(t *testing.T) { } for _, p := range performers { - if err := withTxn(func(r models.Repository) error { - return PerformerScenes(p, nil, r.Scene(), nil) + if err := withTxn(func(ctx context.Context) error { + return PerformerScenes(ctx, p, nil, r.Scene, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that scenes were tagged correctly - withTxn(func(r models.Repository) error { - pqb := r.Performer() + withTxn(func(ctx context.Context) error { + pqb := r.Performer - scenes, err := r.Scene().All() + scenes, err := r.Scene.All(ctx) if err != nil { t.Error(err.Error()) } for _, scene := range scenes { - performers, err := pqb.FindBySceneID(scene.ID) + performers, err := pqb.FindBySceneID(ctx, scene.ID) if err != nil { t.Errorf("Error getting scene performers: %s", err.Error()) @@ -399,9 +404,9 @@ func TestParsePerformerScenes(t *testing.T) { func TestParseStudioScenes(t *testing.T) { var studios []*models.Studio - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - studios, err = r.Studio().All() + studios, err = r.Studio.All(ctx) return err }); err != nil { t.Errorf("Error getting studio: %s", err) @@ -409,21 +414,21 @@ func TestParseStudioScenes(t *testing.T) { } for _, s := range studios { - if err := withTxn(func(r models.Repository) error { - aliases, err := r.Studio().GetAliases(s.ID) + if err := withTxn(func(ctx context.Context) error { + aliases, err := r.Studio.GetAliases(ctx, s.ID) if err != nil { return err } - return StudioScenes(s, nil, aliases, r.Scene(), nil) + return StudioScenes(ctx, s, nil, aliases, r.Scene, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that scenes were tagged correctly - withTxn(func(r models.Repository) error { - scenes, err := r.Scene().All() + withTxn(func(ctx context.Context) error { + scenes, err := r.Scene.All(ctx) if err != nil { t.Error(err.Error()) } @@ -455,9 +460,9 @@ func TestParseStudioScenes(t *testing.T) { func TestParseTagScenes(t *testing.T) { var tags []*models.Tag - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - tags, err = r.Tag().All() + tags, err = r.Tag.All(ctx) return err }); err != nil { t.Errorf("Error getting performer: %s", err) @@ -465,29 +470,29 @@ func TestParseTagScenes(t *testing.T) { } for _, s := range tags { - if err := withTxn(func(r models.Repository) error { - aliases, err := r.Tag().GetAliases(s.ID) + if err := withTxn(func(ctx context.Context) error { + aliases, err := r.Tag.GetAliases(ctx, s.ID) if err != nil { return err } - return TagScenes(s, nil, aliases, r.Scene(), nil) + return TagScenes(ctx, s, nil, aliases, r.Scene, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that scenes were tagged correctly - withTxn(func(r models.Repository) error { - scenes, err := r.Scene().All() + withTxn(func(ctx context.Context) error { + scenes, err := r.Scene.All(ctx) if err != nil { t.Error(err.Error()) } - tqb := r.Tag() + tqb := r.Tag for _, scene := range scenes { - tags, err := tqb.FindBySceneID(scene.ID) + tags, err := tqb.FindBySceneID(ctx, scene.ID) if err != nil { t.Errorf("Error getting scene tags: %s", err.Error()) @@ -507,9 +512,9 @@ func TestParseTagScenes(t *testing.T) { func TestParsePerformerImages(t *testing.T) { var performers []*models.Performer - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - performers, err = r.Performer().All() + performers, err = r.Performer.All(ctx) return err }); err != nil { t.Errorf("Error getting performer: %s", err) @@ -517,24 +522,24 @@ func TestParsePerformerImages(t *testing.T) { } for _, p := range performers { - if err := withTxn(func(r models.Repository) error { - return PerformerImages(p, nil, r.Image(), nil) + if err := withTxn(func(ctx context.Context) error { + return PerformerImages(ctx, p, nil, r.Image, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that images were tagged correctly - withTxn(func(r models.Repository) error { - pqb := r.Performer() + withTxn(func(ctx context.Context) error { + pqb := r.Performer - images, err := r.Image().All() + images, err := r.Image.All(ctx) if err != nil { t.Error(err.Error()) } for _, image := range images { - performers, err := pqb.FindByImageID(image.ID) + performers, err := pqb.FindByImageID(ctx, image.ID) if err != nil { t.Errorf("Error getting image performers: %s", err.Error()) @@ -554,9 +559,9 @@ func TestParsePerformerImages(t *testing.T) { func TestParseStudioImages(t *testing.T) { var studios []*models.Studio - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - studios, err = r.Studio().All() + studios, err = r.Studio.All(ctx) return err }); err != nil { t.Errorf("Error getting studio: %s", err) @@ -564,21 +569,21 @@ func TestParseStudioImages(t *testing.T) { } for _, s := range studios { - if err := withTxn(func(r models.Repository) error { - aliases, err := r.Studio().GetAliases(s.ID) + if err := withTxn(func(ctx context.Context) error { + aliases, err := r.Studio.GetAliases(ctx, s.ID) if err != nil { return err } - return StudioImages(s, nil, aliases, r.Image(), nil) + return StudioImages(ctx, s, nil, aliases, r.Image, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that images were tagged correctly - withTxn(func(r models.Repository) error { - images, err := r.Image().All() + withTxn(func(ctx context.Context) error { + images, err := r.Image.All(ctx) if err != nil { t.Error(err.Error()) } @@ -610,9 +615,9 @@ func TestParseStudioImages(t *testing.T) { func TestParseTagImages(t *testing.T) { var tags []*models.Tag - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - tags, err = r.Tag().All() + tags, err = r.Tag.All(ctx) return err }); err != nil { t.Errorf("Error getting performer: %s", err) @@ -620,29 +625,29 @@ func TestParseTagImages(t *testing.T) { } for _, s := range tags { - if err := withTxn(func(r models.Repository) error { - aliases, err := r.Tag().GetAliases(s.ID) + if err := withTxn(func(ctx context.Context) error { + aliases, err := r.Tag.GetAliases(ctx, s.ID) if err != nil { return err } - return TagImages(s, nil, aliases, r.Image(), nil) + return TagImages(ctx, s, nil, aliases, r.Image, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that images were tagged correctly - withTxn(func(r models.Repository) error { - images, err := r.Image().All() + withTxn(func(ctx context.Context) error { + images, err := r.Image.All(ctx) if err != nil { t.Error(err.Error()) } - tqb := r.Tag() + tqb := r.Tag for _, image := range images { - tags, err := tqb.FindByImageID(image.ID) + tags, err := tqb.FindByImageID(ctx, image.ID) if err != nil { t.Errorf("Error getting image tags: %s", err.Error()) @@ -662,9 +667,9 @@ func TestParseTagImages(t *testing.T) { func TestParsePerformerGalleries(t *testing.T) { var performers []*models.Performer - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - performers, err = r.Performer().All() + performers, err = r.Performer.All(ctx) return err }); err != nil { t.Errorf("Error getting performer: %s", err) @@ -672,24 +677,24 @@ func TestParsePerformerGalleries(t *testing.T) { } for _, p := range performers { - if err := withTxn(func(r models.Repository) error { - return PerformerGalleries(p, nil, r.Gallery(), nil) + if err := withTxn(func(ctx context.Context) error { + return PerformerGalleries(ctx, p, nil, r.Gallery, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that galleries were tagged correctly - withTxn(func(r models.Repository) error { - pqb := r.Performer() + withTxn(func(ctx context.Context) error { + pqb := r.Performer - galleries, err := r.Gallery().All() + galleries, err := r.Gallery.All(ctx) if err != nil { t.Error(err.Error()) } for _, gallery := range galleries { - performers, err := pqb.FindByGalleryID(gallery.ID) + performers, err := pqb.FindByGalleryID(ctx, gallery.ID) if err != nil { t.Errorf("Error getting gallery performers: %s", err.Error()) @@ -709,9 +714,9 @@ func TestParsePerformerGalleries(t *testing.T) { func TestParseStudioGalleries(t *testing.T) { var studios []*models.Studio - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - studios, err = r.Studio().All() + studios, err = r.Studio.All(ctx) return err }); err != nil { t.Errorf("Error getting studio: %s", err) @@ -719,21 +724,21 @@ func TestParseStudioGalleries(t *testing.T) { } for _, s := range studios { - if err := withTxn(func(r models.Repository) error { - aliases, err := r.Studio().GetAliases(s.ID) + if err := withTxn(func(ctx context.Context) error { + aliases, err := r.Studio.GetAliases(ctx, s.ID) if err != nil { return err } - return StudioGalleries(s, nil, aliases, r.Gallery(), nil) + return StudioGalleries(ctx, s, nil, aliases, r.Gallery, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that galleries were tagged correctly - withTxn(func(r models.Repository) error { - galleries, err := r.Gallery().All() + withTxn(func(ctx context.Context) error { + galleries, err := r.Gallery.All(ctx) if err != nil { t.Error(err.Error()) } @@ -765,9 +770,9 @@ func TestParseStudioGalleries(t *testing.T) { func TestParseTagGalleries(t *testing.T) { var tags []*models.Tag - if err := withTxn(func(r models.Repository) error { + if err := withTxn(func(ctx context.Context) error { var err error - tags, err = r.Tag().All() + tags, err = r.Tag.All(ctx) return err }); err != nil { t.Errorf("Error getting performer: %s", err) @@ -775,29 +780,29 @@ func TestParseTagGalleries(t *testing.T) { } for _, s := range tags { - if err := withTxn(func(r models.Repository) error { - aliases, err := r.Tag().GetAliases(s.ID) + if err := withTxn(func(ctx context.Context) error { + aliases, err := r.Tag.GetAliases(ctx, s.ID) if err != nil { return err } - return TagGalleries(s, nil, aliases, r.Gallery(), nil) + return TagGalleries(ctx, s, nil, aliases, r.Gallery, nil) }); err != nil { t.Errorf("Error auto-tagging performers: %s", err) } } // verify that galleries were tagged correctly - withTxn(func(r models.Repository) error { - galleries, err := r.Gallery().All() + withTxn(func(ctx context.Context) error { + galleries, err := r.Gallery.All(ctx) if err != nil { t.Error(err.Error()) } - tqb := r.Tag() + tqb := r.Tag for _, gallery := range galleries { - tags, err := tqb.FindByGalleryID(gallery.ID) + tags, err := tqb.FindByGalleryID(ctx, gallery.ID) if err != nil { t.Errorf("Error getting gallery tags: %s", err.Error()) diff --git a/internal/autotag/performer.go b/internal/autotag/performer.go index a6c89466a..ea42667e3 100644 --- a/internal/autotag/performer.go +++ b/internal/autotag/performer.go @@ -1,6 +1,8 @@ package autotag import ( + "context" + "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/match" @@ -8,6 +10,21 @@ import ( "github.com/stashapp/stash/pkg/scene" ) +type SceneQueryPerformerUpdater interface { + scene.Queryer + scene.PerformerUpdater +} + +type ImageQueryPerformerUpdater interface { + image.Queryer + image.PerformerUpdater +} + +type GalleryQueryPerformerUpdater interface { + gallery.Queryer + gallery.PerformerUpdater +} + func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger { return tagger{ ID: p.ID, @@ -18,28 +35,28 @@ func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger { } // PerformerScenes searches for scenes whose path matches the provided performer name and tags the scene with the performer. -func PerformerScenes(p *models.Performer, paths []string, rw models.SceneReaderWriter, cache *match.Cache) error { +func PerformerScenes(ctx context.Context, p *models.Performer, paths []string, rw SceneQueryPerformerUpdater, cache *match.Cache) error { t := getPerformerTagger(p, cache) - return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { - return scene.AddPerformer(rw, otherID, subjectID) + return t.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return scene.AddPerformer(ctx, rw, otherID, subjectID) }) } // PerformerImages searches for images whose path matches the provided performer name and tags the image with the performer. -func PerformerImages(p *models.Performer, paths []string, rw models.ImageReaderWriter, cache *match.Cache) error { +func PerformerImages(ctx context.Context, p *models.Performer, paths []string, rw ImageQueryPerformerUpdater, cache *match.Cache) error { t := getPerformerTagger(p, cache) - return t.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) { - return image.AddPerformer(rw, otherID, subjectID) + return t.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return image.AddPerformer(ctx, rw, otherID, subjectID) }) } // PerformerGalleries searches for galleries whose path matches the provided performer name and tags the gallery with the performer. -func PerformerGalleries(p *models.Performer, paths []string, rw models.GalleryReaderWriter, cache *match.Cache) error { +func PerformerGalleries(ctx context.Context, p *models.Performer, paths []string, rw GalleryQueryPerformerUpdater, cache *match.Cache) error { t := getPerformerTagger(p, cache) - return t.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) { - return gallery.AddPerformer(rw, otherID, subjectID) + return t.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return gallery.AddPerformer(ctx, rw, otherID, subjectID) }) } diff --git a/internal/autotag/performer_test.go b/internal/autotag/performer_test.go index 31befd76a..54a98958a 100644 --- a/internal/autotag/performer_test.go +++ b/internal/autotag/performer_test.go @@ -72,16 +72,16 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) { PerPage: &perPage, } - mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)). + mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)). Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() for i := range matchingPaths { sceneID := i + 1 - mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once() - mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once() + mockSceneReader.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil).Once() + mockSceneReader.On("UpdatePerformers", testCtx, sceneID, []int{performerID}).Return(nil).Once() } - err := PerformerScenes(&performer, nil, mockSceneReader, nil) + err := PerformerScenes(testCtx, &performer, nil, mockSceneReader, nil) assert := assert.New(t) @@ -147,16 +147,16 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) { PerPage: &perPage, } - mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false)). + mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)). Return(mocks.ImageQueryResult(images, len(images)), nil).Once() for i := range matchingPaths { imageID := i + 1 - mockImageReader.On("GetPerformerIDs", imageID).Return(nil, nil).Once() - mockImageReader.On("UpdatePerformers", imageID, []int{performerID}).Return(nil).Once() + mockImageReader.On("GetPerformerIDs", testCtx, imageID).Return(nil, nil).Once() + mockImageReader.On("UpdatePerformers", testCtx, imageID, []int{performerID}).Return(nil).Once() } - err := PerformerImages(&performer, nil, mockImageReader, nil) + err := PerformerImages(testCtx, &performer, nil, mockImageReader, nil) assert := assert.New(t) @@ -222,15 +222,15 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) { PerPage: &perPage, } - mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() + mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() for i := range matchingPaths { galleryID := i + 1 - mockGalleryReader.On("GetPerformerIDs", galleryID).Return(nil, nil).Once() - mockGalleryReader.On("UpdatePerformers", galleryID, []int{performerID}).Return(nil).Once() + mockGalleryReader.On("GetPerformerIDs", testCtx, galleryID).Return(nil, nil).Once() + mockGalleryReader.On("UpdatePerformers", testCtx, galleryID, []int{performerID}).Return(nil).Once() } - err := PerformerGalleries(&performer, nil, mockGalleryReader, nil) + err := PerformerGalleries(testCtx, &performer, nil, mockGalleryReader, nil) assert := assert.New(t) diff --git a/internal/autotag/scene.go b/internal/autotag/scene.go index cfdcaf393..6c6aeb875 100644 --- a/internal/autotag/scene.go +++ b/internal/autotag/scene.go @@ -1,6 +1,8 @@ package autotag import ( + "context" + "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" @@ -17,18 +19,18 @@ func getSceneFileTagger(s *models.Scene, cache *match.Cache) tagger { } // ScenePerformers tags the provided scene with performers whose name matches the scene's path. -func ScenePerformers(s *models.Scene, rw models.SceneReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error { +func ScenePerformers(ctx context.Context, s *models.Scene, rw scene.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error { t := getSceneFileTagger(s, cache) - return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) { - return scene.AddPerformer(rw, subjectID, otherID) + return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) { + return scene.AddPerformer(ctx, rw, subjectID, otherID) }) } // SceneStudios tags the provided scene with the first studio whose name matches the scene's path. // // Scenes will not be tagged if studio is already set. -func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader models.StudioReader, cache *match.Cache) error { +func SceneStudios(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error { if s.StudioID.Valid { // don't modify return nil @@ -36,16 +38,16 @@ func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader mod t := getSceneFileTagger(s, cache) - return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) { - return addSceneStudio(rw, subjectID, otherID) + return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) { + return addSceneStudio(ctx, rw, subjectID, otherID) }) } // SceneTags tags the provided scene with tags whose name matches the scene's path. -func SceneTags(s *models.Scene, rw models.SceneReaderWriter, tagReader models.TagReader, cache *match.Cache) error { +func SceneTags(ctx context.Context, s *models.Scene, rw scene.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error { t := getSceneFileTagger(s, cache) - return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) { - return scene.AddTag(rw, subjectID, otherID) + return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) { + return scene.AddTag(ctx, rw, subjectID, otherID) }) } diff --git a/internal/autotag/scene_test.go b/internal/autotag/scene_test.go index 578b9e7f6..6e66482fc 100644 --- a/internal/autotag/scene_test.go +++ b/internal/autotag/scene_test.go @@ -173,19 +173,19 @@ func TestScenePerformers(t *testing.T) { mockPerformerReader := &mocks.PerformerReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{} - mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() + mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() if test.Matches { - mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once() - mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once() + mockSceneReader.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil).Once() + mockSceneReader.On("UpdatePerformers", testCtx, sceneID, []int{performerID}).Return(nil).Once() } scene := models.Scene{ ID: sceneID, Path: test.Path, } - err := ScenePerformers(&scene, mockSceneReader, mockPerformerReader, nil) + err := ScenePerformers(testCtx, &scene, mockSceneReader, mockPerformerReader, nil) assert.Nil(err) mockPerformerReader.AssertExpectations(t) @@ -217,9 +217,9 @@ func TestSceneStudios(t *testing.T) { doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) { if test.Matches { - mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once() + mockSceneReader.On("Find", testCtx, sceneID).Return(&models.Scene{}, nil).Once() expectedStudioID := models.NullInt64(studioID) - mockSceneReader.On("Update", models.ScenePartial{ + mockSceneReader.On("Update", testCtx, models.ScenePartial{ ID: sceneID, StudioID: &expectedStudioID, }).Return(nil, nil).Once() @@ -229,7 +229,7 @@ func TestSceneStudios(t *testing.T) { ID: sceneID, Path: test.Path, } - err := SceneStudios(&scene, mockSceneReader, mockStudioReader, nil) + err := SceneStudios(testCtx, &scene, mockSceneReader, mockStudioReader, nil) assert.Nil(err) mockStudioReader.AssertExpectations(t) @@ -240,9 +240,9 @@ func TestSceneStudios(t *testing.T) { mockStudioReader := &mocks.StudioReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{} - mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() + mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() doTest(mockStudioReader, mockSceneReader, test) } @@ -255,12 +255,12 @@ func TestSceneStudios(t *testing.T) { mockStudioReader := &mocks.StudioReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{} - mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() - mockStudioReader.On("GetAliases", studioID).Return([]string{ + mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ studioName, }, nil).Once() - mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once() + mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() doTest(mockStudioReader, mockSceneReader, test) } @@ -290,15 +290,15 @@ func TestSceneTags(t *testing.T) { doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) { if test.Matches { - mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once() - mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once() + mockSceneReader.On("GetTagIDs", testCtx, sceneID).Return(nil, nil).Once() + mockSceneReader.On("UpdateTags", testCtx, sceneID, []int{tagID}).Return(nil).Once() } scene := models.Scene{ ID: sceneID, Path: test.Path, } - err := SceneTags(&scene, mockSceneReader, mockTagReader, nil) + err := SceneTags(testCtx, &scene, mockSceneReader, mockTagReader, nil) assert.Nil(err) mockTagReader.AssertExpectations(t) @@ -309,9 +309,9 @@ func TestSceneTags(t *testing.T) { mockTagReader := &mocks.TagReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{} - mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() + mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() doTest(mockTagReader, mockSceneReader, test) } @@ -324,12 +324,12 @@ func TestSceneTags(t *testing.T) { mockTagReader := &mocks.TagReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{} - mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) - mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() - mockTagReader.On("GetAliases", tagID).Return([]string{ + mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) + mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ tagName, }, nil).Once() - mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once() + mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() doTest(mockTagReader, mockSceneReader, test) } diff --git a/internal/autotag/studio.go b/internal/autotag/studio.go index 4a02e7305..79cb22586 100644 --- a/internal/autotag/studio.go +++ b/internal/autotag/studio.go @@ -1,15 +1,19 @@ package autotag import ( + "context" "database/sql" + "github.com/stashapp/stash/pkg/gallery" + "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" ) -func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int) (bool, error) { +func addSceneStudio(ctx context.Context, sceneWriter SceneFinderUpdater, sceneID, studioID int) (bool, error) { // don't set if already set - scene, err := sceneWriter.Find(sceneID) + scene, err := sceneWriter.Find(ctx, sceneID) if err != nil { return false, err } @@ -25,15 +29,15 @@ func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int) StudioID: &s, } - if _, err := sceneWriter.Update(scenePartial); err != nil { + if _, err := sceneWriter.Update(ctx, scenePartial); err != nil { return false, err } return true, nil } -func addImageStudio(imageWriter models.ImageReaderWriter, imageID, studioID int) (bool, error) { +func addImageStudio(ctx context.Context, imageWriter ImageFinderUpdater, imageID, studioID int) (bool, error) { // don't set if already set - image, err := imageWriter.Find(imageID) + image, err := imageWriter.Find(ctx, imageID) if err != nil { return false, err } @@ -49,15 +53,15 @@ func addImageStudio(imageWriter models.ImageReaderWriter, imageID, studioID int) StudioID: &s, } - if _, err := imageWriter.Update(imagePartial); err != nil { + if _, err := imageWriter.Update(ctx, imagePartial); err != nil { return false, err } return true, nil } -func addGalleryStudio(galleryWriter models.GalleryReaderWriter, galleryID, studioID int) (bool, error) { +func addGalleryStudio(ctx context.Context, galleryWriter GalleryFinderUpdater, galleryID, studioID int) (bool, error) { // don't set if already set - gallery, err := galleryWriter.Find(galleryID) + gallery, err := galleryWriter.Find(ctx, galleryID) if err != nil { return false, err } @@ -73,7 +77,7 @@ func addGalleryStudio(galleryWriter models.GalleryReaderWriter, galleryID, studi StudioID: &s, } - if _, err := galleryWriter.UpdatePartial(galleryPartial); err != nil { + if _, err := galleryWriter.UpdatePartial(ctx, galleryPartial); err != nil { return false, err } return true, nil @@ -98,13 +102,19 @@ func getStudioTagger(p *models.Studio, aliases []string, cache *match.Cache) []t return ret } +type SceneFinderUpdater interface { + scene.Queryer + Find(ctx context.Context, id int) (*models.Scene, error) + Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error) +} + // StudioScenes searches for scenes whose path matches the provided studio name and tags the scene with the studio, if studio is not already set on the scene. -func StudioScenes(p *models.Studio, paths []string, aliases []string, rw models.SceneReaderWriter, cache *match.Cache) error { +func StudioScenes(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw SceneFinderUpdater, cache *match.Cache) error { t := getStudioTagger(p, aliases, cache) for _, tt := range t { - if err := tt.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { - return addSceneStudio(rw, otherID, subjectID) + if err := tt.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return addSceneStudio(ctx, rw, otherID, subjectID) }); err != nil { return err } @@ -113,13 +123,19 @@ func StudioScenes(p *models.Studio, paths []string, aliases []string, rw models. return nil } +type ImageFinderUpdater interface { + image.Queryer + Find(ctx context.Context, id int) (*models.Image, error) + Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error) +} + // StudioImages searches for images whose path matches the provided studio name and tags the image with the studio, if studio is not already set on the image. -func StudioImages(p *models.Studio, paths []string, aliases []string, rw models.ImageReaderWriter, cache *match.Cache) error { +func StudioImages(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw ImageFinderUpdater, cache *match.Cache) error { t := getStudioTagger(p, aliases, cache) for _, tt := range t { - if err := tt.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) { - return addImageStudio(rw, otherID, subjectID) + if err := tt.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return addImageStudio(ctx, rw, otherID, subjectID) }); err != nil { return err } @@ -128,13 +144,19 @@ func StudioImages(p *models.Studio, paths []string, aliases []string, rw models. return nil } +type GalleryFinderUpdater interface { + gallery.Queryer + Find(ctx context.Context, id int) (*models.Gallery, error) + UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error) +} + // StudioGalleries searches for galleries whose path matches the provided studio name and tags the gallery with the studio, if studio is not already set on the gallery. -func StudioGalleries(p *models.Studio, paths []string, aliases []string, rw models.GalleryReaderWriter, cache *match.Cache) error { +func StudioGalleries(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw GalleryFinderUpdater, cache *match.Cache) error { t := getStudioTagger(p, aliases, cache) for _, tt := range t { - if err := tt.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) { - return addGalleryStudio(rw, otherID, subjectID) + if err := tt.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return addGalleryStudio(ctx, rw, otherID, subjectID) }); err != nil { return err } diff --git a/internal/autotag/studio_test.go b/internal/autotag/studio_test.go index 76d7e7db5..861740612 100644 --- a/internal/autotag/studio_test.go +++ b/internal/autotag/studio_test.go @@ -113,7 +113,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) { } // if alias provided, then don't find by name - onNameQuery := mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) + onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) if aliasName == "" { onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() @@ -128,21 +128,21 @@ func testStudioScenes(t *testing.T, tc testStudioCase) { }, } - mockSceneReader.On("Query", scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). + mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() } for i := range matchingPaths { sceneID := i + 1 - mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once() + mockSceneReader.On("Find", testCtx, sceneID).Return(&models.Scene{}, nil).Once() expectedStudioID := models.NullInt64(studioID) - mockSceneReader.On("Update", models.ScenePartial{ + mockSceneReader.On("Update", testCtx, models.ScenePartial{ ID: sceneID, StudioID: &expectedStudioID, }).Return(nil, nil).Once() } - err := StudioScenes(&studio, nil, aliases, mockSceneReader, nil) + err := StudioScenes(testCtx, &studio, nil, aliases, mockSceneReader, nil) assert := assert.New(t) @@ -206,7 +206,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) { } // if alias provided, then don't find by name - onNameQuery := mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) + onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) if aliasName == "" { onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once() } else { @@ -220,21 +220,21 @@ func testStudioImages(t *testing.T, tc testStudioCase) { }, } - mockImageReader.On("Query", image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). + mockImageReader.On("Query", testCtx, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). Return(mocks.ImageQueryResult(images, len(images)), nil).Once() } for i := range matchingPaths { imageID := i + 1 - mockImageReader.On("Find", imageID).Return(&models.Image{}, nil).Once() + mockImageReader.On("Find", testCtx, imageID).Return(&models.Image{}, nil).Once() expectedStudioID := models.NullInt64(studioID) - mockImageReader.On("Update", models.ImagePartial{ + mockImageReader.On("Update", testCtx, models.ImagePartial{ ID: imageID, StudioID: &expectedStudioID, }).Return(nil, nil).Once() } - err := StudioImages(&studio, nil, aliases, mockImageReader, nil) + err := StudioImages(testCtx, &studio, nil, aliases, mockImageReader, nil) assert := assert.New(t) @@ -297,7 +297,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) { } // if alias provided, then don't find by name - onNameQuery := mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter) + onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter) if aliasName == "" { onNameQuery.Return(galleries, len(galleries), nil).Once() } else { @@ -311,20 +311,20 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) { }, } - mockGalleryReader.On("Query", expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() + mockGalleryReader.On("Query", testCtx, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() } for i := range matchingPaths { galleryID := i + 1 - mockGalleryReader.On("Find", galleryID).Return(&models.Gallery{}, nil).Once() + mockGalleryReader.On("Find", testCtx, galleryID).Return(&models.Gallery{}, nil).Once() expectedStudioID := models.NullInt64(studioID) - mockGalleryReader.On("UpdatePartial", models.GalleryPartial{ + mockGalleryReader.On("UpdatePartial", testCtx, models.GalleryPartial{ ID: galleryID, StudioID: &expectedStudioID, }).Return(nil, nil).Once() } - err := StudioGalleries(&studio, nil, aliases, mockGalleryReader, nil) + err := StudioGalleries(testCtx, &studio, nil, aliases, mockGalleryReader, nil) assert := assert.New(t) diff --git a/internal/autotag/tag.go b/internal/autotag/tag.go index f0d080871..4c66573b3 100644 --- a/internal/autotag/tag.go +++ b/internal/autotag/tag.go @@ -1,6 +1,8 @@ package autotag import ( + "context" + "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/match" @@ -8,6 +10,21 @@ import ( "github.com/stashapp/stash/pkg/scene" ) +type SceneQueryTagUpdater interface { + scene.Queryer + scene.TagUpdater +} + +type ImageQueryTagUpdater interface { + image.Queryer + image.TagUpdater +} + +type GalleryQueryTagUpdater interface { + gallery.Queryer + gallery.TagUpdater +} + func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger { ret := []tagger{{ ID: p.ID, @@ -29,12 +46,12 @@ func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger } // TagScenes searches for scenes whose path matches the provided tag name and tags the scene with the tag. -func TagScenes(p *models.Tag, paths []string, aliases []string, rw models.SceneReaderWriter, cache *match.Cache) error { +func TagScenes(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw SceneQueryTagUpdater, cache *match.Cache) error { t := getTagTaggers(p, aliases, cache) for _, tt := range t { - if err := tt.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { - return scene.AddTag(rw, otherID, subjectID) + if err := tt.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return scene.AddTag(ctx, rw, otherID, subjectID) }); err != nil { return err } @@ -43,12 +60,12 @@ func TagScenes(p *models.Tag, paths []string, aliases []string, rw models.SceneR } // TagImages searches for images whose path matches the provided tag name and tags the image with the tag. -func TagImages(p *models.Tag, paths []string, aliases []string, rw models.ImageReaderWriter, cache *match.Cache) error { +func TagImages(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw ImageQueryTagUpdater, cache *match.Cache) error { t := getTagTaggers(p, aliases, cache) for _, tt := range t { - if err := tt.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) { - return image.AddTag(rw, otherID, subjectID) + if err := tt.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return image.AddTag(ctx, rw, otherID, subjectID) }); err != nil { return err } @@ -57,12 +74,12 @@ func TagImages(p *models.Tag, paths []string, aliases []string, rw models.ImageR } // TagGalleries searches for galleries whose path matches the provided tag name and tags the gallery with the tag. -func TagGalleries(p *models.Tag, paths []string, aliases []string, rw models.GalleryReaderWriter, cache *match.Cache) error { +func TagGalleries(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw GalleryQueryTagUpdater, cache *match.Cache) error { t := getTagTaggers(p, aliases, cache) for _, tt := range t { - if err := tt.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) { - return gallery.AddTag(rw, otherID, subjectID) + if err := tt.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) { + return gallery.AddTag(ctx, rw, otherID, subjectID) }); err != nil { return err } diff --git a/internal/autotag/tag_test.go b/internal/autotag/tag_test.go index a1eed1eab..c49f580e3 100644 --- a/internal/autotag/tag_test.go +++ b/internal/autotag/tag_test.go @@ -113,7 +113,7 @@ func testTagScenes(t *testing.T, tc testTagCase) { } // if alias provided, then don't find by name - onNameQuery := mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) + onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) if aliasName == "" { onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() } else { @@ -127,17 +127,17 @@ func testTagScenes(t *testing.T, tc testTagCase) { }, } - mockSceneReader.On("Query", scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). + mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() } for i := range matchingPaths { sceneID := i + 1 - mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once() - mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once() + mockSceneReader.On("GetTagIDs", testCtx, sceneID).Return(nil, nil).Once() + mockSceneReader.On("UpdateTags", testCtx, sceneID, []int{tagID}).Return(nil).Once() } - err := TagScenes(&tag, nil, aliases, mockSceneReader, nil) + err := TagScenes(testCtx, &tag, nil, aliases, mockSceneReader, nil) assert := assert.New(t) @@ -201,7 +201,7 @@ func testTagImages(t *testing.T, tc testTagCase) { } // if alias provided, then don't find by name - onNameQuery := mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) + onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) if aliasName == "" { onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once() } else { @@ -215,17 +215,17 @@ func testTagImages(t *testing.T, tc testTagCase) { }, } - mockImageReader.On("Query", image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). + mockImageReader.On("Query", testCtx, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). Return(mocks.ImageQueryResult(images, len(images)), nil).Once() } for i := range matchingPaths { imageID := i + 1 - mockImageReader.On("GetTagIDs", imageID).Return(nil, nil).Once() - mockImageReader.On("UpdateTags", imageID, []int{tagID}).Return(nil).Once() + mockImageReader.On("GetTagIDs", testCtx, imageID).Return(nil, nil).Once() + mockImageReader.On("UpdateTags", testCtx, imageID, []int{tagID}).Return(nil).Once() } - err := TagImages(&tag, nil, aliases, mockImageReader, nil) + err := TagImages(testCtx, &tag, nil, aliases, mockImageReader, nil) assert := assert.New(t) @@ -289,7 +289,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) { } // if alias provided, then don't find by name - onNameQuery := mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter) + onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter) if aliasName == "" { onNameQuery.Return(galleries, len(galleries), nil).Once() } else { @@ -303,16 +303,16 @@ func testTagGalleries(t *testing.T, tc testTagCase) { }, } - mockGalleryReader.On("Query", expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() + mockGalleryReader.On("Query", testCtx, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() } for i := range matchingPaths { galleryID := i + 1 - mockGalleryReader.On("GetTagIDs", galleryID).Return(nil, nil).Once() - mockGalleryReader.On("UpdateTags", galleryID, []int{tagID}).Return(nil).Once() + mockGalleryReader.On("GetTagIDs", testCtx, galleryID).Return(nil, nil).Once() + mockGalleryReader.On("UpdateTags", testCtx, galleryID, []int{tagID}).Return(nil).Once() } - err := TagGalleries(&tag, nil, aliases, mockGalleryReader, nil) + err := TagGalleries(testCtx, &tag, nil, aliases, mockGalleryReader, nil) assert := assert.New(t) diff --git a/internal/autotag/tagger.go b/internal/autotag/tagger.go index 4ea1fbc01..dae5cdc07 100644 --- a/internal/autotag/tagger.go +++ b/internal/autotag/tagger.go @@ -14,11 +14,14 @@ package autotag import ( + "context" "fmt" + "github.com/stashapp/stash/pkg/gallery" + "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/match" - "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" ) type tagger struct { @@ -41,8 +44,8 @@ func (t *tagger) addLog(otherType, otherName string) { logger.Infof("Added %s '%s' to %s '%s'", otherType, otherName, t.Type, t.Name) } -func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc addLinkFunc) error { - others, err := match.PathToPerformers(t.Path, performerReader, t.cache, t.trimExt) +func (t *tagger) tagPerformers(ctx context.Context, performerReader match.PerformerAutoTagQueryer, addFunc addLinkFunc) error { + others, err := match.PathToPerformers(ctx, t.Path, performerReader, t.cache, t.trimExt) if err != nil { return err } @@ -62,8 +65,8 @@ func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc a return nil } -func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFunc) error { - studio, err := match.PathToStudio(t.Path, studioReader, t.cache, t.trimExt) +func (t *tagger) tagStudios(ctx context.Context, studioReader match.StudioAutoTagQueryer, addFunc addLinkFunc) error { + studio, err := match.PathToStudio(ctx, t.Path, studioReader, t.cache, t.trimExt) if err != nil { return err } @@ -83,8 +86,8 @@ func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFun return nil } -func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error { - others, err := match.PathToTags(t.Path, tagReader, t.cache, t.trimExt) +func (t *tagger) tagTags(ctx context.Context, tagReader match.TagAutoTagQueryer, addFunc addLinkFunc) error { + others, err := match.PathToTags(ctx, t.Path, tagReader, t.cache, t.trimExt) if err != nil { return err } @@ -104,8 +107,8 @@ func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error return nil } -func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFunc addLinkFunc) error { - others, err := match.PathToScenes(t.Name, paths, sceneReader) +func (t *tagger) tagScenes(ctx context.Context, paths []string, sceneReader scene.Queryer, addFunc addLinkFunc) error { + others, err := match.PathToScenes(ctx, t.Name, paths, sceneReader) if err != nil { return err } @@ -125,8 +128,8 @@ func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFu return nil } -func (t *tagger) tagImages(paths []string, imageReader models.ImageReader, addFunc addLinkFunc) error { - others, err := match.PathToImages(t.Name, paths, imageReader) +func (t *tagger) tagImages(ctx context.Context, paths []string, imageReader image.Queryer, addFunc addLinkFunc) error { + others, err := match.PathToImages(ctx, t.Name, paths, imageReader) if err != nil { return err } @@ -146,8 +149,8 @@ func (t *tagger) tagImages(paths []string, imageReader models.ImageReader, addFu return nil } -func (t *tagger) tagGalleries(paths []string, galleryReader models.GalleryReader, addFunc addLinkFunc) error { - others, err := match.PathToGalleries(t.Name, paths, galleryReader) +func (t *tagger) tagGalleries(ctx context.Context, paths []string, galleryReader gallery.Queryer, addFunc addLinkFunc) error { + others, err := match.PathToGalleries(ctx, t.Name, paths, galleryReader) if err != nil { return err } diff --git a/internal/dlna/cds.go b/internal/dlna/cds.go index 4544b8759..6faa312b8 100644 --- a/internal/dlna/cds.go +++ b/internal/dlna/cds.go @@ -41,6 +41,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/sliceutil/stringslice" + "github.com/stashapp/stash/pkg/txn" ) var pageSize = 100 @@ -56,7 +57,6 @@ type browse struct { type contentDirectoryService struct { *Server upnp.Eventing - txnManager models.TransactionManager } func formatDurationSexagesimal(d time.Duration) string { @@ -352,8 +352,8 @@ func (me *contentDirectoryService) handleBrowseMetadata(obj object, host string) } else { var scene *models.Scene - if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { - scene, err = r.Scene().Find(sceneID) + if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { + scene, err = me.repository.SceneFinder.Find(ctx, sceneID) if err != nil { return err } @@ -431,14 +431,14 @@ func getRootObjects() []interface{} { func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} { var objs []interface{} - if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { sort := "title" findFilter := &models.FindFilterType{ PerPage: &pageSize, Sort: &sort, } - scenes, total, err := scene.QueryWithCount(r.Scene(), sceneFilter, findFilter) + scenes, total, err := scene.QueryWithCount(ctx, me.repository.SceneFinder, sceneFilter, findFilter) if err != nil { return err } @@ -449,7 +449,7 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType parentID: parentID, } - objs, err = pager.getPages(r, total) + objs, err = pager.getPages(ctx, me.repository.SceneFinder, total) if err != nil { return err } @@ -470,14 +470,14 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} { var objs []interface{} - if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { pager := scenePager{ sceneFilter: sceneFilter, parentID: parentID, } var err error - objs, err = pager.getPageVideos(r, page, host) + objs, err = pager.getPageVideos(ctx, me.repository.SceneFinder, page, host) if err != nil { return err } @@ -511,8 +511,8 @@ func (me *contentDirectoryService) getAllScenes(host string) []interface{} { func (me *contentDirectoryService) getStudios() []interface{} { var objs []interface{} - if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { - studios, err := r.Studio().All() + if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { + studios, err := me.repository.StudioFinder.All(ctx) if err != nil { return err } @@ -550,8 +550,8 @@ func (me *contentDirectoryService) getStudioScenes(paths []string, host string) func (me *contentDirectoryService) getTags() []interface{} { var objs []interface{} - if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { - tags, err := r.Tag().All() + if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { + tags, err := me.repository.TagFinder.All(ctx) if err != nil { return err } @@ -589,8 +589,8 @@ func (me *contentDirectoryService) getTagScenes(paths []string, host string) []i func (me *contentDirectoryService) getPerformers() []interface{} { var objs []interface{} - if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { - performers, err := r.Performer().All() + if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { + performers, err := me.repository.PerformerFinder.All(ctx) if err != nil { return err } @@ -628,8 +628,8 @@ func (me *contentDirectoryService) getPerformerScenes(paths []string, host strin func (me *contentDirectoryService) getMovies() []interface{} { var objs []interface{} - if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { - movies, err := r.Movie().All() + if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { + movies, err := me.repository.MovieFinder.All(ctx) if err != nil { return err } diff --git a/internal/dlna/cds_test.go b/internal/dlna/cds_test.go index b52ca3b88..592f8f818 100644 --- a/internal/dlna/cds_test.go +++ b/internal/dlna/cds_test.go @@ -31,7 +31,6 @@ import ( "strings" "testing" - "github.com/stashapp/stash/pkg/models/mocks" "github.com/stretchr/testify/assert" ) @@ -59,8 +58,7 @@ func TestRootParentObjectID(t *testing.T) { func testHandleBrowse(argsXML string) (map[string]string, error) { cds := contentDirectoryService{ - Server: &Server{}, - txnManager: mocks.NewTransactionManager(), + Server: &Server{}, } r := &http.Request{} diff --git a/internal/dlna/dms.go b/internal/dlna/dms.go index a1ea8ceac..d5e7cc84e 100644 --- a/internal/dlna/dms.go +++ b/internal/dlna/dms.go @@ -48,8 +48,31 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" + "github.com/stashapp/stash/pkg/txn" ) +type SceneFinder interface { + scene.Queryer + scene.IDFinder +} + +type StudioFinder interface { + All(ctx context.Context) ([]*models.Studio, error) +} + +type TagFinder interface { + All(ctx context.Context) ([]*models.Tag, error) +} + +type PerformerFinder interface { + All(ctx context.Context) ([]*models.Performer, error) +} + +type MovieFinder interface { + All(ctx context.Context) ([]*models.Movie, error) +} + const ( serverField = "Linux/3.4 DLNADOC/1.50 UPnP/1.0 DMS/1.0" rootDeviceType = "urn:schemas-upnp-org:device:MediaServer:1" @@ -249,7 +272,8 @@ type Server struct { // Time interval between SSPD announces NotifyInterval time.Duration - txnManager models.TransactionManager + txnManager txn.Manager + repository Repository sceneServer sceneServer ipWhitelistManager *ipWhitelistManager } @@ -415,12 +439,12 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) { } var scene *models.Scene - err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error { + err := txn.WithTxn(r.Context(), me.txnManager, func(ctx context.Context) error { idInt, err := strconv.Atoi(sceneId) if err != nil { return nil } - scene, _ = r.Scene().Find(idInt) + scene, _ = me.repository.SceneFinder.Find(ctx, idInt) return nil }) if err != nil { @@ -555,12 +579,12 @@ func (me *Server) initMux(mux *http.ServeMux) { mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) { sceneId := r.URL.Query().Get("scene") var scene *models.Scene - err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error { + err := txn.WithTxn(r.Context(), me.txnManager, func(ctx context.Context) error { sceneIdInt, err := strconv.Atoi(sceneId) if err != nil { return nil } - scene, _ = r.Scene().Find(sceneIdInt) + scene, _ = me.repository.SceneFinder.Find(ctx, sceneIdInt) return nil }) if err != nil { @@ -595,8 +619,7 @@ func (me *Server) initMux(mux *http.ServeMux) { func (me *Server) initServices() { me.services = map[string]UPnPService{ "ContentDirectory": &contentDirectoryService{ - Server: me, - txnManager: me.txnManager, + Server: me, }, "ConnectionManager": &connectionManagerService{ Server: me, diff --git a/internal/dlna/paging.go b/internal/dlna/paging.go index 6f2afda8e..e5f65f96a 100644 --- a/internal/dlna/paging.go +++ b/internal/dlna/paging.go @@ -1,6 +1,7 @@ package dlna import ( + "context" "fmt" "math" "strconv" @@ -18,7 +19,7 @@ func (p *scenePager) getPageID(page int) string { return p.parentID + "/page/" + strconv.Itoa(page) } -func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface{}, error) { +func (p *scenePager) getPages(ctx context.Context, r scene.Queryer, total int) ([]interface{}, error) { var objs []interface{} // get the first scene of each page to set an appropriate title @@ -37,7 +38,7 @@ func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface if pages <= 10 || (page-1)%(pages/10) == 0 { thisPage := ((page - 1) * pageSize) + 1 findFilter.Page = &thisPage - scenes, err := scene.Query(r.Scene(), p.sceneFilter, findFilter) + scenes, err := scene.Query(ctx, r, p.sceneFilter, findFilter) if err != nil { return nil, err } @@ -58,7 +59,7 @@ func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface return objs, nil } -func (p *scenePager) getPageVideos(r models.ReaderRepository, page int, host string) ([]interface{}, error) { +func (p *scenePager) getPageVideos(ctx context.Context, r SceneFinder, page int, host string) ([]interface{}, error) { var objs []interface{} sort := "title" @@ -68,7 +69,7 @@ func (p *scenePager) getPageVideos(r models.ReaderRepository, page int, host str Sort: &sort, } - scenes, err := scene.Query(r.Scene(), p.sceneFilter, findFilter) + scenes, err := scene.Query(ctx, r, p.sceneFilter, findFilter) if err != nil { return nil, err } diff --git a/internal/dlna/service.go b/internal/dlna/service.go index 961e3f230..261a2ab62 100644 --- a/internal/dlna/service.go +++ b/internal/dlna/service.go @@ -10,8 +10,17 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" ) +type Repository struct { + SceneFinder SceneFinder + StudioFinder StudioFinder + TagFinder TagFinder + PerformerFinder PerformerFinder + MovieFinder MovieFinder +} + type Status struct { Running bool `json:"running"` // If not currently running, time until it will be started. If running, time until it will be stopped @@ -48,7 +57,8 @@ type Config interface { } type Service struct { - txnManager models.TransactionManager + txnManager txn.Manager + repository Repository config Config sceneServer sceneServer ipWhitelistMgr *ipWhitelistManager @@ -121,6 +131,7 @@ func (s *Service) init() error { s.server = &Server{ txnManager: s.txnManager, sceneServer: s.sceneServer, + repository: s.repository, ipWhitelistManager: s.ipWhitelistMgr, Interfaces: interfaces, HTTPConn: func() net.Listener { @@ -181,9 +192,10 @@ func (s *Service) init() error { // } // NewService initialises and returns a new DLNA service. -func NewService(txnManager models.TransactionManager, cfg Config, sceneServer sceneServer) *Service { +func NewService(txnManager txn.Manager, repo Repository, cfg Config, sceneServer sceneServer) *Service { ret := &Service{ txnManager: txnManager, + repository: repo, sceneServer: sceneServer, config: cfg, ipWhitelistMgr: &ipWhitelistManager{ diff --git a/internal/identify/identify.go b/internal/identify/identify.go index 4d0d0afa2..0c34cce96 100644 --- a/internal/identify/identify.go +++ b/internal/identify/identify.go @@ -9,6 +9,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scraper" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) @@ -28,13 +29,18 @@ type ScraperSource struct { } type SceneIdentifier struct { + SceneReaderUpdater SceneReaderUpdater + StudioCreator StudioCreator + PerformerCreator PerformerCreator + TagCreator TagCreator + DefaultOptions *MetadataOptions Sources []ScraperSource ScreenshotSetter scene.ScreenshotSetter SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor } -func (t *SceneIdentifier) Identify(ctx context.Context, txnManager models.TransactionManager, scene *models.Scene) error { +func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, scene *models.Scene) error { result, err := t.scrapeScene(ctx, scene) if err != nil { return err @@ -80,7 +86,7 @@ func (t *SceneIdentifier) scrapeScene(ctx context.Context, scene *models.Scene) return nil, nil } -func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, result *scrapeResult, repo models.Repository) (*scene.UpdateSet, error) { +func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, result *scrapeResult) (*scene.UpdateSet, error) { ret := &scene.UpdateSet{ ID: s.ID, } @@ -106,15 +112,18 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, scraped := result.result rel := sceneRelationships{ - repo: repo, - scene: s, - result: result, - fieldOptions: fieldOptions, + sceneReader: t.SceneReaderUpdater, + studioCreator: t.StudioCreator, + performerCreator: t.PerformerCreator, + tagCreator: t.TagCreator, + scene: s, + result: result, + fieldOptions: fieldOptions, } ret.Partial = getScenePartial(s, scraped, fieldOptions, setOrganized) - studioID, err := rel.studio() + studioID, err := rel.studio(ctx) if err != nil { return nil, fmt.Errorf("error getting studio: %w", err) } @@ -134,17 +143,17 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, } } - ret.PerformerIDs, err = rel.performers(ignoreMale) + ret.PerformerIDs, err = rel.performers(ctx, ignoreMale) if err != nil { return nil, err } - ret.TagIDs, err = rel.tags() + ret.TagIDs, err = rel.tags(ctx) if err != nil { return nil, err } - ret.StashIDs, err = rel.stashIDs() + ret.StashIDs, err = rel.stashIDs(ctx) if err != nil { return nil, err } @@ -167,11 +176,11 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, return ret, nil } -func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.TransactionManager, s *models.Scene, result *scrapeResult) error { +func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, result *scrapeResult) error { var updater *scene.UpdateSet - if err := txnManager.WithTxn(ctx, func(repo models.Repository) error { + if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error { var err error - updater, err = t.getSceneUpdater(ctx, s, result, repo) + updater, err = t.getSceneUpdater(ctx, s, result) if err != nil { return err } @@ -182,7 +191,7 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.Tra return nil } - _, err = updater.Update(repo.Scene(), t.ScreenshotSetter) + _, err = updater.Update(ctx, t.SceneReaderUpdater, t.ScreenshotSetter) if err != nil { return fmt.Errorf("error updating scene: %w", err) } diff --git a/internal/identify/identify_test.go b/internal/identify/identify_test.go index c5588c78a..88be638df 100644 --- a/internal/identify/identify_test.go +++ b/internal/identify/identify_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/mock" ) +var testCtx = context.Background() + type mockSceneScraper struct { errIDs []int results map[int]*scraper.ScrapedScene @@ -70,11 +72,12 @@ func TestSceneIdentifier_Identify(t *testing.T) { }, } - repo := mocks.NewTransactionManager() - repo.Scene().(*mocks.SceneReaderWriter).On("Update", mock.MatchedBy(func(partial models.ScenePartial) bool { + mockSceneReaderWriter := &mocks.SceneReaderWriter{} + + mockSceneReaderWriter.On("Update", testCtx, mock.MatchedBy(func(partial models.ScenePartial) bool { return partial.ID != errUpdateID })).Return(nil, nil) - repo.Scene().(*mocks.SceneReaderWriter).On("Update", mock.MatchedBy(func(partial models.ScenePartial) bool { + mockSceneReaderWriter.On("Update", testCtx, mock.MatchedBy(func(partial models.ScenePartial) bool { return partial.ID == errUpdateID })).Return(nil, errors.New("update error")) @@ -116,6 +119,7 @@ func TestSceneIdentifier_Identify(t *testing.T) { } identifier := SceneIdentifier{ + SceneReaderUpdater: mockSceneReaderWriter, DefaultOptions: defaultOptions, Sources: sources, SceneUpdatePostHookExecutor: mockHookExecutor{}, @@ -126,7 +130,7 @@ func TestSceneIdentifier_Identify(t *testing.T) { scene := &models.Scene{ ID: tt.sceneID, } - if err := identifier.Identify(context.TODO(), repo, scene); (err != nil) != tt.wantErr { + if err := identifier.Identify(testCtx, &mocks.TxnManager{}, scene); (err != nil) != tt.wantErr { t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -134,7 +138,9 @@ func TestSceneIdentifier_Identify(t *testing.T) { } func TestSceneIdentifier_modifyScene(t *testing.T) { - repo := mocks.NewTransactionManager() + repo := models.Repository{ + TxnManager: &mocks.TxnManager{}, + } tr := &SceneIdentifier{} type args struct { @@ -159,7 +165,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tr.modifyScene(context.TODO(), repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr { + if err := tr.modifyScene(testCtx, repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr { t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/internal/identify/performer.go b/internal/identify/performer.go index 495c3eb8e..435524cc4 100644 --- a/internal/identify/performer.go +++ b/internal/identify/performer.go @@ -1,6 +1,7 @@ package identify import ( + "context" "database/sql" "fmt" "strconv" @@ -10,7 +11,12 @@ import ( "github.com/stashapp/stash/pkg/models" ) -func getPerformerID(endpoint string, r models.Repository, p *models.ScrapedPerformer, createMissing bool) (*int, error) { +type PerformerCreator interface { + Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error) + UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error +} + +func getPerformerID(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer, createMissing bool) (*int, error) { if p.StoredID != nil { // existing performer, just add it performerID, err := strconv.Atoi(*p.StoredID) @@ -20,20 +26,20 @@ func getPerformerID(endpoint string, r models.Repository, p *models.ScrapedPerfo return &performerID, nil } else if createMissing && p.Name != nil { // name is mandatory - return createMissingPerformer(endpoint, r, p) + return createMissingPerformer(ctx, endpoint, w, p) } return nil, nil } -func createMissingPerformer(endpoint string, r models.Repository, p *models.ScrapedPerformer) (*int, error) { - created, err := r.Performer().Create(scrapedToPerformerInput(p)) +func createMissingPerformer(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer) (*int, error) { + created, err := w.Create(ctx, scrapedToPerformerInput(p)) if err != nil { return nil, fmt.Errorf("error creating performer: %w", err) } if endpoint != "" && p.RemoteSiteID != nil { - if err := r.Performer().UpdateStashIDs(created.ID, []models.StashID{ + if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{ { Endpoint: endpoint, StashID: *p.RemoteSiteID, diff --git a/internal/identify/performer_test.go b/internal/identify/performer_test.go index ebe8e49fe..eeed8a1e7 100644 --- a/internal/identify/performer_test.go +++ b/internal/identify/performer_test.go @@ -23,8 +23,8 @@ func Test_getPerformerID(t *testing.T) { validStoredID := 1 name := "name" - repo := mocks.NewTransactionManager() - repo.PerformerMock().On("Create", mock.Anything).Return(&models.Performer{ + mockPerformerReaderWriter := mocks.PerformerReaderWriter{} + mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Performer{ ID: validStoredID, }, nil) @@ -110,7 +110,7 @@ func Test_getPerformerID(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := getPerformerID(tt.args.endpoint, repo, tt.args.p, tt.args.createMissing) + got, err := getPerformerID(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p, tt.args.createMissing) if (err != nil) != tt.wantErr { t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr) return @@ -131,23 +131,23 @@ func Test_createMissingPerformer(t *testing.T) { invalidName := "invalidName" performerID := 1 - repo := mocks.NewTransactionManager() - repo.PerformerMock().On("Create", mock.MatchedBy(func(p models.Performer) bool { + mockPerformerReaderWriter := mocks.PerformerReaderWriter{} + mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool { return p.Name.String == validName })).Return(&models.Performer{ ID: performerID, }, nil) - repo.PerformerMock().On("Create", mock.MatchedBy(func(p models.Performer) bool { + mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool { return p.Name.String == invalidName })).Return(nil, errors.New("error creating performer")) - repo.PerformerMock().On("UpdateStashIDs", performerID, []models.StashID{ + mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{ { Endpoint: invalidEndpoint, StashID: remoteSiteID, }, }).Return(errors.New("error updating stash ids")) - repo.PerformerMock().On("UpdateStashIDs", performerID, []models.StashID{ + mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{ { Endpoint: validEndpoint, StashID: remoteSiteID, @@ -213,7 +213,7 @@ func Test_createMissingPerformer(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := createMissingPerformer(tt.args.endpoint, repo, tt.args.p) + got, err := createMissingPerformer(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p) if (err != nil) != tt.wantErr { t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/internal/identify/scene.go b/internal/identify/scene.go index 3e6fd4a38..4e7f4d3cc 100644 --- a/internal/identify/scene.go +++ b/internal/identify/scene.go @@ -9,19 +9,35 @@ import ( "time" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/sliceutil" "github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stashapp/stash/pkg/utils" ) -type sceneRelationships struct { - repo models.Repository - scene *models.Scene - result *scrapeResult - fieldOptions map[string]*FieldOptions +type SceneReaderUpdater interface { + GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error) + GetTagIDs(ctx context.Context, sceneID int) ([]int, error) + GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error) + GetCover(ctx context.Context, sceneID int) ([]byte, error) + scene.Updater } -func (g sceneRelationships) studio() (*int64, error) { +type TagCreator interface { + Create(ctx context.Context, newTag models.Tag) (*models.Tag, error) +} + +type sceneRelationships struct { + sceneReader SceneReaderUpdater + studioCreator StudioCreator + performerCreator PerformerCreator + tagCreator TagCreator + scene *models.Scene + result *scrapeResult + fieldOptions map[string]*FieldOptions +} + +func (g sceneRelationships) studio(ctx context.Context) (*int64, error) { existingID := g.scene.StudioID fieldStrategy := g.fieldOptions["studio"] createMissing := fieldStrategy != nil && utils.IsTrue(fieldStrategy.CreateMissing) @@ -45,13 +61,13 @@ func (g sceneRelationships) studio() (*int64, error) { return &studioID, nil } } else if createMissing { - return createMissingStudio(endpoint, g.repo, scraped) + return createMissingStudio(ctx, endpoint, g.studioCreator, scraped) } return nil, nil } -func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) { +func (g sceneRelationships) performers(ctx context.Context, ignoreMale bool) ([]int, error) { fieldStrategy := g.fieldOptions["performers"] scraped := g.result.result.Performers @@ -66,11 +82,10 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) { strategy = fieldStrategy.Strategy } - repo := g.repo endpoint := g.result.source.RemoteSite var performerIDs []int - originalPerformerIDs, err := repo.Scene().GetPerformerIDs(g.scene.ID) + originalPerformerIDs, err := g.sceneReader.GetPerformerIDs(ctx, g.scene.ID) if err != nil { return nil, fmt.Errorf("error getting scene performers: %w", err) } @@ -85,7 +100,7 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) { continue } - performerID, err := getPerformerID(endpoint, repo, p, createMissing) + performerID, err := getPerformerID(ctx, endpoint, g.performerCreator, p, createMissing) if err != nil { return nil, err } @@ -103,11 +118,10 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) { return performerIDs, nil } -func (g sceneRelationships) tags() ([]int, error) { +func (g sceneRelationships) tags(ctx context.Context) ([]int, error) { fieldStrategy := g.fieldOptions["tags"] scraped := g.result.result.Tags target := g.scene - r := g.repo // just check if ignored if len(scraped) == 0 || !shouldSetSingleValueField(fieldStrategy, false) { @@ -121,7 +135,7 @@ func (g sceneRelationships) tags() ([]int, error) { } var tagIDs []int - originalTagIDs, err := r.Scene().GetTagIDs(target.ID) + originalTagIDs, err := g.sceneReader.GetTagIDs(ctx, target.ID) if err != nil { return nil, fmt.Errorf("error getting scene tags: %w", err) } @@ -142,7 +156,7 @@ func (g sceneRelationships) tags() ([]int, error) { tagIDs = intslice.IntAppendUnique(tagIDs, int(tagID)) } else if createMissing { now := time.Now() - created, err := r.Tag().Create(models.Tag{ + created, err := g.tagCreator.Create(ctx, models.Tag{ Name: t.Name, CreatedAt: models.SQLiteTimestamp{Timestamp: now}, UpdatedAt: models.SQLiteTimestamp{Timestamp: now}, @@ -163,11 +177,10 @@ func (g sceneRelationships) tags() ([]int, error) { return tagIDs, nil } -func (g sceneRelationships) stashIDs() ([]models.StashID, error) { +func (g sceneRelationships) stashIDs(ctx context.Context) ([]models.StashID, error) { remoteSiteID := g.result.result.RemoteSiteID fieldStrategy := g.fieldOptions["stash_ids"] target := g.scene - r := g.repo endpoint := g.result.source.RemoteSite @@ -183,7 +196,7 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) { var originalStashIDs []models.StashID var stashIDs []models.StashID - stashIDPtrs, err := r.Scene().GetStashIDs(target.ID) + stashIDPtrs, err := g.sceneReader.GetStashIDs(ctx, target.ID) if err != nil { return nil, fmt.Errorf("error getting scene tag: %w", err) } @@ -227,14 +240,13 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) { func (g sceneRelationships) cover(ctx context.Context) ([]byte, error) { scraped := g.result.result.Image - r := g.repo if scraped == nil { return nil, nil } // always overwrite if present - existingCover, err := r.Scene().GetCover(g.scene.ID) + existingCover, err := g.sceneReader.GetCover(ctx, g.scene.ID) if err != nil { return nil, fmt.Errorf("error getting scene cover: %w", err) } diff --git a/internal/identify/scene_test.go b/internal/identify/scene_test.go index 2487e6808..bdef0c864 100644 --- a/internal/identify/scene_test.go +++ b/internal/identify/scene_test.go @@ -24,14 +24,14 @@ func Test_sceneRelationships_studio(t *testing.T) { Strategy: FieldStrategyMerge, } - repo := mocks.NewTransactionManager() - repo.StudioMock().On("Create", mock.Anything).Return(&models.Studio{ + mockStudioReaderWriter := &mocks.StudioReaderWriter{} + mockStudioReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Studio{ ID: int(validStoredIDInt), }, nil) tr := sceneRelationships{ - repo: repo, - fieldOptions: make(map[string]*FieldOptions), + studioCreator: mockStudioReaderWriter, + fieldOptions: make(map[string]*FieldOptions), } tests := []struct { @@ -124,7 +124,7 @@ func Test_sceneRelationships_studio(t *testing.T) { }, } - got, err := tr.studio() + got, err := tr.studio(testCtx) if (err != nil) != tt.wantErr { t.Errorf("sceneRelationships.studio() error = %v, wantErr %v", err, tt.wantErr) return @@ -156,13 +156,13 @@ func Test_sceneRelationships_performers(t *testing.T) { Strategy: FieldStrategyMerge, } - repo := mocks.NewTransactionManager() - repo.SceneMock().On("GetPerformerIDs", sceneID).Return(nil, nil) - repo.SceneMock().On("GetPerformerIDs", sceneWithPerformerID).Return([]int{existingPerformerID}, nil) - repo.SceneMock().On("GetPerformerIDs", errSceneID).Return(nil, errors.New("error getting IDs")) + mockSceneReaderWriter := &mocks.SceneReaderWriter{} + mockSceneReaderWriter.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil) + mockSceneReaderWriter.On("GetPerformerIDs", testCtx, sceneWithPerformerID).Return([]int{existingPerformerID}, nil) + mockSceneReaderWriter.On("GetPerformerIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs")) tr := sceneRelationships{ - repo: repo, + sceneReader: mockSceneReaderWriter, fieldOptions: make(map[string]*FieldOptions), } @@ -316,7 +316,7 @@ func Test_sceneRelationships_performers(t *testing.T) { }, } - got, err := tr.performers(tt.ignoreMale) + got, err := tr.performers(testCtx, tt.ignoreMale) if (err != nil) != tt.wantErr { t.Errorf("sceneRelationships.performers() error = %v, wantErr %v", err, tt.wantErr) return @@ -347,22 +347,24 @@ func Test_sceneRelationships_tags(t *testing.T) { Strategy: FieldStrategyMerge, } - repo := mocks.NewTransactionManager() - repo.SceneMock().On("GetTagIDs", sceneID).Return(nil, nil) - repo.SceneMock().On("GetTagIDs", sceneWithTagID).Return([]int{existingID}, nil) - repo.SceneMock().On("GetTagIDs", errSceneID).Return(nil, errors.New("error getting IDs")) + mockSceneReaderWriter := &mocks.SceneReaderWriter{} + mockTagReaderWriter := &mocks.TagReaderWriter{} + mockSceneReaderWriter.On("GetTagIDs", testCtx, sceneID).Return(nil, nil) + mockSceneReaderWriter.On("GetTagIDs", testCtx, sceneWithTagID).Return([]int{existingID}, nil) + mockSceneReaderWriter.On("GetTagIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs")) - repo.TagMock().On("Create", mock.MatchedBy(func(p models.Tag) bool { + mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Tag) bool { return p.Name == validName })).Return(&models.Tag{ ID: validStoredIDInt, }, nil) - repo.TagMock().On("Create", mock.MatchedBy(func(p models.Tag) bool { + mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Tag) bool { return p.Name == invalidName })).Return(nil, errors.New("error creating tag")) tr := sceneRelationships{ - repo: repo, + sceneReader: mockSceneReaderWriter, + tagCreator: mockTagReaderWriter, fieldOptions: make(map[string]*FieldOptions), } @@ -505,7 +507,7 @@ func Test_sceneRelationships_tags(t *testing.T) { }, } - got, err := tr.tags() + got, err := tr.tags(testCtx) if (err != nil) != tt.wantErr { t.Errorf("sceneRelationships.tags() error = %v, wantErr %v", err, tt.wantErr) return @@ -534,18 +536,18 @@ func Test_sceneRelationships_stashIDs(t *testing.T) { Strategy: FieldStrategyMerge, } - repo := mocks.NewTransactionManager() - repo.SceneMock().On("GetStashIDs", sceneID).Return(nil, nil) - repo.SceneMock().On("GetStashIDs", sceneWithStashID).Return([]*models.StashID{ + mockSceneReaderWriter := &mocks.SceneReaderWriter{} + mockSceneReaderWriter.On("GetStashIDs", testCtx, sceneID).Return(nil, nil) + mockSceneReaderWriter.On("GetStashIDs", testCtx, sceneWithStashID).Return([]*models.StashID{ { StashID: remoteSiteID, Endpoint: existingEndpoint, }, }, nil) - repo.SceneMock().On("GetStashIDs", errSceneID).Return(nil, errors.New("error getting IDs")) + mockSceneReaderWriter.On("GetStashIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs")) tr := sceneRelationships{ - repo: repo, + sceneReader: mockSceneReaderWriter, fieldOptions: make(map[string]*FieldOptions), } @@ -680,7 +682,7 @@ func Test_sceneRelationships_stashIDs(t *testing.T) { }, } - got, err := tr.stashIDs() + got, err := tr.stashIDs(testCtx) if (err != nil) != tt.wantErr { t.Errorf("sceneRelationships.stashIDs() error = %v, wantErr %v", err, tt.wantErr) return @@ -707,12 +709,12 @@ func Test_sceneRelationships_cover(t *testing.T) { newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData) invalidData := newDataEncoded + "!!!" - repo := mocks.NewTransactionManager() - repo.SceneMock().On("GetCover", sceneID).Return(existingData, nil) - repo.SceneMock().On("GetCover", errSceneID).Return(nil, errors.New("error getting cover")) + mockSceneReaderWriter := &mocks.SceneReaderWriter{} + mockSceneReaderWriter.On("GetCover", testCtx, sceneID).Return(existingData, nil) + mockSceneReaderWriter.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover")) tr := sceneRelationships{ - repo: repo, + sceneReader: mockSceneReaderWriter, fieldOptions: make(map[string]*FieldOptions), } diff --git a/internal/identify/studio.go b/internal/identify/studio.go index 86cb6b737..923a0322a 100644 --- a/internal/identify/studio.go +++ b/internal/identify/studio.go @@ -1,6 +1,7 @@ package identify import ( + "context" "database/sql" "fmt" "time" @@ -9,14 +10,19 @@ import ( "github.com/stashapp/stash/pkg/models" ) -func createMissingStudio(endpoint string, repo models.Repository, studio *models.ScrapedStudio) (*int64, error) { - created, err := repo.Studio().Create(scrapedToStudioInput(studio)) +type StudioCreator interface { + Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error) + UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error +} + +func createMissingStudio(ctx context.Context, endpoint string, w StudioCreator, studio *models.ScrapedStudio) (*int64, error) { + created, err := w.Create(ctx, scrapedToStudioInput(studio)) if err != nil { return nil, fmt.Errorf("error creating studio: %w", err) } if endpoint != "" && studio.RemoteSiteID != nil { - if err := repo.Studio().UpdateStashIDs(created.ID, []models.StashID{ + if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{ { Endpoint: endpoint, StashID: *studio.RemoteSiteID, diff --git a/internal/identify/studio_test.go b/internal/identify/studio_test.go index 2ba0b840e..1900259ce 100644 --- a/internal/identify/studio_test.go +++ b/internal/identify/studio_test.go @@ -20,23 +20,24 @@ func Test_createMissingStudio(t *testing.T) { createdID := 1 createdID64 := int64(createdID) - repo := mocks.NewTransactionManager() - repo.StudioMock().On("Create", mock.MatchedBy(func(p models.Studio) bool { + repo := mocks.NewTxnRepository() + mockStudioReaderWriter := repo.Studio.(*mocks.StudioReaderWriter) + mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Studio) bool { return p.Name.String == validName })).Return(&models.Studio{ ID: createdID, }, nil) - repo.StudioMock().On("Create", mock.MatchedBy(func(p models.Studio) bool { + mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Studio) bool { return p.Name.String == invalidName })).Return(nil, errors.New("error creating performer")) - repo.StudioMock().On("UpdateStashIDs", createdID, []models.StashID{ + mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{ { Endpoint: invalidEndpoint, StashID: remoteSiteID, }, }).Return(errors.New("error updating stash ids")) - repo.StudioMock().On("UpdateStashIDs", createdID, []models.StashID{ + mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{ { Endpoint: validEndpoint, StashID: remoteSiteID, @@ -102,7 +103,7 @@ func Test_createMissingStudio(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := createMissingStudio(tt.args.endpoint, repo, tt.args.studio) + got, err := createMissingStudio(testCtx, tt.args.endpoint, mockStudioReaderWriter, tt.args.studio) if (err != nil) != tt.wantErr { t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/internal/manager/checksum.go b/internal/manager/checksum.go index 469f2c47f..53f368913 100644 --- a/internal/manager/checksum.go +++ b/internal/manager/checksum.go @@ -7,16 +7,21 @@ import ( "github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" ) -func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManager) { +type SceneCounter interface { + Count(ctx context.Context) (int, error) +} + +func setInitialMD5Config(ctx context.Context, txnManager txn.Manager, counter SceneCounter) { // if there are no scene files in the database, then default the // VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to // false, otherwise set them to true for backwards compatibility purposes var count int - if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error { var err error - count, err = r.Scene().Count() + count, err = counter.Count(ctx) return err }); err != nil { logger.Errorf("Error while counting scenes: %s", err.Error()) @@ -36,6 +41,11 @@ func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManag } } +type SceneMissingHashCounter interface { + CountMissingChecksum(ctx context.Context) (int, error) + CountMissingOSHash(ctx context.Context) (int, error) +} + // ValidateVideoFileNamingAlgorithm validates changing the // VideoFileNamingAlgorithm configuration flag. // @@ -44,30 +54,27 @@ func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManag // // Likewise, if VideoFileNamingAlgorithm is set to oshash, then this function // will ensure that all oshash values are set on all scenes. -func ValidateVideoFileNamingAlgorithm(txnManager models.TransactionManager, newValue models.HashAlgorithm) error { +func ValidateVideoFileNamingAlgorithm(ctx context.Context, qb SceneMissingHashCounter, newValue models.HashAlgorithm) error { // if algorithm is being set to MD5, then all checksums must be present - return txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { - qb := r.Scene() - if newValue == models.HashAlgorithmMd5 { - missingMD5, err := qb.CountMissingChecksum() - if err != nil { - return err - } - - if missingMD5 > 0 { - return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true") - } - } else if newValue == models.HashAlgorithmOshash { - missingOSHash, err := qb.CountMissingOSHash() - if err != nil { - return err - } - - if missingOSHash > 0 { - return errors.New("some oshash values are missing on scenes. Run Scan to populate") - } + if newValue == models.HashAlgorithmMd5 { + missingMD5, err := qb.CountMissingChecksum(ctx) + if err != nil { + return err } - return nil - }) + if missingMD5 > 0 { + return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true") + } + } else if newValue == models.HashAlgorithmOshash { + missingOSHash, err := qb.CountMissingOSHash(ctx) + if err != nil { + return err + } + + if missingOSHash > 0 { + return errors.New("some oshash values are missing on scenes. Run Scan to populate") + } + } + + return nil } diff --git a/internal/manager/filename_parser.go b/internal/manager/filename_parser.go index 3bd856e69..ecff4ea7a 100644 --- a/internal/manager/filename_parser.go +++ b/internal/manager/filename_parser.go @@ -1,6 +1,7 @@ package manager import ( + "context" "database/sql" "errors" "path/filepath" @@ -470,7 +471,15 @@ func (p *SceneFilenameParser) initWhiteSpaceRegex() { } } -func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParserResult, int, error) { +type SceneFilenameParserRepository struct { + Scene scene.Queryer + Performer PerformerNamesFinder + Studio studio.Queryer + Movie MovieNameFinder + Tag tag.Queryer +} + +func (p *SceneFilenameParser) Parse(ctx context.Context, repo SceneFilenameParserRepository) ([]*SceneParserResult, int, error) { // perform the query to find the scenes mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords) @@ -492,17 +501,17 @@ func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParse p.Filter.Q = nil - scenes, total, err := scene.QueryWithCount(repo.Scene(), sceneFilter, p.Filter) + scenes, total, err := scene.QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter) if err != nil { return nil, 0, err } - ret := p.parseScenes(repo, scenes, mapper) + ret := p.parseScenes(ctx, repo, scenes, mapper) return ret, total, nil } -func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult { +func (p *SceneFilenameParser) parseScenes(ctx context.Context, repo SceneFilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult { var ret []*SceneParserResult for _, scene := range scenes { sceneHolder := mapper.parse(scene) @@ -511,7 +520,7 @@ func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes [ r := &SceneParserResult{ Scene: scene, } - p.setParserResult(repo, *sceneHolder, r) + p.setParserResult(ctx, repo, *sceneHolder, r) ret = append(ret, r) } @@ -530,7 +539,11 @@ func (p SceneFilenameParser) replaceWhitespaceCharacters(value string) string { return value } -func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performerName string) *models.Performer { +type PerformerNamesFinder interface { + FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) +} + +func (p *SceneFilenameParser) queryPerformer(ctx context.Context, qb PerformerNamesFinder, performerName string) *models.Performer { // massage the performer name performerName = delimiterRE.ReplaceAllString(performerName, " ") @@ -540,7 +553,7 @@ func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performe } // perform an exact match and grab the first - performers, _ := qb.FindByNames([]string{performerName}, true) + performers, _ := qb.FindByNames(ctx, []string{performerName}, true) var ret *models.Performer if len(performers) > 0 { @@ -553,7 +566,7 @@ func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performe return ret } -func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName string) *models.Studio { +func (p *SceneFilenameParser) queryStudio(ctx context.Context, qb studio.Queryer, studioName string) *models.Studio { // massage the performer name studioName = delimiterRE.ReplaceAllString(studioName, " ") @@ -562,11 +575,11 @@ func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName str return ret } - ret, _ := studio.ByName(qb, studioName) + ret, _ := studio.ByName(ctx, qb, studioName) // try to match on alias if ret == nil { - ret, _ = studio.ByAlias(qb, studioName) + ret, _ = studio.ByAlias(ctx, qb, studioName) } // add result to cache @@ -575,7 +588,11 @@ func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName str return ret } -func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string) *models.Movie { +type MovieNameFinder interface { + FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error) +} + +func (p *SceneFilenameParser) queryMovie(ctx context.Context, qb MovieNameFinder, movieName string) *models.Movie { // massage the movie name movieName = delimiterRE.ReplaceAllString(movieName, " ") @@ -584,7 +601,7 @@ func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string return ret } - ret, _ := qb.FindByName(movieName, true) + ret, _ := qb.FindByName(ctx, movieName, true) // add result to cache p.movieCache[movieName] = ret @@ -592,7 +609,7 @@ func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string return ret } -func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *models.Tag { +func (p *SceneFilenameParser) queryTag(ctx context.Context, qb tag.Queryer, tagName string) *models.Tag { // massage the tag name tagName = delimiterRE.ReplaceAllString(tagName, " ") @@ -602,11 +619,11 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod } // match tag name exactly - ret, _ := tag.ByName(qb, tagName) + ret, _ := tag.ByName(ctx, qb, tagName) // try to match on alias if ret == nil { - ret, _ = tag.ByAlias(qb, tagName) + ret, _ = tag.ByAlias(ctx, qb, tagName) } // add result to cache @@ -615,12 +632,12 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod return ret } -func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHolder, result *SceneParserResult) { +func (p *SceneFilenameParser) setPerformers(ctx context.Context, qb PerformerNamesFinder, h sceneHolder, result *SceneParserResult) { // query for each performer performersSet := make(map[int]bool) for _, performerName := range h.performers { if performerName != "" { - performer := p.queryPerformer(qb, performerName) + performer := p.queryPerformer(ctx, qb, performerName) if performer != nil { if _, found := performersSet[performer.ID]; !found { result.PerformerIds = append(result.PerformerIds, strconv.Itoa(performer.ID)) @@ -631,12 +648,12 @@ func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHo } } -func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result *SceneParserResult) { +func (p *SceneFilenameParser) setTags(ctx context.Context, qb tag.Queryer, h sceneHolder, result *SceneParserResult) { // query for each performer tagsSet := make(map[int]bool) for _, tagName := range h.tags { if tagName != "" { - tag := p.queryTag(qb, tagName) + tag := p.queryTag(ctx, qb, tagName) if tag != nil { if _, found := tagsSet[tag.ID]; !found { result.TagIds = append(result.TagIds, strconv.Itoa(tag.ID)) @@ -647,10 +664,10 @@ func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result } } -func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, result *SceneParserResult) { +func (p *SceneFilenameParser) setStudio(ctx context.Context, qb studio.Queryer, h sceneHolder, result *SceneParserResult) { // query for each performer if h.studio != "" { - studio := p.queryStudio(qb, h.studio) + studio := p.queryStudio(ctx, qb, h.studio) if studio != nil { studioID := strconv.Itoa(studio.ID) result.StudioID = &studioID @@ -658,12 +675,12 @@ func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, r } } -func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, result *SceneParserResult) { +func (p *SceneFilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sceneHolder, result *SceneParserResult) { // query for each movie moviesSet := make(map[int]bool) for _, movieName := range h.movies { if movieName != "" { - movie := p.queryMovie(qb, movieName) + movie := p.queryMovie(ctx, qb, movieName) if movie != nil { if _, found := moviesSet[movie.ID]; !found { result.Movies = append(result.Movies, &SceneMovieID{ @@ -676,7 +693,7 @@ func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, re } } -func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sceneHolder, result *SceneParserResult) { +func (p *SceneFilenameParser) setParserResult(ctx context.Context, repo SceneFilenameParserRepository, h sceneHolder, result *SceneParserResult) { if h.result.Title.Valid { title := h.result.Title.String title = p.replaceWhitespaceCharacters(title) @@ -698,15 +715,15 @@ func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sc } if len(h.performers) > 0 { - p.setPerformers(repo.Performer(), h, result) + p.setPerformers(ctx, repo.Performer, h, result) } if len(h.tags) > 0 { - p.setTags(repo.Tag(), h, result) + p.setTags(ctx, repo.Tag, h, result) } - p.setStudio(repo.Studio(), h, result) + p.setStudio(ctx, repo.Studio, h, result) if len(h.movies) > 0 { - p.setMovies(repo.Movie(), h, result) + p.setMovies(ctx, repo.Movie, h, result) } } diff --git a/internal/manager/import.go b/internal/manager/import.go index 8e3140577..0762096c2 100644 --- a/internal/manager/import.go +++ b/internal/manager/import.go @@ -1,6 +1,7 @@ package manager import ( + "context" "fmt" "io" "strconv" @@ -52,22 +53,22 @@ func (e ImportDuplicateEnum) MarshalGQL(w io.Writer) { } type importer interface { - PreImport() error - PostImport(id int) error + PreImport(ctx context.Context) error + PostImport(ctx context.Context, id int) error Name() string - FindExistingID() (*int, error) - Create() (*int, error) - Update(id int) error + FindExistingID(ctx context.Context) (*int, error) + Create(ctx context.Context) (*int, error) + Update(ctx context.Context, id int) error } -func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error { - if err := i.PreImport(); err != nil { +func performImport(ctx context.Context, i importer, duplicateBehaviour ImportDuplicateEnum) error { + if err := i.PreImport(ctx); err != nil { return err } // try to find an existing object with the same name name := i.Name() - existing, err := i.FindExistingID() + existing, err := i.FindExistingID(ctx) if err != nil { return fmt.Errorf("error finding existing objects: %v", err) } @@ -84,12 +85,12 @@ func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error { // must be overwriting id = *existing - if err := i.Update(id); err != nil { + if err := i.Update(ctx, id); err != nil { return fmt.Errorf("error updating existing object: %v", err) } } else { // creating - createdID, err := i.Create() + createdID, err := i.Create(ctx) if err != nil { return fmt.Errorf("error creating object: %v", err) } @@ -97,7 +98,7 @@ func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error { id = *createdID } - if err := i.PostImport(id); err != nil { + if err := i.PostImport(ctx, id); err != nil { return err } diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 56221327b..4f292750e 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -17,7 +17,6 @@ import ( "github.com/stashapp/stash/internal/dlna" "github.com/stashapp/stash/internal/log" "github.com/stashapp/stash/internal/manager/config" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/ffmpeg" "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/job" @@ -115,7 +114,8 @@ type Manager struct { DLNAService *dlna.Service - TxnManager models.TransactionManager + Database *sqlite.Database + Repository models.Repository scanSubs *subscriptionManager } @@ -150,6 +150,8 @@ func initialize() error { l := initLog() initProfiling(cfg.GetCPUProfilePath()) + db := &sqlite.Database{} + instance = &Manager{ Config: cfg, Logger: l, @@ -157,7 +159,20 @@ func initialize() error { DownloadStore: NewDownloadStore(), PluginCache: plugin.NewCache(cfg), - TxnManager: sqlite.NewTransactionManager(), + Database: db, + Repository: models.Repository{ + TxnManager: db, + Gallery: sqlite.GalleryReaderWriter, + Image: sqlite.ImageReaderWriter, + Movie: sqlite.MovieReaderWriter, + Performer: sqlite.PerformerReaderWriter, + Scene: sqlite.SceneReaderWriter, + SceneMarker: sqlite.SceneMarkerReaderWriter, + ScrapedItem: sqlite.ScrapedItemReaderWriter, + Studio: sqlite.StudioReaderWriter, + Tag: sqlite.TagReaderWriter, + SavedFilter: sqlite.SavedFilterReaderWriter, + }, scanSubs: &subscriptionManager{}, } @@ -165,9 +180,17 @@ func initialize() error { instance.JobManager = initJobManager() sceneServer := SceneServer{ - TXNManager: instance.TxnManager, + TxnManager: instance.Repository, + SceneCoverGetter: instance.Repository.Scene, } - instance.DLNAService = dlna.NewService(instance.TxnManager, instance.Config, &sceneServer) + + instance.DLNAService = dlna.NewService(instance.Repository, dlna.Repository{ + SceneFinder: instance.Repository.Scene, + StudioFinder: instance.Repository.Studio, + TagFinder: instance.Repository.Tag, + PerformerFinder: instance.Repository.Performer, + MovieFinder: instance.Repository.Movie, + }, instance.Config, &sceneServer) if !cfg.IsNewSystem() { logger.Infof("using config file: %s", cfg.GetConfigFile()) @@ -177,9 +200,14 @@ func initialize() error { } if err != nil { - return fmt.Errorf("error initializing configuration: %w", err) + panic(fmt.Sprintf("error initializing configuration: %s", err.Error())) } else if err := instance.PostInit(ctx); err != nil { - return err + var migrationNeededErr *sqlite.MigrationNeededError + if errors.As(err, &migrationNeededErr) { + logger.Warn(err.Error()) + } else { + panic(err) + } } initSecurity(cfg) @@ -352,7 +380,8 @@ func (s *Manager) PostInit(ctx context.Context) error { }) } - if err := database.Initialize(s.Config.GetDatabasePath()); err != nil { + database := s.Database + if err := database.Open(s.Config.GetDatabasePath()); err != nil { return err } @@ -377,7 +406,14 @@ func writeStashIcon() { // initScraperCache initializes a new scraper cache and returns it. func (s *Manager) initScraperCache() *scraper.Cache { - ret, err := scraper.NewCache(config.GetInstance(), s.TxnManager) + ret, err := scraper.NewCache(config.GetInstance(), s.Repository, scraper.Repository{ + SceneFinder: s.Repository.Scene, + GalleryFinder: s.Repository.Gallery, + TagFinder: s.Repository.Tag, + PerformerFinder: s.Repository.Performer, + MovieFinder: s.Repository.Movie, + StudioFinder: s.Repository.Studio, + }) if err != nil { logger.Errorf("Error reading scraper configs: %s", err.Error()) @@ -476,7 +512,12 @@ func (s *Manager) Setup(ctx context.Context, input SetupInput) error { // initialise the database if err := s.PostInit(ctx); err != nil { - return fmt.Errorf("error initializing the database: %v", err) + var migrationNeededErr *sqlite.MigrationNeededError + if errors.As(err, &migrationNeededErr) { + logger.Warn(err.Error()) + } else { + return fmt.Errorf("error initializing the database: %v", err) + } } s.Config.FinalizeSetup() @@ -501,6 +542,8 @@ type MigrateInput struct { } func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error { + database := s.Database + // always backup so that we can roll back to the previous version if // migration fails backupPath := input.BackupPath @@ -509,7 +552,7 @@ func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error { } // perform database backup - if err := database.Backup(database.DB, backupPath); err != nil { + if err := database.Backup(backupPath); err != nil { return fmt.Errorf("error backing up database: %s", err) } @@ -541,6 +584,7 @@ func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error { } func (s *Manager) GetSystemStatus() *SystemStatus { + database := s.Database status := SystemStatusEnumOk dbSchema := int(database.Version()) dbPath := database.DatabasePath() @@ -569,7 +613,7 @@ func (s *Manager) Shutdown(code int) { // TODO: Each part of the manager needs to gracefully stop at some point // for now, we just close the database. - err := database.Close() + err := s.Database.Close() if err != nil { logger.Errorf("Error closing database: %s", err) if code == 0 { diff --git a/internal/manager/manager_tasks.go b/internal/manager/manager_tasks.go index 3178f7846..95f5c935f 100644 --- a/internal/manager/manager_tasks.go +++ b/internal/manager/manager_tasks.go @@ -84,7 +84,7 @@ func (s *Manager) Scan(ctx context.Context, input ScanMetadataInput) (int, error } scanJob := ScanJob{ - txnManager: s.TxnManager, + txnManager: s.Repository, input: input, subscriptions: s.scanSubs, } @@ -101,7 +101,7 @@ func (s *Manager) Import(ctx context.Context) (int, error) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { task := ImportTask{ - txnManager: s.TxnManager, + txnManager: s.Repository, BaseDir: metadataPath, Reset: true, DuplicateBehaviour: ImportDuplicateEnumFail, @@ -125,7 +125,7 @@ func (s *Manager) Export(ctx context.Context) (int, error) { var wg sync.WaitGroup wg.Add(1) task := ExportTask{ - txnManager: s.TxnManager, + txnManager: s.Repository, full: true, fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(), } @@ -156,7 +156,7 @@ func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (in } j := &GenerateJob{ - txnManager: s.TxnManager, + txnManager: s.Repository, input: input, } @@ -185,9 +185,9 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl } var scene *models.Scene - if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error { var err error - scene, err = r.Scene().Find(sceneIdInt) + scene, err = s.Repository.Scene.Find(ctx, sceneIdInt) return err }); err != nil || scene == nil { logger.Errorf("failed to get scene for generate: %s", err.Error()) @@ -195,7 +195,7 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl } task := GenerateScreenshotTask{ - txnManager: s.TxnManager, + txnManager: s.Repository, Scene: *scene, ScreenshotAt: at, fileNamingAlgorithm: config.GetInstance().GetVideoFileNamingAlgorithm(), @@ -222,7 +222,7 @@ type AutoTagMetadataInput struct { func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int { j := autoTagJob{ - txnManager: s.TxnManager, + txnManager: s.Repository, input: input, } @@ -237,7 +237,7 @@ type CleanMetadataInput struct { func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int { j := cleanJob{ - txnManager: s.TxnManager, + txnManager: s.Repository, input: input, scanSubs: s.scanSubs, } @@ -251,9 +251,9 @@ func (s *Manager) MigrateHash(ctx context.Context) int { logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String()) var scenes []*models.Scene - if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error { var err error - scenes, err = r.Scene().All() + scenes, err = s.Repository.Scene.All(ctx) return err }); err != nil { logger.Errorf("failed to fetch list of scenes for migration: %s", err.Error()) @@ -327,15 +327,14 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB // This is why we mark this section nolint. In principle, we should look to // rewrite the section at some point, to avoid the linter warning. if len(input.PerformerIds) > 0 { //nolint:gocritic - if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - performerQuery := r.Performer() + if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error { + performerQuery := s.Repository.Performer for _, performerID := range input.PerformerIds { if id, err := strconv.Atoi(performerID); err == nil { - performer, err := performerQuery.Find(id) + performer, err := performerQuery.Find(ctx, id) if err == nil { tasks = append(tasks, StashBoxPerformerTagTask{ - txnManager: s.TxnManager, performer: performer, refresh: input.Refresh, box: box, @@ -354,7 +353,6 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB for i := range input.PerformerNames { if len(input.PerformerNames[i]) > 0 { tasks = append(tasks, StashBoxPerformerTagTask{ - txnManager: s.TxnManager, name: &input.PerformerNames[i], refresh: input.Refresh, box: box, @@ -367,14 +365,14 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB // However, this doesn't really help with readability of the current section. Mark it // as nolint for now. In the future we'd like to rewrite this code by factoring some of // this into separate functions. - if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - performerQuery := r.Performer() + if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error { + performerQuery := s.Repository.Performer var performers []*models.Performer var err error if input.Refresh { - performers, err = performerQuery.FindByStashIDStatus(true, box.Endpoint) + performers, err = performerQuery.FindByStashIDStatus(ctx, true, box.Endpoint) } else { - performers, err = performerQuery.FindByStashIDStatus(false, box.Endpoint) + performers, err = performerQuery.FindByStashIDStatus(ctx, false, box.Endpoint) } if err != nil { return fmt.Errorf("error querying performers: %v", err) @@ -382,7 +380,6 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB for _, performer := range performers { tasks = append(tasks, StashBoxPerformerTagTask{ - txnManager: s.TxnManager, performer: performer, refresh: input.Refresh, box: box, diff --git a/internal/manager/post_migrate.go b/internal/manager/post_migrate.go index 1db1aac40..acc93ae69 100644 --- a/internal/manager/post_migrate.go +++ b/internal/manager/post_migrate.go @@ -4,5 +4,5 @@ import "context" // PostMigrate is executed after migrations have been executed. func (s *Manager) PostMigrate(ctx context.Context) { - setInitialMD5Config(ctx, s.TxnManager) + setInitialMD5Config(ctx, s.Repository, s.Repository.Scene) } diff --git a/internal/manager/running_streams.go b/internal/manager/running_streams.go index 41c196462..9d43d26d2 100644 --- a/internal/manager/running_streams.go +++ b/internal/manager/running_streams.go @@ -8,6 +8,7 @@ import ( "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) @@ -49,8 +50,13 @@ func KillRunningStreams(scene *models.Scene, fileNamingAlgo models.HashAlgorithm instance.ReadLockManager.Cancel(transcodePath) } +type SceneCoverGetter interface { + GetCover(ctx context.Context, sceneID int) ([]byte, error) +} + type SceneServer struct { - TXNManager models.TransactionManager + TxnManager txn.Manager + SceneCoverGetter SceneCoverGetter } func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWriter, r *http.Request) { @@ -75,8 +81,8 @@ func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter http.ServeFile(w, r, filepath) } else { var cover []byte - err := s.TXNManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { - cover, _ = repo.Scene().GetCover(scene.ID) + err := txn.WithTxn(r.Context(), s.TxnManager, func(ctx context.Context) error { + cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID) return nil }) if err != nil { diff --git a/internal/manager/studio.go b/internal/manager/studio.go index 3b0d81ceb..6b517af6f 100644 --- a/internal/manager/studio.go +++ b/internal/manager/studio.go @@ -1,13 +1,15 @@ package manager import ( + "context" "errors" "fmt" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/studio" ) -func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) error { +func ValidateModifyStudio(ctx context.Context, studio models.StudioPartial, qb studio.Finder) error { if studio.ParentID == nil || !studio.ParentID.Valid { return nil } @@ -22,7 +24,7 @@ func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) e return errors.New("studio cannot be an ancestor of itself") } - currentStudio, err := qb.Find(int(currentParentID.Int64)) + currentStudio, err := qb.Find(ctx, int(currentParentID.Int64)) if err != nil { return fmt.Errorf("error finding parent studio: %v", err) } diff --git a/internal/manager/task_autotag.go b/internal/manager/task_autotag.go index a912dcedc..674fdfe64 100644 --- a/internal/manager/task_autotag.go +++ b/internal/manager/task_autotag.go @@ -19,7 +19,7 @@ import ( ) type autoTagJob struct { - txnManager models.TransactionManager + txnManager models.Repository input AutoTagMetadataInput cache match.Cache @@ -73,27 +73,28 @@ func (j *autoTagJob) autoTagSpecific(ctx context.Context, progress *job.Progress studioCount := len(studioIds) tagCount := len(tagIds) - if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - performerQuery := r.Performer() - studioQuery := r.Studio() - tagQuery := r.Tag() + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := j.txnManager + performerQuery := r.Performer + studioQuery := r.Studio + tagQuery := r.Tag const wildcard = "*" var err error if performerCount == 1 && performerIds[0] == wildcard { - performerCount, err = performerQuery.Count() + performerCount, err = performerQuery.Count(ctx) if err != nil { return fmt.Errorf("error getting performer count: %v", err) } } if studioCount == 1 && studioIds[0] == wildcard { - studioCount, err = studioQuery.Count() + studioCount, err = studioQuery.Count(ctx) if err != nil { return fmt.Errorf("error getting studio count: %v", err) } } if tagCount == 1 && tagIds[0] == wildcard { - tagCount, err = tagQuery.Count() + tagCount, err = tagQuery.Count(ctx) if err != nil { return fmt.Errorf("error getting tag count: %v", err) } @@ -123,14 +124,14 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre for _, performerId := range performerIds { var performers []*models.Performer - if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - performerQuery := r.Performer() + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + performerQuery := j.txnManager.Performer ignoreAutoTag := false perPage := -1 if performerId == "*" { var err error - performers, _, err = performerQuery.Query(&models.PerformerFilterType{ + performers, _, err = performerQuery.Query(ctx, &models.PerformerFilterType{ IgnoreAutoTag: &ignoreAutoTag, }, &models.FindFilterType{ PerPage: &perPage, @@ -144,7 +145,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error()) } - performer, err := performerQuery.Find(performerIdInt) + performer, err := performerQuery.Find(ctx, performerIdInt) if err != nil { return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error()) } @@ -161,14 +162,15 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre return nil } - if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error { - if err := autotag.PerformerScenes(performer, paths, r.Scene(), &j.cache); err != nil { + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := j.txnManager + if err := autotag.PerformerScenes(ctx, performer, paths, r.Scene, &j.cache); err != nil { return err } - if err := autotag.PerformerImages(performer, paths, r.Image(), &j.cache); err != nil { + if err := autotag.PerformerImages(ctx, performer, paths, r.Image, &j.cache); err != nil { return err } - if err := autotag.PerformerGalleries(performer, paths, r.Gallery(), &j.cache); err != nil { + if err := autotag.PerformerGalleries(ctx, performer, paths, r.Gallery, &j.cache); err != nil { return err } @@ -193,16 +195,18 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress, return } + r := j.txnManager + for _, studioId := range studioIds { var studios []*models.Studio - if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - studioQuery := r.Studio() + if err := r.WithTxn(ctx, func(ctx context.Context) error { + studioQuery := r.Studio ignoreAutoTag := false perPage := -1 if studioId == "*" { var err error - studios, _, err = studioQuery.Query(&models.StudioFilterType{ + studios, _, err = studioQuery.Query(ctx, &models.StudioFilterType{ IgnoreAutoTag: &ignoreAutoTag, }, &models.FindFilterType{ PerPage: &perPage, @@ -216,7 +220,7 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress, return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error()) } - studio, err := studioQuery.Find(studioIdInt) + studio, err := studioQuery.Find(ctx, studioIdInt) if err != nil { return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error()) } @@ -234,19 +238,19 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress, return nil } - if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error { - aliases, err := r.Studio().GetAliases(studio.ID) + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + aliases, err := r.Studio.GetAliases(ctx, studio.ID) if err != nil { return err } - if err := autotag.StudioScenes(studio, paths, aliases, r.Scene(), &j.cache); err != nil { + if err := autotag.StudioScenes(ctx, studio, paths, aliases, r.Scene, &j.cache); err != nil { return err } - if err := autotag.StudioImages(studio, paths, aliases, r.Image(), &j.cache); err != nil { + if err := autotag.StudioImages(ctx, studio, paths, aliases, r.Image, &j.cache); err != nil { return err } - if err := autotag.StudioGalleries(studio, paths, aliases, r.Gallery(), &j.cache); err != nil { + if err := autotag.StudioGalleries(ctx, studio, paths, aliases, r.Gallery, &j.cache); err != nil { return err } @@ -271,15 +275,17 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa return } + r := j.txnManager + for _, tagId := range tagIds { var tags []*models.Tag - if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - tagQuery := r.Tag() + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + tagQuery := r.Tag ignoreAutoTag := false perPage := -1 if tagId == "*" { var err error - tags, _, err = tagQuery.Query(&models.TagFilterType{ + tags, _, err = tagQuery.Query(ctx, &models.TagFilterType{ IgnoreAutoTag: &ignoreAutoTag, }, &models.FindFilterType{ PerPage: &perPage, @@ -293,7 +299,7 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error()) } - tag, err := tagQuery.Find(tagIdInt) + tag, err := tagQuery.Find(ctx, tagIdInt) if err != nil { return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error()) } @@ -306,19 +312,19 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa return nil } - if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error { - aliases, err := r.Tag().GetAliases(tag.ID) + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + aliases, err := r.Tag.GetAliases(ctx, tag.ID) if err != nil { return err } - if err := autotag.TagScenes(tag, paths, aliases, r.Scene(), &j.cache); err != nil { + if err := autotag.TagScenes(ctx, tag, paths, aliases, r.Scene, &j.cache); err != nil { return err } - if err := autotag.TagImages(tag, paths, aliases, r.Image(), &j.cache); err != nil { + if err := autotag.TagImages(ctx, tag, paths, aliases, r.Image, &j.cache); err != nil { return err } - if err := autotag.TagGalleries(tag, paths, aliases, r.Gallery(), &j.cache); err != nil { + if err := autotag.TagGalleries(ctx, tag, paths, aliases, r.Gallery, &j.cache); err != nil { return err } @@ -345,7 +351,7 @@ type autoTagFilesTask struct { tags bool progress *job.Progress - txnManager models.TransactionManager + txnManager models.Repository cache *match.Cache } @@ -425,13 +431,13 @@ func (t *autoTagFilesTask) makeGalleryFilter() *models.GalleryFilterType { return ret } -func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) { +func (t *autoTagFilesTask) getCount(ctx context.Context, r models.Repository) (int, error) { pp := 0 findFilter := &models.FindFilterType{ PerPage: &pp, } - sceneResults, err := r.Scene().Query(models.SceneQueryOptions{ + sceneResults, err := r.Scene.Query(ctx, models.SceneQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: findFilter, Count: true, @@ -444,7 +450,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) { sceneCount := sceneResults.Count - imageResults, err := r.Image().Query(models.ImageQueryOptions{ + imageResults, err := r.Image.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: findFilter, Count: true, @@ -457,7 +463,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) { imageCount := imageResults.Count - _, galleryCount, err := r.Gallery().Query(t.makeGalleryFilter(), findFilter) + _, galleryCount, err := r.Gallery.Query(ctx, t.makeGalleryFilter(), findFilter) if err != nil { return 0, err } @@ -465,7 +471,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) { return sceneCount + imageCount + galleryCount, nil } -func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRepository) error { +func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.Repository) error { if job.IsCancelled(ctx) { return nil } @@ -477,7 +483,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRep more := true for more { - scenes, err := scene.Query(r.Scene(), sceneFilter, findFilter) + scenes, err := scene.Query(ctx, r.Scene, sceneFilter, findFilter) if err != nil { return err } @@ -518,7 +524,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRep return nil } -func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRepository) error { +func (t *autoTagFilesTask) processImages(ctx context.Context, r models.Repository) error { if job.IsCancelled(ctx) { return nil } @@ -530,7 +536,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRep more := true for more { - images, err := image.Query(r.Image(), imageFilter, findFilter) + images, err := image.Query(ctx, r.Image, imageFilter, findFilter) if err != nil { return err } @@ -571,7 +577,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRep return nil } -func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.ReaderRepository) error { +func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Repository) error { if job.IsCancelled(ctx) { return nil } @@ -583,7 +589,7 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Reader more := true for more { - galleries, _, err := r.Gallery().Query(galleryFilter, findFilter) + galleries, _, err := r.Gallery.Query(ctx, galleryFilter, findFilter) if err != nil { return err } @@ -625,8 +631,9 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Reader } func (t *autoTagFilesTask) process(ctx context.Context) { - if err := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - total, err := t.getCount(r) + r := t.txnManager + if err := r.WithTxn(ctx, func(ctx context.Context) error { + total, err := t.getCount(ctx, t.txnManager) if err != nil { return err } @@ -661,7 +668,7 @@ func (t *autoTagFilesTask) process(ctx context.Context) { } type autoTagSceneTask struct { - txnManager models.TransactionManager + txnManager models.Repository scene *models.Scene performers bool @@ -673,19 +680,20 @@ type autoTagSceneTask struct { func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + r := t.txnManager + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if t.performers { - if err := autotag.ScenePerformers(t.scene, r.Scene(), r.Performer(), t.cache); err != nil { + if err := autotag.ScenePerformers(ctx, t.scene, r.Scene, r.Performer, t.cache); err != nil { return fmt.Errorf("error tagging scene performers for %s: %v", t.scene.Path, err) } } if t.studios { - if err := autotag.SceneStudios(t.scene, r.Scene(), r.Studio(), t.cache); err != nil { + if err := autotag.SceneStudios(ctx, t.scene, r.Scene, r.Studio, t.cache); err != nil { return fmt.Errorf("error tagging scene studio for %s: %v", t.scene.Path, err) } } if t.tags { - if err := autotag.SceneTags(t.scene, r.Scene(), r.Tag(), t.cache); err != nil { + if err := autotag.SceneTags(ctx, t.scene, r.Scene, r.Tag, t.cache); err != nil { return fmt.Errorf("error tagging scene tags for %s: %v", t.scene.Path, err) } } @@ -697,7 +705,7 @@ func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) { } type autoTagImageTask struct { - txnManager models.TransactionManager + txnManager models.Repository image *models.Image performers bool @@ -709,19 +717,20 @@ type autoTagImageTask struct { func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + r := t.txnManager + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if t.performers { - if err := autotag.ImagePerformers(t.image, r.Image(), r.Performer(), t.cache); err != nil { + if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil { return fmt.Errorf("error tagging image performers for %s: %v", t.image.Path, err) } } if t.studios { - if err := autotag.ImageStudios(t.image, r.Image(), r.Studio(), t.cache); err != nil { + if err := autotag.ImageStudios(ctx, t.image, r.Image, r.Studio, t.cache); err != nil { return fmt.Errorf("error tagging image studio for %s: %v", t.image.Path, err) } } if t.tags { - if err := autotag.ImageTags(t.image, r.Image(), r.Tag(), t.cache); err != nil { + if err := autotag.ImageTags(ctx, t.image, r.Image, r.Tag, t.cache); err != nil { return fmt.Errorf("error tagging image tags for %s: %v", t.image.Path, err) } } @@ -733,7 +742,7 @@ func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) { } type autoTagGalleryTask struct { - txnManager models.TransactionManager + txnManager models.Repository gallery *models.Gallery performers bool @@ -745,19 +754,20 @@ type autoTagGalleryTask struct { func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + r := t.txnManager + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if t.performers { - if err := autotag.GalleryPerformers(t.gallery, r.Gallery(), r.Performer(), t.cache); err != nil { + if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil { return fmt.Errorf("error tagging gallery performers for %s: %v", t.gallery.Path.String, err) } } if t.studios { - if err := autotag.GalleryStudios(t.gallery, r.Gallery(), r.Studio(), t.cache); err != nil { + if err := autotag.GalleryStudios(ctx, t.gallery, r.Gallery, r.Studio, t.cache); err != nil { return fmt.Errorf("error tagging gallery studio for %s: %v", t.gallery.Path.String, err) } } if t.tags { - if err := autotag.GalleryTags(t.gallery, r.Gallery(), r.Tag(), t.cache); err != nil { + if err := autotag.GalleryTags(ctx, t.gallery, r.Gallery, r.Tag, t.cache); err != nil { return fmt.Errorf("error tagging gallery tags for %s: %v", t.gallery.Path.String, err) } } diff --git a/internal/manager/task_clean.go b/internal/manager/task_clean.go index 2cbd168fb..d165a9eba 100644 --- a/internal/manager/task_clean.go +++ b/internal/manager/task_clean.go @@ -18,7 +18,7 @@ import ( ) type cleanJob struct { - txnManager models.TransactionManager + txnManager models.Repository input CleanMetadataInput scanSubs *subscriptionManager } @@ -29,8 +29,10 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) { logger.Infof("Running in Dry Mode") } - if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - total, err := j.getCount(r) + r := j.txnManager + + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + total, err := j.getCount(ctx, r) if err != nil { return fmt.Errorf("error getting count: %w", err) } @@ -41,13 +43,13 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) { return nil } - if err := j.processScenes(ctx, progress, r.Scene()); err != nil { + if err := j.processScenes(ctx, progress, r.Scene); err != nil { return fmt.Errorf("error cleaning scenes: %w", err) } - if err := j.processImages(ctx, progress, r.Image()); err != nil { + if err := j.processImages(ctx, progress, r.Image); err != nil { return fmt.Errorf("error cleaning images: %w", err) } - if err := j.processGalleries(ctx, progress, r.Gallery(), r.Image()); err != nil { + if err := j.processGalleries(ctx, progress, r.Gallery, r.Image); err != nil { return fmt.Errorf("error cleaning galleries: %w", err) } @@ -66,9 +68,9 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) { logger.Info("Finished Cleaning") } -func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) { +func (j *cleanJob) getCount(ctx context.Context, r models.Repository) (int, error) { sceneFilter := scene.PathsFilter(j.input.Paths) - sceneResult, err := r.Scene().Query(models.SceneQueryOptions{ + sceneResult, err := r.Scene.Query(ctx, models.SceneQueryOptions{ QueryOptions: models.QueryOptions{ Count: true, }, @@ -78,12 +80,12 @@ func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) { return 0, err } - imageCount, err := r.Image().QueryCount(image.PathsFilter(j.input.Paths), nil) + imageCount, err := r.Image.QueryCount(ctx, image.PathsFilter(j.input.Paths), nil) if err != nil { return 0, err } - galleryCount, err := r.Gallery().QueryCount(gallery.PathsFilter(j.input.Paths), nil) + galleryCount, err := r.Gallery.QueryCount(ctx, gallery.PathsFilter(j.input.Paths), nil) if err != nil { return 0, err } @@ -91,7 +93,7 @@ func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) { return sceneResult.Count + imageCount + galleryCount, nil } -func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb models.SceneReader) error { +func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb scene.Queryer) error { batchSize := 1000 findFilter := models.BatchFindFilter(batchSize) @@ -107,7 +109,7 @@ func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb return nil } - scenes, err := scene.Query(qb, sceneFilter, findFilter) + scenes, err := scene.Query(ctx, qb, sceneFilter, findFilter) if err != nil { return fmt.Errorf("error querying for scenes: %w", err) } @@ -154,7 +156,7 @@ func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb return nil } -func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, qb models.GalleryReader, iqb models.ImageReader) error { +func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, qb gallery.Queryer, iqb models.ImageReader) error { batchSize := 1000 findFilter := models.BatchFindFilter(batchSize) @@ -170,14 +172,14 @@ func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, return nil } - galleries, _, err := qb.Query(galleryFilter, findFilter) + galleries, _, err := qb.Query(ctx, galleryFilter, findFilter) if err != nil { return fmt.Errorf("error querying for galleries: %w", err) } for _, gallery := range galleries { progress.ExecuteTask(fmt.Sprintf("Assessing gallery %s for clean", gallery.GetTitle()), func() { - if j.shouldCleanGallery(gallery, iqb) { + if j.shouldCleanGallery(ctx, gallery, iqb) { toDelete = append(toDelete, gallery.ID) } else { // increment progress, no further processing @@ -215,7 +217,7 @@ func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, return nil } -func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb models.ImageReader) error { +func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb image.Queryer) error { batchSize := 1000 findFilter := models.BatchFindFilter(batchSize) @@ -234,7 +236,7 @@ func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb return nil } - images, err := image.Query(qb, imageFilter, findFilter) + images, err := image.Query(ctx, qb, imageFilter, findFilter) if err != nil { return fmt.Errorf("error querying for images: %w", err) } @@ -318,7 +320,7 @@ func (j *cleanJob) shouldCleanScene(s *models.Scene) bool { return false } -func (j *cleanJob) shouldCleanGallery(g *models.Gallery, qb models.ImageReader) bool { +func (j *cleanJob) shouldCleanGallery(ctx context.Context, g *models.Gallery, qb models.ImageReader) bool { // never clean manually created galleries if !g.Path.Valid { return false @@ -348,7 +350,7 @@ func (j *cleanJob) shouldCleanGallery(g *models.Gallery, qb models.ImageReader) } } else { // folder-based - delete if it has no images - count, err := qb.CountByGalleryID(g.ID) + count, err := qb.CountByGalleryID(ctx, g.ID) if err != nil { logger.Warnf("Error trying to count gallery images for %q: %v", path, err) return false @@ -401,16 +403,17 @@ func (j *cleanJob) deleteScene(ctx context.Context, fileNamingAlgorithm models.H Paths: GetInstance().Paths, } var s *models.Scene - if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error { - qb := repo.Scene() + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + repo := j.txnManager + qb := repo.Scene var err error - s, err = qb.Find(sceneID) + s, err = qb.Find(ctx, sceneID) if err != nil { return err } - return scene.Destroy(s, repo, fileDeleter, true, false) + return scene.Destroy(ctx, s, repo.Scene, repo.SceneMarker, fileDeleter, true, false) }); err != nil { fileDeleter.Rollback() @@ -431,16 +434,16 @@ func (j *cleanJob) deleteScene(ctx context.Context, fileNamingAlgorithm models.H func (j *cleanJob) deleteGallery(ctx context.Context, galleryID int) { var g *models.Gallery - if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error { - qb := repo.Gallery() + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + qb := j.txnManager.Gallery var err error - g, err = qb.Find(galleryID) + g, err = qb.Find(ctx, galleryID) if err != nil { return err } - return qb.Destroy(galleryID) + return qb.Destroy(ctx, galleryID) }); err != nil { logger.Errorf("Error deleting gallery from database: %s", err.Error()) return @@ -459,11 +462,11 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) { } var i *models.Image - if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error { - qb := repo.Image() + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + qb := j.txnManager.Image var err error - i, err = qb.Find(imageID) + i, err = qb.Find(ctx, imageID) if err != nil { return err } @@ -472,7 +475,7 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) { return fmt.Errorf("image not found: %d", imageID) } - return image.Destroy(i, qb, fileDeleter, true, false) + return image.Destroy(ctx, i, qb, fileDeleter, true, false) }); err != nil { fileDeleter.Rollback() diff --git a/internal/manager/task_export.go b/internal/manager/task_export.go index 512b7e42f..3219252cb 100644 --- a/internal/manager/task_export.go +++ b/internal/manager/task_export.go @@ -32,7 +32,7 @@ import ( ) type ExportTask struct { - txnManager models.TransactionManager + txnManager models.Repository full bool baseDir string @@ -100,7 +100,7 @@ func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportT } return &ExportTask{ - txnManager: GetInstance().TxnManager, + txnManager: GetInstance().Repository, fileNamingAlgorithm: a, scenes: newExportSpec(input.Scenes), images: newExportSpec(input.Images), @@ -146,30 +146,32 @@ func (t *ExportTask) Start(ctx context.Context, wg *sync.WaitGroup) { paths.EnsureJSONDirs(t.baseDir) - txnErr := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + txnErr := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.txnManager + // include movie scenes and gallery images if !t.full { // only include movie scenes if includeDependencies is also set if !t.scenes.all && t.includeDependencies { - t.populateMovieScenes(r) + t.populateMovieScenes(ctx, r) } // always export gallery images if !t.images.all { - t.populateGalleryImages(r) + t.populateGalleryImages(ctx, r) } } - t.ExportScenes(workerCount, r) - t.ExportImages(workerCount, r) - t.ExportGalleries(workerCount, r) - t.ExportMovies(workerCount, r) - t.ExportPerformers(workerCount, r) - t.ExportStudios(workerCount, r) - t.ExportTags(workerCount, r) + t.ExportScenes(ctx, workerCount, r) + t.ExportImages(ctx, workerCount, r) + t.ExportGalleries(ctx, workerCount, r) + t.ExportMovies(ctx, workerCount, r) + t.ExportPerformers(ctx, workerCount, r) + t.ExportStudios(ctx, workerCount, r) + t.ExportTags(ctx, workerCount, r) if t.full { - t.ExportScrapedItems(r) + t.ExportScrapedItems(ctx, r) } return nil @@ -284,17 +286,17 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error { return nil } -func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) { - reader := repo.Movie() - sceneReader := repo.Scene() +func (t *ExportTask) populateMovieScenes(ctx context.Context, repo models.Repository) { + reader := repo.Movie + sceneReader := repo.Scene var movies []*models.Movie var err error all := t.full || (t.movies != nil && t.movies.all) if all { - movies, err = reader.All() + movies, err = reader.All(ctx) } else if t.movies != nil && len(t.movies.IDs) > 0 { - movies, err = reader.FindMany(t.movies.IDs) + movies, err = reader.FindMany(ctx, t.movies.IDs) } if err != nil { @@ -302,7 +304,7 @@ func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) { } for _, m := range movies { - scenes, err := sceneReader.FindByMovieID(m.ID) + scenes, err := sceneReader.FindByMovieID(ctx, m.ID) if err != nil { logger.Errorf("[movies] <%s> failed to fetch scenes for movie: %s", m.Checksum, err.Error()) continue @@ -314,17 +316,17 @@ func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) { } } -func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) { - reader := repo.Gallery() - imageReader := repo.Image() +func (t *ExportTask) populateGalleryImages(ctx context.Context, repo models.Repository) { + reader := repo.Gallery + imageReader := repo.Image var galleries []*models.Gallery var err error all := t.full || (t.galleries != nil && t.galleries.all) if all { - galleries, err = reader.All() + galleries, err = reader.All(ctx) } else if t.galleries != nil && len(t.galleries.IDs) > 0 { - galleries, err = reader.FindMany(t.galleries.IDs) + galleries, err = reader.FindMany(ctx, t.galleries.IDs) } if err != nil { @@ -332,7 +334,7 @@ func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) { } for _, g := range galleries { - images, err := imageReader.FindByGalleryID(g.ID) + images, err := imageReader.FindByGalleryID(ctx, g.ID) if err != nil { logger.Errorf("[galleries] <%s> failed to fetch images for gallery: %s", g.Checksum, err.Error()) continue @@ -344,18 +346,18 @@ func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) { } } -func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) { +func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo models.Repository) { var scenesWg sync.WaitGroup - sceneReader := repo.Scene() + sceneReader := repo.Scene var scenes []*models.Scene var err error all := t.full || (t.scenes != nil && t.scenes.all) if all { - scenes, err = sceneReader.All() + scenes, err = sceneReader.All(ctx) } else if t.scenes != nil && len(t.scenes.IDs) > 0 { - scenes, err = sceneReader.FindMany(t.scenes.IDs) + scenes, err = sceneReader.FindMany(ctx, t.scenes.IDs) } if err != nil { @@ -369,7 +371,7 @@ func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) { for w := 0; w < workers; w++ { // create export Scene workers scenesWg.Add(1) - go exportScene(&scenesWg, jobCh, repo, t) + go exportScene(ctx, &scenesWg, jobCh, repo, t) } for i, scene := range scenes { @@ -388,32 +390,32 @@ func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) { logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.ReaderRepository, t *ExportTask) { +func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.Repository, t *ExportTask) { defer wg.Done() - sceneReader := repo.Scene() - studioReader := repo.Studio() - movieReader := repo.Movie() - galleryReader := repo.Gallery() - performerReader := repo.Performer() - tagReader := repo.Tag() - sceneMarkerReader := repo.SceneMarker() + sceneReader := repo.Scene + studioReader := repo.Studio + movieReader := repo.Movie + galleryReader := repo.Gallery + performerReader := repo.Performer + tagReader := repo.Tag + sceneMarkerReader := repo.SceneMarker for s := range jobChan { sceneHash := s.GetHash(t.fileNamingAlgorithm) - newSceneJSON, err := scene.ToBasicJSON(sceneReader, s) + newSceneJSON, err := scene.ToBasicJSON(ctx, sceneReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene JSON: %s", sceneHash, err.Error()) continue } - newSceneJSON.Studio, err = scene.GetStudioName(studioReader, s) + newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene studio name: %s", sceneHash, err.Error()) continue } - galleries, err := galleryReader.FindBySceneID(s.ID) + galleries, err := galleryReader.FindBySceneID(ctx, s.ID) if err != nil { logger.Errorf("[scenes] <%s> error getting scene gallery checksums: %s", sceneHash, err.Error()) continue @@ -421,7 +423,7 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R newSceneJSON.Galleries = gallery.GetChecksums(galleries) - performers, err := performerReader.FindBySceneID(s.ID) + performers, err := performerReader.FindBySceneID(ctx, s.ID) if err != nil { logger.Errorf("[scenes] <%s> error getting scene performer names: %s", sceneHash, err.Error()) continue @@ -429,19 +431,19 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R newSceneJSON.Performers = performer.GetNames(performers) - newSceneJSON.Tags, err = scene.GetTagNames(tagReader, s) + newSceneJSON.Tags, err = scene.GetTagNames(ctx, tagReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene tag names: %s", sceneHash, err.Error()) continue } - newSceneJSON.Markers, err = scene.GetSceneMarkersJSON(sceneMarkerReader, tagReader, s) + newSceneJSON.Markers, err = scene.GetSceneMarkersJSON(ctx, sceneMarkerReader, tagReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene markers JSON: %s", sceneHash, err.Error()) continue } - newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(movieReader, sceneReader, s) + newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(ctx, movieReader, sceneReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene movies JSON: %s", sceneHash, err.Error()) continue @@ -454,14 +456,14 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R t.galleries.IDs = intslice.IntAppendUniques(t.galleries.IDs, gallery.GetIDs(galleries)) - tagIDs, err := scene.GetDependentTagIDs(tagReader, sceneMarkerReader, s) + tagIDs, err := scene.GetDependentTagIDs(ctx, tagReader, sceneMarkerReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene tags: %s", sceneHash, err.Error()) continue } t.tags.IDs = intslice.IntAppendUniques(t.tags.IDs, tagIDs) - movieIDs, err := scene.GetDependentMovieIDs(sceneReader, s) + movieIDs, err := scene.GetDependentMovieIDs(ctx, sceneReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene movies: %s", sceneHash, err.Error()) continue @@ -482,18 +484,18 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R } } -func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) { +func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo models.Repository) { var imagesWg sync.WaitGroup - imageReader := repo.Image() + imageReader := repo.Image var images []*models.Image var err error all := t.full || (t.images != nil && t.images.all) if all { - images, err = imageReader.All() + images, err = imageReader.All(ctx) } else if t.images != nil && len(t.images.IDs) > 0 { - images, err = imageReader.FindMany(t.images.IDs) + images, err = imageReader.FindMany(ctx, t.images.IDs) } if err != nil { @@ -507,7 +509,7 @@ func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) { for w := 0; w < workers; w++ { // create export Image workers imagesWg.Add(1) - go exportImage(&imagesWg, jobCh, repo, t) + go exportImage(ctx, &imagesWg, jobCh, repo, t) } for i, image := range images { @@ -526,12 +528,12 @@ func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) { logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.ReaderRepository, t *ExportTask) { +func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.Repository, t *ExportTask) { defer wg.Done() - studioReader := repo.Studio() - galleryReader := repo.Gallery() - performerReader := repo.Performer() - tagReader := repo.Tag() + studioReader := repo.Studio + galleryReader := repo.Gallery + performerReader := repo.Performer + tagReader := repo.Tag for s := range jobChan { imageHash := s.Checksum @@ -539,13 +541,13 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R newImageJSON := image.ToBasicJSON(s) var err error - newImageJSON.Studio, err = image.GetStudioName(studioReader, s) + newImageJSON.Studio, err = image.GetStudioName(ctx, studioReader, s) if err != nil { logger.Errorf("[images] <%s> error getting image studio name: %s", imageHash, err.Error()) continue } - imageGalleries, err := galleryReader.FindByImageID(s.ID) + imageGalleries, err := galleryReader.FindByImageID(ctx, s.ID) if err != nil { logger.Errorf("[images] <%s> error getting image galleries: %s", imageHash, err.Error()) continue @@ -553,7 +555,7 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R newImageJSON.Galleries = t.getGalleryChecksums(imageGalleries) - performers, err := performerReader.FindByImageID(s.ID) + performers, err := performerReader.FindByImageID(ctx, s.ID) if err != nil { logger.Errorf("[images] <%s> error getting image performer names: %s", imageHash, err.Error()) continue @@ -561,7 +563,7 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R newImageJSON.Performers = performer.GetNames(performers) - tags, err := tagReader.FindByImageID(s.ID) + tags, err := tagReader.FindByImageID(ctx, s.ID) if err != nil { logger.Errorf("[images] <%s> error getting image tag names: %s", imageHash, err.Error()) continue @@ -597,18 +599,18 @@ func (t *ExportTask) getGalleryChecksums(galleries []*models.Gallery) (ret []str return } -func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository) { +func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo models.Repository) { var galleriesWg sync.WaitGroup - reader := repo.Gallery() + reader := repo.Gallery var galleries []*models.Gallery var err error all := t.full || (t.galleries != nil && t.galleries.all) if all { - galleries, err = reader.All() + galleries, err = reader.All(ctx) } else if t.galleries != nil && len(t.galleries.IDs) > 0 { - galleries, err = reader.FindMany(t.galleries.IDs) + galleries, err = reader.FindMany(ctx, t.galleries.IDs) } if err != nil { @@ -622,7 +624,7 @@ func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository) for w := 0; w < workers; w++ { // create export Scene workers galleriesWg.Add(1) - go exportGallery(&galleriesWg, jobCh, repo, t) + go exportGallery(ctx, &galleriesWg, jobCh, repo, t) } for i, gallery := range galleries { @@ -646,11 +648,11 @@ func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository) logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.ReaderRepository, t *ExportTask) { +func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.Repository, t *ExportTask) { defer wg.Done() - studioReader := repo.Studio() - performerReader := repo.Performer() - tagReader := repo.Tag() + studioReader := repo.Studio + performerReader := repo.Performer + tagReader := repo.Tag for g := range jobChan { galleryHash := g.Checksum @@ -661,13 +663,13 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode continue } - newGalleryJSON.Studio, err = gallery.GetStudioName(studioReader, g) + newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g) if err != nil { logger.Errorf("[galleries] <%s> error getting gallery studio name: %s", galleryHash, err.Error()) continue } - performers, err := performerReader.FindByGalleryID(g.ID) + performers, err := performerReader.FindByGalleryID(ctx, g.ID) if err != nil { logger.Errorf("[galleries] <%s> error getting gallery performer names: %s", galleryHash, err.Error()) continue @@ -675,7 +677,7 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode newGalleryJSON.Performers = performer.GetNames(performers) - tags, err := tagReader.FindByGalleryID(g.ID) + tags, err := tagReader.FindByGalleryID(ctx, g.ID) if err != nil { logger.Errorf("[galleries] <%s> error getting gallery tag names: %s", galleryHash, err.Error()) continue @@ -703,17 +705,17 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode } } -func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository) { +func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo models.Repository) { var performersWg sync.WaitGroup - reader := repo.Performer() + reader := repo.Performer var performers []*models.Performer var err error all := t.full || (t.performers != nil && t.performers.all) if all { - performers, err = reader.All() + performers, err = reader.All(ctx) } else if t.performers != nil && len(t.performers.IDs) > 0 { - performers, err = reader.FindMany(t.performers.IDs) + performers, err = reader.FindMany(ctx, t.performers.IDs) } if err != nil { @@ -726,7 +728,7 @@ func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository) for w := 0; w < workers; w++ { // create export Performer workers performersWg.Add(1) - go t.exportPerformer(&performersWg, jobCh, repo) + go t.exportPerformer(ctx, &performersWg, jobCh, repo) } for i, performer := range performers { @@ -743,20 +745,20 @@ func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository) logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.ReaderRepository) { +func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.Repository) { defer wg.Done() - performerReader := repo.Performer() + performerReader := repo.Performer for p := range jobChan { - newPerformerJSON, err := performer.ToJSON(performerReader, p) + newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p) if err != nil { logger.Errorf("[performers] <%s> error getting performer JSON: %s", p.Checksum, err.Error()) continue } - tags, err := repo.Tag().FindByPerformerID(p.ID) + tags, err := repo.Tag.FindByPerformerID(ctx, p.ID) if err != nil { logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Checksum, err.Error()) continue @@ -781,17 +783,17 @@ func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models. } } -func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) { +func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo models.Repository) { var studiosWg sync.WaitGroup - reader := repo.Studio() + reader := repo.Studio var studios []*models.Studio var err error all := t.full || (t.studios != nil && t.studios.all) if all { - studios, err = reader.All() + studios, err = reader.All(ctx) } else if t.studios != nil && len(t.studios.IDs) > 0 { - studios, err = reader.FindMany(t.studios.IDs) + studios, err = reader.FindMany(ctx, t.studios.IDs) } if err != nil { @@ -805,7 +807,7 @@ func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) { for w := 0; w < workers; w++ { // create export Studio workers studiosWg.Add(1) - go t.exportStudio(&studiosWg, jobCh, repo) + go t.exportStudio(ctx, &studiosWg, jobCh, repo) } for i, studio := range studios { @@ -822,13 +824,13 @@ func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) { logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.ReaderRepository) { +func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.Repository) { defer wg.Done() - studioReader := repo.Studio() + studioReader := repo.Studio for s := range jobChan { - newStudioJSON, err := studio.ToJSON(studioReader, s) + newStudioJSON, err := studio.ToJSON(ctx, studioReader, s) if err != nil { logger.Errorf("[studios] <%s> error getting studio JSON: %s", s.Checksum, err.Error()) @@ -846,17 +848,17 @@ func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Stu } } -func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) { +func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo models.Repository) { var tagsWg sync.WaitGroup - reader := repo.Tag() + reader := repo.Tag var tags []*models.Tag var err error all := t.full || (t.tags != nil && t.tags.all) if all { - tags, err = reader.All() + tags, err = reader.All(ctx) } else if t.tags != nil && len(t.tags.IDs) > 0 { - tags, err = reader.FindMany(t.tags.IDs) + tags, err = reader.FindMany(ctx, t.tags.IDs) } if err != nil { @@ -870,7 +872,7 @@ func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) { for w := 0; w < workers; w++ { // create export Tag workers tagsWg.Add(1) - go t.exportTag(&tagsWg, jobCh, repo) + go t.exportTag(ctx, &tagsWg, jobCh, repo) } for i, tag := range tags { @@ -890,13 +892,13 @@ func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) { logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.ReaderRepository) { +func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.Repository) { defer wg.Done() - tagReader := repo.Tag() + tagReader := repo.Tag for thisTag := range jobChan { - newTagJSON, err := tag.ToJSON(tagReader, thisTag) + newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag) if err != nil { logger.Errorf("[tags] <%s> error getting tag JSON: %s", thisTag.Name, err.Error()) @@ -917,17 +919,17 @@ func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, r } } -func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) { +func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo models.Repository) { var moviesWg sync.WaitGroup - reader := repo.Movie() + reader := repo.Movie var movies []*models.Movie var err error all := t.full || (t.movies != nil && t.movies.all) if all { - movies, err = reader.All() + movies, err = reader.All(ctx) } else if t.movies != nil && len(t.movies.IDs) > 0 { - movies, err = reader.FindMany(t.movies.IDs) + movies, err = reader.FindMany(ctx, t.movies.IDs) } if err != nil { @@ -941,7 +943,7 @@ func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) { for w := 0; w < workers; w++ { // create export Studio workers moviesWg.Add(1) - go t.exportMovie(&moviesWg, jobCh, repo) + go t.exportMovie(ctx, &moviesWg, jobCh, repo) } for i, movie := range movies { @@ -958,14 +960,14 @@ func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) { logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.ReaderRepository) { +func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.Repository) { defer wg.Done() - movieReader := repo.Movie() - studioReader := repo.Studio() + movieReader := repo.Movie + studioReader := repo.Studio for m := range jobChan { - newMovieJSON, err := movie.ToJSON(movieReader, studioReader, m) + newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m) if err != nil { logger.Errorf("[movies] <%s> error getting tag JSON: %s", m.Checksum, err.Error()) @@ -991,10 +993,10 @@ func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movi } } -func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) { - qb := repo.ScrapedItem() - sqb := repo.Studio() - scrapedItems, err := qb.All() +func (t *ExportTask) ExportScrapedItems(ctx context.Context, repo models.Repository) { + qb := repo.ScrapedItem + sqb := repo.Studio + scrapedItems, err := qb.All(ctx) if err != nil { logger.Errorf("[scraped sites] failed to fetch all items: %s", err.Error()) } @@ -1009,7 +1011,7 @@ func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) { var studioName string if scrapedItem.StudioID.Valid { - studio, _ := sqb.Find(int(scrapedItem.StudioID.Int64)) + studio, _ := sqb.Find(ctx, int(scrapedItem.StudioID.Int64)) if studio != nil { studioName = studio.Name.String } diff --git a/internal/manager/task_generate.go b/internal/manager/task_generate.go index bc385ab22..a8f71f7c6 100644 --- a/internal/manager/task_generate.go +++ b/internal/manager/task_generate.go @@ -54,7 +54,7 @@ type GeneratePreviewOptionsInput struct { const generateQueueSize = 200000 type GenerateJob struct { - txnManager models.TransactionManager + txnManager models.Repository input GenerateMetadataInput overwrite bool @@ -110,20 +110,20 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) { Overwrite: j.overwrite, } - if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - qb := r.Scene() + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + qb := j.txnManager.Scene if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 { totals = j.queueTasks(ctx, g, queue) } else { if len(j.input.SceneIDs) > 0 { - scenes, err = qb.FindMany(sceneIDs) + scenes, err = qb.FindMany(ctx, sceneIDs) for _, s := range scenes { j.queueSceneJobs(ctx, g, s, queue, &totals) } } if len(j.input.MarkerIDs) > 0 { - markers, err = r.SceneMarker().FindMany(markerIDs) + markers, err = j.txnManager.SceneMarker.FindMany(ctx, markerIDs) if err != nil { return err } @@ -192,13 +192,13 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que findFilter := models.BatchFindFilter(batchSize) - if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { for more := true; more; { if job.IsCancelled(ctx) { return context.Canceled } - scenes, err := scene.Query(r.Scene(), nil, findFilter) + scenes, err := scene.Query(ctx, j.txnManager.Scene, nil, findFilter) if err != nil { return err } diff --git a/internal/manager/task_generate_interactive_heatmap_speed.go b/internal/manager/task_generate_interactive_heatmap_speed.go index f6ca0a04e..f9a1e8360 100644 --- a/internal/manager/task_generate_interactive_heatmap_speed.go +++ b/internal/manager/task_generate_interactive_heatmap_speed.go @@ -15,7 +15,7 @@ type GenerateInteractiveHeatmapSpeedTask struct { Scene models.Scene Overwrite bool fileNamingAlgorithm models.HashAlgorithm - TxnManager models.TransactionManager + TxnManager models.Repository } func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string { @@ -47,22 +47,22 @@ func (t *GenerateInteractiveHeatmapSpeedTask) Start(ctx context.Context) { var s *models.Scene - if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - s, err = r.Scene().FindByPath(t.Scene.Path) + s, err = t.TxnManager.Scene.FindByPath(ctx, t.Scene.Path) return err }); err != nil { logger.Error(err.Error()) return } - if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { - qb := r.Scene() + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { + qb := t.TxnManager.Scene scenePartial := models.ScenePartial{ ID: s.ID, InteractiveSpeed: &median, } - _, err := qb.Update(scenePartial) + _, err := qb.Update(ctx, scenePartial) return err }); err != nil { logger.Error(err.Error()) diff --git a/internal/manager/task_generate_markers.go b/internal/manager/task_generate_markers.go index 3ef53ddd0..59ddefe63 100644 --- a/internal/manager/task_generate_markers.go +++ b/internal/manager/task_generate_markers.go @@ -13,7 +13,7 @@ import ( ) type GenerateMarkersTask struct { - TxnManager models.TransactionManager + TxnManager models.Repository Scene *models.Scene Marker *models.SceneMarker Overwrite bool @@ -42,9 +42,9 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) { if t.Marker != nil { var scene *models.Scene - if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - scene, err = r.Scene().Find(int(t.Marker.SceneID.Int64)) + scene, err = t.TxnManager.Scene.Find(ctx, int(t.Marker.SceneID.Int64)) return err }); err != nil { logger.Errorf("error finding scene for marker: %s", err.Error()) @@ -69,9 +69,9 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) { func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) { var sceneMarkers []*models.SceneMarker - if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID) + sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID) return err }); err != nil { logger.Errorf("error getting scene markers: %s", err.Error()) @@ -134,9 +134,9 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *ffmpeg.VideoFile, scene func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int { markers := 0 var sceneMarkers []*models.SceneMarker - if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID) + sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID) return err }); err != nil { logger.Errorf("errror finding scene markers: %s", err.Error()) diff --git a/internal/manager/task_generate_phash.go b/internal/manager/task_generate_phash.go index 880bb7794..b4350cc8b 100644 --- a/internal/manager/task_generate_phash.go +++ b/internal/manager/task_generate_phash.go @@ -14,7 +14,7 @@ type GeneratePhashTask struct { Scene models.Scene Overwrite bool fileNamingAlgorithm models.HashAlgorithm - txnManager models.TransactionManager + txnManager models.Repository } func (t *GeneratePhashTask) GetDescription() string { @@ -40,14 +40,14 @@ func (t *GeneratePhashTask) Start(ctx context.Context) { return } - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - qb := r.Scene() + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + qb := t.txnManager.Scene hashValue := sql.NullInt64{Int64: int64(*hash), Valid: true} scenePartial := models.ScenePartial{ ID: t.Scene.ID, Phash: &hashValue, } - _, err := qb.Update(scenePartial) + _, err := qb.Update(ctx, scenePartial) return err }); err != nil { logger.Error(err.Error()) diff --git a/internal/manager/task_generate_screenshot.go b/internal/manager/task_generate_screenshot.go index 80ef9e40d..9b941de8e 100644 --- a/internal/manager/task_generate_screenshot.go +++ b/internal/manager/task_generate_screenshot.go @@ -17,7 +17,7 @@ type GenerateScreenshotTask struct { Scene models.Scene ScreenshotAt *float64 fileNamingAlgorithm models.HashAlgorithm - txnManager models.TransactionManager + txnManager models.Repository } func (t *GenerateScreenshotTask) Start(ctx context.Context) { @@ -74,8 +74,8 @@ func (t *GenerateScreenshotTask) Start(ctx context.Context) { return } - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - qb := r.Scene() + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + qb := t.txnManager.Scene updatedTime := time.Now() updatedScene := models.ScenePartial{ ID: t.Scene.ID, @@ -87,12 +87,12 @@ func (t *GenerateScreenshotTask) Start(ctx context.Context) { } // update the scene cover table - if err := qb.UpdateCover(t.Scene.ID, coverImageData); err != nil { + if err := qb.UpdateCover(ctx, t.Scene.ID, coverImageData); err != nil { return fmt.Errorf("error setting screenshot: %v", err) } // update the scene with the update date - _, err = qb.Update(updatedScene) + _, err = qb.Update(ctx, updatedScene) if err != nil { return fmt.Errorf("error updating scene: %v", err) } diff --git a/internal/manager/task_identify.go b/internal/manager/task_identify.go index 457c59dd1..beec6fca9 100644 --- a/internal/manager/task_identify.go +++ b/internal/manager/task_identify.go @@ -14,12 +14,12 @@ import ( "github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/sliceutil/stringslice" + "github.com/stashapp/stash/pkg/txn" ) var ErrInput = errors.New("invalid request input") type IdentifyJob struct { - txnManager models.TransactionManager postHookExecutor identify.SceneUpdatePostHookExecutor input identify.Options @@ -29,7 +29,6 @@ type IdentifyJob struct { func CreateIdentifyJob(input identify.Options) *IdentifyJob { return &IdentifyJob{ - txnManager: instance.TxnManager, postHookExecutor: instance.PluginCache, input: input, stashBoxes: instance.Config.GetStashBoxes(), @@ -52,9 +51,9 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) { // if scene ids provided, use those // otherwise, batch query for all scenes - ordering by path - if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { if len(j.input.SceneIDs) == 0 { - return j.identifyAllScenes(ctx, r, sources) + return j.identifyAllScenes(ctx, sources) } sceneIDs, err := stringslice.StringSliceToIntSlice(j.input.SceneIDs) @@ -70,7 +69,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) { // find the scene var err error - scene, err := r.Scene().Find(id) + scene, err := instance.Repository.Scene.Find(ctx, id) if err != nil { return fmt.Errorf("error finding scene with id %d: %w", id, err) } @@ -88,7 +87,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) { } } -func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepository, sources []identify.ScraperSource) error { +func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error { // exclude organised organised := false sceneFilter := scene.FilterFromPaths(j.input.Paths) @@ -102,7 +101,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepo // get the count pp := 0 findFilter.PerPage = &pp - countResult, err := r.Scene().Query(models.SceneQueryOptions{ + countResult, err := instance.Repository.Scene.Query(ctx, models.SceneQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: findFilter, Count: true, @@ -115,7 +114,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepo j.progress.SetTotal(countResult.Count) - return scene.BatchProcess(ctx, r.Scene(), sceneFilter, findFilter, func(scene *models.Scene) error { + return scene.BatchProcess(ctx, instance.Repository.Scene, sceneFilter, findFilter, func(scene *models.Scene) error { if job.IsCancelled(ctx) { return nil } @@ -133,6 +132,11 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source var taskError error j.progress.ExecuteTask("Identifying "+s.Path, func() { task := identify.SceneIdentifier{ + SceneReaderUpdater: instance.Repository.Scene, + StudioCreator: instance.Repository.Studio, + PerformerCreator: instance.Repository.Performer, + TagCreator: instance.Repository.Tag, + DefaultOptions: j.input.Options, Sources: sources, ScreenshotSetter: &scene.PathsScreenshotSetter{ @@ -142,7 +146,7 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source SceneUpdatePostHookExecutor: j.postHookExecutor, } - taskError = task.Identify(ctx, j.txnManager, s) + taskError = task.Identify(ctx, instance.Repository, s) }) if taskError != nil { @@ -166,7 +170,12 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) { src = identify.ScraperSource{ Name: "stash-box: " + stashBox.Endpoint, Scraper: stashboxSource{ - stashbox.NewClient(*stashBox, j.txnManager), + stashbox.NewClient(*stashBox, instance.Repository, stashbox.Repository{ + Scene: instance.Repository.Scene, + Performer: instance.Repository.Performer, + Tag: instance.Repository.Tag, + Studio: instance.Repository.Studio, + }), stashBox.Endpoint, }, RemoteSite: stashBox.Endpoint, diff --git a/internal/manager/task_import.go b/internal/manager/task_import.go index 8175bfc59..dccc98354 100644 --- a/internal/manager/task_import.go +++ b/internal/manager/task_import.go @@ -12,8 +12,6 @@ import ( "time" "github.com/99designs/gqlgen/graphql" - "github.com/stashapp/stash/internal/manager/config" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/image" @@ -30,7 +28,7 @@ import ( ) type ImportTask struct { - txnManager models.TransactionManager + txnManager models.Repository json jsonUtils BaseDir string @@ -73,7 +71,7 @@ func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*Import } return &ImportTask{ - txnManager: GetInstance().TxnManager, + txnManager: GetInstance().Repository, BaseDir: baseDir, TmpZip: tmpZip, Reset: false, @@ -126,7 +124,7 @@ func (t *ImportTask) Start(ctx context.Context) { t.scraped = scraped if t.Reset { - err := database.Reset(config.GetInstance().GetDatabasePath()) + err := t.txnManager.Reset() if err != nil { logger.Errorf("Error resetting database: %s", err.Error()) @@ -211,15 +209,16 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) { logger.Progressf("[performers] %d of %d", index, len(t.mappings.Performers)) - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - readerWriter := r.Performer() + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.txnManager + readerWriter := r.Performer importer := &performer.Importer{ ReaderWriter: readerWriter, - TagWriter: r.Tag(), + TagWriter: r.Tag, Input: *performerJSON, } - return performImport(importer, t.DuplicateBehaviour) + return performImport(ctx, importer, t.DuplicateBehaviour) }); err != nil { logger.Errorf("[performers] <%s> import failed: %s", mappingJSON.Checksum, err.Error()) } @@ -243,8 +242,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Progressf("[studios] %d of %d", index, len(t.mappings.Studios)) - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - return t.ImportStudio(studioJSON, pendingParent, r.Studio()) + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + return t.ImportStudio(ctx, studioJSON, pendingParent, t.txnManager.Studio) }); err != nil { if errors.Is(err, studio.ErrParentStudioNotExist) { // add to the pending parent list so that it is created after the parent @@ -265,8 +264,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { for _, s := range pendingParent { for _, orphanStudioJSON := range s { - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - return t.ImportStudio(orphanStudioJSON, nil, r.Studio()) + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + return t.ImportStudio(ctx, orphanStudioJSON, nil, t.txnManager.Studio) }); err != nil { logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error()) continue @@ -278,7 +277,7 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Info("[studios] import complete") } -func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter models.StudioReaderWriter) error { +func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter studio.NameFinderCreatorUpdater) error { importer := &studio.Importer{ ReaderWriter: readerWriter, Input: *studioJSON, @@ -290,7 +289,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m importer.MissingRefBehaviour = models.ImportMissingRefEnumFail } - if err := performImport(importer, t.DuplicateBehaviour); err != nil { + if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil { return err } @@ -298,7 +297,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m s := pendingParent[studioJSON.Name] for _, childStudioJSON := range s { // map is nil since we're not checking parent studios at this point - if err := t.ImportStudio(childStudioJSON, nil, readerWriter); err != nil { + if err := t.ImportStudio(ctx, childStudioJSON, nil, readerWriter); err != nil { return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error()) } } @@ -322,9 +321,10 @@ func (t *ImportTask) ImportMovies(ctx context.Context) { logger.Progressf("[movies] %d of %d", index, len(t.mappings.Movies)) - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - readerWriter := r.Movie() - studioReaderWriter := r.Studio() + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.txnManager + readerWriter := r.Movie + studioReaderWriter := r.Studio movieImporter := &movie.Importer{ ReaderWriter: readerWriter, @@ -333,7 +333,7 @@ func (t *ImportTask) ImportMovies(ctx context.Context) { MissingRefBehaviour: t.MissingRefBehaviour, } - return performImport(movieImporter, t.DuplicateBehaviour) + return performImport(ctx, movieImporter, t.DuplicateBehaviour) }); err != nil { logger.Errorf("[movies] <%s> import failed: %s", mappingJSON.Checksum, err.Error()) continue @@ -356,11 +356,12 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) { logger.Progressf("[galleries] %d of %d", index, len(t.mappings.Galleries)) - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - readerWriter := r.Gallery() - tagWriter := r.Tag() - performerWriter := r.Performer() - studioWriter := r.Studio() + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.txnManager + readerWriter := r.Gallery + tagWriter := r.Tag + performerWriter := r.Performer + studioWriter := r.Studio galleryImporter := &gallery.Importer{ ReaderWriter: readerWriter, @@ -371,7 +372,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) { MissingRefBehaviour: t.MissingRefBehaviour, } - return performImport(galleryImporter, t.DuplicateBehaviour) + return performImport(ctx, galleryImporter, t.DuplicateBehaviour) }); err != nil { logger.Errorf("[galleries] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error()) continue @@ -395,8 +396,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) { logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags)) - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - return t.ImportTag(tagJSON, pendingParent, false, r.Tag()) + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + return t.ImportTag(ctx, tagJSON, pendingParent, false, t.txnManager.Tag) }); err != nil { var parentError tag.ParentTagNotExistError if errors.As(err, &parentError) { @@ -411,8 +412,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) { for _, s := range pendingParent { for _, orphanTagJSON := range s { - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - return t.ImportTag(orphanTagJSON, nil, true, r.Tag()) + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + return t.ImportTag(ctx, orphanTagJSON, nil, true, t.txnManager.Tag) }); err != nil { logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error()) continue @@ -423,7 +424,7 @@ func (t *ImportTask) ImportTags(ctx context.Context) { logger.Info("[tags] import complete") } -func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter models.TagReaderWriter) error { +func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter tag.NameFinderCreatorUpdater) error { importer := &tag.Importer{ ReaderWriter: readerWriter, Input: *tagJSON, @@ -435,12 +436,12 @@ func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string importer.MissingRefBehaviour = models.ImportMissingRefEnumFail } - if err := performImport(importer, t.DuplicateBehaviour); err != nil { + if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil { return err } for _, childTagJSON := range pendingParent[tagJSON.Name] { - if err := t.ImportTag(childTagJSON, pendingParent, fail, readerWriter); err != nil { + if err := t.ImportTag(ctx, childTagJSON, pendingParent, fail, readerWriter); err != nil { var parentError tag.ParentTagNotExistError if errors.As(err, &parentError) { pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON) @@ -457,10 +458,11 @@ func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string } func (t *ImportTask) ImportScrapedItems(ctx context.Context) { - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { logger.Info("[scraped sites] importing") - qb := r.ScrapedItem() - sqb := r.Studio() + r := t.txnManager + qb := r.ScrapedItem + sqb := r.Studio currentTime := time.Now() for i, mappingJSON := range t.scraped { @@ -484,7 +486,7 @@ func (t *ImportTask) ImportScrapedItems(ctx context.Context) { UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)}, } - studio, err := sqb.FindByName(mappingJSON.Studio, false) + studio, err := sqb.FindByName(ctx, mappingJSON.Studio, false) if err != nil { logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error()) } @@ -492,7 +494,7 @@ func (t *ImportTask) ImportScrapedItems(ctx context.Context) { newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true} } - _, err = qb.Create(newScrapedItem) + _, err = qb.Create(ctx, newScrapedItem) if err != nil { logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error()) } @@ -522,14 +524,15 @@ func (t *ImportTask) ImportScenes(ctx context.Context) { sceneHash := mappingJSON.Checksum - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - readerWriter := r.Scene() - tagWriter := r.Tag() - galleryWriter := r.Gallery() - movieWriter := r.Movie() - performerWriter := r.Performer() - studioWriter := r.Studio() - markerWriter := r.SceneMarker() + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.txnManager + readerWriter := r.Scene + tagWriter := r.Tag + galleryWriter := r.Gallery + movieWriter := r.Movie + performerWriter := r.Performer + studioWriter := r.Studio + markerWriter := r.SceneMarker sceneImporter := &scene.Importer{ ReaderWriter: readerWriter, @@ -546,7 +549,7 @@ func (t *ImportTask) ImportScenes(ctx context.Context) { TagWriter: tagWriter, } - if err := performImport(sceneImporter, t.DuplicateBehaviour); err != nil { + if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil { return err } @@ -560,7 +563,7 @@ func (t *ImportTask) ImportScenes(ctx context.Context) { TagWriter: tagWriter, } - if err := performImport(markerImporter, t.DuplicateBehaviour); err != nil { + if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil { return err } } @@ -590,12 +593,13 @@ func (t *ImportTask) ImportImages(ctx context.Context) { imageHash := mappingJSON.Checksum - if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - readerWriter := r.Image() - tagWriter := r.Tag() - galleryWriter := r.Gallery() - performerWriter := r.Performer() - studioWriter := r.Studio() + if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.txnManager + readerWriter := r.Image + tagWriter := r.Tag + galleryWriter := r.Gallery + performerWriter := r.Performer + studioWriter := r.Studio imageImporter := &image.Importer{ ReaderWriter: readerWriter, @@ -610,7 +614,7 @@ func (t *ImportTask) ImportImages(ctx context.Context) { TagWriter: tagWriter, } - return performImport(imageImporter, t.DuplicateBehaviour) + return performImport(ctx, imageImporter, t.DuplicateBehaviour) }); err != nil { logger.Errorf("[images] <%s> import failed: %s", imageHash, err.Error()) } diff --git a/internal/manager/task_scan.go b/internal/manager/task_scan.go index 042ff80ac..99fafceb0 100644 --- a/internal/manager/task_scan.go +++ b/internal/manager/task_scan.go @@ -24,7 +24,7 @@ import ( const scanQueueSize = 200000 type ScanJob struct { - txnManager models.TransactionManager + txnManager models.Repository input ScanMetadataInput subscriptions *subscriptionManager } @@ -220,20 +220,21 @@ func (j *ScanJob) doesPathExist(ctx context.Context, path string) bool { gExt := config.GetGalleryExtensions() ret := false - txnErr := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + txnErr := j.txnManager.WithTxn(ctx, func(ctx context.Context) error { + r := j.txnManager switch { case fsutil.MatchExtension(path, gExt): - g, _ := r.Gallery().FindByPath(path) + g, _ := r.Gallery.FindByPath(ctx, path) if g != nil { ret = true } case fsutil.MatchExtension(path, vidExt): - s, _ := r.Scene().FindByPath(path) + s, _ := r.Scene.FindByPath(ctx, path) if s != nil { ret = true } case fsutil.MatchExtension(path, imgExt): - i, _ := r.Image().FindByPath(path) + i, _ := r.Image.FindByPath(ctx, path) if i != nil { ret = true } @@ -249,7 +250,7 @@ func (j *ScanJob) doesPathExist(ctx context.Context, path string) bool { } type ScanTask struct { - TxnManager models.TransactionManager + TxnManager models.Repository file file.SourceFile UseFileMetadata bool StripFileExtension bool diff --git a/internal/manager/task_scan_gallery.go b/internal/manager/task_scan_gallery.go index 8c3f5c550..2a2669e28 100644 --- a/internal/manager/task_scan_gallery.go +++ b/internal/manager/task_scan_gallery.go @@ -21,12 +21,12 @@ func (t *ScanTask) scanGallery(ctx context.Context) { images := 0 scanImages := false - if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - g, err = r.Gallery().FindByPath(path) + g, err = t.TxnManager.Gallery.FindByPath(ctx, path) if g != nil && err == nil { - images, err = r.Image().CountByGalleryID(g.ID) + images, err = t.TxnManager.Image.CountByGalleryID(ctx, g.ID) if err != nil { return fmt.Errorf("error getting images for zip gallery %s: %s", path, err.Error()) } @@ -43,7 +43,7 @@ func (t *ScanTask) scanGallery(ctx context.Context) { ImageExtensions: instance.Config.GetImageExtensions(), StripFileExtension: t.StripFileExtension, CaseSensitiveFs: t.CaseSensitiveFs, - TxnManager: t.TxnManager, + CreatorUpdater: t.TxnManager.Gallery, Paths: instance.Paths, PluginCache: instance.PluginCache, MutexManager: t.mutexManager, @@ -79,10 +79,11 @@ func (t *ScanTask) scanGallery(ctx context.Context) { // associates a gallery to a scene with the same basename func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.SizedWaitGroup) { path := t.file.Path() - if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { - qb := r.Gallery() - sqb := r.Scene() - g, err := qb.FindByPath(path) + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { + r := t.TxnManager + qb := r.Gallery + sqb := r.Scene + g, err := qb.FindByPath(ctx, path) if err != nil { return err } @@ -106,10 +107,10 @@ func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.Size } } for _, scenePath := range relatedFiles { - scene, _ := sqb.FindByPath(scenePath) + scene, _ := sqb.FindByPath(ctx, scenePath) // found related Scene if scene != nil { - sceneGalleries, _ := sqb.FindByGalleryID(g.ID) // check if gallery is already associated to the scene + sceneGalleries, _ := sqb.FindByGalleryID(ctx, g.ID) // check if gallery is already associated to the scene isAssoc := false for _, sg := range sceneGalleries { if scene.ID == sg.ID { @@ -119,7 +120,7 @@ func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.Size } if !isAssoc { logger.Infof("associate: Gallery %s is related to scene: %d", path, scene.ID) - if err := sqb.UpdateGalleries(scene.ID, []int{g.ID}); err != nil { + if err := sqb.UpdateGalleries(ctx, scene.ID, []int{g.ID}); err != nil { return err } } @@ -152,11 +153,11 @@ func (t *ScanTask) scanZipImages(ctx context.Context, zipGallery *models.Gallery func (t *ScanTask) regenerateZipImages(ctx context.Context, zipGallery *models.Gallery) { var images []*models.Image - if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - iqb := r.Image() + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { + iqb := t.TxnManager.Image var err error - images, err = iqb.FindByGalleryID(zipGallery.ID) + images, err = iqb.FindByGalleryID(ctx, zipGallery.ID) return err }); err != nil { logger.Warnf("failed to find gallery images: %s", err.Error()) diff --git a/internal/manager/task_scan_image.go b/internal/manager/task_scan_image.go index 36aff5a04..20bd78224 100644 --- a/internal/manager/task_scan_image.go +++ b/internal/manager/task_scan_image.go @@ -23,9 +23,9 @@ func (t *ScanTask) scanImage(ctx context.Context) { var i *models.Image path := t.file.Path() - if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - i, err = r.Image().FindByPath(path) + i, err = t.TxnManager.Image.FindByPath(ctx, path) return err }); err != nil { logger.Error(err.Error()) @@ -36,6 +36,8 @@ func (t *ScanTask) scanImage(ctx context.Context) { Scanner: image.FileScanner(&file.FSHasher{}), StripFileExtension: t.StripFileExtension, TxnManager: t.TxnManager, + CreatorUpdater: t.TxnManager.Image, + CaseSensitiveFs: t.CaseSensitiveFs, Paths: GetInstance().Paths, PluginCache: instance.PluginCache, MutexManager: t.mutexManager, @@ -58,8 +60,8 @@ func (t *ScanTask) scanImage(ctx context.Context) { if i != nil { if t.zipGallery != nil { // associate with gallery - if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { - return gallery.AddImage(r.Gallery(), t.zipGallery.ID, i.ID) + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { + return gallery.AddImage(ctx, t.TxnManager.Gallery, t.zipGallery.ID, i.ID) }); err != nil { logger.Error(err.Error()) return @@ -69,9 +71,9 @@ func (t *ScanTask) scanImage(ctx context.Context) { logger.Infof("Associating image %s with folder gallery", i.Path) var galleryID int var isNewGallery bool - if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - galleryID, isNewGallery, err = t.associateImageWithFolderGallery(i.ID, r.Gallery()) + galleryID, isNewGallery, err = t.associateImageWithFolderGallery(ctx, i.ID, t.TxnManager.Gallery) return err }); err != nil { logger.Error(err.Error()) @@ -90,11 +92,17 @@ func (t *ScanTask) scanImage(ctx context.Context) { } } -func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.GalleryReaderWriter) (galleryID int, isNew bool, err error) { +type GalleryImageAssociator interface { + FindByPath(ctx context.Context, path string) (*models.Gallery, error) + Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error) + gallery.ImageUpdater +} + +func (t *ScanTask) associateImageWithFolderGallery(ctx context.Context, imageID int, qb GalleryImageAssociator) (galleryID int, isNew bool, err error) { // find a gallery with the path specified path := filepath.Dir(t.file.Path()) var g *models.Gallery - g, err = qb.FindByPath(path) + g, err = qb.FindByPath(ctx, path) if err != nil { return } @@ -120,7 +128,7 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.Galler } logger.Infof("Creating gallery for folder %s", path) - g, err = qb.Create(newGallery) + g, err = qb.Create(ctx, newGallery) if err != nil { return 0, false, err } @@ -129,7 +137,7 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.Galler } // associate image with gallery - err = gallery.AddImage(qb, g.ID, imageID) + err = gallery.AddImage(ctx, qb, g.ID, imageID) galleryID = g.ID return } diff --git a/internal/manager/task_scan_scene.go b/internal/manager/task_scan_scene.go index 218a2e012..295a0c7ef 100644 --- a/internal/manager/task_scan_scene.go +++ b/internal/manager/task_scan_scene.go @@ -34,9 +34,9 @@ func (t *ScanTask) scanScene(ctx context.Context) *models.Scene { var retScene *models.Scene var s *models.Scene - if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - s, err = r.Scene().FindByPath(t.file.Path()) + s, err = t.TxnManager.Scene.FindByPath(ctx, t.file.Path()) return err }); err != nil { logger.Error(err.Error()) @@ -54,7 +54,9 @@ func (t *ScanTask) scanScene(ctx context.Context) *models.Scene { StripFileExtension: t.StripFileExtension, FileNamingAlgorithm: t.fileNamingAlgorithm, TxnManager: t.TxnManager, + CreatorUpdater: t.TxnManager.Scene, Paths: GetInstance().Paths, + CaseSensitiveFs: t.CaseSensitiveFs, Screenshotter: &sceneScreenshotter{ g: g, }, @@ -88,12 +90,12 @@ func (t *ScanTask) associateCaptions(ctx context.Context) { captionLang := scene.GetCaptionsLangFromPath(captionPath) relatedFiles := scene.GenerateCaptionCandidates(captionPath, vExt) - if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { + if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { var err error - sqb := r.Scene() + sqb := t.TxnManager.Scene for _, scenePath := range relatedFiles { - s, er := sqb.FindByPath(scenePath) + s, er := sqb.FindByPath(ctx, scenePath) if er != nil { logger.Errorf("Error searching for scene %s: %v", scenePath, er) @@ -101,7 +103,7 @@ func (t *ScanTask) associateCaptions(ctx context.Context) { } if s != nil { // found related Scene logger.Debugf("Matched captions to scene %s", s.Path) - captions, er := sqb.GetCaptions(s.ID) + captions, er := sqb.GetCaptions(ctx, s.ID) if er == nil { fileExt := filepath.Ext(captionPath) ext := fileExt[1:] @@ -112,7 +114,7 @@ func (t *ScanTask) associateCaptions(ctx context.Context) { CaptionType: ext, } captions = append(captions, newCaption) - er = sqb.UpdateCaptions(s.ID, captions) + er = sqb.UpdateCaptions(ctx, s.ID, captions) if er == nil { logger.Debugf("Updated captions for scene %s. Added %s", s.Path, captionLang) } diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go index 2932a3167..cf7add510 100644 --- a/internal/manager/task_stash_box_tag.go +++ b/internal/manager/task_stash_box_tag.go @@ -10,11 +10,11 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scraper/stashbox" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) type StashBoxPerformerTagTask struct { - txnManager models.TransactionManager box *models.StashBox name *string performer *models.Performer @@ -41,12 +41,17 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { var performer *models.ScrapedPerformer var err error - client := stashbox.NewClient(*t.box, t.txnManager) + client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{ + Scene: instance.Repository.Scene, + Performer: instance.Repository.Performer, + Tag: instance.Repository.Tag, + Studio: instance.Repository.Studio, + }) if t.refresh { var performerID string - txnErr := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - stashids, _ := r.Performer().GetStashIDs(t.performer.ID) + txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { + stashids, _ := instance.Repository.Performer.GetStashIDs(ctx, t.performer.ID) for _, id := range stashids { if id.Endpoint == t.box.Endpoint { performerID = id.StashID @@ -156,11 +161,12 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { partial.URL = &value } - txnErr := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - _, err := r.Performer().Update(partial) + txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { + r := instance.Repository + _, err := r.Performer.Update(ctx, partial) if !t.refresh { - err = r.Performer().UpdateStashIDs(t.performer.ID, []models.StashID{ + err = r.Performer.UpdateStashIDs(ctx, t.performer.ID, []models.StashID{ { Endpoint: t.box.Endpoint, StashID: *performer.RemoteSiteID, @@ -176,7 +182,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { if err != nil { return err } - err = r.Performer().UpdateImage(t.performer.ID, image) + err = r.Performer.UpdateImage(ctx, t.performer.ID, image) if err != nil { return err } @@ -218,13 +224,14 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { URL: getNullString(performer.URL), UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, } - err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - createdPerformer, err := r.Performer().Create(newPerformer) + err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { + r := instance.Repository + createdPerformer, err := r.Performer.Create(ctx, newPerformer) if err != nil { return err } - err = r.Performer().UpdateStashIDs(createdPerformer.ID, []models.StashID{ + err = r.Performer.UpdateStashIDs(ctx, createdPerformer.ID, []models.StashID{ { Endpoint: t.box.Endpoint, StashID: *performer.RemoteSiteID, @@ -239,7 +246,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { if imageErr != nil { return imageErr } - err = r.Performer().UpdateImage(createdPerformer.ID, image) + err = r.Performer.UpdateImage(ctx, createdPerformer.ID, image) } return err }) diff --git a/pkg/database/transaction.go b/pkg/database/transaction.go deleted file mode 100644 index d8c23fb3b..000000000 --- a/pkg/database/transaction.go +++ /dev/null @@ -1,40 +0,0 @@ -package database - -import ( - "context" - - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/logger" -) - -// WithTxn executes the provided function within a transaction. It rolls back -// the transaction if the function returns an error, otherwise the transaction -// is committed. -func WithTxn(fn func(tx *sqlx.Tx) error) error { - ctx := context.TODO() - tx := DB.MustBeginTx(ctx, nil) - - var err error - defer func() { - if p := recover(); p != nil { - // a panic occurred, rollback and repanic - if err := tx.Rollback(); err != nil { - logger.Warnf("failure when performing transaction rollback: %v", err) - } - panic(p) - } - - if err != nil { - // something went wrong, rollback - if err := tx.Rollback(); err != nil { - logger.Warnf("failure when performing transaction rollback: %v", err) - } - } else { - // all good, commit - err = tx.Commit() - } - }() - - err = fn(tx) - return err -} diff --git a/pkg/gallery/export.go b/pkg/gallery/export.go index f24660e60..296929b35 100644 --- a/pkg/gallery/export.go +++ b/pkg/gallery/export.go @@ -1,9 +1,12 @@ package gallery import ( + "context" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/json" "github.com/stashapp/stash/pkg/models/jsonschema" + "github.com/stashapp/stash/pkg/studio" "github.com/stashapp/stash/pkg/utils" ) @@ -52,9 +55,9 @@ func ToBasicJSON(gallery *models.Gallery) (*jsonschema.Gallery, error) { // GetStudioName returns the name of the provided gallery's studio. It returns an // empty string if there is no studio assigned to the gallery. -func GetStudioName(reader models.StudioReader, gallery *models.Gallery) (string, error) { +func GetStudioName(ctx context.Context, reader studio.Finder, gallery *models.Gallery) (string, error) { if gallery.StudioID.Valid { - studio, err := reader.Find(int(gallery.StudioID.Int64)) + studio, err := reader.Find(ctx, int(gallery.StudioID.Int64)) if err != nil { return "", err } diff --git a/pkg/gallery/export_test.go b/pkg/gallery/export_test.go index 80418d7e0..fe371fad7 100644 --- a/pkg/gallery/export_test.go +++ b/pkg/gallery/export_test.go @@ -154,15 +154,15 @@ func TestGetStudioName(t *testing.T) { studioErr := errors.New("error getting image") - mockStudioReader.On("Find", studioID).Return(&models.Studio{ + mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ Name: models.NullString(studioName), }, nil).Once() - mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() - mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() + mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() + mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() for i, s := range getStudioScenarios { gallery := s.input - json, err := GetStudioName(mockStudioReader, &gallery) + json, err := GetStudioName(testCtx, mockStudioReader, &gallery) switch { case !s.err && err != nil: diff --git a/pkg/gallery/import.go b/pkg/gallery/import.go index f82cff13b..85c90e3f0 100644 --- a/pkg/gallery/import.go +++ b/pkg/gallery/import.go @@ -1,20 +1,30 @@ package gallery import ( + "context" "database/sql" "fmt" "strings" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/jsonschema" + "github.com/stashapp/stash/pkg/performer" "github.com/stashapp/stash/pkg/sliceutil/stringslice" + "github.com/stashapp/stash/pkg/studio" + "github.com/stashapp/stash/pkg/tag" ) +type FullCreatorUpdater interface { + FinderCreatorUpdater + UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error + UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error +} + type Importer struct { - ReaderWriter models.GalleryReaderWriter - StudioWriter models.StudioReaderWriter - PerformerWriter models.PerformerReaderWriter - TagWriter models.TagReaderWriter + ReaderWriter FullCreatorUpdater + StudioWriter studio.NameFinderCreator + PerformerWriter performer.NameFinderCreator + TagWriter tag.NameFinderCreator Input jsonschema.Gallery MissingRefBehaviour models.ImportMissingRefEnum @@ -23,18 +33,18 @@ type Importer struct { tags []*models.Tag } -func (i *Importer) PreImport() error { +func (i *Importer) PreImport(ctx context.Context) error { i.gallery = i.galleryJSONToGallery(i.Input) - if err := i.populateStudio(); err != nil { + if err := i.populateStudio(ctx); err != nil { return err } - if err := i.populatePerformers(); err != nil { + if err := i.populatePerformers(ctx); err != nil { return err } - if err := i.populateTags(); err != nil { + if err := i.populateTags(ctx); err != nil { return err } @@ -74,9 +84,9 @@ func (i *Importer) galleryJSONToGallery(galleryJSON jsonschema.Gallery) models.G return newGallery } -func (i *Importer) populateStudio() error { +func (i *Importer) populateStudio(ctx context.Context) error { if i.Input.Studio != "" { - studio, err := i.StudioWriter.FindByName(i.Input.Studio, false) + studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false) if err != nil { return fmt.Errorf("error finding studio by name: %v", err) } @@ -91,7 +101,7 @@ func (i *Importer) populateStudio() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - studioID, err := i.createStudio(i.Input.Studio) + studioID, err := i.createStudio(ctx, i.Input.Studio) if err != nil { return err } @@ -108,10 +118,10 @@ func (i *Importer) populateStudio() error { return nil } -func (i *Importer) createStudio(name string) (int, error) { +func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { newStudio := *models.NewStudio(name) - created, err := i.StudioWriter.Create(newStudio) + created, err := i.StudioWriter.Create(ctx, newStudio) if err != nil { return 0, err } @@ -119,10 +129,10 @@ func (i *Importer) createStudio(name string) (int, error) { return created.ID, nil } -func (i *Importer) populatePerformers() error { +func (i *Importer) populatePerformers(ctx context.Context) error { if len(i.Input.Performers) > 0 { names := i.Input.Performers - performers, err := i.PerformerWriter.FindByNames(names, false) + performers, err := i.PerformerWriter.FindByNames(ctx, names, false) if err != nil { return err } @@ -145,7 +155,7 @@ func (i *Importer) populatePerformers() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - createdPerformers, err := i.createPerformers(missingPerformers) + createdPerformers, err := i.createPerformers(ctx, missingPerformers) if err != nil { return fmt.Errorf("error creating gallery performers: %v", err) } @@ -162,12 +172,12 @@ func (i *Importer) populatePerformers() error { return nil } -func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) { +func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*models.Performer, error) { var ret []*models.Performer for _, name := range names { newPerformer := *models.NewPerformer(name) - created, err := i.PerformerWriter.Create(newPerformer) + created, err := i.PerformerWriter.Create(ctx, newPerformer) if err != nil { return nil, err } @@ -178,10 +188,10 @@ func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) return ret, nil } -func (i *Importer) populateTags() error { +func (i *Importer) populateTags(ctx context.Context) error { if len(i.Input.Tags) > 0 { names := i.Input.Tags - tags, err := i.TagWriter.FindByNames(names, false) + tags, err := i.TagWriter.FindByNames(ctx, names, false) if err != nil { return err } @@ -201,7 +211,7 @@ func (i *Importer) populateTags() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - createdTags, err := i.createTags(missingTags) + createdTags, err := i.createTags(ctx, missingTags) if err != nil { return fmt.Errorf("error creating gallery tags: %v", err) } @@ -218,12 +228,12 @@ func (i *Importer) populateTags() error { return nil } -func (i *Importer) createTags(names []string) ([]*models.Tag, error) { +func (i *Importer) createTags(ctx context.Context, names []string) ([]*models.Tag, error) { var ret []*models.Tag for _, name := range names { newTag := *models.NewTag(name) - created, err := i.TagWriter.Create(newTag) + created, err := i.TagWriter.Create(ctx, newTag) if err != nil { return nil, err } @@ -234,14 +244,14 @@ func (i *Importer) createTags(names []string) ([]*models.Tag, error) { return ret, nil } -func (i *Importer) PostImport(id int) error { +func (i *Importer) PostImport(ctx context.Context, id int) error { if len(i.performers) > 0 { var performerIDs []int for _, performer := range i.performers { performerIDs = append(performerIDs, performer.ID) } - if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil { + if err := i.ReaderWriter.UpdatePerformers(ctx, id, performerIDs); err != nil { return fmt.Errorf("failed to associate performers: %v", err) } } @@ -251,7 +261,7 @@ func (i *Importer) PostImport(id int) error { for _, t := range i.tags { tagIDs = append(tagIDs, t.ID) } - if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { + if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %v", err) } } @@ -263,8 +273,8 @@ func (i *Importer) Name() string { return i.Input.Path } -func (i *Importer) FindExistingID() (*int, error) { - existing, err := i.ReaderWriter.FindByChecksum(i.Input.Checksum) +func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { + existing, err := i.ReaderWriter.FindByChecksum(ctx, i.Input.Checksum) if err != nil { return nil, err } @@ -277,8 +287,8 @@ func (i *Importer) FindExistingID() (*int, error) { return nil, nil } -func (i *Importer) Create() (*int, error) { - created, err := i.ReaderWriter.Create(i.gallery) +func (i *Importer) Create(ctx context.Context) (*int, error) { + created, err := i.ReaderWriter.Create(ctx, i.gallery) if err != nil { return nil, fmt.Errorf("error creating gallery: %v", err) } @@ -287,10 +297,10 @@ func (i *Importer) Create() (*int, error) { return &id, nil } -func (i *Importer) Update(id int) error { +func (i *Importer) Update(ctx context.Context, id int) error { gallery := i.gallery gallery.ID = id - _, err := i.ReaderWriter.Update(gallery) + _, err := i.ReaderWriter.Update(ctx, gallery) if err != nil { return fmt.Errorf("error updating existing gallery: %v", err) } diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index d50fd16d1..6f111aa4b 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -1,6 +1,7 @@ package gallery import ( + "context" "errors" "testing" "time" @@ -40,6 +41,8 @@ const ( errChecksum = "errChecksum" ) +var testCtx = context.Background() + var ( createdAt = time.Date(2001, time.January, 2, 1, 2, 3, 4, time.Local) updatedAt = time.Date(2002, time.January, 2, 1, 2, 3, 4, time.Local) @@ -75,7 +78,7 @@ func TestImporterPreImport(t *testing.T) { }, } - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) expectedGallery := models.Gallery{ @@ -112,17 +115,17 @@ func TestImporterPreImportWithStudio(t *testing.T) { }, } - studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{ + studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64) i.Input.Studio = existingStudioErr - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) studioReaderWriter.AssertExpectations(t) @@ -140,20 +143,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3) - studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{ + studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) + studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{ ID: existingStudioID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64) @@ -172,10 +175,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once() - studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) + studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() + studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -193,20 +196,20 @@ func TestImporterPreImportWithPerformer(t *testing.T) { }, } - performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ + performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, Name: models.NullString(existingPerformerName), }, }, nil).Once() - performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() + performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingPerformerID, i.performers[0].ID) i.Input.Performers = []string{existingPerformerErr} - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) performerReaderWriter.AssertExpectations(t) @@ -226,20 +229,20 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(&models.Performer{ + performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{ ID: existingPerformerID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingPerformerID, i.performers[0].ID) @@ -260,10 +263,10 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) + performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -281,20 +284,20 @@ func TestImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{ + tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Once() - tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() + tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) i.Input.Tags = []string{existingTagErr} - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) tagReaderWriter.AssertExpectations(t) @@ -314,20 +317,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3) - tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{ + tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) + tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{ ID: existingTagID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) @@ -348,10 +351,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once() - tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) + tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() + tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -369,13 +372,13 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { updateErr := errors.New("UpdatePerformers error") - galleryReaderWriter.On("UpdatePerformers", galleryID, []int{existingPerformerID}).Return(nil).Once() - galleryReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + galleryReaderWriter.On("UpdatePerformers", testCtx, galleryID, []int{existingPerformerID}).Return(nil).Once() + galleryReaderWriter.On("UpdatePerformers", testCtx, errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(galleryID) + err := i.PostImport(testCtx, galleryID) assert.Nil(t, err) - err = i.PostImport(errPerformersID) + err = i.PostImport(testCtx, errPerformersID) assert.NotNil(t, err) galleryReaderWriter.AssertExpectations(t) @@ -395,13 +398,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) { updateErr := errors.New("UpdateTags error") - galleryReaderWriter.On("UpdateTags", galleryID, []int{existingTagID}).Return(nil).Once() - galleryReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + galleryReaderWriter.On("UpdateTags", testCtx, galleryID, []int{existingTagID}).Return(nil).Once() + galleryReaderWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(galleryID) + err := i.PostImport(testCtx, galleryID) assert.Nil(t, err) - err = i.PostImport(errTagsID) + err = i.PostImport(testCtx, errTagsID) assert.NotNil(t, err) galleryReaderWriter.AssertExpectations(t) @@ -419,23 +422,23 @@ func TestImporterFindExistingID(t *testing.T) { } expectedErr := errors.New("FindBy* error") - readerWriter.On("FindByChecksum", missingChecksum).Return(nil, nil).Once() - readerWriter.On("FindByChecksum", checksum).Return(&models.Gallery{ + readerWriter.On("FindByChecksum", testCtx, missingChecksum).Return(nil, nil).Once() + readerWriter.On("FindByChecksum", testCtx, checksum).Return(&models.Gallery{ ID: existingGalleryID, }, nil).Once() - readerWriter.On("FindByChecksum", errChecksum).Return(nil, expectedErr).Once() + readerWriter.On("FindByChecksum", testCtx, errChecksum).Return(nil, expectedErr).Once() - id, err := i.FindExistingID() + id, err := i.FindExistingID(testCtx) assert.Nil(t, id) assert.Nil(t, err) i.Input.Checksum = checksum - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Equal(t, existingGalleryID, *id) assert.Nil(t, err) i.Input.Checksum = errChecksum - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -459,17 +462,17 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", gallery).Return(&models.Gallery{ + readerWriter.On("Create", testCtx, gallery).Return(&models.Gallery{ ID: galleryID, }, nil).Once() - readerWriter.On("Create", galleryErr).Return(nil, errCreate).Once() + readerWriter.On("Create", testCtx, galleryErr).Return(nil, errCreate).Once() - id, err := i.Create() + id, err := i.Create(testCtx) assert.Equal(t, galleryID, *id) assert.Nil(t, err) i.gallery = galleryErr - id, err = i.Create() + id, err = i.Create(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -490,9 +493,9 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input gallery.ID = galleryID - readerWriter.On("Update", gallery).Return(nil, nil).Once() + readerWriter.On("Update", testCtx, gallery).Return(nil, nil).Once() - err := i.Update(galleryID) + err := i.Update(testCtx, galleryID) assert.Nil(t, err) readerWriter.AssertExpectations(t) diff --git a/pkg/gallery/query.go b/pkg/gallery/query.go index f15e480f2..34065435b 100644 --- a/pkg/gallery/query.go +++ b/pkg/gallery/query.go @@ -1,12 +1,25 @@ package gallery import ( + "context" "strconv" "github.com/stashapp/stash/pkg/models" ) -func CountByPerformerID(r models.GalleryReader, id int) (int, error) { +type Queryer interface { + Query(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) +} + +type CountQueryer interface { + QueryCount(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) +} + +type ChecksumsFinder interface { + FindByChecksums(ctx context.Context, checksums []string) ([]*models.Gallery, error) +} + +func CountByPerformerID(ctx context.Context, r CountQueryer, id int) (int, error) { filter := &models.GalleryFilterType{ Performers: &models.MultiCriterionInput{ Value: []string{strconv.Itoa(id)}, @@ -14,10 +27,10 @@ func CountByPerformerID(r models.GalleryReader, id int) (int, error) { }, } - return r.QueryCount(filter, nil) + return r.QueryCount(ctx, filter, nil) } -func CountByStudioID(r models.GalleryReader, id int) (int, error) { +func CountByStudioID(ctx context.Context, r CountQueryer, id int) (int, error) { filter := &models.GalleryFilterType{ Studios: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(id)}, @@ -25,10 +38,10 @@ func CountByStudioID(r models.GalleryReader, id int) (int, error) { }, } - return r.QueryCount(filter, nil) + return r.QueryCount(ctx, filter, nil) } -func CountByTagID(r models.GalleryReader, id int) (int, error) { +func CountByTagID(ctx context.Context, r CountQueryer, id int) (int, error) { filter := &models.GalleryFilterType{ Tags: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(id)}, @@ -36,5 +49,5 @@ func CountByTagID(r models.GalleryReader, id int) (int, error) { }, } - return r.QueryCount(filter, nil) + return r.QueryCount(ctx, filter, nil) } diff --git a/pkg/gallery/scan.go b/pkg/gallery/scan.go index f45a26d77..643ba8988 100644 --- a/pkg/gallery/scan.go +++ b/pkg/gallery/scan.go @@ -14,18 +14,26 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) const mutexType = "gallery" +type FinderCreatorUpdater interface { + FindByChecksum(ctx context.Context, checksum string) (*models.Gallery, error) + Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error) + Update(ctx context.Context, updatedGallery models.Gallery) (*models.Gallery, error) +} + type Scanner struct { file.Scanner ImageExtensions []string StripFileExtension bool CaseSensitiveFs bool - TxnManager models.TransactionManager + TxnManager txn.Manager + CreatorUpdater FinderCreatorUpdater Paths *paths.Paths PluginCache *plugin.Cache MutexManager *utils.MutexManager @@ -75,19 +83,19 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase done := make(chan struct{}) scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done) - if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { // free the mutex once transaction is complete defer close(done) // ensure no clashes of hashes if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum { - dupe, _ := r.Gallery().FindByChecksum(retGallery.Checksum) + dupe, _ := scanner.CreatorUpdater.FindByChecksum(ctx, retGallery.Checksum) if dupe != nil { return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path.String) } } - retGallery, err = r.Gallery().Update(*retGallery) + retGallery, err = scanner.CreatorUpdater.Update(ctx, *retGallery) return err }); err != nil { return nil, false, err @@ -116,10 +124,10 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG scanner.MutexManager.Claim(mutexType, checksum, done) defer close(done) - if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { - qb := r.Gallery() + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { + qb := scanner.CreatorUpdater - g, _ = qb.FindByChecksum(checksum) + g, _ = qb.FindByChecksum(ctx, checksum) if g != nil { exists, _ := fsutil.FileExists(g.Path.String) if !scanner.CaseSensitiveFs { @@ -138,7 +146,7 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG String: path, Valid: true, } - g, err = qb.Update(*g) + g, err = qb.Update(ctx, *g) if err != nil { return err } @@ -167,7 +175,7 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG } logger.Infof("%s doesn't exist. Creating new item...", path) - g, err = qb.Create(*g) + g, err = qb.Create(ctx, *g) if err != nil { return err } diff --git a/pkg/gallery/update.go b/pkg/gallery/update.go index 4c16793ca..1c94faea6 100644 --- a/pkg/gallery/update.go +++ b/pkg/gallery/update.go @@ -1,29 +1,50 @@ package gallery import ( + "context" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/sliceutil/intslice" ) -func UpdateFileModTime(qb models.GalleryWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) { - return qb.UpdatePartial(models.GalleryPartial{ +type PartialUpdater interface { + UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error) +} + +type ImageUpdater interface { + GetImageIDs(ctx context.Context, galleryID int) ([]int, error) + UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error +} + +type PerformerUpdater interface { + GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error) + UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error +} + +type TagUpdater interface { + GetTagIDs(ctx context.Context, galleryID int) ([]int, error) + UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error +} + +func UpdateFileModTime(ctx context.Context, qb PartialUpdater, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) { + return qb.UpdatePartial(ctx, models.GalleryPartial{ ID: id, FileModTime: &modTime, }) } -func AddImage(qb models.GalleryReaderWriter, galleryID int, imageID int) error { - imageIDs, err := qb.GetImageIDs(galleryID) +func AddImage(ctx context.Context, qb ImageUpdater, galleryID int, imageID int) error { + imageIDs, err := qb.GetImageIDs(ctx, galleryID) if err != nil { return err } imageIDs = intslice.IntAppendUnique(imageIDs, imageID) - return qb.UpdateImages(galleryID, imageIDs) + return qb.UpdateImages(ctx, galleryID, imageIDs) } -func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool, error) { - performerIDs, err := qb.GetPerformerIDs(id) +func AddPerformer(ctx context.Context, qb PerformerUpdater, id int, performerID int) (bool, error) { + performerIDs, err := qb.GetPerformerIDs(ctx, id) if err != nil { return false, err } @@ -32,7 +53,7 @@ func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool, performerIDs = intslice.IntAppendUnique(performerIDs, performerID) if len(performerIDs) != oldLen { - if err := qb.UpdatePerformers(id, performerIDs); err != nil { + if err := qb.UpdatePerformers(ctx, id, performerIDs); err != nil { return false, err } @@ -42,8 +63,8 @@ func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool, return false, nil } -func AddTag(qb models.GalleryReaderWriter, id int, tagID int) (bool, error) { - tagIDs, err := qb.GetTagIDs(id) +func AddTag(ctx context.Context, qb TagUpdater, id int, tagID int) (bool, error) { + tagIDs, err := qb.GetTagIDs(ctx, id) if err != nil { return false, err } @@ -52,7 +73,7 @@ func AddTag(qb models.GalleryReaderWriter, id int, tagID int) (bool, error) { tagIDs = intslice.IntAppendUnique(tagIDs, tagID) if len(tagIDs) != oldLen { - if err := qb.UpdateTags(id, tagIDs); err != nil { + if err := qb.UpdateTags(ctx, id, tagIDs); err != nil { return false, err } diff --git a/pkg/image/delete.go b/pkg/image/delete.go index 35ab3704b..8e2ca8237 100644 --- a/pkg/image/delete.go +++ b/pkg/image/delete.go @@ -1,6 +1,8 @@ package image import ( + "context" + "github.com/stashapp/stash/pkg/file" "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/models" @@ -8,7 +10,7 @@ import ( ) type Destroyer interface { - Destroy(id int) error + Destroy(ctx context.Context, id int) error } // FileDeleter is an extension of file.Deleter that handles deletion of image files. @@ -30,7 +32,7 @@ func (d *FileDeleter) MarkGeneratedFiles(image *models.Image) error { } // Destroy destroys an image, optionally marking the file and generated files for deletion. -func Destroy(i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error { +func Destroy(ctx context.Context, i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error { // don't try to delete if the image is in a zip file if deleteFile && !file.IsZipPath(i.Path) { if err := fileDeleter.Files([]string{i.Path}); err != nil { @@ -44,5 +46,5 @@ func Destroy(i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, del } } - return destroyer.Destroy(i.ID) + return destroyer.Destroy(ctx, i.ID) } diff --git a/pkg/image/export.go b/pkg/image/export.go index 3938a39bf..da7306bdb 100644 --- a/pkg/image/export.go +++ b/pkg/image/export.go @@ -1,9 +1,12 @@ package image import ( + "context" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/json" "github.com/stashapp/stash/pkg/models/jsonschema" + "github.com/stashapp/stash/pkg/studio" ) // ToBasicJSON converts a image object into its JSON object equivalent. It @@ -56,9 +59,9 @@ func getImageFileJSON(image *models.Image) *jsonschema.ImageFile { // GetStudioName returns the name of the provided image's studio. It returns an // empty string if there is no studio assigned to the image. -func GetStudioName(reader models.StudioReader, image *models.Image) (string, error) { +func GetStudioName(ctx context.Context, reader studio.Finder, image *models.Image) (string, error) { if image.StudioID.Valid { - studio, err := reader.Find(int(image.StudioID.Int64)) + studio, err := reader.Find(ctx, int(image.StudioID.Int64)) if err != nil { return "", err } diff --git a/pkg/image/export_test.go b/pkg/image/export_test.go index 0a449c443..2aacac5ad 100644 --- a/pkg/image/export_test.go +++ b/pkg/image/export_test.go @@ -156,15 +156,15 @@ func TestGetStudioName(t *testing.T) { studioErr := errors.New("error getting image") - mockStudioReader.On("Find", studioID).Return(&models.Studio{ + mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ Name: models.NullString(studioName), }, nil).Once() - mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() - mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() + mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() + mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() for i, s := range getStudioScenarios { image := s.input - json, err := GetStudioName(mockStudioReader, &image) + json, err := GetStudioName(testCtx, mockStudioReader, &image) switch { case !s.err && err != nil: diff --git a/pkg/image/import.go b/pkg/image/import.go index 78b60c4b1..d1de6b2a5 100644 --- a/pkg/image/import.go +++ b/pkg/image/import.go @@ -1,21 +1,33 @@ package image import ( + "context" "database/sql" "fmt" "strings" + "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/jsonschema" + "github.com/stashapp/stash/pkg/performer" "github.com/stashapp/stash/pkg/sliceutil/stringslice" + "github.com/stashapp/stash/pkg/studio" + "github.com/stashapp/stash/pkg/tag" ) +type FullCreatorUpdater interface { + FinderCreatorUpdater + UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error + UpdateTags(ctx context.Context, imageID int, tagIDs []int) error + UpdateGalleries(ctx context.Context, imageID int, galleryIDs []int) error +} + type Importer struct { - ReaderWriter models.ImageReaderWriter - StudioWriter models.StudioReaderWriter - GalleryWriter models.GalleryReaderWriter - PerformerWriter models.PerformerReaderWriter - TagWriter models.TagReaderWriter + ReaderWriter FullCreatorUpdater + StudioWriter studio.NameFinderCreator + GalleryWriter gallery.ChecksumsFinder + PerformerWriter performer.NameFinderCreator + TagWriter tag.NameFinderCreator Input jsonschema.Image Path string MissingRefBehaviour models.ImportMissingRefEnum @@ -27,22 +39,22 @@ type Importer struct { tags []*models.Tag } -func (i *Importer) PreImport() error { +func (i *Importer) PreImport(ctx context.Context) error { i.image = i.imageJSONToImage(i.Input) - if err := i.populateStudio(); err != nil { + if err := i.populateStudio(ctx); err != nil { return err } - if err := i.populateGalleries(); err != nil { + if err := i.populateGalleries(ctx); err != nil { return err } - if err := i.populatePerformers(); err != nil { + if err := i.populatePerformers(ctx); err != nil { return err } - if err := i.populateTags(); err != nil { + if err := i.populateTags(ctx); err != nil { return err } @@ -82,9 +94,9 @@ func (i *Importer) imageJSONToImage(imageJSON jsonschema.Image) models.Image { return newImage } -func (i *Importer) populateStudio() error { +func (i *Importer) populateStudio(ctx context.Context) error { if i.Input.Studio != "" { - studio, err := i.StudioWriter.FindByName(i.Input.Studio, false) + studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false) if err != nil { return fmt.Errorf("error finding studio by name: %v", err) } @@ -99,7 +111,7 @@ func (i *Importer) populateStudio() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - studioID, err := i.createStudio(i.Input.Studio) + studioID, err := i.createStudio(ctx, i.Input.Studio) if err != nil { return err } @@ -116,10 +128,10 @@ func (i *Importer) populateStudio() error { return nil } -func (i *Importer) createStudio(name string) (int, error) { +func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { newStudio := *models.NewStudio(name) - created, err := i.StudioWriter.Create(newStudio) + created, err := i.StudioWriter.Create(ctx, newStudio) if err != nil { return 0, err } @@ -127,14 +139,14 @@ func (i *Importer) createStudio(name string) (int, error) { return created.ID, nil } -func (i *Importer) populateGalleries() error { +func (i *Importer) populateGalleries(ctx context.Context) error { for _, checksum := range i.Input.Galleries { - gallery, err := i.GalleryWriter.FindByChecksum(checksum) + gallery, err := i.GalleryWriter.FindByChecksums(ctx, []string{checksum}) if err != nil { return fmt.Errorf("error finding gallery: %v", err) } - if gallery == nil { + if len(gallery) == 0 { if i.MissingRefBehaviour == models.ImportMissingRefEnumFail { return fmt.Errorf("image gallery '%s' not found", i.Input.Studio) } @@ -144,17 +156,17 @@ func (i *Importer) populateGalleries() error { continue } } else { - i.galleries = append(i.galleries, gallery) + i.galleries = append(i.galleries, gallery[0]) } } return nil } -func (i *Importer) populatePerformers() error { +func (i *Importer) populatePerformers(ctx context.Context) error { if len(i.Input.Performers) > 0 { names := i.Input.Performers - performers, err := i.PerformerWriter.FindByNames(names, false) + performers, err := i.PerformerWriter.FindByNames(ctx, names, false) if err != nil { return err } @@ -177,7 +189,7 @@ func (i *Importer) populatePerformers() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - createdPerformers, err := i.createPerformers(missingPerformers) + createdPerformers, err := i.createPerformers(ctx, missingPerformers) if err != nil { return fmt.Errorf("error creating image performers: %v", err) } @@ -194,12 +206,12 @@ func (i *Importer) populatePerformers() error { return nil } -func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) { +func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*models.Performer, error) { var ret []*models.Performer for _, name := range names { newPerformer := *models.NewPerformer(name) - created, err := i.PerformerWriter.Create(newPerformer) + created, err := i.PerformerWriter.Create(ctx, newPerformer) if err != nil { return nil, err } @@ -210,10 +222,10 @@ func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) return ret, nil } -func (i *Importer) populateTags() error { +func (i *Importer) populateTags(ctx context.Context) error { if len(i.Input.Tags) > 0 { - tags, err := importTags(i.TagWriter, i.Input.Tags, i.MissingRefBehaviour) + tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour) if err != nil { return err } @@ -224,14 +236,14 @@ func (i *Importer) populateTags() error { return nil } -func (i *Importer) PostImport(id int) error { +func (i *Importer) PostImport(ctx context.Context, id int) error { if len(i.galleries) > 0 { var galleryIDs []int for _, g := range i.galleries { galleryIDs = append(galleryIDs, g.ID) } - if err := i.ReaderWriter.UpdateGalleries(id, galleryIDs); err != nil { + if err := i.ReaderWriter.UpdateGalleries(ctx, id, galleryIDs); err != nil { return fmt.Errorf("failed to associate galleries: %v", err) } } @@ -242,7 +254,7 @@ func (i *Importer) PostImport(id int) error { performerIDs = append(performerIDs, performer.ID) } - if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil { + if err := i.ReaderWriter.UpdatePerformers(ctx, id, performerIDs); err != nil { return fmt.Errorf("failed to associate performers: %v", err) } } @@ -252,7 +264,7 @@ func (i *Importer) PostImport(id int) error { for _, t := range i.tags { tagIDs = append(tagIDs, t.ID) } - if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { + if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %v", err) } } @@ -264,10 +276,10 @@ func (i *Importer) Name() string { return i.Path } -func (i *Importer) FindExistingID() (*int, error) { +func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { var existing *models.Image var err error - existing, err = i.ReaderWriter.FindByChecksum(i.Input.Checksum) + existing, err = i.ReaderWriter.FindByChecksum(ctx, i.Input.Checksum) if err != nil { return nil, err @@ -281,8 +293,8 @@ func (i *Importer) FindExistingID() (*int, error) { return nil, nil } -func (i *Importer) Create() (*int, error) { - created, err := i.ReaderWriter.Create(i.image) +func (i *Importer) Create(ctx context.Context) (*int, error) { + created, err := i.ReaderWriter.Create(ctx, i.image) if err != nil { return nil, fmt.Errorf("error creating image: %v", err) } @@ -292,11 +304,11 @@ func (i *Importer) Create() (*int, error) { return &id, nil } -func (i *Importer) Update(id int) error { +func (i *Importer) Update(ctx context.Context, id int) error { image := i.image image.ID = id i.ID = id - _, err := i.ReaderWriter.UpdateFull(image) + _, err := i.ReaderWriter.UpdateFull(ctx, image) if err != nil { return fmt.Errorf("error updating existing image: %v", err) } @@ -304,8 +316,8 @@ func (i *Importer) Update(id int) error { return nil } -func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) { - tags, err := tagWriter.FindByNames(names, false) +func importTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) { + tags, err := tagWriter.FindByNames(ctx, names, false) if err != nil { return nil, err } @@ -325,7 +337,7 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha } if missingRefBehaviour == models.ImportMissingRefEnumCreate { - createdTags, err := createTags(tagWriter, missingTags) + createdTags, err := createTags(ctx, tagWriter, missingTags) if err != nil { return nil, fmt.Errorf("error creating tags: %v", err) } @@ -339,12 +351,12 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha return tags, nil } -func createTags(tagWriter models.TagWriter, names []string) ([]*models.Tag, error) { +func createTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string) ([]*models.Tag, error) { var ret []*models.Tag for _, name := range names { newTag := *models.NewTag(name) - created, err := tagWriter.Create(newTag) + created, err := tagWriter.Create(ctx, newTag) if err != nil { return nil, err } diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index 156ec96d2..856c338c1 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -1,6 +1,7 @@ package image import ( + "context" "errors" "testing" @@ -47,6 +48,8 @@ const ( errChecksum = "errChecksum" ) +var testCtx = context.Background() + func TestImporterName(t *testing.T) { i := Importer{ Path: path, @@ -61,7 +64,7 @@ func TestImporterPreImport(t *testing.T) { Path: path, } - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) } @@ -76,17 +79,17 @@ func TestImporterPreImportWithStudio(t *testing.T) { }, } - studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{ + studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.image.StudioID.Int64) i.Input.Studio = existingStudioErr - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) studioReaderWriter.AssertExpectations(t) @@ -104,20 +107,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3) - studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{ + studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) + studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{ ID: existingStudioID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.image.StudioID.Int64) @@ -136,10 +139,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once() - studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) + studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() + studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -156,12 +159,12 @@ func TestImporterPreImportWithGallery(t *testing.T) { }, } - galleryReaderWriter.On("FindByChecksum", existingGalleryChecksum).Return(&models.Gallery{ + galleryReaderWriter.On("FindByChecksums", testCtx, []string{existingGalleryChecksum}).Return([]*models.Gallery{{ ID: existingGalleryID, - }, nil).Once() - galleryReaderWriter.On("FindByChecksum", existingGalleryErr).Return(nil, errors.New("FindByChecksum error")).Once() + }}, nil).Once() + galleryReaderWriter.On("FindByChecksums", testCtx, []string{existingGalleryErr}).Return(nil, errors.New("FindByChecksum error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingGalleryID, i.galleries[0].ID) @@ -169,7 +172,7 @@ func TestImporterPreImportWithGallery(t *testing.T) { existingGalleryErr, } - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) galleryReaderWriter.AssertExpectations(t) @@ -189,18 +192,18 @@ func TestImporterPreImportWithMissingGallery(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - galleryReaderWriter.On("FindByChecksum", missingGalleryChecksum).Return(nil, nil).Times(3) + galleryReaderWriter.On("FindByChecksums", testCtx, []string{missingGalleryChecksum}).Return(nil, nil).Times(3) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Nil(t, i.galleries) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Nil(t, i.galleries) @@ -221,20 +224,20 @@ func TestImporterPreImportWithPerformer(t *testing.T) { }, } - performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ + performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, Name: models.NullString(existingPerformerName), }, }, nil).Once() - performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() + performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingPerformerID, i.performers[0].ID) i.Input.Performers = []string{existingPerformerErr} - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) performerReaderWriter.AssertExpectations(t) @@ -254,20 +257,20 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(&models.Performer{ + performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{ ID: existingPerformerID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingPerformerID, i.performers[0].ID) @@ -288,10 +291,10 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) + performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -309,20 +312,20 @@ func TestImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{ + tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Once() - tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() + tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) i.Input.Tags = []string{existingTagErr} - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) tagReaderWriter.AssertExpectations(t) @@ -342,20 +345,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3) - tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{ + tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) + tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{ ID: existingTagID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) @@ -376,10 +379,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once() - tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) + tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() + tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -397,13 +400,13 @@ func TestImporterPostImportUpdateGallery(t *testing.T) { updateErr := errors.New("UpdateGalleries error") - readerWriter.On("UpdateGalleries", imageID, []int{existingGalleryID}).Return(nil).Once() - readerWriter.On("UpdateGalleries", errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + readerWriter.On("UpdateGalleries", testCtx, imageID, []int{existingGalleryID}).Return(nil).Once() + readerWriter.On("UpdateGalleries", testCtx, errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(imageID) + err := i.PostImport(testCtx, imageID) assert.Nil(t, err) - err = i.PostImport(errGalleriesID) + err = i.PostImport(testCtx, errGalleriesID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -423,13 +426,13 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { updateErr := errors.New("UpdatePerformers error") - readerWriter.On("UpdatePerformers", imageID, []int{existingPerformerID}).Return(nil).Once() - readerWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + readerWriter.On("UpdatePerformers", testCtx, imageID, []int{existingPerformerID}).Return(nil).Once() + readerWriter.On("UpdatePerformers", testCtx, errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(imageID) + err := i.PostImport(testCtx, imageID) assert.Nil(t, err) - err = i.PostImport(errPerformersID) + err = i.PostImport(testCtx, errPerformersID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -449,13 +452,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) { updateErr := errors.New("UpdateTags error") - readerWriter.On("UpdateTags", imageID, []int{existingTagID}).Return(nil).Once() - readerWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + readerWriter.On("UpdateTags", testCtx, imageID, []int{existingTagID}).Return(nil).Once() + readerWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(imageID) + err := i.PostImport(testCtx, imageID) assert.Nil(t, err) - err = i.PostImport(errTagsID) + err = i.PostImport(testCtx, errTagsID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -473,23 +476,23 @@ func TestImporterFindExistingID(t *testing.T) { } expectedErr := errors.New("FindBy* error") - readerWriter.On("FindByChecksum", missingChecksum).Return(nil, nil).Once() - readerWriter.On("FindByChecksum", checksum).Return(&models.Image{ + readerWriter.On("FindByChecksum", testCtx, missingChecksum).Return(nil, nil).Once() + readerWriter.On("FindByChecksum", testCtx, checksum).Return(&models.Image{ ID: existingImageID, }, nil).Once() - readerWriter.On("FindByChecksum", errChecksum).Return(nil, expectedErr).Once() + readerWriter.On("FindByChecksum", testCtx, errChecksum).Return(nil, expectedErr).Once() - id, err := i.FindExistingID() + id, err := i.FindExistingID(testCtx) assert.Nil(t, id) assert.Nil(t, err) i.Input.Checksum = checksum - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Equal(t, existingImageID, *id) assert.Nil(t, err) i.Input.Checksum = errChecksum - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -513,18 +516,18 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", image).Return(&models.Image{ + readerWriter.On("Create", testCtx, image).Return(&models.Image{ ID: imageID, }, nil).Once() - readerWriter.On("Create", imageErr).Return(nil, errCreate).Once() + readerWriter.On("Create", testCtx, imageErr).Return(nil, errCreate).Once() - id, err := i.Create() + id, err := i.Create(testCtx) assert.Equal(t, imageID, *id) assert.Nil(t, err) assert.Equal(t, imageID, i.ID) i.image = imageErr - id, err = i.Create() + id, err = i.Create(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -551,9 +554,9 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input image.ID = imageID - readerWriter.On("UpdateFull", image).Return(nil, nil).Once() + readerWriter.On("UpdateFull", testCtx, image).Return(nil, nil).Once() - err := i.Update(imageID) + err := i.Update(testCtx, imageID) assert.Nil(t, err) assert.Equal(t, imageID, i.ID) @@ -561,9 +564,9 @@ func TestUpdate(t *testing.T) { // need to set id separately imageErr.ID = errImageID - readerWriter.On("UpdateFull", imageErr).Return(nil, errUpdate).Once() + readerWriter.On("UpdateFull", testCtx, imageErr).Return(nil, errUpdate).Once() - err = i.Update(errImageID) + err = i.Update(testCtx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) diff --git a/pkg/image/query.go b/pkg/image/query.go index 058d0a842..36ed3a8c3 100644 --- a/pkg/image/query.go +++ b/pkg/image/query.go @@ -1,13 +1,18 @@ package image import ( + "context" "strconv" "github.com/stashapp/stash/pkg/models" ) type Queryer interface { - Query(options models.ImageQueryOptions) (*models.ImageQueryResult, error) + Query(ctx context.Context, options models.ImageQueryOptions) (*models.ImageQueryResult, error) +} + +type CountQueryer interface { + QueryCount(ctx context.Context, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) } // QueryOptions returns a ImageQueryResult populated with the provided filters. @@ -22,13 +27,13 @@ func QueryOptions(imageFilter *models.ImageFilterType, findFilter *models.FindFi } // Query queries for images using the provided filters. -func Query(qb Queryer, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, error) { - result, err := qb.Query(QueryOptions(imageFilter, findFilter, false)) +func Query(ctx context.Context, qb Queryer, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, error) { + result, err := qb.Query(ctx, QueryOptions(imageFilter, findFilter, false)) if err != nil { return nil, err } - images, err := result.Resolve() + images, err := result.Resolve(ctx) if err != nil { return nil, err } @@ -36,7 +41,7 @@ func Query(qb Queryer, imageFilter *models.ImageFilterType, findFilter *models.F return images, nil } -func CountByPerformerID(r models.ImageReader, id int) (int, error) { +func CountByPerformerID(ctx context.Context, r CountQueryer, id int) (int, error) { filter := &models.ImageFilterType{ Performers: &models.MultiCriterionInput{ Value: []string{strconv.Itoa(id)}, @@ -44,10 +49,10 @@ func CountByPerformerID(r models.ImageReader, id int) (int, error) { }, } - return r.QueryCount(filter, nil) + return r.QueryCount(ctx, filter, nil) } -func CountByStudioID(r models.ImageReader, id int) (int, error) { +func CountByStudioID(ctx context.Context, r CountQueryer, id int) (int, error) { filter := &models.ImageFilterType{ Studios: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(id)}, @@ -55,10 +60,10 @@ func CountByStudioID(r models.ImageReader, id int) (int, error) { }, } - return r.QueryCount(filter, nil) + return r.QueryCount(ctx, filter, nil) } -func CountByTagID(r models.ImageReader, id int) (int, error) { +func CountByTagID(ctx context.Context, r CountQueryer, id int) (int, error) { filter := &models.ImageFilterType{ Tags: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(id)}, @@ -66,10 +71,10 @@ func CountByTagID(r models.ImageReader, id int) (int, error) { }, } - return r.QueryCount(filter, nil) + return r.QueryCount(ctx, filter, nil) } -func FindByGalleryID(r models.ImageReader, galleryID int, sortBy string, sortDir models.SortDirectionEnum) ([]*models.Image, error) { +func FindByGalleryID(ctx context.Context, r Queryer, galleryID int, sortBy string, sortDir models.SortDirectionEnum) ([]*models.Image, error) { perPage := -1 findFilter := models.FindFilterType{ @@ -84,7 +89,7 @@ func FindByGalleryID(r models.ImageReader, galleryID int, sortBy string, sortDir findFilter.Direction = &sortDir } - return Query(r, &models.ImageFilterType{ + return Query(ctx, r, &models.ImageFilterType{ Galleries: &models.MultiCriterionInput{ Value: []string{strconv.Itoa(galleryID)}, Modifier: models.CriterionModifierIncludes, diff --git a/pkg/image/scan.go b/pkg/image/scan.go index 8fa2f24a6..751f41000 100644 --- a/pkg/image/scan.go +++ b/pkg/image/scan.go @@ -12,18 +12,27 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) const mutexType = "image" +type FinderCreatorUpdater interface { + FindByChecksum(ctx context.Context, checksum string) (*models.Image, error) + Create(ctx context.Context, newImage models.Image) (*models.Image, error) + UpdateFull(ctx context.Context, updatedImage models.Image) (*models.Image, error) + Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error) +} + type Scanner struct { file.Scanner StripFileExtension bool CaseSensitiveFs bool - TxnManager models.TransactionManager + TxnManager txn.Manager + CreatorUpdater FinderCreatorUpdater Paths *paths.Paths PluginCache *plugin.Cache MutexManager *utils.MutexManager @@ -71,20 +80,20 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase done := make(chan struct{}) scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done) - if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { // free the mutex once transaction is complete defer close(done) var err error // ensure no clashes of hashes if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum { - dupe, _ := r.Image().FindByChecksum(i.Checksum) + dupe, _ := scanner.CreatorUpdater.FindByChecksum(ctx, i.Checksum) if dupe != nil { return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path) } } - retImage, err = r.Image().UpdateFull(*i) + retImage, err = scanner.CreatorUpdater.UpdateFull(ctx, *i) return err }); err != nil { return nil, err @@ -121,9 +130,9 @@ func (scanner *Scanner) ScanNew(ctx context.Context, f file.SourceFile) (retImag // check for image by checksum var existingImage *models.Image - if err := scanner.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { var err error - existingImage, err = r.Image().FindByChecksum(checksum) + existingImage, err = scanner.CreatorUpdater.FindByChecksum(ctx, checksum) return err }); err != nil { return nil, err @@ -151,8 +160,8 @@ func (scanner *Scanner) ScanNew(ctx context.Context, f file.SourceFile) (retImag Path: &path, } - if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { - retImage, err = r.Image().Update(imagePartial) + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { + retImage, err = scanner.CreatorUpdater.Update(ctx, imagePartial) return err }); err != nil { return nil, err @@ -176,9 +185,9 @@ func (scanner *Scanner) ScanNew(ctx context.Context, f file.SourceFile) (retImag return nil, err } - if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { var err error - retImage, err = r.Image().Create(newImage) + retImage, err = scanner.CreatorUpdater.Create(ctx, newImage) return err }); err != nil { return nil, err diff --git a/pkg/image/update.go b/pkg/image/update.go index f8c43f965..1b1b22535 100644 --- a/pkg/image/update.go +++ b/pkg/image/update.go @@ -1,19 +1,35 @@ package image import ( + "context" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/sliceutil/intslice" ) -func UpdateFileModTime(qb models.ImageWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Image, error) { - return qb.Update(models.ImagePartial{ +type PartialUpdater interface { + Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error) +} + +type PerformerUpdater interface { + GetPerformerIDs(ctx context.Context, imageID int) ([]int, error) + UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error +} + +type TagUpdater interface { + GetTagIDs(ctx context.Context, imageID int) ([]int, error) + UpdateTags(ctx context.Context, imageID int, tagIDs []int) error +} + +func UpdateFileModTime(ctx context.Context, qb PartialUpdater, id int, modTime models.NullSQLiteTimestamp) (*models.Image, error) { + return qb.Update(ctx, models.ImagePartial{ ID: id, FileModTime: &modTime, }) } -func AddPerformer(qb models.ImageReaderWriter, id int, performerID int) (bool, error) { - performerIDs, err := qb.GetPerformerIDs(id) +func AddPerformer(ctx context.Context, qb PerformerUpdater, id int, performerID int) (bool, error) { + performerIDs, err := qb.GetPerformerIDs(ctx, id) if err != nil { return false, err } @@ -22,7 +38,7 @@ func AddPerformer(qb models.ImageReaderWriter, id int, performerID int) (bool, e performerIDs = intslice.IntAppendUnique(performerIDs, performerID) if len(performerIDs) != oldLen { - if err := qb.UpdatePerformers(id, performerIDs); err != nil { + if err := qb.UpdatePerformers(ctx, id, performerIDs); err != nil { return false, err } @@ -32,8 +48,8 @@ func AddPerformer(qb models.ImageReaderWriter, id int, performerID int) (bool, e return false, nil } -func AddTag(qb models.ImageReaderWriter, id int, tagID int) (bool, error) { - tagIDs, err := qb.GetTagIDs(id) +func AddTag(ctx context.Context, qb TagUpdater, id int, tagID int) (bool, error) { + tagIDs, err := qb.GetTagIDs(ctx, id) if err != nil { return false, err } @@ -42,7 +58,7 @@ func AddTag(qb models.ImageReaderWriter, id int, tagID int) (bool, error) { tagIDs = intslice.IntAppendUnique(tagIDs, tagID) if len(tagIDs) != oldLen { - if err := qb.UpdateTags(id, tagIDs); err != nil { + if err := qb.UpdateTags(ctx, id, tagIDs); err != nil { return false, err } diff --git a/pkg/match/cache.go b/pkg/match/cache.go index 6d7238809..06237c7f6 100644 --- a/pkg/match/cache.go +++ b/pkg/match/cache.go @@ -1,6 +1,10 @@ package match -import "github.com/stashapp/stash/pkg/models" +import ( + "context" + + "github.com/stashapp/stash/pkg/models" +) const singleFirstCharacterRegex = `^[\p{L}][.\-_ ]` @@ -16,14 +20,14 @@ type Cache struct { // against. This means that performers with single-letter words in their names could potentially // be missed. // This query is expensive, so it's queried once and cached, if the cache if provided. -func getSingleLetterPerformers(c *Cache, reader models.PerformerReader) ([]*models.Performer, error) { +func getSingleLetterPerformers(ctx context.Context, c *Cache, reader PerformerAutoTagQueryer) ([]*models.Performer, error) { if c == nil { c = &Cache{} } if c.singleCharPerformers == nil { pp := -1 - performers, _, err := reader.Query(&models.PerformerFilterType{ + performers, _, err := reader.Query(ctx, &models.PerformerFilterType{ Name: &models.StringCriterionInput{ Value: singleFirstCharacterRegex, Modifier: models.CriterionModifierMatchesRegex, @@ -49,14 +53,14 @@ func getSingleLetterPerformers(c *Cache, reader models.PerformerReader) ([]*mode // getSingleLetterStudios returns all studios with names that start with single character words. // See getSingleLetterPerformers for details. -func getSingleLetterStudios(c *Cache, reader models.StudioReader) ([]*models.Studio, error) { +func getSingleLetterStudios(ctx context.Context, c *Cache, reader StudioAutoTagQueryer) ([]*models.Studio, error) { if c == nil { c = &Cache{} } if c.singleCharStudios == nil { pp := -1 - studios, _, err := reader.Query(&models.StudioFilterType{ + studios, _, err := reader.Query(ctx, &models.StudioFilterType{ Name: &models.StringCriterionInput{ Value: singleFirstCharacterRegex, Modifier: models.CriterionModifierMatchesRegex, @@ -82,14 +86,14 @@ func getSingleLetterStudios(c *Cache, reader models.StudioReader) ([]*models.Stu // getSingleLetterTags returns all tags with names that start with single character words. // See getSingleLetterPerformers for details. -func getSingleLetterTags(c *Cache, reader models.TagReader) ([]*models.Tag, error) { +func getSingleLetterTags(ctx context.Context, c *Cache, reader TagAutoTagQueryer) ([]*models.Tag, error) { if c == nil { c = &Cache{} } if c.singleCharTags == nil { pp := -1 - tags, _, err := reader.Query(&models.TagFilterType{ + tags, _, err := reader.Query(ctx, &models.TagFilterType{ Name: &models.StringCriterionInput{ Value: singleFirstCharacterRegex, Modifier: models.CriterionModifierMatchesRegex, diff --git a/pkg/match/path.go b/pkg/match/path.go index 4f20423dd..a20678834 100644 --- a/pkg/match/path.go +++ b/pkg/match/path.go @@ -1,6 +1,7 @@ package match import ( + "context" "fmt" "path/filepath" "regexp" @@ -12,6 +13,8 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/sliceutil/stringslice" + "github.com/stashapp/stash/pkg/studio" + "github.com/stashapp/stash/pkg/tag" ) const ( @@ -24,6 +27,23 @@ const ( var separatorRE = regexp.MustCompile(separatorPattern) +type PerformerAutoTagQueryer interface { + Query(ctx context.Context, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) + QueryForAutoTag(ctx context.Context, words []string) ([]*models.Performer, error) +} + +type StudioAutoTagQueryer interface { + QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) + studio.Queryer + GetAliases(ctx context.Context, studioID int) ([]string, error) +} + +type TagAutoTagQueryer interface { + QueryForAutoTag(ctx context.Context, words []string) ([]*models.Tag, error) + tag.Queryer + GetAliases(ctx context.Context, tagID int) ([]string, error) +} + func getPathQueryRegex(name string) string { // escape specific regex characters name = regexp.QuoteMeta(name) @@ -124,13 +144,13 @@ func regexpMatchesPath(r *regexp.Regexp, path string) int { return found[len(found)-1][0] } -func getPerformers(words []string, performerReader models.PerformerReader, cache *Cache) ([]*models.Performer, error) { - performers, err := performerReader.QueryForAutoTag(words) +func getPerformers(ctx context.Context, words []string, performerReader PerformerAutoTagQueryer, cache *Cache) ([]*models.Performer, error) { + performers, err := performerReader.QueryForAutoTag(ctx, words) if err != nil { return nil, err } - swPerformers, err := getSingleLetterPerformers(cache, performerReader) + swPerformers, err := getSingleLetterPerformers(ctx, cache, performerReader) if err != nil { return nil, err } @@ -138,10 +158,10 @@ func getPerformers(words []string, performerReader models.PerformerReader, cache return append(performers, swPerformers...), nil } -func PathToPerformers(path string, reader models.PerformerReader, cache *Cache, trimExt bool) ([]*models.Performer, error) { +func PathToPerformers(ctx context.Context, path string, reader PerformerAutoTagQueryer, cache *Cache, trimExt bool) ([]*models.Performer, error) { words := getPathWords(path, trimExt) - performers, err := getPerformers(words, reader, cache) + performers, err := getPerformers(ctx, words, reader, cache) if err != nil { return nil, err } @@ -157,13 +177,13 @@ func PathToPerformers(path string, reader models.PerformerReader, cache *Cache, return ret, nil } -func getStudios(words []string, reader models.StudioReader, cache *Cache) ([]*models.Studio, error) { - studios, err := reader.QueryForAutoTag(words) +func getStudios(ctx context.Context, words []string, reader StudioAutoTagQueryer, cache *Cache) ([]*models.Studio, error) { + studios, err := reader.QueryForAutoTag(ctx, words) if err != nil { return nil, err } - swStudios, err := getSingleLetterStudios(cache, reader) + swStudios, err := getSingleLetterStudios(ctx, cache, reader) if err != nil { return nil, err } @@ -174,9 +194,9 @@ func getStudios(words []string, reader models.StudioReader, cache *Cache) ([]*mo // PathToStudio returns the Studio that matches the given path. // Where multiple matching studios are found, the one that matches the latest // position in the path is returned. -func PathToStudio(path string, reader models.StudioReader, cache *Cache, trimExt bool) (*models.Studio, error) { +func PathToStudio(ctx context.Context, path string, reader StudioAutoTagQueryer, cache *Cache, trimExt bool) (*models.Studio, error) { words := getPathWords(path, trimExt) - candidates, err := getStudios(words, reader, cache) + candidates, err := getStudios(ctx, words, reader, cache) if err != nil { return nil, err @@ -191,7 +211,7 @@ func PathToStudio(path string, reader models.StudioReader, cache *Cache, trimExt index = matchIndex } - aliases, err := reader.GetAliases(c.ID) + aliases, err := reader.GetAliases(ctx, c.ID) if err != nil { return nil, err } @@ -208,13 +228,13 @@ func PathToStudio(path string, reader models.StudioReader, cache *Cache, trimExt return ret, nil } -func getTags(words []string, reader models.TagReader, cache *Cache) ([]*models.Tag, error) { - tags, err := reader.QueryForAutoTag(words) +func getTags(ctx context.Context, words []string, reader TagAutoTagQueryer, cache *Cache) ([]*models.Tag, error) { + tags, err := reader.QueryForAutoTag(ctx, words) if err != nil { return nil, err } - swTags, err := getSingleLetterTags(cache, reader) + swTags, err := getSingleLetterTags(ctx, cache, reader) if err != nil { return nil, err } @@ -222,9 +242,9 @@ func getTags(words []string, reader models.TagReader, cache *Cache) ([]*models.T return append(tags, swTags...), nil } -func PathToTags(path string, reader models.TagReader, cache *Cache, trimExt bool) ([]*models.Tag, error) { +func PathToTags(ctx context.Context, path string, reader TagAutoTagQueryer, cache *Cache, trimExt bool) ([]*models.Tag, error) { words := getPathWords(path, trimExt) - tags, err := getTags(words, reader, cache) + tags, err := getTags(ctx, words, reader, cache) if err != nil { return nil, err @@ -238,7 +258,7 @@ func PathToTags(path string, reader models.TagReader, cache *Cache, trimExt bool } if !matches { - aliases, err := reader.GetAliases(t.ID) + aliases, err := reader.GetAliases(ctx, t.ID) if err != nil { return nil, err } @@ -258,7 +278,7 @@ func PathToTags(path string, reader models.TagReader, cache *Cache, trimExt bool return ret, nil } -func PathToScenes(name string, paths []string, sceneReader models.SceneReader) ([]*models.Scene, error) { +func PathToScenes(ctx context.Context, name string, paths []string, sceneReader scene.Queryer) ([]*models.Scene, error) { regex := getPathQueryRegex(name) organized := false filter := models.SceneFilterType{ @@ -272,7 +292,7 @@ func PathToScenes(name string, paths []string, sceneReader models.SceneReader) ( filter.And = scene.PathsFilter(paths) pp := models.PerPageAll - scenes, err := scene.Query(sceneReader, &filter, &models.FindFilterType{ + scenes, err := scene.Query(ctx, sceneReader, &filter, &models.FindFilterType{ PerPage: &pp, }) @@ -295,7 +315,7 @@ func PathToScenes(name string, paths []string, sceneReader models.SceneReader) ( return ret, nil } -func PathToImages(name string, paths []string, imageReader models.ImageReader) ([]*models.Image, error) { +func PathToImages(ctx context.Context, name string, paths []string, imageReader image.Queryer) ([]*models.Image, error) { regex := getPathQueryRegex(name) organized := false filter := models.ImageFilterType{ @@ -309,7 +329,7 @@ func PathToImages(name string, paths []string, imageReader models.ImageReader) ( filter.And = image.PathsFilter(paths) pp := models.PerPageAll - images, err := image.Query(imageReader, &filter, &models.FindFilterType{ + images, err := image.Query(ctx, imageReader, &filter, &models.FindFilterType{ PerPage: &pp, }) @@ -332,7 +352,7 @@ func PathToImages(name string, paths []string, imageReader models.ImageReader) ( return ret, nil } -func PathToGalleries(name string, paths []string, galleryReader models.GalleryReader) ([]*models.Gallery, error) { +func PathToGalleries(ctx context.Context, name string, paths []string, galleryReader gallery.Queryer) ([]*models.Gallery, error) { regex := getPathQueryRegex(name) organized := false filter := models.GalleryFilterType{ @@ -346,7 +366,7 @@ func PathToGalleries(name string, paths []string, galleryReader models.GalleryRe filter.And = gallery.PathsFilter(paths) pp := models.PerPageAll - gallerys, _, err := galleryReader.Query(&filter, &models.FindFilterType{ + gallerys, _, err := galleryReader.Query(ctx, &filter, &models.FindFilterType{ PerPage: &pp, }) diff --git a/pkg/match/scraped.go b/pkg/match/scraped.go index 1e9de81e1..d1182a329 100644 --- a/pkg/match/scraped.go +++ b/pkg/match/scraped.go @@ -1,6 +1,7 @@ package match import ( + "context" "strconv" "github.com/stashapp/stash/pkg/models" @@ -8,16 +9,25 @@ import ( "github.com/stashapp/stash/pkg/tag" ) +type PerformerFinder interface { + FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) + FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Performer, error) +} + +type MovieNamesFinder interface { + FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Movie, error) +} + // ScrapedPerformer matches the provided performer with the // performers in the database and sets the ID field if one is found. -func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer, stashBoxEndpoint *string) error { +func ScrapedPerformer(ctx context.Context, qb PerformerFinder, p *models.ScrapedPerformer, stashBoxEndpoint *string) error { if p.StoredID != nil || p.Name == nil { return nil } // Check if a performer with the StashID already exists if stashBoxEndpoint != nil && p.RemoteSiteID != nil { - performers, err := qb.FindByStashID(models.StashID{ + performers, err := qb.FindByStashID(ctx, models.StashID{ StashID: *p.RemoteSiteID, Endpoint: *stashBoxEndpoint, }) @@ -31,7 +41,7 @@ func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer, sta } } - performers, err := qb.FindByNames([]string{*p.Name}, true) + performers, err := qb.FindByNames(ctx, []string{*p.Name}, true) if err != nil { return err @@ -47,16 +57,21 @@ func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer, sta return nil } +type StudioFinder interface { + studio.Queryer + FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) +} + // ScrapedStudio matches the provided studio with the studios // in the database and sets the ID field if one is found. -func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndpoint *string) error { +func ScrapedStudio(ctx context.Context, qb StudioFinder, s *models.ScrapedStudio, stashBoxEndpoint *string) error { if s.StoredID != nil { return nil } // Check if a studio with the StashID already exists if stashBoxEndpoint != nil && s.RemoteSiteID != nil { - studios, err := qb.FindByStashID(models.StashID{ + studios, err := qb.FindByStashID(ctx, models.StashID{ StashID: *s.RemoteSiteID, Endpoint: *stashBoxEndpoint, }) @@ -70,7 +85,7 @@ func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndp } } - st, err := studio.ByName(qb, s.Name) + st, err := studio.ByName(ctx, qb, s.Name) if err != nil { return err @@ -78,7 +93,7 @@ func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndp if st == nil { // try matching by alias - st, err = studio.ByAlias(qb, s.Name) + st, err = studio.ByAlias(ctx, qb, s.Name) if err != nil { return err } @@ -96,12 +111,12 @@ func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndp // ScrapedMovie matches the provided movie with the movies // in the database and sets the ID field if one is found. -func ScrapedMovie(qb models.MovieReader, m *models.ScrapedMovie) error { +func ScrapedMovie(ctx context.Context, qb MovieNamesFinder, m *models.ScrapedMovie) error { if m.StoredID != nil || m.Name == nil { return nil } - movies, err := qb.FindByNames([]string{*m.Name}, true) + movies, err := qb.FindByNames(ctx, []string{*m.Name}, true) if err != nil { return err @@ -119,12 +134,12 @@ func ScrapedMovie(qb models.MovieReader, m *models.ScrapedMovie) error { // ScrapedTag matches the provided tag with the tags // in the database and sets the ID field if one is found. -func ScrapedTag(qb models.TagReader, s *models.ScrapedTag) error { +func ScrapedTag(ctx context.Context, qb tag.Queryer, s *models.ScrapedTag) error { if s.StoredID != nil { return nil } - t, err := tag.ByName(qb, s.Name) + t, err := tag.ByName(ctx, qb, s.Name) if err != nil { return err @@ -132,7 +147,7 @@ func ScrapedTag(qb models.TagReader, s *models.ScrapedTag) error { if t == nil { // try matching by alias - t, err = tag.ByAlias(qb, s.Name) + t, err = tag.ByAlias(ctx, qb, s.Name) if err != nil { return err } diff --git a/pkg/models/gallery.go b/pkg/models/gallery.go index 3c85e7193..676b61937 100644 --- a/pkg/models/gallery.go +++ b/pkg/models/gallery.go @@ -1,5 +1,7 @@ package models +import "context" + type GalleryFilterType struct { And *GalleryFilterType `json:"AND"` Or *GalleryFilterType `json:"OR"` @@ -67,33 +69,33 @@ type GalleryDestroyInput struct { } type GalleryReader interface { - Find(id int) (*Gallery, error) - FindMany(ids []int) ([]*Gallery, error) - FindByChecksum(checksum string) (*Gallery, error) - FindByChecksums(checksums []string) ([]*Gallery, error) - FindByPath(path string) (*Gallery, error) - FindBySceneID(sceneID int) ([]*Gallery, error) - FindByImageID(imageID int) ([]*Gallery, error) - Count() (int, error) - All() ([]*Gallery, error) - Query(galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int, error) - QueryCount(galleryFilter *GalleryFilterType, findFilter *FindFilterType) (int, error) - GetPerformerIDs(galleryID int) ([]int, error) - GetTagIDs(galleryID int) ([]int, error) - GetSceneIDs(galleryID int) ([]int, error) - GetImageIDs(galleryID int) ([]int, error) + Find(ctx context.Context, id int) (*Gallery, error) + FindMany(ctx context.Context, ids []int) ([]*Gallery, error) + FindByChecksum(ctx context.Context, checksum string) (*Gallery, error) + FindByChecksums(ctx context.Context, checksums []string) ([]*Gallery, error) + FindByPath(ctx context.Context, path string) (*Gallery, error) + FindBySceneID(ctx context.Context, sceneID int) ([]*Gallery, error) + FindByImageID(ctx context.Context, imageID int) ([]*Gallery, error) + Count(ctx context.Context) (int, error) + All(ctx context.Context) ([]*Gallery, error) + Query(ctx context.Context, galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int, error) + QueryCount(ctx context.Context, galleryFilter *GalleryFilterType, findFilter *FindFilterType) (int, error) + GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error) + GetTagIDs(ctx context.Context, galleryID int) ([]int, error) + GetSceneIDs(ctx context.Context, galleryID int) ([]int, error) + GetImageIDs(ctx context.Context, galleryID int) ([]int, error) } type GalleryWriter interface { - Create(newGallery Gallery) (*Gallery, error) - Update(updatedGallery Gallery) (*Gallery, error) - UpdatePartial(updatedGallery GalleryPartial) (*Gallery, error) - UpdateFileModTime(id int, modTime NullSQLiteTimestamp) error - Destroy(id int) error - UpdatePerformers(galleryID int, performerIDs []int) error - UpdateTags(galleryID int, tagIDs []int) error - UpdateScenes(galleryID int, sceneIDs []int) error - UpdateImages(galleryID int, imageIDs []int) error + Create(ctx context.Context, newGallery Gallery) (*Gallery, error) + Update(ctx context.Context, updatedGallery Gallery) (*Gallery, error) + UpdatePartial(ctx context.Context, updatedGallery GalleryPartial) (*Gallery, error) + UpdateFileModTime(ctx context.Context, id int, modTime NullSQLiteTimestamp) error + Destroy(ctx context.Context, id int) error + UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error + UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error + UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error + UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error } type GalleryReaderWriter interface { diff --git a/pkg/models/image.go b/pkg/models/image.go index 4b28c4c36..4509ef709 100644 --- a/pkg/models/image.go +++ b/pkg/models/image.go @@ -1,5 +1,7 @@ package models +import "context" + type ImageFilterType struct { And *ImageFilterType `json:"AND"` Or *ImageFilterType `json:"OR"` @@ -73,54 +75,54 @@ func NewImageQueryResult(finder ImageFinder) *ImageQueryResult { } } -func (r *ImageQueryResult) Resolve() ([]*Image, error) { +func (r *ImageQueryResult) Resolve(ctx context.Context) ([]*Image, error) { // cache results if r.images == nil && r.resolveErr == nil { - r.images, r.resolveErr = r.finder.FindMany(r.IDs) + r.images, r.resolveErr = r.finder.FindMany(ctx, r.IDs) } return r.images, r.resolveErr } type ImageFinder interface { // TODO - rename to Find and remove existing method - FindMany(ids []int) ([]*Image, error) + FindMany(ctx context.Context, ids []int) ([]*Image, error) } type ImageReader interface { ImageFinder // TODO - remove this in another PR - Find(id int) (*Image, error) - FindByChecksum(checksum string) (*Image, error) - FindByGalleryID(galleryID int) ([]*Image, error) - CountByGalleryID(galleryID int) (int, error) - FindByPath(path string) (*Image, error) + Find(ctx context.Context, id int) (*Image, error) + FindByChecksum(ctx context.Context, checksum string) (*Image, error) + FindByGalleryID(ctx context.Context, galleryID int) ([]*Image, error) + CountByGalleryID(ctx context.Context, galleryID int) (int, error) + FindByPath(ctx context.Context, path string) (*Image, error) // FindByPerformerID(performerID int) ([]*Image, error) // CountByPerformerID(performerID int) (int, error) // FindByStudioID(studioID int) ([]*Image, error) - Count() (int, error) - Size() (float64, error) + Count(ctx context.Context) (int, error) + Size(ctx context.Context) (float64, error) // SizeCount() (string, error) // CountByStudioID(studioID int) (int, error) // CountByTagID(tagID int) (int, error) - All() ([]*Image, error) - Query(options ImageQueryOptions) (*ImageQueryResult, error) - QueryCount(imageFilter *ImageFilterType, findFilter *FindFilterType) (int, error) - GetGalleryIDs(imageID int) ([]int, error) - GetTagIDs(imageID int) ([]int, error) - GetPerformerIDs(imageID int) ([]int, error) + All(ctx context.Context) ([]*Image, error) + Query(ctx context.Context, options ImageQueryOptions) (*ImageQueryResult, error) + QueryCount(ctx context.Context, imageFilter *ImageFilterType, findFilter *FindFilterType) (int, error) + GetGalleryIDs(ctx context.Context, imageID int) ([]int, error) + GetTagIDs(ctx context.Context, imageID int) ([]int, error) + GetPerformerIDs(ctx context.Context, imageID int) ([]int, error) } type ImageWriter interface { - Create(newImage Image) (*Image, error) - Update(updatedImage ImagePartial) (*Image, error) - UpdateFull(updatedImage Image) (*Image, error) - IncrementOCounter(id int) (int, error) - DecrementOCounter(id int) (int, error) - ResetOCounter(id int) (int, error) - Destroy(id int) error - UpdateGalleries(imageID int, galleryIDs []int) error - UpdatePerformers(imageID int, performerIDs []int) error - UpdateTags(imageID int, tagIDs []int) error + Create(ctx context.Context, newImage Image) (*Image, error) + Update(ctx context.Context, updatedImage ImagePartial) (*Image, error) + UpdateFull(ctx context.Context, updatedImage Image) (*Image, error) + IncrementOCounter(ctx context.Context, id int) (int, error) + DecrementOCounter(ctx context.Context, id int) (int, error) + ResetOCounter(ctx context.Context, id int) (int, error) + Destroy(ctx context.Context, id int) error + UpdateGalleries(ctx context.Context, imageID int, galleryIDs []int) error + UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error + UpdateTags(ctx context.Context, imageID int, tagIDs []int) error } type ImageReaderWriter interface { diff --git a/pkg/models/mocks/GalleryReaderWriter.go b/pkg/models/mocks/GalleryReaderWriter.go index 9731147fe..ee8ec643d 100644 --- a/pkg/models/mocks/GalleryReaderWriter.go +++ b/pkg/models/mocks/GalleryReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type GalleryReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *GalleryReaderWriter) All() ([]*models.Gallery, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *GalleryReaderWriter) All(ctx context.Context) ([]*models.Gallery, error) { + ret := _m.Called(ctx) var r0 []*models.Gallery - if rf, ok := ret.Get(0).(func() []*models.Gallery); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.Gallery); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Gallery) @@ -26,8 +28,8 @@ func (_m *GalleryReaderWriter) All() ([]*models.Gallery, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,20 +37,20 @@ func (_m *GalleryReaderWriter) All() ([]*models.Gallery, error) { return r0, r1 } -// Count provides a mock function with given fields: -func (_m *GalleryReaderWriter) Count() (int, error) { - ret := _m.Called() +// Count provides a mock function with given fields: ctx +func (_m *GalleryReaderWriter) Count(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -56,13 +58,13 @@ func (_m *GalleryReaderWriter) Count() (int, error) { return r0, r1 } -// Create provides a mock function with given fields: newGallery -func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Gallery, error) { - ret := _m.Called(newGallery) +// Create provides a mock function with given fields: ctx, newGallery +func (_m *GalleryReaderWriter) Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error) { + ret := _m.Called(ctx, newGallery) var r0 *models.Gallery - if rf, ok := ret.Get(0).(func(models.Gallery) *models.Gallery); ok { - r0 = rf(newGallery) + if rf, ok := ret.Get(0).(func(context.Context, models.Gallery) *models.Gallery); ok { + r0 = rf(ctx, newGallery) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Gallery) @@ -70,8 +72,8 @@ func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Galler } var r1 error - if rf, ok := ret.Get(1).(func(models.Gallery) error); ok { - r1 = rf(newGallery) + if rf, ok := ret.Get(1).(func(context.Context, models.Gallery) error); ok { + r1 = rf(ctx, newGallery) } else { r1 = ret.Error(1) } @@ -79,13 +81,13 @@ func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Galler return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *GalleryReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *GalleryReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -93,13 +95,13 @@ func (_m *GalleryReaderWriter) Destroy(id int) error { return r0 } -// Find provides a mock function with given fields: id -func (_m *GalleryReaderWriter) Find(id int) (*models.Gallery, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *GalleryReaderWriter) Find(ctx context.Context, id int) (*models.Gallery, error) { + ret := _m.Called(ctx, id) var r0 *models.Gallery - if rf, ok := ret.Get(0).(func(int) *models.Gallery); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.Gallery); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Gallery) @@ -107,8 +109,8 @@ func (_m *GalleryReaderWriter) Find(id int) (*models.Gallery, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -116,13 +118,13 @@ func (_m *GalleryReaderWriter) Find(id int) (*models.Gallery, error) { return r0, r1 } -// FindByChecksum provides a mock function with given fields: checksum -func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery, error) { - ret := _m.Called(checksum) +// FindByChecksum provides a mock function with given fields: ctx, checksum +func (_m *GalleryReaderWriter) FindByChecksum(ctx context.Context, checksum string) (*models.Gallery, error) { + ret := _m.Called(ctx, checksum) var r0 *models.Gallery - if rf, ok := ret.Get(0).(func(string) *models.Gallery); ok { - r0 = rf(checksum) + if rf, ok := ret.Get(0).(func(context.Context, string) *models.Gallery); ok { + r0 = rf(ctx, checksum) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Gallery) @@ -130,8 +132,8 @@ func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery, } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(checksum) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, checksum) } else { r1 = ret.Error(1) } @@ -139,13 +141,13 @@ func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery, return r0, r1 } -// FindByChecksums provides a mock function with given fields: checksums -func (_m *GalleryReaderWriter) FindByChecksums(checksums []string) ([]*models.Gallery, error) { - ret := _m.Called(checksums) +// FindByChecksums provides a mock function with given fields: ctx, checksums +func (_m *GalleryReaderWriter) FindByChecksums(ctx context.Context, checksums []string) ([]*models.Gallery, error) { + ret := _m.Called(ctx, checksums) var r0 []*models.Gallery - if rf, ok := ret.Get(0).(func([]string) []*models.Gallery); ok { - r0 = rf(checksums) + if rf, ok := ret.Get(0).(func(context.Context, []string) []*models.Gallery); ok { + r0 = rf(ctx, checksums) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Gallery) @@ -153,8 +155,8 @@ func (_m *GalleryReaderWriter) FindByChecksums(checksums []string) ([]*models.Ga } var r1 error - if rf, ok := ret.Get(1).(func([]string) error); ok { - r1 = rf(checksums) + if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { + r1 = rf(ctx, checksums) } else { r1 = ret.Error(1) } @@ -162,13 +164,13 @@ func (_m *GalleryReaderWriter) FindByChecksums(checksums []string) ([]*models.Ga return r0, r1 } -// FindByImageID provides a mock function with given fields: imageID -func (_m *GalleryReaderWriter) FindByImageID(imageID int) ([]*models.Gallery, error) { - ret := _m.Called(imageID) +// FindByImageID provides a mock function with given fields: ctx, imageID +func (_m *GalleryReaderWriter) FindByImageID(ctx context.Context, imageID int) ([]*models.Gallery, error) { + ret := _m.Called(ctx, imageID) var r0 []*models.Gallery - if rf, ok := ret.Get(0).(func(int) []*models.Gallery); ok { - r0 = rf(imageID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Gallery); ok { + r0 = rf(ctx, imageID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Gallery) @@ -176,8 +178,8 @@ func (_m *GalleryReaderWriter) FindByImageID(imageID int) ([]*models.Gallery, er } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(imageID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, imageID) } else { r1 = ret.Error(1) } @@ -185,13 +187,13 @@ func (_m *GalleryReaderWriter) FindByImageID(imageID int) ([]*models.Gallery, er return r0, r1 } -// FindByPath provides a mock function with given fields: path -func (_m *GalleryReaderWriter) FindByPath(path string) (*models.Gallery, error) { - ret := _m.Called(path) +// FindByPath provides a mock function with given fields: ctx, path +func (_m *GalleryReaderWriter) FindByPath(ctx context.Context, path string) (*models.Gallery, error) { + ret := _m.Called(ctx, path) var r0 *models.Gallery - if rf, ok := ret.Get(0).(func(string) *models.Gallery); ok { - r0 = rf(path) + if rf, ok := ret.Get(0).(func(context.Context, string) *models.Gallery); ok { + r0 = rf(ctx, path) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Gallery) @@ -199,8 +201,8 @@ func (_m *GalleryReaderWriter) FindByPath(path string) (*models.Gallery, error) } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(path) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, path) } else { r1 = ret.Error(1) } @@ -208,13 +210,13 @@ func (_m *GalleryReaderWriter) FindByPath(path string) (*models.Gallery, error) return r0, r1 } -// FindBySceneID provides a mock function with given fields: sceneID -func (_m *GalleryReaderWriter) FindBySceneID(sceneID int) ([]*models.Gallery, error) { - ret := _m.Called(sceneID) +// FindBySceneID provides a mock function with given fields: ctx, sceneID +func (_m *GalleryReaderWriter) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Gallery, error) { + ret := _m.Called(ctx, sceneID) var r0 []*models.Gallery - if rf, ok := ret.Get(0).(func(int) []*models.Gallery); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Gallery); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Gallery) @@ -222,8 +224,8 @@ func (_m *GalleryReaderWriter) FindBySceneID(sceneID int) ([]*models.Gallery, er } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -231,13 +233,13 @@ func (_m *GalleryReaderWriter) FindBySceneID(sceneID int) ([]*models.Gallery, er return r0, r1 } -// FindMany provides a mock function with given fields: ids -func (_m *GalleryReaderWriter) FindMany(ids []int) ([]*models.Gallery, error) { - ret := _m.Called(ids) +// FindMany provides a mock function with given fields: ctx, ids +func (_m *GalleryReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Gallery, error) { + ret := _m.Called(ctx, ids) var r0 []*models.Gallery - if rf, ok := ret.Get(0).(func([]int) []*models.Gallery); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Gallery); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Gallery) @@ -245,8 +247,8 @@ func (_m *GalleryReaderWriter) FindMany(ids []int) ([]*models.Gallery, error) { } var r1 error - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -254,13 +256,13 @@ func (_m *GalleryReaderWriter) FindMany(ids []int) ([]*models.Gallery, error) { return r0, r1 } -// GetImageIDs provides a mock function with given fields: galleryID -func (_m *GalleryReaderWriter) GetImageIDs(galleryID int) ([]int, error) { - ret := _m.Called(galleryID) +// GetImageIDs provides a mock function with given fields: ctx, galleryID +func (_m *GalleryReaderWriter) GetImageIDs(ctx context.Context, galleryID int) ([]int, error) { + ret := _m.Called(ctx, galleryID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(galleryID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, galleryID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -268,8 +270,8 @@ func (_m *GalleryReaderWriter) GetImageIDs(galleryID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(galleryID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, galleryID) } else { r1 = ret.Error(1) } @@ -277,13 +279,13 @@ func (_m *GalleryReaderWriter) GetImageIDs(galleryID int) ([]int, error) { return r0, r1 } -// GetPerformerIDs provides a mock function with given fields: galleryID -func (_m *GalleryReaderWriter) GetPerformerIDs(galleryID int) ([]int, error) { - ret := _m.Called(galleryID) +// GetPerformerIDs provides a mock function with given fields: ctx, galleryID +func (_m *GalleryReaderWriter) GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error) { + ret := _m.Called(ctx, galleryID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(galleryID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, galleryID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -291,8 +293,8 @@ func (_m *GalleryReaderWriter) GetPerformerIDs(galleryID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(galleryID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, galleryID) } else { r1 = ret.Error(1) } @@ -300,13 +302,13 @@ func (_m *GalleryReaderWriter) GetPerformerIDs(galleryID int) ([]int, error) { return r0, r1 } -// GetSceneIDs provides a mock function with given fields: galleryID -func (_m *GalleryReaderWriter) GetSceneIDs(galleryID int) ([]int, error) { - ret := _m.Called(galleryID) +// GetSceneIDs provides a mock function with given fields: ctx, galleryID +func (_m *GalleryReaderWriter) GetSceneIDs(ctx context.Context, galleryID int) ([]int, error) { + ret := _m.Called(ctx, galleryID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(galleryID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, galleryID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -314,8 +316,8 @@ func (_m *GalleryReaderWriter) GetSceneIDs(galleryID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(galleryID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, galleryID) } else { r1 = ret.Error(1) } @@ -323,13 +325,13 @@ func (_m *GalleryReaderWriter) GetSceneIDs(galleryID int) ([]int, error) { return r0, r1 } -// GetTagIDs provides a mock function with given fields: galleryID -func (_m *GalleryReaderWriter) GetTagIDs(galleryID int) ([]int, error) { - ret := _m.Called(galleryID) +// GetTagIDs provides a mock function with given fields: ctx, galleryID +func (_m *GalleryReaderWriter) GetTagIDs(ctx context.Context, galleryID int) ([]int, error) { + ret := _m.Called(ctx, galleryID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(galleryID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, galleryID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -337,8 +339,8 @@ func (_m *GalleryReaderWriter) GetTagIDs(galleryID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(galleryID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, galleryID) } else { r1 = ret.Error(1) } @@ -346,13 +348,13 @@ func (_m *GalleryReaderWriter) GetTagIDs(galleryID int) ([]int, error) { return r0, r1 } -// Query provides a mock function with given fields: galleryFilter, findFilter -func (_m *GalleryReaderWriter) Query(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) { - ret := _m.Called(galleryFilter, findFilter) +// Query provides a mock function with given fields: ctx, galleryFilter, findFilter +func (_m *GalleryReaderWriter) Query(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) { + ret := _m.Called(ctx, galleryFilter, findFilter) var r0 []*models.Gallery - if rf, ok := ret.Get(0).(func(*models.GalleryFilterType, *models.FindFilterType) []*models.Gallery); ok { - r0 = rf(galleryFilter, findFilter) + if rf, ok := ret.Get(0).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) []*models.Gallery); ok { + r0 = rf(ctx, galleryFilter, findFilter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Gallery) @@ -360,15 +362,15 @@ func (_m *GalleryReaderWriter) Query(galleryFilter *models.GalleryFilterType, fi } var r1 int - if rf, ok := ret.Get(1).(func(*models.GalleryFilterType, *models.FindFilterType) int); ok { - r1 = rf(galleryFilter, findFilter) + if rf, ok := ret.Get(1).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) int); ok { + r1 = rf(ctx, galleryFilter, findFilter) } else { r1 = ret.Get(1).(int) } var r2 error - if rf, ok := ret.Get(2).(func(*models.GalleryFilterType, *models.FindFilterType) error); ok { - r2 = rf(galleryFilter, findFilter) + if rf, ok := ret.Get(2).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) error); ok { + r2 = rf(ctx, galleryFilter, findFilter) } else { r2 = ret.Error(2) } @@ -376,20 +378,20 @@ func (_m *GalleryReaderWriter) Query(galleryFilter *models.GalleryFilterType, fi return r0, r1, r2 } -// QueryCount provides a mock function with given fields: galleryFilter, findFilter -func (_m *GalleryReaderWriter) QueryCount(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) { - ret := _m.Called(galleryFilter, findFilter) +// QueryCount provides a mock function with given fields: ctx, galleryFilter, findFilter +func (_m *GalleryReaderWriter) QueryCount(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) { + ret := _m.Called(ctx, galleryFilter, findFilter) var r0 int - if rf, ok := ret.Get(0).(func(*models.GalleryFilterType, *models.FindFilterType) int); ok { - r0 = rf(galleryFilter, findFilter) + if rf, ok := ret.Get(0).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) int); ok { + r0 = rf(ctx, galleryFilter, findFilter) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(*models.GalleryFilterType, *models.FindFilterType) error); ok { - r1 = rf(galleryFilter, findFilter) + if rf, ok := ret.Get(1).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) error); ok { + r1 = rf(ctx, galleryFilter, findFilter) } else { r1 = ret.Error(1) } @@ -397,13 +399,13 @@ func (_m *GalleryReaderWriter) QueryCount(galleryFilter *models.GalleryFilterTyp return r0, r1 } -// Update provides a mock function with given fields: updatedGallery -func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Gallery, error) { - ret := _m.Called(updatedGallery) +// Update provides a mock function with given fields: ctx, updatedGallery +func (_m *GalleryReaderWriter) Update(ctx context.Context, updatedGallery models.Gallery) (*models.Gallery, error) { + ret := _m.Called(ctx, updatedGallery) var r0 *models.Gallery - if rf, ok := ret.Get(0).(func(models.Gallery) *models.Gallery); ok { - r0 = rf(updatedGallery) + if rf, ok := ret.Get(0).(func(context.Context, models.Gallery) *models.Gallery); ok { + r0 = rf(ctx, updatedGallery) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Gallery) @@ -411,8 +413,8 @@ func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Ga } var r1 error - if rf, ok := ret.Get(1).(func(models.Gallery) error); ok { - r1 = rf(updatedGallery) + if rf, ok := ret.Get(1).(func(context.Context, models.Gallery) error); ok { + r1 = rf(ctx, updatedGallery) } else { r1 = ret.Error(1) } @@ -420,13 +422,13 @@ func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Ga return r0, r1 } -// UpdateFileModTime provides a mock function with given fields: id, modTime -func (_m *GalleryReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error { - ret := _m.Called(id, modTime) +// UpdateFileModTime provides a mock function with given fields: ctx, id, modTime +func (_m *GalleryReaderWriter) UpdateFileModTime(ctx context.Context, id int, modTime models.NullSQLiteTimestamp) error { + ret := _m.Called(ctx, id, modTime) var r0 error - if rf, ok := ret.Get(0).(func(int, models.NullSQLiteTimestamp) error); ok { - r0 = rf(id, modTime) + if rf, ok := ret.Get(0).(func(context.Context, int, models.NullSQLiteTimestamp) error); ok { + r0 = rf(ctx, id, modTime) } else { r0 = ret.Error(0) } @@ -434,13 +436,13 @@ func (_m *GalleryReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLi return r0 } -// UpdateImages provides a mock function with given fields: galleryID, imageIDs -func (_m *GalleryReaderWriter) UpdateImages(galleryID int, imageIDs []int) error { - ret := _m.Called(galleryID, imageIDs) +// UpdateImages provides a mock function with given fields: ctx, galleryID, imageIDs +func (_m *GalleryReaderWriter) UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error { + ret := _m.Called(ctx, galleryID, imageIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(galleryID, imageIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, galleryID, imageIDs) } else { r0 = ret.Error(0) } @@ -448,13 +450,13 @@ func (_m *GalleryReaderWriter) UpdateImages(galleryID int, imageIDs []int) error return r0 } -// UpdatePartial provides a mock function with given fields: updatedGallery -func (_m *GalleryReaderWriter) UpdatePartial(updatedGallery models.GalleryPartial) (*models.Gallery, error) { - ret := _m.Called(updatedGallery) +// UpdatePartial provides a mock function with given fields: ctx, updatedGallery +func (_m *GalleryReaderWriter) UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error) { + ret := _m.Called(ctx, updatedGallery) var r0 *models.Gallery - if rf, ok := ret.Get(0).(func(models.GalleryPartial) *models.Gallery); ok { - r0 = rf(updatedGallery) + if rf, ok := ret.Get(0).(func(context.Context, models.GalleryPartial) *models.Gallery); ok { + r0 = rf(ctx, updatedGallery) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Gallery) @@ -462,8 +464,8 @@ func (_m *GalleryReaderWriter) UpdatePartial(updatedGallery models.GalleryPartia } var r1 error - if rf, ok := ret.Get(1).(func(models.GalleryPartial) error); ok { - r1 = rf(updatedGallery) + if rf, ok := ret.Get(1).(func(context.Context, models.GalleryPartial) error); ok { + r1 = rf(ctx, updatedGallery) } else { r1 = ret.Error(1) } @@ -471,13 +473,13 @@ func (_m *GalleryReaderWriter) UpdatePartial(updatedGallery models.GalleryPartia return r0, r1 } -// UpdatePerformers provides a mock function with given fields: galleryID, performerIDs -func (_m *GalleryReaderWriter) UpdatePerformers(galleryID int, performerIDs []int) error { - ret := _m.Called(galleryID, performerIDs) +// UpdatePerformers provides a mock function with given fields: ctx, galleryID, performerIDs +func (_m *GalleryReaderWriter) UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error { + ret := _m.Called(ctx, galleryID, performerIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(galleryID, performerIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, galleryID, performerIDs) } else { r0 = ret.Error(0) } @@ -485,13 +487,13 @@ func (_m *GalleryReaderWriter) UpdatePerformers(galleryID int, performerIDs []in return r0 } -// UpdateScenes provides a mock function with given fields: galleryID, sceneIDs -func (_m *GalleryReaderWriter) UpdateScenes(galleryID int, sceneIDs []int) error { - ret := _m.Called(galleryID, sceneIDs) +// UpdateScenes provides a mock function with given fields: ctx, galleryID, sceneIDs +func (_m *GalleryReaderWriter) UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error { + ret := _m.Called(ctx, galleryID, sceneIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(galleryID, sceneIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, galleryID, sceneIDs) } else { r0 = ret.Error(0) } @@ -499,13 +501,13 @@ func (_m *GalleryReaderWriter) UpdateScenes(galleryID int, sceneIDs []int) error return r0 } -// UpdateTags provides a mock function with given fields: galleryID, tagIDs -func (_m *GalleryReaderWriter) UpdateTags(galleryID int, tagIDs []int) error { - ret := _m.Called(galleryID, tagIDs) +// UpdateTags provides a mock function with given fields: ctx, galleryID, tagIDs +func (_m *GalleryReaderWriter) UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error { + ret := _m.Called(ctx, galleryID, tagIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(galleryID, tagIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, galleryID, tagIDs) } else { r0 = ret.Error(0) } diff --git a/pkg/models/mocks/ImageReaderWriter.go b/pkg/models/mocks/ImageReaderWriter.go index 5a13ad986..9660849f1 100644 --- a/pkg/models/mocks/ImageReaderWriter.go +++ b/pkg/models/mocks/ImageReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type ImageReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *ImageReaderWriter) All() ([]*models.Image, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *ImageReaderWriter) All(ctx context.Context) ([]*models.Image, error) { + ret := _m.Called(ctx) var r0 []*models.Image - if rf, ok := ret.Get(0).(func() []*models.Image); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.Image); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Image) @@ -26,8 +28,8 @@ func (_m *ImageReaderWriter) All() ([]*models.Image, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,20 +37,20 @@ func (_m *ImageReaderWriter) All() ([]*models.Image, error) { return r0, r1 } -// Count provides a mock function with given fields: -func (_m *ImageReaderWriter) Count() (int, error) { - ret := _m.Called() +// Count provides a mock function with given fields: ctx +func (_m *ImageReaderWriter) Count(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -56,20 +58,20 @@ func (_m *ImageReaderWriter) Count() (int, error) { return r0, r1 } -// CountByGalleryID provides a mock function with given fields: galleryID -func (_m *ImageReaderWriter) CountByGalleryID(galleryID int) (int, error) { - ret := _m.Called(galleryID) +// CountByGalleryID provides a mock function with given fields: ctx, galleryID +func (_m *ImageReaderWriter) CountByGalleryID(ctx context.Context, galleryID int) (int, error) { + ret := _m.Called(ctx, galleryID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(galleryID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, galleryID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(galleryID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, galleryID) } else { r1 = ret.Error(1) } @@ -77,13 +79,13 @@ func (_m *ImageReaderWriter) CountByGalleryID(galleryID int) (int, error) { return r0, r1 } -// Create provides a mock function with given fields: newImage -func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error) { - ret := _m.Called(newImage) +// Create provides a mock function with given fields: ctx, newImage +func (_m *ImageReaderWriter) Create(ctx context.Context, newImage models.Image) (*models.Image, error) { + ret := _m.Called(ctx, newImage) var r0 *models.Image - if rf, ok := ret.Get(0).(func(models.Image) *models.Image); ok { - r0 = rf(newImage) + if rf, ok := ret.Get(0).(func(context.Context, models.Image) *models.Image); ok { + r0 = rf(ctx, newImage) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Image) @@ -91,8 +93,8 @@ func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error } var r1 error - if rf, ok := ret.Get(1).(func(models.Image) error); ok { - r1 = rf(newImage) + if rf, ok := ret.Get(1).(func(context.Context, models.Image) error); ok { + r1 = rf(ctx, newImage) } else { r1 = ret.Error(1) } @@ -100,20 +102,20 @@ func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error return r0, r1 } -// DecrementOCounter provides a mock function with given fields: id -func (_m *ImageReaderWriter) DecrementOCounter(id int) (int, error) { - ret := _m.Called(id) +// DecrementOCounter provides a mock function with given fields: ctx, id +func (_m *ImageReaderWriter) DecrementOCounter(ctx context.Context, id int) (int, error) { + ret := _m.Called(ctx, id) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -121,13 +123,13 @@ func (_m *ImageReaderWriter) DecrementOCounter(id int) (int, error) { return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *ImageReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *ImageReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -135,13 +137,13 @@ func (_m *ImageReaderWriter) Destroy(id int) error { return r0 } -// Find provides a mock function with given fields: id -func (_m *ImageReaderWriter) Find(id int) (*models.Image, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *ImageReaderWriter) Find(ctx context.Context, id int) (*models.Image, error) { + ret := _m.Called(ctx, id) var r0 *models.Image - if rf, ok := ret.Get(0).(func(int) *models.Image); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.Image); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Image) @@ -149,8 +151,8 @@ func (_m *ImageReaderWriter) Find(id int) (*models.Image, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -158,13 +160,13 @@ func (_m *ImageReaderWriter) Find(id int) (*models.Image, error) { return r0, r1 } -// FindByChecksum provides a mock function with given fields: checksum -func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, error) { - ret := _m.Called(checksum) +// FindByChecksum provides a mock function with given fields: ctx, checksum +func (_m *ImageReaderWriter) FindByChecksum(ctx context.Context, checksum string) (*models.Image, error) { + ret := _m.Called(ctx, checksum) var r0 *models.Image - if rf, ok := ret.Get(0).(func(string) *models.Image); ok { - r0 = rf(checksum) + if rf, ok := ret.Get(0).(func(context.Context, string) *models.Image); ok { + r0 = rf(ctx, checksum) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Image) @@ -172,8 +174,8 @@ func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, err } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(checksum) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, checksum) } else { r1 = ret.Error(1) } @@ -181,13 +183,13 @@ func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, err return r0, r1 } -// FindByGalleryID provides a mock function with given fields: galleryID -func (_m *ImageReaderWriter) FindByGalleryID(galleryID int) ([]*models.Image, error) { - ret := _m.Called(galleryID) +// FindByGalleryID provides a mock function with given fields: ctx, galleryID +func (_m *ImageReaderWriter) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Image, error) { + ret := _m.Called(ctx, galleryID) var r0 []*models.Image - if rf, ok := ret.Get(0).(func(int) []*models.Image); ok { - r0 = rf(galleryID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Image); ok { + r0 = rf(ctx, galleryID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Image) @@ -195,8 +197,8 @@ func (_m *ImageReaderWriter) FindByGalleryID(galleryID int) ([]*models.Image, er } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(galleryID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, galleryID) } else { r1 = ret.Error(1) } @@ -204,13 +206,13 @@ func (_m *ImageReaderWriter) FindByGalleryID(galleryID int) ([]*models.Image, er return r0, r1 } -// FindByPath provides a mock function with given fields: path -func (_m *ImageReaderWriter) FindByPath(path string) (*models.Image, error) { - ret := _m.Called(path) +// FindByPath provides a mock function with given fields: ctx, path +func (_m *ImageReaderWriter) FindByPath(ctx context.Context, path string) (*models.Image, error) { + ret := _m.Called(ctx, path) var r0 *models.Image - if rf, ok := ret.Get(0).(func(string) *models.Image); ok { - r0 = rf(path) + if rf, ok := ret.Get(0).(func(context.Context, string) *models.Image); ok { + r0 = rf(ctx, path) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Image) @@ -218,8 +220,8 @@ func (_m *ImageReaderWriter) FindByPath(path string) (*models.Image, error) { } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(path) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, path) } else { r1 = ret.Error(1) } @@ -227,13 +229,13 @@ func (_m *ImageReaderWriter) FindByPath(path string) (*models.Image, error) { return r0, r1 } -// FindMany provides a mock function with given fields: ids -func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) { - ret := _m.Called(ids) +// FindMany provides a mock function with given fields: ctx, ids +func (_m *ImageReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Image, error) { + ret := _m.Called(ctx, ids) var r0 []*models.Image - if rf, ok := ret.Get(0).(func([]int) []*models.Image); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Image); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Image) @@ -241,8 +243,8 @@ func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) { } var r1 error - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -250,13 +252,13 @@ func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) { return r0, r1 } -// GetGalleryIDs provides a mock function with given fields: imageID -func (_m *ImageReaderWriter) GetGalleryIDs(imageID int) ([]int, error) { - ret := _m.Called(imageID) +// GetGalleryIDs provides a mock function with given fields: ctx, imageID +func (_m *ImageReaderWriter) GetGalleryIDs(ctx context.Context, imageID int) ([]int, error) { + ret := _m.Called(ctx, imageID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(imageID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, imageID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -264,8 +266,8 @@ func (_m *ImageReaderWriter) GetGalleryIDs(imageID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(imageID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, imageID) } else { r1 = ret.Error(1) } @@ -273,13 +275,13 @@ func (_m *ImageReaderWriter) GetGalleryIDs(imageID int) ([]int, error) { return r0, r1 } -// GetPerformerIDs provides a mock function with given fields: imageID -func (_m *ImageReaderWriter) GetPerformerIDs(imageID int) ([]int, error) { - ret := _m.Called(imageID) +// GetPerformerIDs provides a mock function with given fields: ctx, imageID +func (_m *ImageReaderWriter) GetPerformerIDs(ctx context.Context, imageID int) ([]int, error) { + ret := _m.Called(ctx, imageID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(imageID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, imageID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -287,8 +289,8 @@ func (_m *ImageReaderWriter) GetPerformerIDs(imageID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(imageID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, imageID) } else { r1 = ret.Error(1) } @@ -296,13 +298,13 @@ func (_m *ImageReaderWriter) GetPerformerIDs(imageID int) ([]int, error) { return r0, r1 } -// GetTagIDs provides a mock function with given fields: imageID -func (_m *ImageReaderWriter) GetTagIDs(imageID int) ([]int, error) { - ret := _m.Called(imageID) +// GetTagIDs provides a mock function with given fields: ctx, imageID +func (_m *ImageReaderWriter) GetTagIDs(ctx context.Context, imageID int) ([]int, error) { + ret := _m.Called(ctx, imageID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(imageID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, imageID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -310,8 +312,8 @@ func (_m *ImageReaderWriter) GetTagIDs(imageID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(imageID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, imageID) } else { r1 = ret.Error(1) } @@ -319,20 +321,20 @@ func (_m *ImageReaderWriter) GetTagIDs(imageID int) ([]int, error) { return r0, r1 } -// IncrementOCounter provides a mock function with given fields: id -func (_m *ImageReaderWriter) IncrementOCounter(id int) (int, error) { - ret := _m.Called(id) +// IncrementOCounter provides a mock function with given fields: ctx, id +func (_m *ImageReaderWriter) IncrementOCounter(ctx context.Context, id int) (int, error) { + ret := _m.Called(ctx, id) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -340,13 +342,13 @@ func (_m *ImageReaderWriter) IncrementOCounter(id int) (int, error) { return r0, r1 } -// Query provides a mock function with given fields: options -func (_m *ImageReaderWriter) Query(options models.ImageQueryOptions) (*models.ImageQueryResult, error) { - ret := _m.Called(options) +// Query provides a mock function with given fields: ctx, options +func (_m *ImageReaderWriter) Query(ctx context.Context, options models.ImageQueryOptions) (*models.ImageQueryResult, error) { + ret := _m.Called(ctx, options) var r0 *models.ImageQueryResult - if rf, ok := ret.Get(0).(func(models.ImageQueryOptions) *models.ImageQueryResult); ok { - r0 = rf(options) + if rf, ok := ret.Get(0).(func(context.Context, models.ImageQueryOptions) *models.ImageQueryResult); ok { + r0 = rf(ctx, options) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.ImageQueryResult) @@ -354,8 +356,8 @@ func (_m *ImageReaderWriter) Query(options models.ImageQueryOptions) (*models.Im } var r1 error - if rf, ok := ret.Get(1).(func(models.ImageQueryOptions) error); ok { - r1 = rf(options) + if rf, ok := ret.Get(1).(func(context.Context, models.ImageQueryOptions) error); ok { + r1 = rf(ctx, options) } else { r1 = ret.Error(1) } @@ -363,20 +365,20 @@ func (_m *ImageReaderWriter) Query(options models.ImageQueryOptions) (*models.Im return r0, r1 } -// QueryCount provides a mock function with given fields: imageFilter, findFilter -func (_m *ImageReaderWriter) QueryCount(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) { - ret := _m.Called(imageFilter, findFilter) +// QueryCount provides a mock function with given fields: ctx, imageFilter, findFilter +func (_m *ImageReaderWriter) QueryCount(ctx context.Context, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) { + ret := _m.Called(ctx, imageFilter, findFilter) var r0 int - if rf, ok := ret.Get(0).(func(*models.ImageFilterType, *models.FindFilterType) int); ok { - r0 = rf(imageFilter, findFilter) + if rf, ok := ret.Get(0).(func(context.Context, *models.ImageFilterType, *models.FindFilterType) int); ok { + r0 = rf(ctx, imageFilter, findFilter) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(*models.ImageFilterType, *models.FindFilterType) error); ok { - r1 = rf(imageFilter, findFilter) + if rf, ok := ret.Get(1).(func(context.Context, *models.ImageFilterType, *models.FindFilterType) error); ok { + r1 = rf(ctx, imageFilter, findFilter) } else { r1 = ret.Error(1) } @@ -384,20 +386,20 @@ func (_m *ImageReaderWriter) QueryCount(imageFilter *models.ImageFilterType, fin return r0, r1 } -// ResetOCounter provides a mock function with given fields: id -func (_m *ImageReaderWriter) ResetOCounter(id int) (int, error) { - ret := _m.Called(id) +// ResetOCounter provides a mock function with given fields: ctx, id +func (_m *ImageReaderWriter) ResetOCounter(ctx context.Context, id int) (int, error) { + ret := _m.Called(ctx, id) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -405,20 +407,20 @@ func (_m *ImageReaderWriter) ResetOCounter(id int) (int, error) { return r0, r1 } -// Size provides a mock function with given fields: -func (_m *ImageReaderWriter) Size() (float64, error) { - ret := _m.Called() +// Size provides a mock function with given fields: ctx +func (_m *ImageReaderWriter) Size(ctx context.Context) (float64, error) { + ret := _m.Called(ctx) var r0 float64 - if rf, ok := ret.Get(0).(func() float64); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) float64); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(float64) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -426,13 +428,13 @@ func (_m *ImageReaderWriter) Size() (float64, error) { return r0, r1 } -// Update provides a mock function with given fields: updatedImage -func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.Image, error) { - ret := _m.Called(updatedImage) +// Update provides a mock function with given fields: ctx, updatedImage +func (_m *ImageReaderWriter) Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error) { + ret := _m.Called(ctx, updatedImage) var r0 *models.Image - if rf, ok := ret.Get(0).(func(models.ImagePartial) *models.Image); ok { - r0 = rf(updatedImage) + if rf, ok := ret.Get(0).(func(context.Context, models.ImagePartial) *models.Image); ok { + r0 = rf(ctx, updatedImage) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Image) @@ -440,8 +442,8 @@ func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.I } var r1 error - if rf, ok := ret.Get(1).(func(models.ImagePartial) error); ok { - r1 = rf(updatedImage) + if rf, ok := ret.Get(1).(func(context.Context, models.ImagePartial) error); ok { + r1 = rf(ctx, updatedImage) } else { r1 = ret.Error(1) } @@ -449,13 +451,13 @@ func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.I return r0, r1 } -// UpdateFull provides a mock function with given fields: updatedImage -func (_m *ImageReaderWriter) UpdateFull(updatedImage models.Image) (*models.Image, error) { - ret := _m.Called(updatedImage) +// UpdateFull provides a mock function with given fields: ctx, updatedImage +func (_m *ImageReaderWriter) UpdateFull(ctx context.Context, updatedImage models.Image) (*models.Image, error) { + ret := _m.Called(ctx, updatedImage) var r0 *models.Image - if rf, ok := ret.Get(0).(func(models.Image) *models.Image); ok { - r0 = rf(updatedImage) + if rf, ok := ret.Get(0).(func(context.Context, models.Image) *models.Image); ok { + r0 = rf(ctx, updatedImage) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Image) @@ -463,8 +465,8 @@ func (_m *ImageReaderWriter) UpdateFull(updatedImage models.Image) (*models.Imag } var r1 error - if rf, ok := ret.Get(1).(func(models.Image) error); ok { - r1 = rf(updatedImage) + if rf, ok := ret.Get(1).(func(context.Context, models.Image) error); ok { + r1 = rf(ctx, updatedImage) } else { r1 = ret.Error(1) } @@ -472,13 +474,13 @@ func (_m *ImageReaderWriter) UpdateFull(updatedImage models.Image) (*models.Imag return r0, r1 } -// UpdateGalleries provides a mock function with given fields: imageID, galleryIDs -func (_m *ImageReaderWriter) UpdateGalleries(imageID int, galleryIDs []int) error { - ret := _m.Called(imageID, galleryIDs) +// UpdateGalleries provides a mock function with given fields: ctx, imageID, galleryIDs +func (_m *ImageReaderWriter) UpdateGalleries(ctx context.Context, imageID int, galleryIDs []int) error { + ret := _m.Called(ctx, imageID, galleryIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(imageID, galleryIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, imageID, galleryIDs) } else { r0 = ret.Error(0) } @@ -486,13 +488,13 @@ func (_m *ImageReaderWriter) UpdateGalleries(imageID int, galleryIDs []int) erro return r0 } -// UpdatePerformers provides a mock function with given fields: imageID, performerIDs -func (_m *ImageReaderWriter) UpdatePerformers(imageID int, performerIDs []int) error { - ret := _m.Called(imageID, performerIDs) +// UpdatePerformers provides a mock function with given fields: ctx, imageID, performerIDs +func (_m *ImageReaderWriter) UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error { + ret := _m.Called(ctx, imageID, performerIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(imageID, performerIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, imageID, performerIDs) } else { r0 = ret.Error(0) } @@ -500,13 +502,13 @@ func (_m *ImageReaderWriter) UpdatePerformers(imageID int, performerIDs []int) e return r0 } -// UpdateTags provides a mock function with given fields: imageID, tagIDs -func (_m *ImageReaderWriter) UpdateTags(imageID int, tagIDs []int) error { - ret := _m.Called(imageID, tagIDs) +// UpdateTags provides a mock function with given fields: ctx, imageID, tagIDs +func (_m *ImageReaderWriter) UpdateTags(ctx context.Context, imageID int, tagIDs []int) error { + ret := _m.Called(ctx, imageID, tagIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(imageID, tagIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, imageID, tagIDs) } else { r0 = ret.Error(0) } diff --git a/pkg/models/mocks/MovieReaderWriter.go b/pkg/models/mocks/MovieReaderWriter.go index 288eb6fd4..c125fc7b1 100644 --- a/pkg/models/mocks/MovieReaderWriter.go +++ b/pkg/models/mocks/MovieReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type MovieReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *MovieReaderWriter) All() ([]*models.Movie, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *MovieReaderWriter) All(ctx context.Context) ([]*models.Movie, error) { + ret := _m.Called(ctx) var r0 []*models.Movie - if rf, ok := ret.Get(0).(func() []*models.Movie); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.Movie); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Movie) @@ -26,8 +28,8 @@ func (_m *MovieReaderWriter) All() ([]*models.Movie, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,20 +37,20 @@ func (_m *MovieReaderWriter) All() ([]*models.Movie, error) { return r0, r1 } -// Count provides a mock function with given fields: -func (_m *MovieReaderWriter) Count() (int, error) { - ret := _m.Called() +// Count provides a mock function with given fields: ctx +func (_m *MovieReaderWriter) Count(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -56,20 +58,20 @@ func (_m *MovieReaderWriter) Count() (int, error) { return r0, r1 } -// CountByPerformerID provides a mock function with given fields: performerID -func (_m *MovieReaderWriter) CountByPerformerID(performerID int) (int, error) { - ret := _m.Called(performerID) +// CountByPerformerID provides a mock function with given fields: ctx, performerID +func (_m *MovieReaderWriter) CountByPerformerID(ctx context.Context, performerID int) (int, error) { + ret := _m.Called(ctx, performerID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, performerID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -77,20 +79,20 @@ func (_m *MovieReaderWriter) CountByPerformerID(performerID int) (int, error) { return r0, r1 } -// CountByStudioID provides a mock function with given fields: studioID -func (_m *MovieReaderWriter) CountByStudioID(studioID int) (int, error) { - ret := _m.Called(studioID) +// CountByStudioID provides a mock function with given fields: ctx, studioID +func (_m *MovieReaderWriter) CountByStudioID(ctx context.Context, studioID int) (int, error) { + ret := _m.Called(ctx, studioID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(studioID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, studioID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(studioID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, studioID) } else { r1 = ret.Error(1) } @@ -98,13 +100,13 @@ func (_m *MovieReaderWriter) CountByStudioID(studioID int) (int, error) { return r0, r1 } -// Create provides a mock function with given fields: newMovie -func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error) { - ret := _m.Called(newMovie) +// Create provides a mock function with given fields: ctx, newMovie +func (_m *MovieReaderWriter) Create(ctx context.Context, newMovie models.Movie) (*models.Movie, error) { + ret := _m.Called(ctx, newMovie) var r0 *models.Movie - if rf, ok := ret.Get(0).(func(models.Movie) *models.Movie); ok { - r0 = rf(newMovie) + if rf, ok := ret.Get(0).(func(context.Context, models.Movie) *models.Movie); ok { + r0 = rf(ctx, newMovie) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Movie) @@ -112,8 +114,8 @@ func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error } var r1 error - if rf, ok := ret.Get(1).(func(models.Movie) error); ok { - r1 = rf(newMovie) + if rf, ok := ret.Get(1).(func(context.Context, models.Movie) error); ok { + r1 = rf(ctx, newMovie) } else { r1 = ret.Error(1) } @@ -121,13 +123,13 @@ func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *MovieReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *MovieReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -135,13 +137,13 @@ func (_m *MovieReaderWriter) Destroy(id int) error { return r0 } -// DestroyImages provides a mock function with given fields: movieID -func (_m *MovieReaderWriter) DestroyImages(movieID int) error { - ret := _m.Called(movieID) +// DestroyImages provides a mock function with given fields: ctx, movieID +func (_m *MovieReaderWriter) DestroyImages(ctx context.Context, movieID int) error { + ret := _m.Called(ctx, movieID) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(movieID) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, movieID) } else { r0 = ret.Error(0) } @@ -149,13 +151,13 @@ func (_m *MovieReaderWriter) DestroyImages(movieID int) error { return r0 } -// Find provides a mock function with given fields: id -func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *MovieReaderWriter) Find(ctx context.Context, id int) (*models.Movie, error) { + ret := _m.Called(ctx, id) var r0 *models.Movie - if rf, ok := ret.Get(0).(func(int) *models.Movie); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.Movie); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Movie) @@ -163,8 +165,8 @@ func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -172,13 +174,13 @@ func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) { return r0, r1 } -// FindByName provides a mock function with given fields: name, nocase -func (_m *MovieReaderWriter) FindByName(name string, nocase bool) (*models.Movie, error) { - ret := _m.Called(name, nocase) +// FindByName provides a mock function with given fields: ctx, name, nocase +func (_m *MovieReaderWriter) FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error) { + ret := _m.Called(ctx, name, nocase) var r0 *models.Movie - if rf, ok := ret.Get(0).(func(string, bool) *models.Movie); ok { - r0 = rf(name, nocase) + if rf, ok := ret.Get(0).(func(context.Context, string, bool) *models.Movie); ok { + r0 = rf(ctx, name, nocase) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Movie) @@ -186,8 +188,8 @@ func (_m *MovieReaderWriter) FindByName(name string, nocase bool) (*models.Movie } var r1 error - if rf, ok := ret.Get(1).(func(string, bool) error); ok { - r1 = rf(name, nocase) + if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok { + r1 = rf(ctx, name, nocase) } else { r1 = ret.Error(1) } @@ -195,13 +197,13 @@ func (_m *MovieReaderWriter) FindByName(name string, nocase bool) (*models.Movie return r0, r1 } -// FindByNames provides a mock function with given fields: names, nocase -func (_m *MovieReaderWriter) FindByNames(names []string, nocase bool) ([]*models.Movie, error) { - ret := _m.Called(names, nocase) +// FindByNames provides a mock function with given fields: ctx, names, nocase +func (_m *MovieReaderWriter) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Movie, error) { + ret := _m.Called(ctx, names, nocase) var r0 []*models.Movie - if rf, ok := ret.Get(0).(func([]string, bool) []*models.Movie); ok { - r0 = rf(names, nocase) + if rf, ok := ret.Get(0).(func(context.Context, []string, bool) []*models.Movie); ok { + r0 = rf(ctx, names, nocase) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Movie) @@ -209,8 +211,8 @@ func (_m *MovieReaderWriter) FindByNames(names []string, nocase bool) ([]*models } var r1 error - if rf, ok := ret.Get(1).(func([]string, bool) error); ok { - r1 = rf(names, nocase) + if rf, ok := ret.Get(1).(func(context.Context, []string, bool) error); ok { + r1 = rf(ctx, names, nocase) } else { r1 = ret.Error(1) } @@ -218,13 +220,13 @@ func (_m *MovieReaderWriter) FindByNames(names []string, nocase bool) ([]*models return r0, r1 } -// FindByPerformerID provides a mock function with given fields: performerID -func (_m *MovieReaderWriter) FindByPerformerID(performerID int) ([]*models.Movie, error) { - ret := _m.Called(performerID) +// FindByPerformerID provides a mock function with given fields: ctx, performerID +func (_m *MovieReaderWriter) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Movie, error) { + ret := _m.Called(ctx, performerID) var r0 []*models.Movie - if rf, ok := ret.Get(0).(func(int) []*models.Movie); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Movie); ok { + r0 = rf(ctx, performerID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Movie) @@ -232,8 +234,8 @@ func (_m *MovieReaderWriter) FindByPerformerID(performerID int) ([]*models.Movie } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -241,13 +243,13 @@ func (_m *MovieReaderWriter) FindByPerformerID(performerID int) ([]*models.Movie return r0, r1 } -// FindByStudioID provides a mock function with given fields: studioID -func (_m *MovieReaderWriter) FindByStudioID(studioID int) ([]*models.Movie, error) { - ret := _m.Called(studioID) +// FindByStudioID provides a mock function with given fields: ctx, studioID +func (_m *MovieReaderWriter) FindByStudioID(ctx context.Context, studioID int) ([]*models.Movie, error) { + ret := _m.Called(ctx, studioID) var r0 []*models.Movie - if rf, ok := ret.Get(0).(func(int) []*models.Movie); ok { - r0 = rf(studioID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Movie); ok { + r0 = rf(ctx, studioID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Movie) @@ -255,8 +257,8 @@ func (_m *MovieReaderWriter) FindByStudioID(studioID int) ([]*models.Movie, erro } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(studioID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, studioID) } else { r1 = ret.Error(1) } @@ -264,13 +266,13 @@ func (_m *MovieReaderWriter) FindByStudioID(studioID int) ([]*models.Movie, erro return r0, r1 } -// FindMany provides a mock function with given fields: ids -func (_m *MovieReaderWriter) FindMany(ids []int) ([]*models.Movie, error) { - ret := _m.Called(ids) +// FindMany provides a mock function with given fields: ctx, ids +func (_m *MovieReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Movie, error) { + ret := _m.Called(ctx, ids) var r0 []*models.Movie - if rf, ok := ret.Get(0).(func([]int) []*models.Movie); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Movie); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Movie) @@ -278,8 +280,8 @@ func (_m *MovieReaderWriter) FindMany(ids []int) ([]*models.Movie, error) { } var r1 error - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -287,13 +289,13 @@ func (_m *MovieReaderWriter) FindMany(ids []int) ([]*models.Movie, error) { return r0, r1 } -// GetBackImage provides a mock function with given fields: movieID -func (_m *MovieReaderWriter) GetBackImage(movieID int) ([]byte, error) { - ret := _m.Called(movieID) +// GetBackImage provides a mock function with given fields: ctx, movieID +func (_m *MovieReaderWriter) GetBackImage(ctx context.Context, movieID int) ([]byte, error) { + ret := _m.Called(ctx, movieID) var r0 []byte - if rf, ok := ret.Get(0).(func(int) []byte); ok { - r0 = rf(movieID) + if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok { + r0 = rf(ctx, movieID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) @@ -301,8 +303,8 @@ func (_m *MovieReaderWriter) GetBackImage(movieID int) ([]byte, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(movieID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, movieID) } else { r1 = ret.Error(1) } @@ -310,13 +312,13 @@ func (_m *MovieReaderWriter) GetBackImage(movieID int) ([]byte, error) { return r0, r1 } -// GetFrontImage provides a mock function with given fields: movieID -func (_m *MovieReaderWriter) GetFrontImage(movieID int) ([]byte, error) { - ret := _m.Called(movieID) +// GetFrontImage provides a mock function with given fields: ctx, movieID +func (_m *MovieReaderWriter) GetFrontImage(ctx context.Context, movieID int) ([]byte, error) { + ret := _m.Called(ctx, movieID) var r0 []byte - if rf, ok := ret.Get(0).(func(int) []byte); ok { - r0 = rf(movieID) + if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok { + r0 = rf(ctx, movieID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) @@ -324,8 +326,8 @@ func (_m *MovieReaderWriter) GetFrontImage(movieID int) ([]byte, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(movieID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, movieID) } else { r1 = ret.Error(1) } @@ -333,13 +335,13 @@ func (_m *MovieReaderWriter) GetFrontImage(movieID int) ([]byte, error) { return r0, r1 } -// Query provides a mock function with given fields: movieFilter, findFilter -func (_m *MovieReaderWriter) Query(movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) { - ret := _m.Called(movieFilter, findFilter) +// Query provides a mock function with given fields: ctx, movieFilter, findFilter +func (_m *MovieReaderWriter) Query(ctx context.Context, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) { + ret := _m.Called(ctx, movieFilter, findFilter) var r0 []*models.Movie - if rf, ok := ret.Get(0).(func(*models.MovieFilterType, *models.FindFilterType) []*models.Movie); ok { - r0 = rf(movieFilter, findFilter) + if rf, ok := ret.Get(0).(func(context.Context, *models.MovieFilterType, *models.FindFilterType) []*models.Movie); ok { + r0 = rf(ctx, movieFilter, findFilter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Movie) @@ -347,15 +349,15 @@ func (_m *MovieReaderWriter) Query(movieFilter *models.MovieFilterType, findFilt } var r1 int - if rf, ok := ret.Get(1).(func(*models.MovieFilterType, *models.FindFilterType) int); ok { - r1 = rf(movieFilter, findFilter) + if rf, ok := ret.Get(1).(func(context.Context, *models.MovieFilterType, *models.FindFilterType) int); ok { + r1 = rf(ctx, movieFilter, findFilter) } else { r1 = ret.Get(1).(int) } var r2 error - if rf, ok := ret.Get(2).(func(*models.MovieFilterType, *models.FindFilterType) error); ok { - r2 = rf(movieFilter, findFilter) + if rf, ok := ret.Get(2).(func(context.Context, *models.MovieFilterType, *models.FindFilterType) error); ok { + r2 = rf(ctx, movieFilter, findFilter) } else { r2 = ret.Error(2) } @@ -363,13 +365,13 @@ func (_m *MovieReaderWriter) Query(movieFilter *models.MovieFilterType, findFilt return r0, r1, r2 } -// Update provides a mock function with given fields: updatedMovie -func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.Movie, error) { - ret := _m.Called(updatedMovie) +// Update provides a mock function with given fields: ctx, updatedMovie +func (_m *MovieReaderWriter) Update(ctx context.Context, updatedMovie models.MoviePartial) (*models.Movie, error) { + ret := _m.Called(ctx, updatedMovie) var r0 *models.Movie - if rf, ok := ret.Get(0).(func(models.MoviePartial) *models.Movie); ok { - r0 = rf(updatedMovie) + if rf, ok := ret.Get(0).(func(context.Context, models.MoviePartial) *models.Movie); ok { + r0 = rf(ctx, updatedMovie) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Movie) @@ -377,8 +379,8 @@ func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.M } var r1 error - if rf, ok := ret.Get(1).(func(models.MoviePartial) error); ok { - r1 = rf(updatedMovie) + if rf, ok := ret.Get(1).(func(context.Context, models.MoviePartial) error); ok { + r1 = rf(ctx, updatedMovie) } else { r1 = ret.Error(1) } @@ -386,13 +388,13 @@ func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.M return r0, r1 } -// UpdateFull provides a mock function with given fields: updatedMovie -func (_m *MovieReaderWriter) UpdateFull(updatedMovie models.Movie) (*models.Movie, error) { - ret := _m.Called(updatedMovie) +// UpdateFull provides a mock function with given fields: ctx, updatedMovie +func (_m *MovieReaderWriter) UpdateFull(ctx context.Context, updatedMovie models.Movie) (*models.Movie, error) { + ret := _m.Called(ctx, updatedMovie) var r0 *models.Movie - if rf, ok := ret.Get(0).(func(models.Movie) *models.Movie); ok { - r0 = rf(updatedMovie) + if rf, ok := ret.Get(0).(func(context.Context, models.Movie) *models.Movie); ok { + r0 = rf(ctx, updatedMovie) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Movie) @@ -400,8 +402,8 @@ func (_m *MovieReaderWriter) UpdateFull(updatedMovie models.Movie) (*models.Movi } var r1 error - if rf, ok := ret.Get(1).(func(models.Movie) error); ok { - r1 = rf(updatedMovie) + if rf, ok := ret.Get(1).(func(context.Context, models.Movie) error); ok { + r1 = rf(ctx, updatedMovie) } else { r1 = ret.Error(1) } @@ -409,13 +411,13 @@ func (_m *MovieReaderWriter) UpdateFull(updatedMovie models.Movie) (*models.Movi return r0, r1 } -// UpdateImages provides a mock function with given fields: movieID, frontImage, backImage -func (_m *MovieReaderWriter) UpdateImages(movieID int, frontImage []byte, backImage []byte) error { - ret := _m.Called(movieID, frontImage, backImage) +// UpdateImages provides a mock function with given fields: ctx, movieID, frontImage, backImage +func (_m *MovieReaderWriter) UpdateImages(ctx context.Context, movieID int, frontImage []byte, backImage []byte) error { + ret := _m.Called(ctx, movieID, frontImage, backImage) var r0 error - if rf, ok := ret.Get(0).(func(int, []byte, []byte) error); ok { - r0 = rf(movieID, frontImage, backImage) + if rf, ok := ret.Get(0).(func(context.Context, int, []byte, []byte) error); ok { + r0 = rf(ctx, movieID, frontImage, backImage) } else { r0 = ret.Error(0) } diff --git a/pkg/models/mocks/PerformerReaderWriter.go b/pkg/models/mocks/PerformerReaderWriter.go index 485f75170..2f97b66eb 100644 --- a/pkg/models/mocks/PerformerReaderWriter.go +++ b/pkg/models/mocks/PerformerReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type PerformerReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *PerformerReaderWriter) All() ([]*models.Performer, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *PerformerReaderWriter) All(ctx context.Context) ([]*models.Performer, error) { + ret := _m.Called(ctx) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func() []*models.Performer); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.Performer); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -26,8 +28,8 @@ func (_m *PerformerReaderWriter) All() ([]*models.Performer, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,20 +37,20 @@ func (_m *PerformerReaderWriter) All() ([]*models.Performer, error) { return r0, r1 } -// Count provides a mock function with given fields: -func (_m *PerformerReaderWriter) Count() (int, error) { - ret := _m.Called() +// Count provides a mock function with given fields: ctx +func (_m *PerformerReaderWriter) Count(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -56,20 +58,20 @@ func (_m *PerformerReaderWriter) Count() (int, error) { return r0, r1 } -// CountByTagID provides a mock function with given fields: tagID -func (_m *PerformerReaderWriter) CountByTagID(tagID int) (int, error) { - ret := _m.Called(tagID) +// CountByTagID provides a mock function with given fields: ctx, tagID +func (_m *PerformerReaderWriter) CountByTagID(ctx context.Context, tagID int) (int, error) { + ret := _m.Called(ctx, tagID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(tagID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, tagID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(tagID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, tagID) } else { r1 = ret.Error(1) } @@ -77,13 +79,13 @@ func (_m *PerformerReaderWriter) CountByTagID(tagID int) (int, error) { return r0, r1 } -// Create provides a mock function with given fields: newPerformer -func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models.Performer, error) { - ret := _m.Called(newPerformer) +// Create provides a mock function with given fields: ctx, newPerformer +func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error) { + ret := _m.Called(ctx, newPerformer) var r0 *models.Performer - if rf, ok := ret.Get(0).(func(models.Performer) *models.Performer); ok { - r0 = rf(newPerformer) + if rf, ok := ret.Get(0).(func(context.Context, models.Performer) *models.Performer); ok { + r0 = rf(ctx, newPerformer) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Performer) @@ -91,8 +93,8 @@ func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models. } var r1 error - if rf, ok := ret.Get(1).(func(models.Performer) error); ok { - r1 = rf(newPerformer) + if rf, ok := ret.Get(1).(func(context.Context, models.Performer) error); ok { + r1 = rf(ctx, newPerformer) } else { r1 = ret.Error(1) } @@ -100,13 +102,13 @@ func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models. return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *PerformerReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *PerformerReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -114,13 +116,13 @@ func (_m *PerformerReaderWriter) Destroy(id int) error { return r0 } -// DestroyImage provides a mock function with given fields: performerID -func (_m *PerformerReaderWriter) DestroyImage(performerID int) error { - ret := _m.Called(performerID) +// DestroyImage provides a mock function with given fields: ctx, performerID +func (_m *PerformerReaderWriter) DestroyImage(ctx context.Context, performerID int) error { + ret := _m.Called(ctx, performerID) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, performerID) } else { r0 = ret.Error(0) } @@ -128,13 +130,13 @@ func (_m *PerformerReaderWriter) DestroyImage(performerID int) error { return r0 } -// Find provides a mock function with given fields: id -func (_m *PerformerReaderWriter) Find(id int) (*models.Performer, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *PerformerReaderWriter) Find(ctx context.Context, id int) (*models.Performer, error) { + ret := _m.Called(ctx, id) var r0 *models.Performer - if rf, ok := ret.Get(0).(func(int) *models.Performer); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.Performer); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Performer) @@ -142,8 +144,8 @@ func (_m *PerformerReaderWriter) Find(id int) (*models.Performer, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -151,13 +153,13 @@ func (_m *PerformerReaderWriter) Find(id int) (*models.Performer, error) { return r0, r1 } -// FindByGalleryID provides a mock function with given fields: galleryID -func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Performer, error) { - ret := _m.Called(galleryID) +// FindByGalleryID provides a mock function with given fields: ctx, galleryID +func (_m *PerformerReaderWriter) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Performer, error) { + ret := _m.Called(ctx, galleryID) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func(int) []*models.Performer); ok { - r0 = rf(galleryID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok { + r0 = rf(ctx, galleryID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -165,8 +167,8 @@ func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Perfo } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(galleryID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, galleryID) } else { r1 = ret.Error(1) } @@ -174,13 +176,13 @@ func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Perfo return r0, r1 } -// FindByImageID provides a mock function with given fields: imageID -func (_m *PerformerReaderWriter) FindByImageID(imageID int) ([]*models.Performer, error) { - ret := _m.Called(imageID) +// FindByImageID provides a mock function with given fields: ctx, imageID +func (_m *PerformerReaderWriter) FindByImageID(ctx context.Context, imageID int) ([]*models.Performer, error) { + ret := _m.Called(ctx, imageID) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func(int) []*models.Performer); ok { - r0 = rf(imageID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok { + r0 = rf(ctx, imageID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -188,8 +190,8 @@ func (_m *PerformerReaderWriter) FindByImageID(imageID int) ([]*models.Performer } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(imageID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, imageID) } else { r1 = ret.Error(1) } @@ -197,13 +199,13 @@ func (_m *PerformerReaderWriter) FindByImageID(imageID int) ([]*models.Performer return r0, r1 } -// FindByNames provides a mock function with given fields: names, nocase -func (_m *PerformerReaderWriter) FindByNames(names []string, nocase bool) ([]*models.Performer, error) { - ret := _m.Called(names, nocase) +// FindByNames provides a mock function with given fields: ctx, names, nocase +func (_m *PerformerReaderWriter) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) { + ret := _m.Called(ctx, names, nocase) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func([]string, bool) []*models.Performer); ok { - r0 = rf(names, nocase) + if rf, ok := ret.Get(0).(func(context.Context, []string, bool) []*models.Performer); ok { + r0 = rf(ctx, names, nocase) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -211,8 +213,8 @@ func (_m *PerformerReaderWriter) FindByNames(names []string, nocase bool) ([]*mo } var r1 error - if rf, ok := ret.Get(1).(func([]string, bool) error); ok { - r1 = rf(names, nocase) + if rf, ok := ret.Get(1).(func(context.Context, []string, bool) error); ok { + r1 = rf(ctx, names, nocase) } else { r1 = ret.Error(1) } @@ -220,13 +222,13 @@ func (_m *PerformerReaderWriter) FindByNames(names []string, nocase bool) ([]*mo return r0, r1 } -// FindBySceneID provides a mock function with given fields: sceneID -func (_m *PerformerReaderWriter) FindBySceneID(sceneID int) ([]*models.Performer, error) { - ret := _m.Called(sceneID) +// FindBySceneID provides a mock function with given fields: ctx, sceneID +func (_m *PerformerReaderWriter) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) { + ret := _m.Called(ctx, sceneID) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func(int) []*models.Performer); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -234,8 +236,8 @@ func (_m *PerformerReaderWriter) FindBySceneID(sceneID int) ([]*models.Performer } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -243,13 +245,13 @@ func (_m *PerformerReaderWriter) FindBySceneID(sceneID int) ([]*models.Performer return r0, r1 } -// FindByStashID provides a mock function with given fields: stashID -func (_m *PerformerReaderWriter) FindByStashID(stashID models.StashID) ([]*models.Performer, error) { - ret := _m.Called(stashID) +// FindByStashID provides a mock function with given fields: ctx, stashID +func (_m *PerformerReaderWriter) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Performer, error) { + ret := _m.Called(ctx, stashID) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func(models.StashID) []*models.Performer); ok { - r0 = rf(stashID) + if rf, ok := ret.Get(0).(func(context.Context, models.StashID) []*models.Performer); ok { + r0 = rf(ctx, stashID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -257,8 +259,8 @@ func (_m *PerformerReaderWriter) FindByStashID(stashID models.StashID) ([]*model } var r1 error - if rf, ok := ret.Get(1).(func(models.StashID) error); ok { - r1 = rf(stashID) + if rf, ok := ret.Get(1).(func(context.Context, models.StashID) error); ok { + r1 = rf(ctx, stashID) } else { r1 = ret.Error(1) } @@ -266,13 +268,13 @@ func (_m *PerformerReaderWriter) FindByStashID(stashID models.StashID) ([]*model return r0, r1 } -// FindByStashIDStatus provides a mock function with given fields: hasStashID, stashboxEndpoint -func (_m *PerformerReaderWriter) FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) { - ret := _m.Called(hasStashID, stashboxEndpoint) +// FindByStashIDStatus provides a mock function with given fields: ctx, hasStashID, stashboxEndpoint +func (_m *PerformerReaderWriter) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) { + ret := _m.Called(ctx, hasStashID, stashboxEndpoint) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func(bool, string) []*models.Performer); ok { - r0 = rf(hasStashID, stashboxEndpoint) + if rf, ok := ret.Get(0).(func(context.Context, bool, string) []*models.Performer); ok { + r0 = rf(ctx, hasStashID, stashboxEndpoint) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -280,8 +282,8 @@ func (_m *PerformerReaderWriter) FindByStashIDStatus(hasStashID bool, stashboxEn } var r1 error - if rf, ok := ret.Get(1).(func(bool, string) error); ok { - r1 = rf(hasStashID, stashboxEndpoint) + if rf, ok := ret.Get(1).(func(context.Context, bool, string) error); ok { + r1 = rf(ctx, hasStashID, stashboxEndpoint) } else { r1 = ret.Error(1) } @@ -289,13 +291,13 @@ func (_m *PerformerReaderWriter) FindByStashIDStatus(hasStashID bool, stashboxEn return r0, r1 } -// FindMany provides a mock function with given fields: ids -func (_m *PerformerReaderWriter) FindMany(ids []int) ([]*models.Performer, error) { - ret := _m.Called(ids) +// FindMany provides a mock function with given fields: ctx, ids +func (_m *PerformerReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Performer, error) { + ret := _m.Called(ctx, ids) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func([]int) []*models.Performer); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Performer); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -303,8 +305,8 @@ func (_m *PerformerReaderWriter) FindMany(ids []int) ([]*models.Performer, error } var r1 error - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -312,13 +314,13 @@ func (_m *PerformerReaderWriter) FindMany(ids []int) ([]*models.Performer, error return r0, r1 } -// FindNamesBySceneID provides a mock function with given fields: sceneID -func (_m *PerformerReaderWriter) FindNamesBySceneID(sceneID int) ([]*models.Performer, error) { - ret := _m.Called(sceneID) +// FindNamesBySceneID provides a mock function with given fields: ctx, sceneID +func (_m *PerformerReaderWriter) FindNamesBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) { + ret := _m.Called(ctx, sceneID) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func(int) []*models.Performer); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -326,8 +328,8 @@ func (_m *PerformerReaderWriter) FindNamesBySceneID(sceneID int) ([]*models.Perf } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -335,13 +337,13 @@ func (_m *PerformerReaderWriter) FindNamesBySceneID(sceneID int) ([]*models.Perf return r0, r1 } -// GetImage provides a mock function with given fields: performerID -func (_m *PerformerReaderWriter) GetImage(performerID int) ([]byte, error) { - ret := _m.Called(performerID) +// GetImage provides a mock function with given fields: ctx, performerID +func (_m *PerformerReaderWriter) GetImage(ctx context.Context, performerID int) ([]byte, error) { + ret := _m.Called(ctx, performerID) var r0 []byte - if rf, ok := ret.Get(0).(func(int) []byte); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok { + r0 = rf(ctx, performerID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) @@ -349,8 +351,8 @@ func (_m *PerformerReaderWriter) GetImage(performerID int) ([]byte, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -358,13 +360,13 @@ func (_m *PerformerReaderWriter) GetImage(performerID int) ([]byte, error) { return r0, r1 } -// GetStashIDs provides a mock function with given fields: performerID -func (_m *PerformerReaderWriter) GetStashIDs(performerID int) ([]*models.StashID, error) { - ret := _m.Called(performerID) +// GetStashIDs provides a mock function with given fields: ctx, performerID +func (_m *PerformerReaderWriter) GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error) { + ret := _m.Called(ctx, performerID) var r0 []*models.StashID - if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.StashID); ok { + r0 = rf(ctx, performerID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.StashID) @@ -372,8 +374,8 @@ func (_m *PerformerReaderWriter) GetStashIDs(performerID int) ([]*models.StashID } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -381,13 +383,13 @@ func (_m *PerformerReaderWriter) GetStashIDs(performerID int) ([]*models.StashID return r0, r1 } -// GetTagIDs provides a mock function with given fields: performerID -func (_m *PerformerReaderWriter) GetTagIDs(performerID int) ([]int, error) { - ret := _m.Called(performerID) +// GetTagIDs provides a mock function with given fields: ctx, performerID +func (_m *PerformerReaderWriter) GetTagIDs(ctx context.Context, performerID int) ([]int, error) { + ret := _m.Called(ctx, performerID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, performerID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -395,8 +397,8 @@ func (_m *PerformerReaderWriter) GetTagIDs(performerID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -404,13 +406,13 @@ func (_m *PerformerReaderWriter) GetTagIDs(performerID int) ([]int, error) { return r0, r1 } -// Query provides a mock function with given fields: performerFilter, findFilter -func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { - ret := _m.Called(performerFilter, findFilter) +// Query provides a mock function with given fields: ctx, performerFilter, findFilter +func (_m *PerformerReaderWriter) Query(ctx context.Context, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { + ret := _m.Called(ctx, performerFilter, findFilter) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func(*models.PerformerFilterType, *models.FindFilterType) []*models.Performer); ok { - r0 = rf(performerFilter, findFilter) + if rf, ok := ret.Get(0).(func(context.Context, *models.PerformerFilterType, *models.FindFilterType) []*models.Performer); ok { + r0 = rf(ctx, performerFilter, findFilter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -418,15 +420,15 @@ func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterTy } var r1 int - if rf, ok := ret.Get(1).(func(*models.PerformerFilterType, *models.FindFilterType) int); ok { - r1 = rf(performerFilter, findFilter) + if rf, ok := ret.Get(1).(func(context.Context, *models.PerformerFilterType, *models.FindFilterType) int); ok { + r1 = rf(ctx, performerFilter, findFilter) } else { r1 = ret.Get(1).(int) } var r2 error - if rf, ok := ret.Get(2).(func(*models.PerformerFilterType, *models.FindFilterType) error); ok { - r2 = rf(performerFilter, findFilter) + if rf, ok := ret.Get(2).(func(context.Context, *models.PerformerFilterType, *models.FindFilterType) error); ok { + r2 = rf(ctx, performerFilter, findFilter) } else { r2 = ret.Error(2) } @@ -434,13 +436,13 @@ func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterTy return r0, r1, r2 } -// QueryForAutoTag provides a mock function with given fields: words -func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Performer, error) { - ret := _m.Called(words) +// QueryForAutoTag provides a mock function with given fields: ctx, words +func (_m *PerformerReaderWriter) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Performer, error) { + ret := _m.Called(ctx, words) var r0 []*models.Performer - if rf, ok := ret.Get(0).(func([]string) []*models.Performer); ok { - r0 = rf(words) + if rf, ok := ret.Get(0).(func(context.Context, []string) []*models.Performer); ok { + r0 = rf(ctx, words) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Performer) @@ -448,8 +450,8 @@ func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Perf } var r1 error - if rf, ok := ret.Get(1).(func([]string) error); ok { - r1 = rf(words) + if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { + r1 = rf(ctx, words) } else { r1 = ret.Error(1) } @@ -457,13 +459,13 @@ func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Perf return r0, r1 } -// Update provides a mock function with given fields: updatedPerformer -func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial) (*models.Performer, error) { - ret := _m.Called(updatedPerformer) +// Update provides a mock function with given fields: ctx, updatedPerformer +func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer models.PerformerPartial) (*models.Performer, error) { + ret := _m.Called(ctx, updatedPerformer) var r0 *models.Performer - if rf, ok := ret.Get(0).(func(models.PerformerPartial) *models.Performer); ok { - r0 = rf(updatedPerformer) + if rf, ok := ret.Get(0).(func(context.Context, models.PerformerPartial) *models.Performer); ok { + r0 = rf(ctx, updatedPerformer) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Performer) @@ -471,8 +473,8 @@ func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial } var r1 error - if rf, ok := ret.Get(1).(func(models.PerformerPartial) error); ok { - r1 = rf(updatedPerformer) + if rf, ok := ret.Get(1).(func(context.Context, models.PerformerPartial) error); ok { + r1 = rf(ctx, updatedPerformer) } else { r1 = ret.Error(1) } @@ -480,13 +482,13 @@ func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial return r0, r1 } -// UpdateFull provides a mock function with given fields: updatedPerformer -func (_m *PerformerReaderWriter) UpdateFull(updatedPerformer models.Performer) (*models.Performer, error) { - ret := _m.Called(updatedPerformer) +// UpdateFull provides a mock function with given fields: ctx, updatedPerformer +func (_m *PerformerReaderWriter) UpdateFull(ctx context.Context, updatedPerformer models.Performer) (*models.Performer, error) { + ret := _m.Called(ctx, updatedPerformer) var r0 *models.Performer - if rf, ok := ret.Get(0).(func(models.Performer) *models.Performer); ok { - r0 = rf(updatedPerformer) + if rf, ok := ret.Get(0).(func(context.Context, models.Performer) *models.Performer); ok { + r0 = rf(ctx, updatedPerformer) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Performer) @@ -494,8 +496,8 @@ func (_m *PerformerReaderWriter) UpdateFull(updatedPerformer models.Performer) ( } var r1 error - if rf, ok := ret.Get(1).(func(models.Performer) error); ok { - r1 = rf(updatedPerformer) + if rf, ok := ret.Get(1).(func(context.Context, models.Performer) error); ok { + r1 = rf(ctx, updatedPerformer) } else { r1 = ret.Error(1) } @@ -503,13 +505,13 @@ func (_m *PerformerReaderWriter) UpdateFull(updatedPerformer models.Performer) ( return r0, r1 } -// UpdateImage provides a mock function with given fields: performerID, image -func (_m *PerformerReaderWriter) UpdateImage(performerID int, image []byte) error { - ret := _m.Called(performerID, image) +// UpdateImage provides a mock function with given fields: ctx, performerID, image +func (_m *PerformerReaderWriter) UpdateImage(ctx context.Context, performerID int, image []byte) error { + ret := _m.Called(ctx, performerID, image) var r0 error - if rf, ok := ret.Get(0).(func(int, []byte) error); ok { - r0 = rf(performerID, image) + if rf, ok := ret.Get(0).(func(context.Context, int, []byte) error); ok { + r0 = rf(ctx, performerID, image) } else { r0 = ret.Error(0) } @@ -517,13 +519,13 @@ func (_m *PerformerReaderWriter) UpdateImage(performerID int, image []byte) erro return r0 } -// UpdateStashIDs provides a mock function with given fields: performerID, stashIDs -func (_m *PerformerReaderWriter) UpdateStashIDs(performerID int, stashIDs []models.StashID) error { - ret := _m.Called(performerID, stashIDs) +// UpdateStashIDs provides a mock function with given fields: ctx, performerID, stashIDs +func (_m *PerformerReaderWriter) UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error { + ret := _m.Called(ctx, performerID, stashIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok { - r0 = rf(performerID, stashIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok { + r0 = rf(ctx, performerID, stashIDs) } else { r0 = ret.Error(0) } @@ -531,13 +533,13 @@ func (_m *PerformerReaderWriter) UpdateStashIDs(performerID int, stashIDs []mode return r0 } -// UpdateTags provides a mock function with given fields: performerID, tagIDs -func (_m *PerformerReaderWriter) UpdateTags(performerID int, tagIDs []int) error { - ret := _m.Called(performerID, tagIDs) +// UpdateTags provides a mock function with given fields: ctx, performerID, tagIDs +func (_m *PerformerReaderWriter) UpdateTags(ctx context.Context, performerID int, tagIDs []int) error { + ret := _m.Called(ctx, performerID, tagIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(performerID, tagIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, performerID, tagIDs) } else { r0 = ret.Error(0) } diff --git a/pkg/models/mocks/SavedFilterReaderWriter.go b/pkg/models/mocks/SavedFilterReaderWriter.go index 952497be2..8f9e6e553 100644 --- a/pkg/models/mocks/SavedFilterReaderWriter.go +++ b/pkg/models/mocks/SavedFilterReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type SavedFilterReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *SavedFilterReaderWriter) All() ([]*models.SavedFilter, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *SavedFilterReaderWriter) All(ctx context.Context) ([]*models.SavedFilter, error) { + ret := _m.Called(ctx) var r0 []*models.SavedFilter - if rf, ok := ret.Get(0).(func() []*models.SavedFilter); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.SavedFilter); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SavedFilter) @@ -26,8 +28,8 @@ func (_m *SavedFilterReaderWriter) All() ([]*models.SavedFilter, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,13 +37,13 @@ func (_m *SavedFilterReaderWriter) All() ([]*models.SavedFilter, error) { return r0, r1 } -// Create provides a mock function with given fields: obj -func (_m *SavedFilterReaderWriter) Create(obj models.SavedFilter) (*models.SavedFilter, error) { - ret := _m.Called(obj) +// Create provides a mock function with given fields: ctx, obj +func (_m *SavedFilterReaderWriter) Create(ctx context.Context, obj models.SavedFilter) (*models.SavedFilter, error) { + ret := _m.Called(ctx, obj) var r0 *models.SavedFilter - if rf, ok := ret.Get(0).(func(models.SavedFilter) *models.SavedFilter); ok { - r0 = rf(obj) + if rf, ok := ret.Get(0).(func(context.Context, models.SavedFilter) *models.SavedFilter); ok { + r0 = rf(ctx, obj) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SavedFilter) @@ -49,8 +51,8 @@ func (_m *SavedFilterReaderWriter) Create(obj models.SavedFilter) (*models.Saved } var r1 error - if rf, ok := ret.Get(1).(func(models.SavedFilter) error); ok { - r1 = rf(obj) + if rf, ok := ret.Get(1).(func(context.Context, models.SavedFilter) error); ok { + r1 = rf(ctx, obj) } else { r1 = ret.Error(1) } @@ -58,13 +60,13 @@ func (_m *SavedFilterReaderWriter) Create(obj models.SavedFilter) (*models.Saved return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *SavedFilterReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *SavedFilterReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -72,13 +74,13 @@ func (_m *SavedFilterReaderWriter) Destroy(id int) error { return r0 } -// Find provides a mock function with given fields: id -func (_m *SavedFilterReaderWriter) Find(id int) (*models.SavedFilter, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *SavedFilterReaderWriter) Find(ctx context.Context, id int) (*models.SavedFilter, error) { + ret := _m.Called(ctx, id) var r0 *models.SavedFilter - if rf, ok := ret.Get(0).(func(int) *models.SavedFilter); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.SavedFilter); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SavedFilter) @@ -86,8 +88,8 @@ func (_m *SavedFilterReaderWriter) Find(id int) (*models.SavedFilter, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -95,13 +97,13 @@ func (_m *SavedFilterReaderWriter) Find(id int) (*models.SavedFilter, error) { return r0, r1 } -// FindByMode provides a mock function with given fields: mode -func (_m *SavedFilterReaderWriter) FindByMode(mode models.FilterMode) ([]*models.SavedFilter, error) { - ret := _m.Called(mode) +// FindByMode provides a mock function with given fields: ctx, mode +func (_m *SavedFilterReaderWriter) FindByMode(ctx context.Context, mode models.FilterMode) ([]*models.SavedFilter, error) { + ret := _m.Called(ctx, mode) var r0 []*models.SavedFilter - if rf, ok := ret.Get(0).(func(models.FilterMode) []*models.SavedFilter); ok { - r0 = rf(mode) + if rf, ok := ret.Get(0).(func(context.Context, models.FilterMode) []*models.SavedFilter); ok { + r0 = rf(ctx, mode) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SavedFilter) @@ -109,8 +111,8 @@ func (_m *SavedFilterReaderWriter) FindByMode(mode models.FilterMode) ([]*models } var r1 error - if rf, ok := ret.Get(1).(func(models.FilterMode) error); ok { - r1 = rf(mode) + if rf, ok := ret.Get(1).(func(context.Context, models.FilterMode) error); ok { + r1 = rf(ctx, mode) } else { r1 = ret.Error(1) } @@ -118,13 +120,13 @@ func (_m *SavedFilterReaderWriter) FindByMode(mode models.FilterMode) ([]*models return r0, r1 } -// FindDefault provides a mock function with given fields: mode -func (_m *SavedFilterReaderWriter) FindDefault(mode models.FilterMode) (*models.SavedFilter, error) { - ret := _m.Called(mode) +// FindDefault provides a mock function with given fields: ctx, mode +func (_m *SavedFilterReaderWriter) FindDefault(ctx context.Context, mode models.FilterMode) (*models.SavedFilter, error) { + ret := _m.Called(ctx, mode) var r0 *models.SavedFilter - if rf, ok := ret.Get(0).(func(models.FilterMode) *models.SavedFilter); ok { - r0 = rf(mode) + if rf, ok := ret.Get(0).(func(context.Context, models.FilterMode) *models.SavedFilter); ok { + r0 = rf(ctx, mode) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SavedFilter) @@ -132,8 +134,8 @@ func (_m *SavedFilterReaderWriter) FindDefault(mode models.FilterMode) (*models. } var r1 error - if rf, ok := ret.Get(1).(func(models.FilterMode) error); ok { - r1 = rf(mode) + if rf, ok := ret.Get(1).(func(context.Context, models.FilterMode) error); ok { + r1 = rf(ctx, mode) } else { r1 = ret.Error(1) } @@ -141,13 +143,13 @@ func (_m *SavedFilterReaderWriter) FindDefault(mode models.FilterMode) (*models. return r0, r1 } -// FindMany provides a mock function with given fields: ids, ignoreNotFound -func (_m *SavedFilterReaderWriter) FindMany(ids []int, ignoreNotFound bool) ([]*models.SavedFilter, error) { - ret := _m.Called(ids, ignoreNotFound) +// FindMany provides a mock function with given fields: ctx, ids, ignoreNotFound +func (_m *SavedFilterReaderWriter) FindMany(ctx context.Context, ids []int, ignoreNotFound bool) ([]*models.SavedFilter, error) { + ret := _m.Called(ctx, ids, ignoreNotFound) var r0 []*models.SavedFilter - if rf, ok := ret.Get(0).(func([]int, bool) []*models.SavedFilter); ok { - r0 = rf(ids, ignoreNotFound) + if rf, ok := ret.Get(0).(func(context.Context, []int, bool) []*models.SavedFilter); ok { + r0 = rf(ctx, ids, ignoreNotFound) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SavedFilter) @@ -155,8 +157,8 @@ func (_m *SavedFilterReaderWriter) FindMany(ids []int, ignoreNotFound bool) ([]* } var r1 error - if rf, ok := ret.Get(1).(func([]int, bool) error); ok { - r1 = rf(ids, ignoreNotFound) + if rf, ok := ret.Get(1).(func(context.Context, []int, bool) error); ok { + r1 = rf(ctx, ids, ignoreNotFound) } else { r1 = ret.Error(1) } @@ -164,13 +166,13 @@ func (_m *SavedFilterReaderWriter) FindMany(ids []int, ignoreNotFound bool) ([]* return r0, r1 } -// SetDefault provides a mock function with given fields: obj -func (_m *SavedFilterReaderWriter) SetDefault(obj models.SavedFilter) (*models.SavedFilter, error) { - ret := _m.Called(obj) +// SetDefault provides a mock function with given fields: ctx, obj +func (_m *SavedFilterReaderWriter) SetDefault(ctx context.Context, obj models.SavedFilter) (*models.SavedFilter, error) { + ret := _m.Called(ctx, obj) var r0 *models.SavedFilter - if rf, ok := ret.Get(0).(func(models.SavedFilter) *models.SavedFilter); ok { - r0 = rf(obj) + if rf, ok := ret.Get(0).(func(context.Context, models.SavedFilter) *models.SavedFilter); ok { + r0 = rf(ctx, obj) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SavedFilter) @@ -178,8 +180,8 @@ func (_m *SavedFilterReaderWriter) SetDefault(obj models.SavedFilter) (*models.S } var r1 error - if rf, ok := ret.Get(1).(func(models.SavedFilter) error); ok { - r1 = rf(obj) + if rf, ok := ret.Get(1).(func(context.Context, models.SavedFilter) error); ok { + r1 = rf(ctx, obj) } else { r1 = ret.Error(1) } @@ -187,13 +189,13 @@ func (_m *SavedFilterReaderWriter) SetDefault(obj models.SavedFilter) (*models.S return r0, r1 } -// Update provides a mock function with given fields: obj -func (_m *SavedFilterReaderWriter) Update(obj models.SavedFilter) (*models.SavedFilter, error) { - ret := _m.Called(obj) +// Update provides a mock function with given fields: ctx, obj +func (_m *SavedFilterReaderWriter) Update(ctx context.Context, obj models.SavedFilter) (*models.SavedFilter, error) { + ret := _m.Called(ctx, obj) var r0 *models.SavedFilter - if rf, ok := ret.Get(0).(func(models.SavedFilter) *models.SavedFilter); ok { - r0 = rf(obj) + if rf, ok := ret.Get(0).(func(context.Context, models.SavedFilter) *models.SavedFilter); ok { + r0 = rf(ctx, obj) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SavedFilter) @@ -201,8 +203,8 @@ func (_m *SavedFilterReaderWriter) Update(obj models.SavedFilter) (*models.Saved } var r1 error - if rf, ok := ret.Get(1).(func(models.SavedFilter) error); ok { - r1 = rf(obj) + if rf, ok := ret.Get(1).(func(context.Context, models.SavedFilter) error); ok { + r1 = rf(ctx, obj) } else { r1 = ret.Error(1) } diff --git a/pkg/models/mocks/SceneMarkerReaderWriter.go b/pkg/models/mocks/SceneMarkerReaderWriter.go index 2e6fea3a0..695a54391 100644 --- a/pkg/models/mocks/SceneMarkerReaderWriter.go +++ b/pkg/models/mocks/SceneMarkerReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,20 +14,20 @@ type SceneMarkerReaderWriter struct { mock.Mock } -// CountByTagID provides a mock function with given fields: tagID -func (_m *SceneMarkerReaderWriter) CountByTagID(tagID int) (int, error) { - ret := _m.Called(tagID) +// CountByTagID provides a mock function with given fields: ctx, tagID +func (_m *SceneMarkerReaderWriter) CountByTagID(ctx context.Context, tagID int) (int, error) { + ret := _m.Called(ctx, tagID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(tagID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, tagID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(tagID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, tagID) } else { r1 = ret.Error(1) } @@ -33,13 +35,13 @@ func (_m *SceneMarkerReaderWriter) CountByTagID(tagID int) (int, error) { return r0, r1 } -// Create provides a mock function with given fields: newSceneMarker -func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*models.SceneMarker, error) { - ret := _m.Called(newSceneMarker) +// Create provides a mock function with given fields: ctx, newSceneMarker +func (_m *SceneMarkerReaderWriter) Create(ctx context.Context, newSceneMarker models.SceneMarker) (*models.SceneMarker, error) { + ret := _m.Called(ctx, newSceneMarker) var r0 *models.SceneMarker - if rf, ok := ret.Get(0).(func(models.SceneMarker) *models.SceneMarker); ok { - r0 = rf(newSceneMarker) + if rf, ok := ret.Get(0).(func(context.Context, models.SceneMarker) *models.SceneMarker); ok { + r0 = rf(ctx, newSceneMarker) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SceneMarker) @@ -47,8 +49,8 @@ func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*m } var r1 error - if rf, ok := ret.Get(1).(func(models.SceneMarker) error); ok { - r1 = rf(newSceneMarker) + if rf, ok := ret.Get(1).(func(context.Context, models.SceneMarker) error); ok { + r1 = rf(ctx, newSceneMarker) } else { r1 = ret.Error(1) } @@ -56,13 +58,13 @@ func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*m return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *SceneMarkerReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *SceneMarkerReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -70,13 +72,13 @@ func (_m *SceneMarkerReaderWriter) Destroy(id int) error { return r0 } -// Find provides a mock function with given fields: id -func (_m *SceneMarkerReaderWriter) Find(id int) (*models.SceneMarker, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *SceneMarkerReaderWriter) Find(ctx context.Context, id int) (*models.SceneMarker, error) { + ret := _m.Called(ctx, id) var r0 *models.SceneMarker - if rf, ok := ret.Get(0).(func(int) *models.SceneMarker); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.SceneMarker); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SceneMarker) @@ -84,8 +86,8 @@ func (_m *SceneMarkerReaderWriter) Find(id int) (*models.SceneMarker, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -93,13 +95,13 @@ func (_m *SceneMarkerReaderWriter) Find(id int) (*models.SceneMarker, error) { return r0, r1 } -// FindBySceneID provides a mock function with given fields: sceneID -func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMarker, error) { - ret := _m.Called(sceneID) +// FindBySceneID provides a mock function with given fields: ctx, sceneID +func (_m *SceneMarkerReaderWriter) FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error) { + ret := _m.Called(ctx, sceneID) var r0 []*models.SceneMarker - if rf, ok := ret.Get(0).(func(int) []*models.SceneMarker); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.SceneMarker); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SceneMarker) @@ -107,8 +109,8 @@ func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMa } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -116,13 +118,13 @@ func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMa return r0, r1 } -// FindMany provides a mock function with given fields: ids -func (_m *SceneMarkerReaderWriter) FindMany(ids []int) ([]*models.SceneMarker, error) { - ret := _m.Called(ids) +// FindMany provides a mock function with given fields: ctx, ids +func (_m *SceneMarkerReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.SceneMarker, error) { + ret := _m.Called(ctx, ids) var r0 []*models.SceneMarker - if rf, ok := ret.Get(0).(func([]int) []*models.SceneMarker); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.SceneMarker); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SceneMarker) @@ -130,8 +132,8 @@ func (_m *SceneMarkerReaderWriter) FindMany(ids []int) ([]*models.SceneMarker, e } var r1 error - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -139,13 +141,13 @@ func (_m *SceneMarkerReaderWriter) FindMany(ids []int) ([]*models.SceneMarker, e return r0, r1 } -// GetMarkerStrings provides a mock function with given fields: q, sort -func (_m *SceneMarkerReaderWriter) GetMarkerStrings(q *string, sort *string) ([]*models.MarkerStringsResultType, error) { - ret := _m.Called(q, sort) +// GetMarkerStrings provides a mock function with given fields: ctx, q, sort +func (_m *SceneMarkerReaderWriter) GetMarkerStrings(ctx context.Context, q *string, sort *string) ([]*models.MarkerStringsResultType, error) { + ret := _m.Called(ctx, q, sort) var r0 []*models.MarkerStringsResultType - if rf, ok := ret.Get(0).(func(*string, *string) []*models.MarkerStringsResultType); ok { - r0 = rf(q, sort) + if rf, ok := ret.Get(0).(func(context.Context, *string, *string) []*models.MarkerStringsResultType); ok { + r0 = rf(ctx, q, sort) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.MarkerStringsResultType) @@ -153,8 +155,8 @@ func (_m *SceneMarkerReaderWriter) GetMarkerStrings(q *string, sort *string) ([] } var r1 error - if rf, ok := ret.Get(1).(func(*string, *string) error); ok { - r1 = rf(q, sort) + if rf, ok := ret.Get(1).(func(context.Context, *string, *string) error); ok { + r1 = rf(ctx, q, sort) } else { r1 = ret.Error(1) } @@ -162,13 +164,13 @@ func (_m *SceneMarkerReaderWriter) GetMarkerStrings(q *string, sort *string) ([] return r0, r1 } -// GetTagIDs provides a mock function with given fields: imageID -func (_m *SceneMarkerReaderWriter) GetTagIDs(imageID int) ([]int, error) { - ret := _m.Called(imageID) +// GetTagIDs provides a mock function with given fields: ctx, imageID +func (_m *SceneMarkerReaderWriter) GetTagIDs(ctx context.Context, imageID int) ([]int, error) { + ret := _m.Called(ctx, imageID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(imageID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, imageID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -176,8 +178,8 @@ func (_m *SceneMarkerReaderWriter) GetTagIDs(imageID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(imageID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, imageID) } else { r1 = ret.Error(1) } @@ -185,13 +187,13 @@ func (_m *SceneMarkerReaderWriter) GetTagIDs(imageID int) ([]int, error) { return r0, r1 } -// Query provides a mock function with given fields: sceneMarkerFilter, findFilter -func (_m *SceneMarkerReaderWriter) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) { - ret := _m.Called(sceneMarkerFilter, findFilter) +// Query provides a mock function with given fields: ctx, sceneMarkerFilter, findFilter +func (_m *SceneMarkerReaderWriter) Query(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) { + ret := _m.Called(ctx, sceneMarkerFilter, findFilter) var r0 []*models.SceneMarker - if rf, ok := ret.Get(0).(func(*models.SceneMarkerFilterType, *models.FindFilterType) []*models.SceneMarker); ok { - r0 = rf(sceneMarkerFilter, findFilter) + if rf, ok := ret.Get(0).(func(context.Context, *models.SceneMarkerFilterType, *models.FindFilterType) []*models.SceneMarker); ok { + r0 = rf(ctx, sceneMarkerFilter, findFilter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SceneMarker) @@ -199,15 +201,15 @@ func (_m *SceneMarkerReaderWriter) Query(sceneMarkerFilter *models.SceneMarkerFi } var r1 int - if rf, ok := ret.Get(1).(func(*models.SceneMarkerFilterType, *models.FindFilterType) int); ok { - r1 = rf(sceneMarkerFilter, findFilter) + if rf, ok := ret.Get(1).(func(context.Context, *models.SceneMarkerFilterType, *models.FindFilterType) int); ok { + r1 = rf(ctx, sceneMarkerFilter, findFilter) } else { r1 = ret.Get(1).(int) } var r2 error - if rf, ok := ret.Get(2).(func(*models.SceneMarkerFilterType, *models.FindFilterType) error); ok { - r2 = rf(sceneMarkerFilter, findFilter) + if rf, ok := ret.Get(2).(func(context.Context, *models.SceneMarkerFilterType, *models.FindFilterType) error); ok { + r2 = rf(ctx, sceneMarkerFilter, findFilter) } else { r2 = ret.Error(2) } @@ -215,13 +217,13 @@ func (_m *SceneMarkerReaderWriter) Query(sceneMarkerFilter *models.SceneMarkerFi return r0, r1, r2 } -// Update provides a mock function with given fields: updatedSceneMarker -func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error) { - ret := _m.Called(updatedSceneMarker) +// Update provides a mock function with given fields: ctx, updatedSceneMarker +func (_m *SceneMarkerReaderWriter) Update(ctx context.Context, updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error) { + ret := _m.Called(ctx, updatedSceneMarker) var r0 *models.SceneMarker - if rf, ok := ret.Get(0).(func(models.SceneMarker) *models.SceneMarker); ok { - r0 = rf(updatedSceneMarker) + if rf, ok := ret.Get(0).(func(context.Context, models.SceneMarker) *models.SceneMarker); ok { + r0 = rf(ctx, updatedSceneMarker) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SceneMarker) @@ -229,8 +231,8 @@ func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker) } var r1 error - if rf, ok := ret.Get(1).(func(models.SceneMarker) error); ok { - r1 = rf(updatedSceneMarker) + if rf, ok := ret.Get(1).(func(context.Context, models.SceneMarker) error); ok { + r1 = rf(ctx, updatedSceneMarker) } else { r1 = ret.Error(1) } @@ -238,13 +240,13 @@ func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker) return r0, r1 } -// UpdateTags provides a mock function with given fields: markerID, tagIDs -func (_m *SceneMarkerReaderWriter) UpdateTags(markerID int, tagIDs []int) error { - ret := _m.Called(markerID, tagIDs) +// UpdateTags provides a mock function with given fields: ctx, markerID, tagIDs +func (_m *SceneMarkerReaderWriter) UpdateTags(ctx context.Context, markerID int, tagIDs []int) error { + ret := _m.Called(ctx, markerID, tagIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(markerID, tagIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, markerID, tagIDs) } else { r0 = ret.Error(0) } @@ -252,13 +254,13 @@ func (_m *SceneMarkerReaderWriter) UpdateTags(markerID int, tagIDs []int) error return r0 } -// Wall provides a mock function with given fields: q -func (_m *SceneMarkerReaderWriter) Wall(q *string) ([]*models.SceneMarker, error) { - ret := _m.Called(q) +// Wall provides a mock function with given fields: ctx, q +func (_m *SceneMarkerReaderWriter) Wall(ctx context.Context, q *string) ([]*models.SceneMarker, error) { + ret := _m.Called(ctx, q) var r0 []*models.SceneMarker - if rf, ok := ret.Get(0).(func(*string) []*models.SceneMarker); ok { - r0 = rf(q) + if rf, ok := ret.Get(0).(func(context.Context, *string) []*models.SceneMarker); ok { + r0 = rf(ctx, q) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SceneMarker) @@ -266,8 +268,8 @@ func (_m *SceneMarkerReaderWriter) Wall(q *string) ([]*models.SceneMarker, error } var r1 error - if rf, ok := ret.Get(1).(func(*string) error); ok { - r1 = rf(q) + if rf, ok := ret.Get(1).(func(context.Context, *string) error); ok { + r1 = rf(ctx, q) } else { r1 = ret.Error(1) } diff --git a/pkg/models/mocks/SceneReaderWriter.go b/pkg/models/mocks/SceneReaderWriter.go index 0635fd200..a9ab69097 100644 --- a/pkg/models/mocks/SceneReaderWriter.go +++ b/pkg/models/mocks/SceneReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type SceneReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *SceneReaderWriter) All() ([]*models.Scene, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *SceneReaderWriter) All(ctx context.Context) ([]*models.Scene, error) { + ret := _m.Called(ctx) var r0 []*models.Scene - if rf, ok := ret.Get(0).(func() []*models.Scene); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.Scene); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Scene) @@ -26,8 +28,8 @@ func (_m *SceneReaderWriter) All() ([]*models.Scene, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,20 +37,20 @@ func (_m *SceneReaderWriter) All() ([]*models.Scene, error) { return r0, r1 } -// Count provides a mock function with given fields: -func (_m *SceneReaderWriter) Count() (int, error) { - ret := _m.Called() +// Count provides a mock function with given fields: ctx +func (_m *SceneReaderWriter) Count(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -56,20 +58,20 @@ func (_m *SceneReaderWriter) Count() (int, error) { return r0, r1 } -// CountByMovieID provides a mock function with given fields: movieID -func (_m *SceneReaderWriter) CountByMovieID(movieID int) (int, error) { - ret := _m.Called(movieID) +// CountByMovieID provides a mock function with given fields: ctx, movieID +func (_m *SceneReaderWriter) CountByMovieID(ctx context.Context, movieID int) (int, error) { + ret := _m.Called(ctx, movieID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(movieID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, movieID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(movieID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, movieID) } else { r1 = ret.Error(1) } @@ -77,20 +79,20 @@ func (_m *SceneReaderWriter) CountByMovieID(movieID int) (int, error) { return r0, r1 } -// CountByPerformerID provides a mock function with given fields: performerID -func (_m *SceneReaderWriter) CountByPerformerID(performerID int) (int, error) { - ret := _m.Called(performerID) +// CountByPerformerID provides a mock function with given fields: ctx, performerID +func (_m *SceneReaderWriter) CountByPerformerID(ctx context.Context, performerID int) (int, error) { + ret := _m.Called(ctx, performerID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, performerID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -98,20 +100,20 @@ func (_m *SceneReaderWriter) CountByPerformerID(performerID int) (int, error) { return r0, r1 } -// CountByStudioID provides a mock function with given fields: studioID -func (_m *SceneReaderWriter) CountByStudioID(studioID int) (int, error) { - ret := _m.Called(studioID) +// CountByStudioID provides a mock function with given fields: ctx, studioID +func (_m *SceneReaderWriter) CountByStudioID(ctx context.Context, studioID int) (int, error) { + ret := _m.Called(ctx, studioID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(studioID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, studioID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(studioID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, studioID) } else { r1 = ret.Error(1) } @@ -119,20 +121,20 @@ func (_m *SceneReaderWriter) CountByStudioID(studioID int) (int, error) { return r0, r1 } -// CountByTagID provides a mock function with given fields: tagID -func (_m *SceneReaderWriter) CountByTagID(tagID int) (int, error) { - ret := _m.Called(tagID) +// CountByTagID provides a mock function with given fields: ctx, tagID +func (_m *SceneReaderWriter) CountByTagID(ctx context.Context, tagID int) (int, error) { + ret := _m.Called(ctx, tagID) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(tagID) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, tagID) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(tagID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, tagID) } else { r1 = ret.Error(1) } @@ -140,20 +142,20 @@ func (_m *SceneReaderWriter) CountByTagID(tagID int) (int, error) { return r0, r1 } -// CountMissingChecksum provides a mock function with given fields: -func (_m *SceneReaderWriter) CountMissingChecksum() (int, error) { - ret := _m.Called() +// CountMissingChecksum provides a mock function with given fields: ctx +func (_m *SceneReaderWriter) CountMissingChecksum(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -161,20 +163,20 @@ func (_m *SceneReaderWriter) CountMissingChecksum() (int, error) { return r0, r1 } -// CountMissingOSHash provides a mock function with given fields: -func (_m *SceneReaderWriter) CountMissingOSHash() (int, error) { - ret := _m.Called() +// CountMissingOSHash provides a mock function with given fields: ctx +func (_m *SceneReaderWriter) CountMissingOSHash(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -182,13 +184,13 @@ func (_m *SceneReaderWriter) CountMissingOSHash() (int, error) { return r0, r1 } -// Create provides a mock function with given fields: newScene -func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error) { - ret := _m.Called(newScene) +// Create provides a mock function with given fields: ctx, newScene +func (_m *SceneReaderWriter) Create(ctx context.Context, newScene models.Scene) (*models.Scene, error) { + ret := _m.Called(ctx, newScene) var r0 *models.Scene - if rf, ok := ret.Get(0).(func(models.Scene) *models.Scene); ok { - r0 = rf(newScene) + if rf, ok := ret.Get(0).(func(context.Context, models.Scene) *models.Scene); ok { + r0 = rf(ctx, newScene) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Scene) @@ -196,8 +198,8 @@ func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error } var r1 error - if rf, ok := ret.Get(1).(func(models.Scene) error); ok { - r1 = rf(newScene) + if rf, ok := ret.Get(1).(func(context.Context, models.Scene) error); ok { + r1 = rf(ctx, newScene) } else { r1 = ret.Error(1) } @@ -205,20 +207,20 @@ func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error return r0, r1 } -// DecrementOCounter provides a mock function with given fields: id -func (_m *SceneReaderWriter) DecrementOCounter(id int) (int, error) { - ret := _m.Called(id) +// DecrementOCounter provides a mock function with given fields: ctx, id +func (_m *SceneReaderWriter) DecrementOCounter(ctx context.Context, id int) (int, error) { + ret := _m.Called(ctx, id) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -226,13 +228,13 @@ func (_m *SceneReaderWriter) DecrementOCounter(id int) (int, error) { return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *SceneReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *SceneReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -240,13 +242,13 @@ func (_m *SceneReaderWriter) Destroy(id int) error { return r0 } -// DestroyCover provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) DestroyCover(sceneID int) error { - ret := _m.Called(sceneID) +// DestroyCover provides a mock function with given fields: ctx, sceneID +func (_m *SceneReaderWriter) DestroyCover(ctx context.Context, sceneID int) error { + ret := _m.Called(ctx, sceneID) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, sceneID) } else { r0 = ret.Error(0) } @@ -254,20 +256,20 @@ func (_m *SceneReaderWriter) DestroyCover(sceneID int) error { return r0 } -// Duration provides a mock function with given fields: -func (_m *SceneReaderWriter) Duration() (float64, error) { - ret := _m.Called() +// Duration provides a mock function with given fields: ctx +func (_m *SceneReaderWriter) Duration(ctx context.Context) (float64, error) { + ret := _m.Called(ctx) var r0 float64 - if rf, ok := ret.Get(0).(func() float64); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) float64); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(float64) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -275,13 +277,13 @@ func (_m *SceneReaderWriter) Duration() (float64, error) { return r0, r1 } -// Find provides a mock function with given fields: id -func (_m *SceneReaderWriter) Find(id int) (*models.Scene, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *SceneReaderWriter) Find(ctx context.Context, id int) (*models.Scene, error) { + ret := _m.Called(ctx, id) var r0 *models.Scene - if rf, ok := ret.Get(0).(func(int) *models.Scene); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.Scene); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Scene) @@ -289,8 +291,8 @@ func (_m *SceneReaderWriter) Find(id int) (*models.Scene, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -298,13 +300,13 @@ func (_m *SceneReaderWriter) Find(id int) (*models.Scene, error) { return r0, r1 } -// FindByChecksum provides a mock function with given fields: checksum -func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, error) { - ret := _m.Called(checksum) +// FindByChecksum provides a mock function with given fields: ctx, checksum +func (_m *SceneReaderWriter) FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error) { + ret := _m.Called(ctx, checksum) var r0 *models.Scene - if rf, ok := ret.Get(0).(func(string) *models.Scene); ok { - r0 = rf(checksum) + if rf, ok := ret.Get(0).(func(context.Context, string) *models.Scene); ok { + r0 = rf(ctx, checksum) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Scene) @@ -312,8 +314,8 @@ func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, err } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(checksum) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, checksum) } else { r1 = ret.Error(1) } @@ -321,13 +323,13 @@ func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, err return r0, r1 } -// FindByGalleryID provides a mock function with given fields: performerID -func (_m *SceneReaderWriter) FindByGalleryID(performerID int) ([]*models.Scene, error) { - ret := _m.Called(performerID) +// FindByGalleryID provides a mock function with given fields: ctx, performerID +func (_m *SceneReaderWriter) FindByGalleryID(ctx context.Context, performerID int) ([]*models.Scene, error) { + ret := _m.Called(ctx, performerID) var r0 []*models.Scene - if rf, ok := ret.Get(0).(func(int) []*models.Scene); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Scene); ok { + r0 = rf(ctx, performerID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Scene) @@ -335,8 +337,8 @@ func (_m *SceneReaderWriter) FindByGalleryID(performerID int) ([]*models.Scene, } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -344,13 +346,13 @@ func (_m *SceneReaderWriter) FindByGalleryID(performerID int) ([]*models.Scene, return r0, r1 } -// FindByMovieID provides a mock function with given fields: movieID -func (_m *SceneReaderWriter) FindByMovieID(movieID int) ([]*models.Scene, error) { - ret := _m.Called(movieID) +// FindByMovieID provides a mock function with given fields: ctx, movieID +func (_m *SceneReaderWriter) FindByMovieID(ctx context.Context, movieID int) ([]*models.Scene, error) { + ret := _m.Called(ctx, movieID) var r0 []*models.Scene - if rf, ok := ret.Get(0).(func(int) []*models.Scene); ok { - r0 = rf(movieID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Scene); ok { + r0 = rf(ctx, movieID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Scene) @@ -358,8 +360,8 @@ func (_m *SceneReaderWriter) FindByMovieID(movieID int) ([]*models.Scene, error) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(movieID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, movieID) } else { r1 = ret.Error(1) } @@ -367,13 +369,13 @@ func (_m *SceneReaderWriter) FindByMovieID(movieID int) ([]*models.Scene, error) return r0, r1 } -// FindByOSHash provides a mock function with given fields: oshash -func (_m *SceneReaderWriter) FindByOSHash(oshash string) (*models.Scene, error) { - ret := _m.Called(oshash) +// FindByOSHash provides a mock function with given fields: ctx, oshash +func (_m *SceneReaderWriter) FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error) { + ret := _m.Called(ctx, oshash) var r0 *models.Scene - if rf, ok := ret.Get(0).(func(string) *models.Scene); ok { - r0 = rf(oshash) + if rf, ok := ret.Get(0).(func(context.Context, string) *models.Scene); ok { + r0 = rf(ctx, oshash) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Scene) @@ -381,8 +383,8 @@ func (_m *SceneReaderWriter) FindByOSHash(oshash string) (*models.Scene, error) } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(oshash) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, oshash) } else { r1 = ret.Error(1) } @@ -390,13 +392,13 @@ func (_m *SceneReaderWriter) FindByOSHash(oshash string) (*models.Scene, error) return r0, r1 } -// FindByPath provides a mock function with given fields: path -func (_m *SceneReaderWriter) FindByPath(path string) (*models.Scene, error) { - ret := _m.Called(path) +// FindByPath provides a mock function with given fields: ctx, path +func (_m *SceneReaderWriter) FindByPath(ctx context.Context, path string) (*models.Scene, error) { + ret := _m.Called(ctx, path) var r0 *models.Scene - if rf, ok := ret.Get(0).(func(string) *models.Scene); ok { - r0 = rf(path) + if rf, ok := ret.Get(0).(func(context.Context, string) *models.Scene); ok { + r0 = rf(ctx, path) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Scene) @@ -404,8 +406,8 @@ func (_m *SceneReaderWriter) FindByPath(path string) (*models.Scene, error) { } var r1 error - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(path) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, path) } else { r1 = ret.Error(1) } @@ -413,13 +415,13 @@ func (_m *SceneReaderWriter) FindByPath(path string) (*models.Scene, error) { return r0, r1 } -// FindByPerformerID provides a mock function with given fields: performerID -func (_m *SceneReaderWriter) FindByPerformerID(performerID int) ([]*models.Scene, error) { - ret := _m.Called(performerID) +// FindByPerformerID provides a mock function with given fields: ctx, performerID +func (_m *SceneReaderWriter) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Scene, error) { + ret := _m.Called(ctx, performerID) var r0 []*models.Scene - if rf, ok := ret.Get(0).(func(int) []*models.Scene); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Scene); ok { + r0 = rf(ctx, performerID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Scene) @@ -427,8 +429,8 @@ func (_m *SceneReaderWriter) FindByPerformerID(performerID int) ([]*models.Scene } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -436,13 +438,13 @@ func (_m *SceneReaderWriter) FindByPerformerID(performerID int) ([]*models.Scene return r0, r1 } -// FindDuplicates provides a mock function with given fields: distance -func (_m *SceneReaderWriter) FindDuplicates(distance int) ([][]*models.Scene, error) { - ret := _m.Called(distance) +// FindDuplicates provides a mock function with given fields: ctx, distance +func (_m *SceneReaderWriter) FindDuplicates(ctx context.Context, distance int) ([][]*models.Scene, error) { + ret := _m.Called(ctx, distance) var r0 [][]*models.Scene - if rf, ok := ret.Get(0).(func(int) [][]*models.Scene); ok { - r0 = rf(distance) + if rf, ok := ret.Get(0).(func(context.Context, int) [][]*models.Scene); ok { + r0 = rf(ctx, distance) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([][]*models.Scene) @@ -450,8 +452,8 @@ func (_m *SceneReaderWriter) FindDuplicates(distance int) ([][]*models.Scene, er } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(distance) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, distance) } else { r1 = ret.Error(1) } @@ -459,13 +461,13 @@ func (_m *SceneReaderWriter) FindDuplicates(distance int) ([][]*models.Scene, er return r0, r1 } -// FindMany provides a mock function with given fields: ids -func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) { - ret := _m.Called(ids) +// FindMany provides a mock function with given fields: ctx, ids +func (_m *SceneReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Scene, error) { + ret := _m.Called(ctx, ids) var r0 []*models.Scene - if rf, ok := ret.Get(0).(func([]int) []*models.Scene); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Scene); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Scene) @@ -473,8 +475,8 @@ func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) { } var r1 error - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -482,13 +484,13 @@ func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) { return r0, r1 } -// GetCaptions provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) GetCaptions(sceneID int) ([]*models.SceneCaption, error) { - ret := _m.Called(sceneID) +// GetCaptions provides a mock function with given fields: ctx, sceneID +func (_m *SceneReaderWriter) GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error) { + ret := _m.Called(ctx, sceneID) var r0 []*models.SceneCaption - if rf, ok := ret.Get(0).(func(int) []*models.SceneCaption); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.SceneCaption); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.SceneCaption) @@ -496,8 +498,8 @@ func (_m *SceneReaderWriter) GetCaptions(sceneID int) ([]*models.SceneCaption, e } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -505,13 +507,13 @@ func (_m *SceneReaderWriter) GetCaptions(sceneID int) ([]*models.SceneCaption, e return r0, r1 } -// GetCover provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) GetCover(sceneID int) ([]byte, error) { - ret := _m.Called(sceneID) +// GetCover provides a mock function with given fields: ctx, sceneID +func (_m *SceneReaderWriter) GetCover(ctx context.Context, sceneID int) ([]byte, error) { + ret := _m.Called(ctx, sceneID) var r0 []byte - if rf, ok := ret.Get(0).(func(int) []byte); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) @@ -519,8 +521,8 @@ func (_m *SceneReaderWriter) GetCover(sceneID int) ([]byte, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -528,13 +530,13 @@ func (_m *SceneReaderWriter) GetCover(sceneID int) ([]byte, error) { return r0, r1 } -// GetGalleryIDs provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) GetGalleryIDs(sceneID int) ([]int, error) { - ret := _m.Called(sceneID) +// GetGalleryIDs provides a mock function with given fields: ctx, sceneID +func (_m *SceneReaderWriter) GetGalleryIDs(ctx context.Context, sceneID int) ([]int, error) { + ret := _m.Called(ctx, sceneID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -542,8 +544,8 @@ func (_m *SceneReaderWriter) GetGalleryIDs(sceneID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -551,13 +553,13 @@ func (_m *SceneReaderWriter) GetGalleryIDs(sceneID int) ([]int, error) { return r0, r1 } -// GetMovies provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) GetMovies(sceneID int) ([]models.MoviesScenes, error) { - ret := _m.Called(sceneID) +// GetMovies provides a mock function with given fields: ctx, sceneID +func (_m *SceneReaderWriter) GetMovies(ctx context.Context, sceneID int) ([]models.MoviesScenes, error) { + ret := _m.Called(ctx, sceneID) var r0 []models.MoviesScenes - if rf, ok := ret.Get(0).(func(int) []models.MoviesScenes); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []models.MoviesScenes); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]models.MoviesScenes) @@ -565,8 +567,8 @@ func (_m *SceneReaderWriter) GetMovies(sceneID int) ([]models.MoviesScenes, erro } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -574,13 +576,13 @@ func (_m *SceneReaderWriter) GetMovies(sceneID int) ([]models.MoviesScenes, erro return r0, r1 } -// GetPerformerIDs provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) GetPerformerIDs(sceneID int) ([]int, error) { - ret := _m.Called(sceneID) +// GetPerformerIDs provides a mock function with given fields: ctx, sceneID +func (_m *SceneReaderWriter) GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error) { + ret := _m.Called(ctx, sceneID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -588,8 +590,8 @@ func (_m *SceneReaderWriter) GetPerformerIDs(sceneID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -597,13 +599,13 @@ func (_m *SceneReaderWriter) GetPerformerIDs(sceneID int) ([]int, error) { return r0, r1 } -// GetStashIDs provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) GetStashIDs(sceneID int) ([]*models.StashID, error) { - ret := _m.Called(sceneID) +// GetStashIDs provides a mock function with given fields: ctx, sceneID +func (_m *SceneReaderWriter) GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error) { + ret := _m.Called(ctx, sceneID) var r0 []*models.StashID - if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.StashID); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.StashID) @@ -611,8 +613,8 @@ func (_m *SceneReaderWriter) GetStashIDs(sceneID int) ([]*models.StashID, error) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -620,13 +622,13 @@ func (_m *SceneReaderWriter) GetStashIDs(sceneID int) ([]*models.StashID, error) return r0, r1 } -// GetTagIDs provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) GetTagIDs(sceneID int) ([]int, error) { - ret := _m.Called(sceneID) +// GetTagIDs provides a mock function with given fields: ctx, sceneID +func (_m *SceneReaderWriter) GetTagIDs(ctx context.Context, sceneID int) ([]int, error) { + ret := _m.Called(ctx, sceneID) var r0 []int - if rf, ok := ret.Get(0).(func(int) []int); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]int) @@ -634,8 +636,8 @@ func (_m *SceneReaderWriter) GetTagIDs(sceneID int) ([]int, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -643,20 +645,20 @@ func (_m *SceneReaderWriter) GetTagIDs(sceneID int) ([]int, error) { return r0, r1 } -// IncrementOCounter provides a mock function with given fields: id -func (_m *SceneReaderWriter) IncrementOCounter(id int) (int, error) { - ret := _m.Called(id) +// IncrementOCounter provides a mock function with given fields: ctx, id +func (_m *SceneReaderWriter) IncrementOCounter(ctx context.Context, id int) (int, error) { + ret := _m.Called(ctx, id) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -664,13 +666,13 @@ func (_m *SceneReaderWriter) IncrementOCounter(id int) (int, error) { return r0, r1 } -// Query provides a mock function with given fields: options -func (_m *SceneReaderWriter) Query(options models.SceneQueryOptions) (*models.SceneQueryResult, error) { - ret := _m.Called(options) +// Query provides a mock function with given fields: ctx, options +func (_m *SceneReaderWriter) Query(ctx context.Context, options models.SceneQueryOptions) (*models.SceneQueryResult, error) { + ret := _m.Called(ctx, options) var r0 *models.SceneQueryResult - if rf, ok := ret.Get(0).(func(models.SceneQueryOptions) *models.SceneQueryResult); ok { - r0 = rf(options) + if rf, ok := ret.Get(0).(func(context.Context, models.SceneQueryOptions) *models.SceneQueryResult); ok { + r0 = rf(ctx, options) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.SceneQueryResult) @@ -678,8 +680,8 @@ func (_m *SceneReaderWriter) Query(options models.SceneQueryOptions) (*models.Sc } var r1 error - if rf, ok := ret.Get(1).(func(models.SceneQueryOptions) error); ok { - r1 = rf(options) + if rf, ok := ret.Get(1).(func(context.Context, models.SceneQueryOptions) error); ok { + r1 = rf(ctx, options) } else { r1 = ret.Error(1) } @@ -687,20 +689,20 @@ func (_m *SceneReaderWriter) Query(options models.SceneQueryOptions) (*models.Sc return r0, r1 } -// ResetOCounter provides a mock function with given fields: id -func (_m *SceneReaderWriter) ResetOCounter(id int) (int, error) { - ret := _m.Called(id) +// ResetOCounter provides a mock function with given fields: ctx, id +func (_m *SceneReaderWriter) ResetOCounter(ctx context.Context, id int) (int, error) { + ret := _m.Called(ctx, id) var r0 int - if rf, ok := ret.Get(0).(func(int) int); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -708,20 +710,20 @@ func (_m *SceneReaderWriter) ResetOCounter(id int) (int, error) { return r0, r1 } -// Size provides a mock function with given fields: -func (_m *SceneReaderWriter) Size() (float64, error) { - ret := _m.Called() +// Size provides a mock function with given fields: ctx +func (_m *SceneReaderWriter) Size(ctx context.Context) (float64, error) { + ret := _m.Called(ctx) var r0 float64 - if rf, ok := ret.Get(0).(func() float64); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) float64); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(float64) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -729,13 +731,13 @@ func (_m *SceneReaderWriter) Size() (float64, error) { return r0, r1 } -// Update provides a mock function with given fields: updatedScene -func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.Scene, error) { - ret := _m.Called(updatedScene) +// Update provides a mock function with given fields: ctx, updatedScene +func (_m *SceneReaderWriter) Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error) { + ret := _m.Called(ctx, updatedScene) var r0 *models.Scene - if rf, ok := ret.Get(0).(func(models.ScenePartial) *models.Scene); ok { - r0 = rf(updatedScene) + if rf, ok := ret.Get(0).(func(context.Context, models.ScenePartial) *models.Scene); ok { + r0 = rf(ctx, updatedScene) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Scene) @@ -743,8 +745,8 @@ func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.S } var r1 error - if rf, ok := ret.Get(1).(func(models.ScenePartial) error); ok { - r1 = rf(updatedScene) + if rf, ok := ret.Get(1).(func(context.Context, models.ScenePartial) error); ok { + r1 = rf(ctx, updatedScene) } else { r1 = ret.Error(1) } @@ -752,13 +754,13 @@ func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.S return r0, r1 } -// UpdateCaptions provides a mock function with given fields: id, captions -func (_m *SceneReaderWriter) UpdateCaptions(id int, captions []*models.SceneCaption) error { - ret := _m.Called(id, captions) +// UpdateCaptions provides a mock function with given fields: ctx, id, captions +func (_m *SceneReaderWriter) UpdateCaptions(ctx context.Context, id int, captions []*models.SceneCaption) error { + ret := _m.Called(ctx, id, captions) var r0 error - if rf, ok := ret.Get(0).(func(int, []*models.SceneCaption) error); ok { - r0 = rf(id, captions) + if rf, ok := ret.Get(0).(func(context.Context, int, []*models.SceneCaption) error); ok { + r0 = rf(ctx, id, captions) } else { r0 = ret.Error(0) } @@ -766,13 +768,13 @@ func (_m *SceneReaderWriter) UpdateCaptions(id int, captions []*models.SceneCapt return r0 } -// UpdateCover provides a mock function with given fields: sceneID, cover -func (_m *SceneReaderWriter) UpdateCover(sceneID int, cover []byte) error { - ret := _m.Called(sceneID, cover) +// UpdateCover provides a mock function with given fields: ctx, sceneID, cover +func (_m *SceneReaderWriter) UpdateCover(ctx context.Context, sceneID int, cover []byte) error { + ret := _m.Called(ctx, sceneID, cover) var r0 error - if rf, ok := ret.Get(0).(func(int, []byte) error); ok { - r0 = rf(sceneID, cover) + if rf, ok := ret.Get(0).(func(context.Context, int, []byte) error); ok { + r0 = rf(ctx, sceneID, cover) } else { r0 = ret.Error(0) } @@ -780,13 +782,13 @@ func (_m *SceneReaderWriter) UpdateCover(sceneID int, cover []byte) error { return r0 } -// UpdateFileModTime provides a mock function with given fields: id, modTime -func (_m *SceneReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error { - ret := _m.Called(id, modTime) +// UpdateFileModTime provides a mock function with given fields: ctx, id, modTime +func (_m *SceneReaderWriter) UpdateFileModTime(ctx context.Context, id int, modTime models.NullSQLiteTimestamp) error { + ret := _m.Called(ctx, id, modTime) var r0 error - if rf, ok := ret.Get(0).(func(int, models.NullSQLiteTimestamp) error); ok { - r0 = rf(id, modTime) + if rf, ok := ret.Get(0).(func(context.Context, int, models.NullSQLiteTimestamp) error); ok { + r0 = rf(ctx, id, modTime) } else { r0 = ret.Error(0) } @@ -794,13 +796,13 @@ func (_m *SceneReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLite return r0 } -// UpdateFull provides a mock function with given fields: updatedScene -func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scene, error) { - ret := _m.Called(updatedScene) +// UpdateFull provides a mock function with given fields: ctx, updatedScene +func (_m *SceneReaderWriter) UpdateFull(ctx context.Context, updatedScene models.Scene) (*models.Scene, error) { + ret := _m.Called(ctx, updatedScene) var r0 *models.Scene - if rf, ok := ret.Get(0).(func(models.Scene) *models.Scene); ok { - r0 = rf(updatedScene) + if rf, ok := ret.Get(0).(func(context.Context, models.Scene) *models.Scene); ok { + r0 = rf(ctx, updatedScene) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Scene) @@ -808,8 +810,8 @@ func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scen } var r1 error - if rf, ok := ret.Get(1).(func(models.Scene) error); ok { - r1 = rf(updatedScene) + if rf, ok := ret.Get(1).(func(context.Context, models.Scene) error); ok { + r1 = rf(ctx, updatedScene) } else { r1 = ret.Error(1) } @@ -817,13 +819,13 @@ func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scen return r0, r1 } -// UpdateGalleries provides a mock function with given fields: sceneID, galleryIDs -func (_m *SceneReaderWriter) UpdateGalleries(sceneID int, galleryIDs []int) error { - ret := _m.Called(sceneID, galleryIDs) +// UpdateGalleries provides a mock function with given fields: ctx, sceneID, galleryIDs +func (_m *SceneReaderWriter) UpdateGalleries(ctx context.Context, sceneID int, galleryIDs []int) error { + ret := _m.Called(ctx, sceneID, galleryIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(sceneID, galleryIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, sceneID, galleryIDs) } else { r0 = ret.Error(0) } @@ -831,13 +833,13 @@ func (_m *SceneReaderWriter) UpdateGalleries(sceneID int, galleryIDs []int) erro return r0 } -// UpdateMovies provides a mock function with given fields: sceneID, movies -func (_m *SceneReaderWriter) UpdateMovies(sceneID int, movies []models.MoviesScenes) error { - ret := _m.Called(sceneID, movies) +// UpdateMovies provides a mock function with given fields: ctx, sceneID, movies +func (_m *SceneReaderWriter) UpdateMovies(ctx context.Context, sceneID int, movies []models.MoviesScenes) error { + ret := _m.Called(ctx, sceneID, movies) var r0 error - if rf, ok := ret.Get(0).(func(int, []models.MoviesScenes) error); ok { - r0 = rf(sceneID, movies) + if rf, ok := ret.Get(0).(func(context.Context, int, []models.MoviesScenes) error); ok { + r0 = rf(ctx, sceneID, movies) } else { r0 = ret.Error(0) } @@ -845,13 +847,13 @@ func (_m *SceneReaderWriter) UpdateMovies(sceneID int, movies []models.MoviesSce return r0 } -// UpdatePerformers provides a mock function with given fields: sceneID, performerIDs -func (_m *SceneReaderWriter) UpdatePerformers(sceneID int, performerIDs []int) error { - ret := _m.Called(sceneID, performerIDs) +// UpdatePerformers provides a mock function with given fields: ctx, sceneID, performerIDs +func (_m *SceneReaderWriter) UpdatePerformers(ctx context.Context, sceneID int, performerIDs []int) error { + ret := _m.Called(ctx, sceneID, performerIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(sceneID, performerIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, sceneID, performerIDs) } else { r0 = ret.Error(0) } @@ -859,13 +861,13 @@ func (_m *SceneReaderWriter) UpdatePerformers(sceneID int, performerIDs []int) e return r0 } -// UpdateStashIDs provides a mock function with given fields: sceneID, stashIDs -func (_m *SceneReaderWriter) UpdateStashIDs(sceneID int, stashIDs []models.StashID) error { - ret := _m.Called(sceneID, stashIDs) +// UpdateStashIDs provides a mock function with given fields: ctx, sceneID, stashIDs +func (_m *SceneReaderWriter) UpdateStashIDs(ctx context.Context, sceneID int, stashIDs []models.StashID) error { + ret := _m.Called(ctx, sceneID, stashIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok { - r0 = rf(sceneID, stashIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok { + r0 = rf(ctx, sceneID, stashIDs) } else { r0 = ret.Error(0) } @@ -873,13 +875,13 @@ func (_m *SceneReaderWriter) UpdateStashIDs(sceneID int, stashIDs []models.Stash return r0 } -// UpdateTags provides a mock function with given fields: sceneID, tagIDs -func (_m *SceneReaderWriter) UpdateTags(sceneID int, tagIDs []int) error { - ret := _m.Called(sceneID, tagIDs) +// UpdateTags provides a mock function with given fields: ctx, sceneID, tagIDs +func (_m *SceneReaderWriter) UpdateTags(ctx context.Context, sceneID int, tagIDs []int) error { + ret := _m.Called(ctx, sceneID, tagIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(sceneID, tagIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, sceneID, tagIDs) } else { r0 = ret.Error(0) } @@ -887,13 +889,13 @@ func (_m *SceneReaderWriter) UpdateTags(sceneID int, tagIDs []int) error { return r0 } -// Wall provides a mock function with given fields: q -func (_m *SceneReaderWriter) Wall(q *string) ([]*models.Scene, error) { - ret := _m.Called(q) +// Wall provides a mock function with given fields: ctx, q +func (_m *SceneReaderWriter) Wall(ctx context.Context, q *string) ([]*models.Scene, error) { + ret := _m.Called(ctx, q) var r0 []*models.Scene - if rf, ok := ret.Get(0).(func(*string) []*models.Scene); ok { - r0 = rf(q) + if rf, ok := ret.Get(0).(func(context.Context, *string) []*models.Scene); ok { + r0 = rf(ctx, q) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Scene) @@ -901,8 +903,8 @@ func (_m *SceneReaderWriter) Wall(q *string) ([]*models.Scene, error) { } var r1 error - if rf, ok := ret.Get(1).(func(*string) error); ok { - r1 = rf(q) + if rf, ok := ret.Get(1).(func(context.Context, *string) error); ok { + r1 = rf(ctx, q) } else { r1 = ret.Error(1) } diff --git a/pkg/models/mocks/ScrapedItemReaderWriter.go b/pkg/models/mocks/ScrapedItemReaderWriter.go index e06b7451d..7157ab855 100644 --- a/pkg/models/mocks/ScrapedItemReaderWriter.go +++ b/pkg/models/mocks/ScrapedItemReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type ScrapedItemReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *ScrapedItemReaderWriter) All() ([]*models.ScrapedItem, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *ScrapedItemReaderWriter) All(ctx context.Context) ([]*models.ScrapedItem, error) { + ret := _m.Called(ctx) var r0 []*models.ScrapedItem - if rf, ok := ret.Get(0).(func() []*models.ScrapedItem); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.ScrapedItem); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.ScrapedItem) @@ -26,8 +28,8 @@ func (_m *ScrapedItemReaderWriter) All() ([]*models.ScrapedItem, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,13 +37,13 @@ func (_m *ScrapedItemReaderWriter) All() ([]*models.ScrapedItem, error) { return r0, r1 } -// Create provides a mock function with given fields: newObject -func (_m *ScrapedItemReaderWriter) Create(newObject models.ScrapedItem) (*models.ScrapedItem, error) { - ret := _m.Called(newObject) +// Create provides a mock function with given fields: ctx, newObject +func (_m *ScrapedItemReaderWriter) Create(ctx context.Context, newObject models.ScrapedItem) (*models.ScrapedItem, error) { + ret := _m.Called(ctx, newObject) var r0 *models.ScrapedItem - if rf, ok := ret.Get(0).(func(models.ScrapedItem) *models.ScrapedItem); ok { - r0 = rf(newObject) + if rf, ok := ret.Get(0).(func(context.Context, models.ScrapedItem) *models.ScrapedItem); ok { + r0 = rf(ctx, newObject) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.ScrapedItem) @@ -49,8 +51,8 @@ func (_m *ScrapedItemReaderWriter) Create(newObject models.ScrapedItem) (*models } var r1 error - if rf, ok := ret.Get(1).(func(models.ScrapedItem) error); ok { - r1 = rf(newObject) + if rf, ok := ret.Get(1).(func(context.Context, models.ScrapedItem) error); ok { + r1 = rf(ctx, newObject) } else { r1 = ret.Error(1) } diff --git a/pkg/models/mocks/StudioReaderWriter.go b/pkg/models/mocks/StudioReaderWriter.go index c15c73719..bc8891983 100644 --- a/pkg/models/mocks/StudioReaderWriter.go +++ b/pkg/models/mocks/StudioReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type StudioReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *StudioReaderWriter) All() ([]*models.Studio, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *StudioReaderWriter) All(ctx context.Context) ([]*models.Studio, error) { + ret := _m.Called(ctx) var r0 []*models.Studio - if rf, ok := ret.Get(0).(func() []*models.Studio); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.Studio); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Studio) @@ -26,8 +28,8 @@ func (_m *StudioReaderWriter) All() ([]*models.Studio, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,20 +37,20 @@ func (_m *StudioReaderWriter) All() ([]*models.Studio, error) { return r0, r1 } -// Count provides a mock function with given fields: -func (_m *StudioReaderWriter) Count() (int, error) { - ret := _m.Called() +// Count provides a mock function with given fields: ctx +func (_m *StudioReaderWriter) Count(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -56,13 +58,13 @@ func (_m *StudioReaderWriter) Count() (int, error) { return r0, r1 } -// Create provides a mock function with given fields: newStudio -func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, error) { - ret := _m.Called(newStudio) +// Create provides a mock function with given fields: ctx, newStudio +func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error) { + ret := _m.Called(ctx, newStudio) var r0 *models.Studio - if rf, ok := ret.Get(0).(func(models.Studio) *models.Studio); ok { - r0 = rf(newStudio) + if rf, ok := ret.Get(0).(func(context.Context, models.Studio) *models.Studio); ok { + r0 = rf(ctx, newStudio) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Studio) @@ -70,8 +72,8 @@ func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, e } var r1 error - if rf, ok := ret.Get(1).(func(models.Studio) error); ok { - r1 = rf(newStudio) + if rf, ok := ret.Get(1).(func(context.Context, models.Studio) error); ok { + r1 = rf(ctx, newStudio) } else { r1 = ret.Error(1) } @@ -79,13 +81,13 @@ func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, e return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *StudioReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *StudioReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -93,13 +95,13 @@ func (_m *StudioReaderWriter) Destroy(id int) error { return r0 } -// DestroyImage provides a mock function with given fields: studioID -func (_m *StudioReaderWriter) DestroyImage(studioID int) error { - ret := _m.Called(studioID) +// DestroyImage provides a mock function with given fields: ctx, studioID +func (_m *StudioReaderWriter) DestroyImage(ctx context.Context, studioID int) error { + ret := _m.Called(ctx, studioID) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(studioID) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, studioID) } else { r0 = ret.Error(0) } @@ -107,13 +109,13 @@ func (_m *StudioReaderWriter) DestroyImage(studioID int) error { return r0 } -// Find provides a mock function with given fields: id -func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *StudioReaderWriter) Find(ctx context.Context, id int) (*models.Studio, error) { + ret := _m.Called(ctx, id) var r0 *models.Studio - if rf, ok := ret.Get(0).(func(int) *models.Studio); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.Studio); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Studio) @@ -121,8 +123,8 @@ func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -130,13 +132,13 @@ func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) { return r0, r1 } -// FindByName provides a mock function with given fields: name, nocase -func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Studio, error) { - ret := _m.Called(name, nocase) +// FindByName provides a mock function with given fields: ctx, name, nocase +func (_m *StudioReaderWriter) FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error) { + ret := _m.Called(ctx, name, nocase) var r0 *models.Studio - if rf, ok := ret.Get(0).(func(string, bool) *models.Studio); ok { - r0 = rf(name, nocase) + if rf, ok := ret.Get(0).(func(context.Context, string, bool) *models.Studio); ok { + r0 = rf(ctx, name, nocase) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Studio) @@ -144,8 +146,8 @@ func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Stud } var r1 error - if rf, ok := ret.Get(1).(func(string, bool) error); ok { - r1 = rf(name, nocase) + if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok { + r1 = rf(ctx, name, nocase) } else { r1 = ret.Error(1) } @@ -153,13 +155,13 @@ func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Stud return r0, r1 } -// FindByStashID provides a mock function with given fields: stashID -func (_m *StudioReaderWriter) FindByStashID(stashID models.StashID) ([]*models.Studio, error) { - ret := _m.Called(stashID) +// FindByStashID provides a mock function with given fields: ctx, stashID +func (_m *StudioReaderWriter) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) { + ret := _m.Called(ctx, stashID) var r0 []*models.Studio - if rf, ok := ret.Get(0).(func(models.StashID) []*models.Studio); ok { - r0 = rf(stashID) + if rf, ok := ret.Get(0).(func(context.Context, models.StashID) []*models.Studio); ok { + r0 = rf(ctx, stashID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Studio) @@ -167,8 +169,8 @@ func (_m *StudioReaderWriter) FindByStashID(stashID models.StashID) ([]*models.S } var r1 error - if rf, ok := ret.Get(1).(func(models.StashID) error); ok { - r1 = rf(stashID) + if rf, ok := ret.Get(1).(func(context.Context, models.StashID) error); ok { + r1 = rf(ctx, stashID) } else { r1 = ret.Error(1) } @@ -176,13 +178,13 @@ func (_m *StudioReaderWriter) FindByStashID(stashID models.StashID) ([]*models.S return r0, r1 } -// FindChildren provides a mock function with given fields: id -func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) { - ret := _m.Called(id) +// FindChildren provides a mock function with given fields: ctx, id +func (_m *StudioReaderWriter) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) { + ret := _m.Called(ctx, id) var r0 []*models.Studio - if rf, ok := ret.Get(0).(func(int) []*models.Studio); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Studio); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Studio) @@ -190,8 +192,8 @@ func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -199,13 +201,13 @@ func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) { return r0, r1 } -// FindMany provides a mock function with given fields: ids -func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) { - ret := _m.Called(ids) +// FindMany provides a mock function with given fields: ctx, ids +func (_m *StudioReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Studio, error) { + ret := _m.Called(ctx, ids) var r0 []*models.Studio - if rf, ok := ret.Get(0).(func([]int) []*models.Studio); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Studio); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Studio) @@ -213,8 +215,8 @@ func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) { } var r1 error - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -222,13 +224,13 @@ func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) { return r0, r1 } -// GetAliases provides a mock function with given fields: studioID -func (_m *StudioReaderWriter) GetAliases(studioID int) ([]string, error) { - ret := _m.Called(studioID) +// GetAliases provides a mock function with given fields: ctx, studioID +func (_m *StudioReaderWriter) GetAliases(ctx context.Context, studioID int) ([]string, error) { + ret := _m.Called(ctx, studioID) var r0 []string - if rf, ok := ret.Get(0).(func(int) []string); ok { - r0 = rf(studioID) + if rf, ok := ret.Get(0).(func(context.Context, int) []string); ok { + r0 = rf(ctx, studioID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]string) @@ -236,8 +238,8 @@ func (_m *StudioReaderWriter) GetAliases(studioID int) ([]string, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(studioID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, studioID) } else { r1 = ret.Error(1) } @@ -245,13 +247,13 @@ func (_m *StudioReaderWriter) GetAliases(studioID int) ([]string, error) { return r0, r1 } -// GetImage provides a mock function with given fields: studioID -func (_m *StudioReaderWriter) GetImage(studioID int) ([]byte, error) { - ret := _m.Called(studioID) +// GetImage provides a mock function with given fields: ctx, studioID +func (_m *StudioReaderWriter) GetImage(ctx context.Context, studioID int) ([]byte, error) { + ret := _m.Called(ctx, studioID) var r0 []byte - if rf, ok := ret.Get(0).(func(int) []byte); ok { - r0 = rf(studioID) + if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok { + r0 = rf(ctx, studioID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) @@ -259,8 +261,8 @@ func (_m *StudioReaderWriter) GetImage(studioID int) ([]byte, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(studioID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, studioID) } else { r1 = ret.Error(1) } @@ -268,13 +270,13 @@ func (_m *StudioReaderWriter) GetImage(studioID int) ([]byte, error) { return r0, r1 } -// GetStashIDs provides a mock function with given fields: studioID -func (_m *StudioReaderWriter) GetStashIDs(studioID int) ([]*models.StashID, error) { - ret := _m.Called(studioID) +// GetStashIDs provides a mock function with given fields: ctx, studioID +func (_m *StudioReaderWriter) GetStashIDs(ctx context.Context, studioID int) ([]*models.StashID, error) { + ret := _m.Called(ctx, studioID) var r0 []*models.StashID - if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok { - r0 = rf(studioID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.StashID); ok { + r0 = rf(ctx, studioID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.StashID) @@ -282,8 +284,8 @@ func (_m *StudioReaderWriter) GetStashIDs(studioID int) ([]*models.StashID, erro } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(studioID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, studioID) } else { r1 = ret.Error(1) } @@ -291,20 +293,20 @@ func (_m *StudioReaderWriter) GetStashIDs(studioID int) ([]*models.StashID, erro return r0, r1 } -// HasImage provides a mock function with given fields: studioID -func (_m *StudioReaderWriter) HasImage(studioID int) (bool, error) { - ret := _m.Called(studioID) +// HasImage provides a mock function with given fields: ctx, studioID +func (_m *StudioReaderWriter) HasImage(ctx context.Context, studioID int) (bool, error) { + ret := _m.Called(ctx, studioID) var r0 bool - if rf, ok := ret.Get(0).(func(int) bool); ok { - r0 = rf(studioID) + if rf, ok := ret.Get(0).(func(context.Context, int) bool); ok { + r0 = rf(ctx, studioID) } else { r0 = ret.Get(0).(bool) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(studioID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, studioID) } else { r1 = ret.Error(1) } @@ -312,13 +314,13 @@ func (_m *StudioReaderWriter) HasImage(studioID int) (bool, error) { return r0, r1 } -// Query provides a mock function with given fields: studioFilter, findFilter -func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) { - ret := _m.Called(studioFilter, findFilter) +// Query provides a mock function with given fields: ctx, studioFilter, findFilter +func (_m *StudioReaderWriter) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) { + ret := _m.Called(ctx, studioFilter, findFilter) var r0 []*models.Studio - if rf, ok := ret.Get(0).(func(*models.StudioFilterType, *models.FindFilterType) []*models.Studio); ok { - r0 = rf(studioFilter, findFilter) + if rf, ok := ret.Get(0).(func(context.Context, *models.StudioFilterType, *models.FindFilterType) []*models.Studio); ok { + r0 = rf(ctx, studioFilter, findFilter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Studio) @@ -326,15 +328,15 @@ func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findF } var r1 int - if rf, ok := ret.Get(1).(func(*models.StudioFilterType, *models.FindFilterType) int); ok { - r1 = rf(studioFilter, findFilter) + if rf, ok := ret.Get(1).(func(context.Context, *models.StudioFilterType, *models.FindFilterType) int); ok { + r1 = rf(ctx, studioFilter, findFilter) } else { r1 = ret.Get(1).(int) } var r2 error - if rf, ok := ret.Get(2).(func(*models.StudioFilterType, *models.FindFilterType) error); ok { - r2 = rf(studioFilter, findFilter) + if rf, ok := ret.Get(2).(func(context.Context, *models.StudioFilterType, *models.FindFilterType) error); ok { + r2 = rf(ctx, studioFilter, findFilter) } else { r2 = ret.Error(2) } @@ -342,13 +344,13 @@ func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findF return r0, r1, r2 } -// QueryForAutoTag provides a mock function with given fields: words -func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio, error) { - ret := _m.Called(words) +// QueryForAutoTag provides a mock function with given fields: ctx, words +func (_m *StudioReaderWriter) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) { + ret := _m.Called(ctx, words) var r0 []*models.Studio - if rf, ok := ret.Get(0).(func([]string) []*models.Studio); ok { - r0 = rf(words) + if rf, ok := ret.Get(0).(func(context.Context, []string) []*models.Studio); ok { + r0 = rf(ctx, words) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Studio) @@ -356,8 +358,8 @@ func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio, } var r1 error - if rf, ok := ret.Get(1).(func([]string) error); ok { - r1 = rf(words) + if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { + r1 = rf(ctx, words) } else { r1 = ret.Error(1) } @@ -365,13 +367,13 @@ func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio, return r0, r1 } -// Update provides a mock function with given fields: updatedStudio -func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*models.Studio, error) { - ret := _m.Called(updatedStudio) +// Update provides a mock function with given fields: ctx, updatedStudio +func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio models.StudioPartial) (*models.Studio, error) { + ret := _m.Called(ctx, updatedStudio) var r0 *models.Studio - if rf, ok := ret.Get(0).(func(models.StudioPartial) *models.Studio); ok { - r0 = rf(updatedStudio) + if rf, ok := ret.Get(0).(func(context.Context, models.StudioPartial) *models.Studio); ok { + r0 = rf(ctx, updatedStudio) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Studio) @@ -379,8 +381,8 @@ func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*model } var r1 error - if rf, ok := ret.Get(1).(func(models.StudioPartial) error); ok { - r1 = rf(updatedStudio) + if rf, ok := ret.Get(1).(func(context.Context, models.StudioPartial) error); ok { + r1 = rf(ctx, updatedStudio) } else { r1 = ret.Error(1) } @@ -388,13 +390,13 @@ func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*model return r0, r1 } -// UpdateAliases provides a mock function with given fields: studioID, aliases -func (_m *StudioReaderWriter) UpdateAliases(studioID int, aliases []string) error { - ret := _m.Called(studioID, aliases) +// UpdateAliases provides a mock function with given fields: ctx, studioID, aliases +func (_m *StudioReaderWriter) UpdateAliases(ctx context.Context, studioID int, aliases []string) error { + ret := _m.Called(ctx, studioID, aliases) var r0 error - if rf, ok := ret.Get(0).(func(int, []string) error); ok { - r0 = rf(studioID, aliases) + if rf, ok := ret.Get(0).(func(context.Context, int, []string) error); ok { + r0 = rf(ctx, studioID, aliases) } else { r0 = ret.Error(0) } @@ -402,13 +404,13 @@ func (_m *StudioReaderWriter) UpdateAliases(studioID int, aliases []string) erro return r0 } -// UpdateFull provides a mock function with given fields: updatedStudio -func (_m *StudioReaderWriter) UpdateFull(updatedStudio models.Studio) (*models.Studio, error) { - ret := _m.Called(updatedStudio) +// UpdateFull provides a mock function with given fields: ctx, updatedStudio +func (_m *StudioReaderWriter) UpdateFull(ctx context.Context, updatedStudio models.Studio) (*models.Studio, error) { + ret := _m.Called(ctx, updatedStudio) var r0 *models.Studio - if rf, ok := ret.Get(0).(func(models.Studio) *models.Studio); ok { - r0 = rf(updatedStudio) + if rf, ok := ret.Get(0).(func(context.Context, models.Studio) *models.Studio); ok { + r0 = rf(ctx, updatedStudio) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Studio) @@ -416,8 +418,8 @@ func (_m *StudioReaderWriter) UpdateFull(updatedStudio models.Studio) (*models.S } var r1 error - if rf, ok := ret.Get(1).(func(models.Studio) error); ok { - r1 = rf(updatedStudio) + if rf, ok := ret.Get(1).(func(context.Context, models.Studio) error); ok { + r1 = rf(ctx, updatedStudio) } else { r1 = ret.Error(1) } @@ -425,13 +427,13 @@ func (_m *StudioReaderWriter) UpdateFull(updatedStudio models.Studio) (*models.S return r0, r1 } -// UpdateImage provides a mock function with given fields: studioID, image -func (_m *StudioReaderWriter) UpdateImage(studioID int, image []byte) error { - ret := _m.Called(studioID, image) +// UpdateImage provides a mock function with given fields: ctx, studioID, image +func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, image []byte) error { + ret := _m.Called(ctx, studioID, image) var r0 error - if rf, ok := ret.Get(0).(func(int, []byte) error); ok { - r0 = rf(studioID, image) + if rf, ok := ret.Get(0).(func(context.Context, int, []byte) error); ok { + r0 = rf(ctx, studioID, image) } else { r0 = ret.Error(0) } @@ -439,13 +441,13 @@ func (_m *StudioReaderWriter) UpdateImage(studioID int, image []byte) error { return r0 } -// UpdateStashIDs provides a mock function with given fields: studioID, stashIDs -func (_m *StudioReaderWriter) UpdateStashIDs(studioID int, stashIDs []models.StashID) error { - ret := _m.Called(studioID, stashIDs) +// UpdateStashIDs provides a mock function with given fields: ctx, studioID, stashIDs +func (_m *StudioReaderWriter) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error { + ret := _m.Called(ctx, studioID, stashIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok { - r0 = rf(studioID, stashIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok { + r0 = rf(ctx, studioID, stashIDs) } else { r0 = ret.Error(0) } diff --git a/pkg/models/mocks/TagReaderWriter.go b/pkg/models/mocks/TagReaderWriter.go index 64a8088a6..1a53adf05 100644 --- a/pkg/models/mocks/TagReaderWriter.go +++ b/pkg/models/mocks/TagReaderWriter.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + models "github.com/stashapp/stash/pkg/models" mock "github.com/stretchr/testify/mock" ) @@ -12,13 +14,13 @@ type TagReaderWriter struct { mock.Mock } -// All provides a mock function with given fields: -func (_m *TagReaderWriter) All() ([]*models.Tag, error) { - ret := _m.Called() +// All provides a mock function with given fields: ctx +func (_m *TagReaderWriter) All(ctx context.Context) ([]*models.Tag, error) { + ret := _m.Called(ctx) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func() []*models.Tag); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []*models.Tag); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -26,8 +28,8 @@ func (_m *TagReaderWriter) All() ([]*models.Tag, error) { } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -35,20 +37,20 @@ func (_m *TagReaderWriter) All() ([]*models.Tag, error) { return r0, r1 } -// Count provides a mock function with given fields: -func (_m *TagReaderWriter) Count() (int, error) { - ret := _m.Called() +// Count provides a mock function with given fields: ctx +func (_m *TagReaderWriter) Count(ctx context.Context) (int, error) { + ret := _m.Called(ctx) var r0 int - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -56,13 +58,13 @@ func (_m *TagReaderWriter) Count() (int, error) { return r0, r1 } -// Create provides a mock function with given fields: newTag -func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) { - ret := _m.Called(newTag) +// Create provides a mock function with given fields: ctx, newTag +func (_m *TagReaderWriter) Create(ctx context.Context, newTag models.Tag) (*models.Tag, error) { + ret := _m.Called(ctx, newTag) var r0 *models.Tag - if rf, ok := ret.Get(0).(func(models.Tag) *models.Tag); ok { - r0 = rf(newTag) + if rf, ok := ret.Get(0).(func(context.Context, models.Tag) *models.Tag); ok { + r0 = rf(ctx, newTag) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Tag) @@ -70,8 +72,8 @@ func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) { } var r1 error - if rf, ok := ret.Get(1).(func(models.Tag) error); ok { - r1 = rf(newTag) + if rf, ok := ret.Get(1).(func(context.Context, models.Tag) error); ok { + r1 = rf(ctx, newTag) } else { r1 = ret.Error(1) } @@ -79,13 +81,13 @@ func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) { return r0, r1 } -// Destroy provides a mock function with given fields: id -func (_m *TagReaderWriter) Destroy(id int) error { - ret := _m.Called(id) +// Destroy provides a mock function with given fields: ctx, id +func (_m *TagReaderWriter) Destroy(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -93,13 +95,13 @@ func (_m *TagReaderWriter) Destroy(id int) error { return r0 } -// DestroyImage provides a mock function with given fields: tagID -func (_m *TagReaderWriter) DestroyImage(tagID int) error { - ret := _m.Called(tagID) +// DestroyImage provides a mock function with given fields: ctx, tagID +func (_m *TagReaderWriter) DestroyImage(ctx context.Context, tagID int) error { + ret := _m.Called(ctx, tagID) var r0 error - if rf, ok := ret.Get(0).(func(int) error); ok { - r0 = rf(tagID) + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, tagID) } else { r0 = ret.Error(0) } @@ -107,13 +109,13 @@ func (_m *TagReaderWriter) DestroyImage(tagID int) error { return r0 } -// Find provides a mock function with given fields: id -func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) { - ret := _m.Called(id) +// Find provides a mock function with given fields: ctx, id +func (_m *TagReaderWriter) Find(ctx context.Context, id int) (*models.Tag, error) { + ret := _m.Called(ctx, id) var r0 *models.Tag - if rf, ok := ret.Get(0).(func(int) *models.Tag); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int) *models.Tag); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Tag) @@ -121,8 +123,8 @@ func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -130,13 +132,13 @@ func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) { return r0, r1 } -// FindAllAncestors provides a mock function with given fields: tagID, excludeIDs -func (_m *TagReaderWriter) FindAllAncestors(tagID int, excludeIDs []int) ([]*models.TagPath, error) { - ret := _m.Called(tagID, excludeIDs) +// FindAllAncestors provides a mock function with given fields: ctx, tagID, excludeIDs +func (_m *TagReaderWriter) FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) { + ret := _m.Called(ctx, tagID, excludeIDs) var r0 []*models.TagPath - if rf, ok := ret.Get(0).(func(int, []int) []*models.TagPath); ok { - r0 = rf(tagID, excludeIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) []*models.TagPath); ok { + r0 = rf(ctx, tagID, excludeIDs) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.TagPath) @@ -144,8 +146,8 @@ func (_m *TagReaderWriter) FindAllAncestors(tagID int, excludeIDs []int) ([]*mod } var r1 error - if rf, ok := ret.Get(1).(func(int, []int) error); ok { - r1 = rf(tagID, excludeIDs) + if rf, ok := ret.Get(1).(func(context.Context, int, []int) error); ok { + r1 = rf(ctx, tagID, excludeIDs) } else { r1 = ret.Error(1) } @@ -153,13 +155,13 @@ func (_m *TagReaderWriter) FindAllAncestors(tagID int, excludeIDs []int) ([]*mod return r0, r1 } -// FindAllDescendants provides a mock function with given fields: tagID, excludeIDs -func (_m *TagReaderWriter) FindAllDescendants(tagID int, excludeIDs []int) ([]*models.TagPath, error) { - ret := _m.Called(tagID, excludeIDs) +// FindAllDescendants provides a mock function with given fields: ctx, tagID, excludeIDs +func (_m *TagReaderWriter) FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) { + ret := _m.Called(ctx, tagID, excludeIDs) var r0 []*models.TagPath - if rf, ok := ret.Get(0).(func(int, []int) []*models.TagPath); ok { - r0 = rf(tagID, excludeIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) []*models.TagPath); ok { + r0 = rf(ctx, tagID, excludeIDs) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.TagPath) @@ -167,8 +169,8 @@ func (_m *TagReaderWriter) FindAllDescendants(tagID int, excludeIDs []int) ([]*m } var r1 error - if rf, ok := ret.Get(1).(func(int, []int) error); ok { - r1 = rf(tagID, excludeIDs) + if rf, ok := ret.Get(1).(func(context.Context, int, []int) error); ok { + r1 = rf(ctx, tagID, excludeIDs) } else { r1 = ret.Error(1) } @@ -176,13 +178,13 @@ func (_m *TagReaderWriter) FindAllDescendants(tagID int, excludeIDs []int) ([]*m return r0, r1 } -// FindByChildTagID provides a mock function with given fields: childID -func (_m *TagReaderWriter) FindByChildTagID(childID int) ([]*models.Tag, error) { - ret := _m.Called(childID) +// FindByChildTagID provides a mock function with given fields: ctx, childID +func (_m *TagReaderWriter) FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error) { + ret := _m.Called(ctx, childID) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { - r0 = rf(childID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok { + r0 = rf(ctx, childID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -190,8 +192,8 @@ func (_m *TagReaderWriter) FindByChildTagID(childID int) ([]*models.Tag, error) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(childID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, childID) } else { r1 = ret.Error(1) } @@ -199,13 +201,13 @@ func (_m *TagReaderWriter) FindByChildTagID(childID int) ([]*models.Tag, error) return r0, r1 } -// FindByGalleryID provides a mock function with given fields: galleryID -func (_m *TagReaderWriter) FindByGalleryID(galleryID int) ([]*models.Tag, error) { - ret := _m.Called(galleryID) +// FindByGalleryID provides a mock function with given fields: ctx, galleryID +func (_m *TagReaderWriter) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Tag, error) { + ret := _m.Called(ctx, galleryID) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { - r0 = rf(galleryID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok { + r0 = rf(ctx, galleryID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -213,8 +215,8 @@ func (_m *TagReaderWriter) FindByGalleryID(galleryID int) ([]*models.Tag, error) } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(galleryID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, galleryID) } else { r1 = ret.Error(1) } @@ -222,13 +224,13 @@ func (_m *TagReaderWriter) FindByGalleryID(galleryID int) ([]*models.Tag, error) return r0, r1 } -// FindByImageID provides a mock function with given fields: imageID -func (_m *TagReaderWriter) FindByImageID(imageID int) ([]*models.Tag, error) { - ret := _m.Called(imageID) +// FindByImageID provides a mock function with given fields: ctx, imageID +func (_m *TagReaderWriter) FindByImageID(ctx context.Context, imageID int) ([]*models.Tag, error) { + ret := _m.Called(ctx, imageID) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { - r0 = rf(imageID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok { + r0 = rf(ctx, imageID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -236,8 +238,8 @@ func (_m *TagReaderWriter) FindByImageID(imageID int) ([]*models.Tag, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(imageID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, imageID) } else { r1 = ret.Error(1) } @@ -245,13 +247,13 @@ func (_m *TagReaderWriter) FindByImageID(imageID int) ([]*models.Tag, error) { return r0, r1 } -// FindByName provides a mock function with given fields: name, nocase -func (_m *TagReaderWriter) FindByName(name string, nocase bool) (*models.Tag, error) { - ret := _m.Called(name, nocase) +// FindByName provides a mock function with given fields: ctx, name, nocase +func (_m *TagReaderWriter) FindByName(ctx context.Context, name string, nocase bool) (*models.Tag, error) { + ret := _m.Called(ctx, name, nocase) var r0 *models.Tag - if rf, ok := ret.Get(0).(func(string, bool) *models.Tag); ok { - r0 = rf(name, nocase) + if rf, ok := ret.Get(0).(func(context.Context, string, bool) *models.Tag); ok { + r0 = rf(ctx, name, nocase) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Tag) @@ -259,8 +261,8 @@ func (_m *TagReaderWriter) FindByName(name string, nocase bool) (*models.Tag, er } var r1 error - if rf, ok := ret.Get(1).(func(string, bool) error); ok { - r1 = rf(name, nocase) + if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok { + r1 = rf(ctx, name, nocase) } else { r1 = ret.Error(1) } @@ -268,13 +270,13 @@ func (_m *TagReaderWriter) FindByName(name string, nocase bool) (*models.Tag, er return r0, r1 } -// FindByNames provides a mock function with given fields: names, nocase -func (_m *TagReaderWriter) FindByNames(names []string, nocase bool) ([]*models.Tag, error) { - ret := _m.Called(names, nocase) +// FindByNames provides a mock function with given fields: ctx, names, nocase +func (_m *TagReaderWriter) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Tag, error) { + ret := _m.Called(ctx, names, nocase) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func([]string, bool) []*models.Tag); ok { - r0 = rf(names, nocase) + if rf, ok := ret.Get(0).(func(context.Context, []string, bool) []*models.Tag); ok { + r0 = rf(ctx, names, nocase) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -282,8 +284,8 @@ func (_m *TagReaderWriter) FindByNames(names []string, nocase bool) ([]*models.T } var r1 error - if rf, ok := ret.Get(1).(func([]string, bool) error); ok { - r1 = rf(names, nocase) + if rf, ok := ret.Get(1).(func(context.Context, []string, bool) error); ok { + r1 = rf(ctx, names, nocase) } else { r1 = ret.Error(1) } @@ -291,13 +293,13 @@ func (_m *TagReaderWriter) FindByNames(names []string, nocase bool) ([]*models.T return r0, r1 } -// FindByParentTagID provides a mock function with given fields: parentID -func (_m *TagReaderWriter) FindByParentTagID(parentID int) ([]*models.Tag, error) { - ret := _m.Called(parentID) +// FindByParentTagID provides a mock function with given fields: ctx, parentID +func (_m *TagReaderWriter) FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error) { + ret := _m.Called(ctx, parentID) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { - r0 = rf(parentID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok { + r0 = rf(ctx, parentID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -305,8 +307,8 @@ func (_m *TagReaderWriter) FindByParentTagID(parentID int) ([]*models.Tag, error } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(parentID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, parentID) } else { r1 = ret.Error(1) } @@ -314,13 +316,13 @@ func (_m *TagReaderWriter) FindByParentTagID(parentID int) ([]*models.Tag, error return r0, r1 } -// FindByPerformerID provides a mock function with given fields: performerID -func (_m *TagReaderWriter) FindByPerformerID(performerID int) ([]*models.Tag, error) { - ret := _m.Called(performerID) +// FindByPerformerID provides a mock function with given fields: ctx, performerID +func (_m *TagReaderWriter) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Tag, error) { + ret := _m.Called(ctx, performerID) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { - r0 = rf(performerID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok { + r0 = rf(ctx, performerID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -328,8 +330,8 @@ func (_m *TagReaderWriter) FindByPerformerID(performerID int) ([]*models.Tag, er } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(performerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, performerID) } else { r1 = ret.Error(1) } @@ -337,13 +339,13 @@ func (_m *TagReaderWriter) FindByPerformerID(performerID int) ([]*models.Tag, er return r0, r1 } -// FindBySceneID provides a mock function with given fields: sceneID -func (_m *TagReaderWriter) FindBySceneID(sceneID int) ([]*models.Tag, error) { - ret := _m.Called(sceneID) +// FindBySceneID provides a mock function with given fields: ctx, sceneID +func (_m *TagReaderWriter) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Tag, error) { + ret := _m.Called(ctx, sceneID) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { - r0 = rf(sceneID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok { + r0 = rf(ctx, sceneID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -351,8 +353,8 @@ func (_m *TagReaderWriter) FindBySceneID(sceneID int) ([]*models.Tag, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneID) } else { r1 = ret.Error(1) } @@ -360,13 +362,13 @@ func (_m *TagReaderWriter) FindBySceneID(sceneID int) ([]*models.Tag, error) { return r0, r1 } -// FindBySceneMarkerID provides a mock function with given fields: sceneMarkerID -func (_m *TagReaderWriter) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag, error) { - ret := _m.Called(sceneMarkerID) +// FindBySceneMarkerID provides a mock function with given fields: ctx, sceneMarkerID +func (_m *TagReaderWriter) FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*models.Tag, error) { + ret := _m.Called(ctx, sceneMarkerID) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { - r0 = rf(sceneMarkerID) + if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok { + r0 = rf(ctx, sceneMarkerID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -374,8 +376,8 @@ func (_m *TagReaderWriter) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneMarkerID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, sceneMarkerID) } else { r1 = ret.Error(1) } @@ -383,13 +385,13 @@ func (_m *TagReaderWriter) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag return r0, r1 } -// FindMany provides a mock function with given fields: ids -func (_m *TagReaderWriter) FindMany(ids []int) ([]*models.Tag, error) { - ret := _m.Called(ids) +// FindMany provides a mock function with given fields: ctx, ids +func (_m *TagReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Tag, error) { + ret := _m.Called(ctx, ids) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func([]int) []*models.Tag); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Tag); ok { + r0 = rf(ctx, ids) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -397,8 +399,8 @@ func (_m *TagReaderWriter) FindMany(ids []int) ([]*models.Tag, error) { } var r1 error - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok { + r1 = rf(ctx, ids) } else { r1 = ret.Error(1) } @@ -406,13 +408,13 @@ func (_m *TagReaderWriter) FindMany(ids []int) ([]*models.Tag, error) { return r0, r1 } -// GetAliases provides a mock function with given fields: tagID -func (_m *TagReaderWriter) GetAliases(tagID int) ([]string, error) { - ret := _m.Called(tagID) +// GetAliases provides a mock function with given fields: ctx, tagID +func (_m *TagReaderWriter) GetAliases(ctx context.Context, tagID int) ([]string, error) { + ret := _m.Called(ctx, tagID) var r0 []string - if rf, ok := ret.Get(0).(func(int) []string); ok { - r0 = rf(tagID) + if rf, ok := ret.Get(0).(func(context.Context, int) []string); ok { + r0 = rf(ctx, tagID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]string) @@ -420,8 +422,8 @@ func (_m *TagReaderWriter) GetAliases(tagID int) ([]string, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(tagID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, tagID) } else { r1 = ret.Error(1) } @@ -429,13 +431,13 @@ func (_m *TagReaderWriter) GetAliases(tagID int) ([]string, error) { return r0, r1 } -// GetImage provides a mock function with given fields: tagID -func (_m *TagReaderWriter) GetImage(tagID int) ([]byte, error) { - ret := _m.Called(tagID) +// GetImage provides a mock function with given fields: ctx, tagID +func (_m *TagReaderWriter) GetImage(ctx context.Context, tagID int) ([]byte, error) { + ret := _m.Called(ctx, tagID) var r0 []byte - if rf, ok := ret.Get(0).(func(int) []byte); ok { - r0 = rf(tagID) + if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok { + r0 = rf(ctx, tagID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) @@ -443,8 +445,8 @@ func (_m *TagReaderWriter) GetImage(tagID int) ([]byte, error) { } var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(tagID) + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, tagID) } else { r1 = ret.Error(1) } @@ -452,13 +454,13 @@ func (_m *TagReaderWriter) GetImage(tagID int) ([]byte, error) { return r0, r1 } -// Merge provides a mock function with given fields: source, destination -func (_m *TagReaderWriter) Merge(source []int, destination int) error { - ret := _m.Called(source, destination) +// Merge provides a mock function with given fields: ctx, source, destination +func (_m *TagReaderWriter) Merge(ctx context.Context, source []int, destination int) error { + ret := _m.Called(ctx, source, destination) var r0 error - if rf, ok := ret.Get(0).(func([]int, int) error); ok { - r0 = rf(source, destination) + if rf, ok := ret.Get(0).(func(context.Context, []int, int) error); ok { + r0 = rf(ctx, source, destination) } else { r0 = ret.Error(0) } @@ -466,13 +468,13 @@ func (_m *TagReaderWriter) Merge(source []int, destination int) error { return r0 } -// Query provides a mock function with given fields: tagFilter, findFilter -func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) { - ret := _m.Called(tagFilter, findFilter) +// Query provides a mock function with given fields: ctx, tagFilter, findFilter +func (_m *TagReaderWriter) Query(ctx context.Context, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) { + ret := _m.Called(ctx, tagFilter, findFilter) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func(*models.TagFilterType, *models.FindFilterType) []*models.Tag); ok { - r0 = rf(tagFilter, findFilter) + if rf, ok := ret.Get(0).(func(context.Context, *models.TagFilterType, *models.FindFilterType) []*models.Tag); ok { + r0 = rf(ctx, tagFilter, findFilter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -480,15 +482,15 @@ func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *mo } var r1 int - if rf, ok := ret.Get(1).(func(*models.TagFilterType, *models.FindFilterType) int); ok { - r1 = rf(tagFilter, findFilter) + if rf, ok := ret.Get(1).(func(context.Context, *models.TagFilterType, *models.FindFilterType) int); ok { + r1 = rf(ctx, tagFilter, findFilter) } else { r1 = ret.Get(1).(int) } var r2 error - if rf, ok := ret.Get(2).(func(*models.TagFilterType, *models.FindFilterType) error); ok { - r2 = rf(tagFilter, findFilter) + if rf, ok := ret.Get(2).(func(context.Context, *models.TagFilterType, *models.FindFilterType) error); ok { + r2 = rf(ctx, tagFilter, findFilter) } else { r2 = ret.Error(2) } @@ -496,13 +498,13 @@ func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *mo return r0, r1, r2 } -// QueryForAutoTag provides a mock function with given fields: words -func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error) { - ret := _m.Called(words) +// QueryForAutoTag provides a mock function with given fields: ctx, words +func (_m *TagReaderWriter) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Tag, error) { + ret := _m.Called(ctx, words) var r0 []*models.Tag - if rf, ok := ret.Get(0).(func([]string) []*models.Tag); ok { - r0 = rf(words) + if rf, ok := ret.Get(0).(func(context.Context, []string) []*models.Tag); ok { + r0 = rf(ctx, words) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*models.Tag) @@ -510,8 +512,8 @@ func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error } var r1 error - if rf, ok := ret.Get(1).(func([]string) error); ok { - r1 = rf(words) + if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { + r1 = rf(ctx, words) } else { r1 = ret.Error(1) } @@ -519,13 +521,13 @@ func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error return r0, r1 } -// Update provides a mock function with given fields: updateTag -func (_m *TagReaderWriter) Update(updateTag models.TagPartial) (*models.Tag, error) { - ret := _m.Called(updateTag) +// Update provides a mock function with given fields: ctx, updateTag +func (_m *TagReaderWriter) Update(ctx context.Context, updateTag models.TagPartial) (*models.Tag, error) { + ret := _m.Called(ctx, updateTag) var r0 *models.Tag - if rf, ok := ret.Get(0).(func(models.TagPartial) *models.Tag); ok { - r0 = rf(updateTag) + if rf, ok := ret.Get(0).(func(context.Context, models.TagPartial) *models.Tag); ok { + r0 = rf(ctx, updateTag) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Tag) @@ -533,8 +535,8 @@ func (_m *TagReaderWriter) Update(updateTag models.TagPartial) (*models.Tag, err } var r1 error - if rf, ok := ret.Get(1).(func(models.TagPartial) error); ok { - r1 = rf(updateTag) + if rf, ok := ret.Get(1).(func(context.Context, models.TagPartial) error); ok { + r1 = rf(ctx, updateTag) } else { r1 = ret.Error(1) } @@ -542,13 +544,13 @@ func (_m *TagReaderWriter) Update(updateTag models.TagPartial) (*models.Tag, err return r0, r1 } -// UpdateAliases provides a mock function with given fields: tagID, aliases -func (_m *TagReaderWriter) UpdateAliases(tagID int, aliases []string) error { - ret := _m.Called(tagID, aliases) +// UpdateAliases provides a mock function with given fields: ctx, tagID, aliases +func (_m *TagReaderWriter) UpdateAliases(ctx context.Context, tagID int, aliases []string) error { + ret := _m.Called(ctx, tagID, aliases) var r0 error - if rf, ok := ret.Get(0).(func(int, []string) error); ok { - r0 = rf(tagID, aliases) + if rf, ok := ret.Get(0).(func(context.Context, int, []string) error); ok { + r0 = rf(ctx, tagID, aliases) } else { r0 = ret.Error(0) } @@ -556,13 +558,13 @@ func (_m *TagReaderWriter) UpdateAliases(tagID int, aliases []string) error { return r0 } -// UpdateChildTags provides a mock function with given fields: tagID, parentIDs -func (_m *TagReaderWriter) UpdateChildTags(tagID int, parentIDs []int) error { - ret := _m.Called(tagID, parentIDs) +// UpdateChildTags provides a mock function with given fields: ctx, tagID, parentIDs +func (_m *TagReaderWriter) UpdateChildTags(ctx context.Context, tagID int, parentIDs []int) error { + ret := _m.Called(ctx, tagID, parentIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(tagID, parentIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, tagID, parentIDs) } else { r0 = ret.Error(0) } @@ -570,13 +572,13 @@ func (_m *TagReaderWriter) UpdateChildTags(tagID int, parentIDs []int) error { return r0 } -// UpdateFull provides a mock function with given fields: updatedTag -func (_m *TagReaderWriter) UpdateFull(updatedTag models.Tag) (*models.Tag, error) { - ret := _m.Called(updatedTag) +// UpdateFull provides a mock function with given fields: ctx, updatedTag +func (_m *TagReaderWriter) UpdateFull(ctx context.Context, updatedTag models.Tag) (*models.Tag, error) { + ret := _m.Called(ctx, updatedTag) var r0 *models.Tag - if rf, ok := ret.Get(0).(func(models.Tag) *models.Tag); ok { - r0 = rf(updatedTag) + if rf, ok := ret.Get(0).(func(context.Context, models.Tag) *models.Tag); ok { + r0 = rf(ctx, updatedTag) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*models.Tag) @@ -584,8 +586,8 @@ func (_m *TagReaderWriter) UpdateFull(updatedTag models.Tag) (*models.Tag, error } var r1 error - if rf, ok := ret.Get(1).(func(models.Tag) error); ok { - r1 = rf(updatedTag) + if rf, ok := ret.Get(1).(func(context.Context, models.Tag) error); ok { + r1 = rf(ctx, updatedTag) } else { r1 = ret.Error(1) } @@ -593,13 +595,13 @@ func (_m *TagReaderWriter) UpdateFull(updatedTag models.Tag) (*models.Tag, error return r0, r1 } -// UpdateImage provides a mock function with given fields: tagID, image -func (_m *TagReaderWriter) UpdateImage(tagID int, image []byte) error { - ret := _m.Called(tagID, image) +// UpdateImage provides a mock function with given fields: ctx, tagID, image +func (_m *TagReaderWriter) UpdateImage(ctx context.Context, tagID int, image []byte) error { + ret := _m.Called(ctx, tagID, image) var r0 error - if rf, ok := ret.Get(0).(func(int, []byte) error); ok { - r0 = rf(tagID, image) + if rf, ok := ret.Get(0).(func(context.Context, int, []byte) error); ok { + r0 = rf(ctx, tagID, image) } else { r0 = ret.Error(0) } @@ -607,13 +609,13 @@ func (_m *TagReaderWriter) UpdateImage(tagID int, image []byte) error { return r0 } -// UpdateParentTags provides a mock function with given fields: tagID, parentIDs -func (_m *TagReaderWriter) UpdateParentTags(tagID int, parentIDs []int) error { - ret := _m.Called(tagID, parentIDs) +// UpdateParentTags provides a mock function with given fields: ctx, tagID, parentIDs +func (_m *TagReaderWriter) UpdateParentTags(ctx context.Context, tagID int, parentIDs []int) error { + ret := _m.Called(ctx, tagID, parentIDs) var r0 error - if rf, ok := ret.Get(0).(func(int, []int) error); ok { - r0 = rf(tagID, parentIDs) + if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok { + r0 = rf(ctx, tagID, parentIDs) } else { r0 = ret.Error(0) } diff --git a/pkg/models/mocks/query.go b/pkg/models/mocks/query.go index 152335fc2..346bd1e55 100644 --- a/pkg/models/mocks/query.go +++ b/pkg/models/mocks/query.go @@ -1,16 +1,20 @@ package mocks -import "github.com/stashapp/stash/pkg/models" +import ( + context "context" + + "github.com/stashapp/stash/pkg/models" +) type sceneResolver struct { scenes []*models.Scene } -func (s *sceneResolver) Find(id int) (*models.Scene, error) { +func (s *sceneResolver) Find(ctx context.Context, id int) (*models.Scene, error) { panic("not implemented") } -func (s *sceneResolver) FindMany(ids []int) ([]*models.Scene, error) { +func (s *sceneResolver) FindMany(ctx context.Context, ids []int) ([]*models.Scene, error) { return s.scenes, nil } @@ -27,7 +31,7 @@ type imageResolver struct { images []*models.Image } -func (s *imageResolver) FindMany(ids []int) ([]*models.Image, error) { +func (s *imageResolver) FindMany(ctx context.Context, ids []int) ([]*models.Image, error) { return s.images, nil } diff --git a/pkg/models/mocks/transaction.go b/pkg/models/mocks/transaction.go index 886fef7d6..ab5c7dba3 100644 --- a/pkg/models/mocks/transaction.go +++ b/pkg/models/mocks/transaction.go @@ -1,167 +1,41 @@ package mocks import ( - "context" + context "context" - models "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models" ) -type TransactionManager struct { - gallery *GalleryReaderWriter - image *ImageReaderWriter - movie *MovieReaderWriter - performer *PerformerReaderWriter - scene *SceneReaderWriter - sceneMarker *SceneMarkerReaderWriter - scrapedItem *ScrapedItemReaderWriter - studio *StudioReaderWriter - tag *TagReaderWriter - savedFilter *SavedFilterReaderWriter +type TxnManager struct{} + +func (*TxnManager) Begin(ctx context.Context) (context.Context, error) { + return ctx, nil } -func NewTransactionManager() *TransactionManager { - return &TransactionManager{ - gallery: &GalleryReaderWriter{}, - image: &ImageReaderWriter{}, - movie: &MovieReaderWriter{}, - performer: &PerformerReaderWriter{}, - scene: &SceneReaderWriter{}, - sceneMarker: &SceneMarkerReaderWriter{}, - scrapedItem: &ScrapedItemReaderWriter{}, - studio: &StudioReaderWriter{}, - tag: &TagReaderWriter{}, - savedFilter: &SavedFilterReaderWriter{}, +func (*TxnManager) Commit(ctx context.Context) error { + return nil +} + +func (*TxnManager) Rollback(ctx context.Context) error { + return nil +} + +func (*TxnManager) Reset() error { + return nil +} + +func NewTxnRepository() models.Repository { + return models.Repository{ + TxnManager: &TxnManager{}, + Gallery: &GalleryReaderWriter{}, + Image: &ImageReaderWriter{}, + Movie: &MovieReaderWriter{}, + Performer: &PerformerReaderWriter{}, + Scene: &SceneReaderWriter{}, + SceneMarker: &SceneMarkerReaderWriter{}, + ScrapedItem: &ScrapedItemReaderWriter{}, + Studio: &StudioReaderWriter{}, + Tag: &TagReaderWriter{}, + SavedFilter: &SavedFilterReaderWriter{}, } } - -func (t *TransactionManager) WithTxn(ctx context.Context, fn func(r models.Repository) error) error { - return fn(t) -} - -func (t *TransactionManager) GalleryMock() *GalleryReaderWriter { - return t.gallery -} - -func (t *TransactionManager) ImageMock() *ImageReaderWriter { - return t.image -} - -func (t *TransactionManager) MovieMock() *MovieReaderWriter { - return t.movie -} - -func (t *TransactionManager) PerformerMock() *PerformerReaderWriter { - return t.performer -} - -func (t *TransactionManager) SceneMarkerMock() *SceneMarkerReaderWriter { - return t.sceneMarker -} - -func (t *TransactionManager) SceneMock() *SceneReaderWriter { - return t.scene -} - -func (t *TransactionManager) ScrapedItemMock() *ScrapedItemReaderWriter { - return t.scrapedItem -} - -func (t *TransactionManager) StudioMock() *StudioReaderWriter { - return t.studio -} - -func (t *TransactionManager) TagMock() *TagReaderWriter { - return t.tag -} - -func (t *TransactionManager) SavedFilterMock() *SavedFilterReaderWriter { - return t.savedFilter -} - -func (t *TransactionManager) Gallery() models.GalleryReaderWriter { - return t.GalleryMock() -} - -func (t *TransactionManager) Image() models.ImageReaderWriter { - return t.ImageMock() -} - -func (t *TransactionManager) Movie() models.MovieReaderWriter { - return t.MovieMock() -} - -func (t *TransactionManager) Performer() models.PerformerReaderWriter { - return t.PerformerMock() -} - -func (t *TransactionManager) SceneMarker() models.SceneMarkerReaderWriter { - return t.SceneMarkerMock() -} - -func (t *TransactionManager) Scene() models.SceneReaderWriter { - return t.SceneMock() -} - -func (t *TransactionManager) ScrapedItem() models.ScrapedItemReaderWriter { - return t.ScrapedItemMock() -} - -func (t *TransactionManager) Studio() models.StudioReaderWriter { - return t.StudioMock() -} - -func (t *TransactionManager) Tag() models.TagReaderWriter { - return t.TagMock() -} - -func (t *TransactionManager) SavedFilter() models.SavedFilterReaderWriter { - return t.SavedFilterMock() -} - -type ReadTransaction struct { - *TransactionManager -} - -func (t *TransactionManager) WithReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error { - return fn(&ReadTransaction{t}) -} - -func (r *ReadTransaction) Gallery() models.GalleryReader { - return r.GalleryMock() -} - -func (r *ReadTransaction) Image() models.ImageReader { - return r.ImageMock() -} - -func (r *ReadTransaction) Movie() models.MovieReader { - return r.MovieMock() -} - -func (r *ReadTransaction) Performer() models.PerformerReader { - return r.PerformerMock() -} - -func (r *ReadTransaction) SceneMarker() models.SceneMarkerReader { - return r.SceneMarkerMock() -} - -func (r *ReadTransaction) Scene() models.SceneReader { - return r.SceneMock() -} - -func (r *ReadTransaction) ScrapedItem() models.ScrapedItemReader { - return r.ScrapedItemMock() -} - -func (r *ReadTransaction) Studio() models.StudioReader { - return r.StudioMock() -} - -func (r *ReadTransaction) Tag() models.TagReader { - return r.TagMock() -} - -func (r *ReadTransaction) SavedFilter() models.SavedFilterReader { - return r.SavedFilterMock() -} diff --git a/pkg/models/movie.go b/pkg/models/movie.go index 8b68217ce..3fc1890a6 100644 --- a/pkg/models/movie.go +++ b/pkg/models/movie.go @@ -1,5 +1,7 @@ package models +import "context" + type MovieFilterType struct { Name *StringCriterionInput `json:"name"` Director *StringCriterionInput `json:"director"` @@ -19,29 +21,29 @@ type MovieFilterType struct { } type MovieReader interface { - Find(id int) (*Movie, error) - FindMany(ids []int) ([]*Movie, error) + Find(ctx context.Context, id int) (*Movie, error) + FindMany(ctx context.Context, ids []int) ([]*Movie, error) // FindBySceneID(sceneID int) ([]*Movie, error) - FindByName(name string, nocase bool) (*Movie, error) - FindByNames(names []string, nocase bool) ([]*Movie, error) - All() ([]*Movie, error) - Count() (int, error) - Query(movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int, error) - GetFrontImage(movieID int) ([]byte, error) - GetBackImage(movieID int) ([]byte, error) - FindByPerformerID(performerID int) ([]*Movie, error) - CountByPerformerID(performerID int) (int, error) - FindByStudioID(studioID int) ([]*Movie, error) - CountByStudioID(studioID int) (int, error) + FindByName(ctx context.Context, name string, nocase bool) (*Movie, error) + FindByNames(ctx context.Context, names []string, nocase bool) ([]*Movie, error) + All(ctx context.Context) ([]*Movie, error) + Count(ctx context.Context) (int, error) + Query(ctx context.Context, movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int, error) + GetFrontImage(ctx context.Context, movieID int) ([]byte, error) + GetBackImage(ctx context.Context, movieID int) ([]byte, error) + FindByPerformerID(ctx context.Context, performerID int) ([]*Movie, error) + CountByPerformerID(ctx context.Context, performerID int) (int, error) + FindByStudioID(ctx context.Context, studioID int) ([]*Movie, error) + CountByStudioID(ctx context.Context, studioID int) (int, error) } type MovieWriter interface { - Create(newMovie Movie) (*Movie, error) - Update(updatedMovie MoviePartial) (*Movie, error) - UpdateFull(updatedMovie Movie) (*Movie, error) - Destroy(id int) error - UpdateImages(movieID int, frontImage []byte, backImage []byte) error - DestroyImages(movieID int) error + Create(ctx context.Context, newMovie Movie) (*Movie, error) + Update(ctx context.Context, updatedMovie MoviePartial) (*Movie, error) + UpdateFull(ctx context.Context, updatedMovie Movie) (*Movie, error) + Destroy(ctx context.Context, id int) error + UpdateImages(ctx context.Context, movieID int, frontImage []byte, backImage []byte) error + DestroyImages(ctx context.Context, movieID int) error } type MovieReaderWriter interface { diff --git a/pkg/models/performer.go b/pkg/models/performer.go index d50c360a4..1bf3ec918 100644 --- a/pkg/models/performer.go +++ b/pkg/models/performer.go @@ -1,6 +1,7 @@ package models import ( + "context" "fmt" "io" "strconv" @@ -125,36 +126,36 @@ type PerformerFilterType struct { } type PerformerReader interface { - Find(id int) (*Performer, error) - FindMany(ids []int) ([]*Performer, error) - FindBySceneID(sceneID int) ([]*Performer, error) - FindNamesBySceneID(sceneID int) ([]*Performer, error) - FindByImageID(imageID int) ([]*Performer, error) - FindByGalleryID(galleryID int) ([]*Performer, error) - FindByNames(names []string, nocase bool) ([]*Performer, error) - FindByStashID(stashID StashID) ([]*Performer, error) - FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*Performer, error) - CountByTagID(tagID int) (int, error) - Count() (int, error) - All() ([]*Performer, error) + Find(ctx context.Context, id int) (*Performer, error) + FindMany(ctx context.Context, ids []int) ([]*Performer, error) + FindBySceneID(ctx context.Context, sceneID int) ([]*Performer, error) + FindNamesBySceneID(ctx context.Context, sceneID int) ([]*Performer, error) + FindByImageID(ctx context.Context, imageID int) ([]*Performer, error) + FindByGalleryID(ctx context.Context, galleryID int) ([]*Performer, error) + FindByNames(ctx context.Context, names []string, nocase bool) ([]*Performer, error) + FindByStashID(ctx context.Context, stashID StashID) ([]*Performer, error) + FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*Performer, error) + CountByTagID(ctx context.Context, tagID int) (int, error) + Count(ctx context.Context) (int, error) + All(ctx context.Context) ([]*Performer, error) // TODO - this interface is temporary until the filter schema can fully // support the query needed - QueryForAutoTag(words []string) ([]*Performer, error) - Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error) - GetImage(performerID int) ([]byte, error) - GetStashIDs(performerID int) ([]*StashID, error) - GetTagIDs(performerID int) ([]int, error) + QueryForAutoTag(ctx context.Context, words []string) ([]*Performer, error) + Query(ctx context.Context, performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error) + GetImage(ctx context.Context, performerID int) ([]byte, error) + GetStashIDs(ctx context.Context, performerID int) ([]*StashID, error) + GetTagIDs(ctx context.Context, performerID int) ([]int, error) } type PerformerWriter interface { - Create(newPerformer Performer) (*Performer, error) - Update(updatedPerformer PerformerPartial) (*Performer, error) - UpdateFull(updatedPerformer Performer) (*Performer, error) - Destroy(id int) error - UpdateImage(performerID int, image []byte) error - DestroyImage(performerID int) error - UpdateStashIDs(performerID int, stashIDs []StashID) error - UpdateTags(performerID int, tagIDs []int) error + Create(ctx context.Context, newPerformer Performer) (*Performer, error) + Update(ctx context.Context, updatedPerformer PerformerPartial) (*Performer, error) + UpdateFull(ctx context.Context, updatedPerformer Performer) (*Performer, error) + Destroy(ctx context.Context, id int) error + UpdateImage(ctx context.Context, performerID int, image []byte) error + DestroyImage(ctx context.Context, performerID int) error + UpdateStashIDs(ctx context.Context, performerID int, stashIDs []StashID) error + UpdateTags(ctx context.Context, performerID int, tagIDs []int) error } type PerformerReaderWriter interface { diff --git a/pkg/models/repository.go b/pkg/models/repository.go index 2686d7c3a..0056ccad3 100644 --- a/pkg/models/repository.go +++ b/pkg/models/repository.go @@ -1,27 +1,31 @@ package models -type Repository interface { - Gallery() GalleryReaderWriter - Image() ImageReaderWriter - Movie() MovieReaderWriter - Performer() PerformerReaderWriter - Scene() SceneReaderWriter - SceneMarker() SceneMarkerReaderWriter - ScrapedItem() ScrapedItemReaderWriter - Studio() StudioReaderWriter - Tag() TagReaderWriter - SavedFilter() SavedFilterReaderWriter +import ( + "context" + + "github.com/stashapp/stash/pkg/txn" +) + +type TxnManager interface { + txn.Manager + Reset() error } -type ReaderRepository interface { - Gallery() GalleryReader - Image() ImageReader - Movie() MovieReader - Performer() PerformerReader - Scene() SceneReader - SceneMarker() SceneMarkerReader - ScrapedItem() ScrapedItemReader - Studio() StudioReader - Tag() TagReader - SavedFilter() SavedFilterReader +type Repository struct { + TxnManager + + Gallery GalleryReaderWriter + Image ImageReaderWriter + Movie MovieReaderWriter + Performer PerformerReaderWriter + Scene SceneReaderWriter + SceneMarker SceneMarkerReaderWriter + ScrapedItem ScrapedItemReaderWriter + Studio StudioReaderWriter + Tag TagReaderWriter + SavedFilter SavedFilterReaderWriter +} + +func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error { + return txn.WithTxn(ctx, r, fn) } diff --git a/pkg/models/saved_filter.go b/pkg/models/saved_filter.go index e6cd2f8e0..10dd4af36 100644 --- a/pkg/models/saved_filter.go +++ b/pkg/models/saved_filter.go @@ -1,18 +1,20 @@ package models +import "context" + type SavedFilterReader interface { - All() ([]*SavedFilter, error) - Find(id int) (*SavedFilter, error) - FindMany(ids []int, ignoreNotFound bool) ([]*SavedFilter, error) - FindByMode(mode FilterMode) ([]*SavedFilter, error) - FindDefault(mode FilterMode) (*SavedFilter, error) + All(ctx context.Context) ([]*SavedFilter, error) + Find(ctx context.Context, id int) (*SavedFilter, error) + FindMany(ctx context.Context, ids []int, ignoreNotFound bool) ([]*SavedFilter, error) + FindByMode(ctx context.Context, mode FilterMode) ([]*SavedFilter, error) + FindDefault(ctx context.Context, mode FilterMode) (*SavedFilter, error) } type SavedFilterWriter interface { - Create(obj SavedFilter) (*SavedFilter, error) - Update(obj SavedFilter) (*SavedFilter, error) - SetDefault(obj SavedFilter) (*SavedFilter, error) - Destroy(id int) error + Create(ctx context.Context, obj SavedFilter) (*SavedFilter, error) + Update(ctx context.Context, obj SavedFilter) (*SavedFilter, error) + SetDefault(ctx context.Context, obj SavedFilter) (*SavedFilter, error) + Destroy(ctx context.Context, id int) error } type SavedFilterReaderWriter interface { diff --git a/pkg/models/scene.go b/pkg/models/scene.go index 6940abe84..4c6c3aabe 100644 --- a/pkg/models/scene.go +++ b/pkg/models/scene.go @@ -1,5 +1,7 @@ package models +import "context" + type PHashDuplicationCriterionInput struct { Duplicated *bool `json:"duplicated"` // Currently unimplemented @@ -102,70 +104,70 @@ func NewSceneQueryResult(finder SceneFinder) *SceneQueryResult { } } -func (r *SceneQueryResult) Resolve() ([]*Scene, error) { +func (r *SceneQueryResult) Resolve(ctx context.Context) ([]*Scene, error) { // cache results if r.scenes == nil && r.resolveErr == nil { - r.scenes, r.resolveErr = r.finder.FindMany(r.IDs) + r.scenes, r.resolveErr = r.finder.FindMany(ctx, r.IDs) } return r.scenes, r.resolveErr } type SceneFinder interface { // TODO - rename this to Find and remove existing method - FindMany(ids []int) ([]*Scene, error) + FindMany(ctx context.Context, ids []int) ([]*Scene, error) } type SceneReader interface { SceneFinder // TODO - remove this in another PR - Find(id int) (*Scene, error) - FindByChecksum(checksum string) (*Scene, error) - FindByOSHash(oshash string) (*Scene, error) - FindByPath(path string) (*Scene, error) - FindByPerformerID(performerID int) ([]*Scene, error) - FindByGalleryID(performerID int) ([]*Scene, error) - FindDuplicates(distance int) ([][]*Scene, error) - CountByPerformerID(performerID int) (int, error) + Find(ctx context.Context, id int) (*Scene, error) + FindByChecksum(ctx context.Context, checksum string) (*Scene, error) + FindByOSHash(ctx context.Context, oshash string) (*Scene, error) + FindByPath(ctx context.Context, path string) (*Scene, error) + FindByPerformerID(ctx context.Context, performerID int) ([]*Scene, error) + FindByGalleryID(ctx context.Context, performerID int) ([]*Scene, error) + FindDuplicates(ctx context.Context, distance int) ([][]*Scene, error) + CountByPerformerID(ctx context.Context, performerID int) (int, error) // FindByStudioID(studioID int) ([]*Scene, error) - FindByMovieID(movieID int) ([]*Scene, error) - CountByMovieID(movieID int) (int, error) - Count() (int, error) - Size() (float64, error) - Duration() (float64, error) + FindByMovieID(ctx context.Context, movieID int) ([]*Scene, error) + CountByMovieID(ctx context.Context, movieID int) (int, error) + Count(ctx context.Context) (int, error) + Size(ctx context.Context) (float64, error) + Duration(ctx context.Context) (float64, error) // SizeCount() (string, error) - CountByStudioID(studioID int) (int, error) - CountByTagID(tagID int) (int, error) - CountMissingChecksum() (int, error) - CountMissingOSHash() (int, error) - Wall(q *string) ([]*Scene, error) - All() ([]*Scene, error) - Query(options SceneQueryOptions) (*SceneQueryResult, error) - GetCaptions(sceneID int) ([]*SceneCaption, error) - GetCover(sceneID int) ([]byte, error) - GetMovies(sceneID int) ([]MoviesScenes, error) - GetTagIDs(sceneID int) ([]int, error) - GetGalleryIDs(sceneID int) ([]int, error) - GetPerformerIDs(sceneID int) ([]int, error) - GetStashIDs(sceneID int) ([]*StashID, error) + CountByStudioID(ctx context.Context, studioID int) (int, error) + CountByTagID(ctx context.Context, tagID int) (int, error) + CountMissingChecksum(ctx context.Context) (int, error) + CountMissingOSHash(ctx context.Context) (int, error) + Wall(ctx context.Context, q *string) ([]*Scene, error) + All(ctx context.Context) ([]*Scene, error) + Query(ctx context.Context, options SceneQueryOptions) (*SceneQueryResult, error) + GetCaptions(ctx context.Context, sceneID int) ([]*SceneCaption, error) + GetCover(ctx context.Context, sceneID int) ([]byte, error) + GetMovies(ctx context.Context, sceneID int) ([]MoviesScenes, error) + GetTagIDs(ctx context.Context, sceneID int) ([]int, error) + GetGalleryIDs(ctx context.Context, sceneID int) ([]int, error) + GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error) + GetStashIDs(ctx context.Context, sceneID int) ([]*StashID, error) } type SceneWriter interface { - Create(newScene Scene) (*Scene, error) - Update(updatedScene ScenePartial) (*Scene, error) - UpdateFull(updatedScene Scene) (*Scene, error) - IncrementOCounter(id int) (int, error) - DecrementOCounter(id int) (int, error) - ResetOCounter(id int) (int, error) - UpdateFileModTime(id int, modTime NullSQLiteTimestamp) error - Destroy(id int) error - UpdateCaptions(id int, captions []*SceneCaption) error - UpdateCover(sceneID int, cover []byte) error - DestroyCover(sceneID int) error - UpdatePerformers(sceneID int, performerIDs []int) error - UpdateTags(sceneID int, tagIDs []int) error - UpdateGalleries(sceneID int, galleryIDs []int) error - UpdateMovies(sceneID int, movies []MoviesScenes) error - UpdateStashIDs(sceneID int, stashIDs []StashID) error + Create(ctx context.Context, newScene Scene) (*Scene, error) + Update(ctx context.Context, updatedScene ScenePartial) (*Scene, error) + UpdateFull(ctx context.Context, updatedScene Scene) (*Scene, error) + IncrementOCounter(ctx context.Context, id int) (int, error) + DecrementOCounter(ctx context.Context, id int) (int, error) + ResetOCounter(ctx context.Context, id int) (int, error) + UpdateFileModTime(ctx context.Context, id int, modTime NullSQLiteTimestamp) error + Destroy(ctx context.Context, id int) error + UpdateCaptions(ctx context.Context, id int, captions []*SceneCaption) error + UpdateCover(ctx context.Context, sceneID int, cover []byte) error + DestroyCover(ctx context.Context, sceneID int) error + UpdatePerformers(ctx context.Context, sceneID int, performerIDs []int) error + UpdateTags(ctx context.Context, sceneID int, tagIDs []int) error + UpdateGalleries(ctx context.Context, sceneID int, galleryIDs []int) error + UpdateMovies(ctx context.Context, sceneID int, movies []MoviesScenes) error + UpdateStashIDs(ctx context.Context, sceneID int, stashIDs []StashID) error } type SceneReaderWriter interface { diff --git a/pkg/models/scene_marker.go b/pkg/models/scene_marker.go index ac7501672..dd0b786f6 100644 --- a/pkg/models/scene_marker.go +++ b/pkg/models/scene_marker.go @@ -1,5 +1,7 @@ package models +import "context" + type SceneMarkerFilterType struct { // Filter to only include scene markers with this tag TagID *string `json:"tag_id"` @@ -18,21 +20,21 @@ type MarkerStringsResultType struct { } type SceneMarkerReader interface { - Find(id int) (*SceneMarker, error) - FindMany(ids []int) ([]*SceneMarker, error) - FindBySceneID(sceneID int) ([]*SceneMarker, error) - CountByTagID(tagID int) (int, error) - GetMarkerStrings(q *string, sort *string) ([]*MarkerStringsResultType, error) - Wall(q *string) ([]*SceneMarker, error) - Query(sceneMarkerFilter *SceneMarkerFilterType, findFilter *FindFilterType) ([]*SceneMarker, int, error) - GetTagIDs(imageID int) ([]int, error) + Find(ctx context.Context, id int) (*SceneMarker, error) + FindMany(ctx context.Context, ids []int) ([]*SceneMarker, error) + FindBySceneID(ctx context.Context, sceneID int) ([]*SceneMarker, error) + CountByTagID(ctx context.Context, tagID int) (int, error) + GetMarkerStrings(ctx context.Context, q *string, sort *string) ([]*MarkerStringsResultType, error) + Wall(ctx context.Context, q *string) ([]*SceneMarker, error) + Query(ctx context.Context, sceneMarkerFilter *SceneMarkerFilterType, findFilter *FindFilterType) ([]*SceneMarker, int, error) + GetTagIDs(ctx context.Context, imageID int) ([]int, error) } type SceneMarkerWriter interface { - Create(newSceneMarker SceneMarker) (*SceneMarker, error) - Update(updatedSceneMarker SceneMarker) (*SceneMarker, error) - Destroy(id int) error - UpdateTags(markerID int, tagIDs []int) error + Create(ctx context.Context, newSceneMarker SceneMarker) (*SceneMarker, error) + Update(ctx context.Context, updatedSceneMarker SceneMarker) (*SceneMarker, error) + Destroy(ctx context.Context, id int) error + UpdateTags(ctx context.Context, markerID int, tagIDs []int) error } type SceneMarkerReaderWriter interface { diff --git a/pkg/models/scraped.go b/pkg/models/scraped.go index f57a8409a..be424147b 100644 --- a/pkg/models/scraped.go +++ b/pkg/models/scraped.go @@ -1,15 +1,18 @@ package models -import "errors" +import ( + "context" + "errors" +) var ErrScraperSource = errors.New("invalid ScraperSource") type ScrapedItemReader interface { - All() ([]*ScrapedItem, error) + All(ctx context.Context) ([]*ScrapedItem, error) } type ScrapedItemWriter interface { - Create(newObject ScrapedItem) (*ScrapedItem, error) + Create(ctx context.Context, newObject ScrapedItem) (*ScrapedItem, error) } type ScrapedItemReaderWriter interface { diff --git a/pkg/models/studio.go b/pkg/models/studio.go index fb9f7c415..c1f077ce7 100644 --- a/pkg/models/studio.go +++ b/pkg/models/studio.go @@ -1,5 +1,7 @@ package models +import "context" + type StudioFilterType struct { And *StudioFilterType `json:"AND"` Or *StudioFilterType `json:"OR"` @@ -29,32 +31,32 @@ type StudioFilterType struct { } type StudioReader interface { - Find(id int) (*Studio, error) - FindMany(ids []int) ([]*Studio, error) - FindChildren(id int) ([]*Studio, error) - FindByName(name string, nocase bool) (*Studio, error) - FindByStashID(stashID StashID) ([]*Studio, error) - Count() (int, error) - All() ([]*Studio, error) + Find(ctx context.Context, id int) (*Studio, error) + FindMany(ctx context.Context, ids []int) ([]*Studio, error) + FindChildren(ctx context.Context, id int) ([]*Studio, error) + FindByName(ctx context.Context, name string, nocase bool) (*Studio, error) + FindByStashID(ctx context.Context, stashID StashID) ([]*Studio, error) + Count(ctx context.Context) (int, error) + All(ctx context.Context) ([]*Studio, error) // TODO - this interface is temporary until the filter schema can fully // support the query needed - QueryForAutoTag(words []string) ([]*Studio, error) - Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error) - GetImage(studioID int) ([]byte, error) - HasImage(studioID int) (bool, error) - GetStashIDs(studioID int) ([]*StashID, error) - GetAliases(studioID int) ([]string, error) + QueryForAutoTag(ctx context.Context, words []string) ([]*Studio, error) + Query(ctx context.Context, studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error) + GetImage(ctx context.Context, studioID int) ([]byte, error) + HasImage(ctx context.Context, studioID int) (bool, error) + GetStashIDs(ctx context.Context, studioID int) ([]*StashID, error) + GetAliases(ctx context.Context, studioID int) ([]string, error) } type StudioWriter interface { - Create(newStudio Studio) (*Studio, error) - Update(updatedStudio StudioPartial) (*Studio, error) - UpdateFull(updatedStudio Studio) (*Studio, error) - Destroy(id int) error - UpdateImage(studioID int, image []byte) error - DestroyImage(studioID int) error - UpdateStashIDs(studioID int, stashIDs []StashID) error - UpdateAliases(studioID int, aliases []string) error + Create(ctx context.Context, newStudio Studio) (*Studio, error) + Update(ctx context.Context, updatedStudio StudioPartial) (*Studio, error) + UpdateFull(ctx context.Context, updatedStudio Studio) (*Studio, error) + Destroy(ctx context.Context, id int) error + UpdateImage(ctx context.Context, studioID int, image []byte) error + DestroyImage(ctx context.Context, studioID int) error + UpdateStashIDs(ctx context.Context, studioID int, stashIDs []StashID) error + UpdateAliases(ctx context.Context, studioID int, aliases []string) error } type StudioReaderWriter interface { diff --git a/pkg/models/tag.go b/pkg/models/tag.go index a7f7518bf..33ff859c6 100644 --- a/pkg/models/tag.go +++ b/pkg/models/tag.go @@ -1,5 +1,7 @@ package models +import "context" + type TagFilterType struct { And *TagFilterType `json:"AND"` Or *TagFilterType `json:"OR"` @@ -33,40 +35,40 @@ type TagFilterType struct { } type TagReader interface { - Find(id int) (*Tag, error) - FindMany(ids []int) ([]*Tag, error) - FindBySceneID(sceneID int) ([]*Tag, error) - FindByPerformerID(performerID int) ([]*Tag, error) - FindBySceneMarkerID(sceneMarkerID int) ([]*Tag, error) - FindByImageID(imageID int) ([]*Tag, error) - FindByGalleryID(galleryID int) ([]*Tag, error) - FindByName(name string, nocase bool) (*Tag, error) - FindByNames(names []string, nocase bool) ([]*Tag, error) - FindByParentTagID(parentID int) ([]*Tag, error) - FindByChildTagID(childID int) ([]*Tag, error) - Count() (int, error) - All() ([]*Tag, error) + Find(ctx context.Context, id int) (*Tag, error) + FindMany(ctx context.Context, ids []int) ([]*Tag, error) + FindBySceneID(ctx context.Context, sceneID int) ([]*Tag, error) + FindByPerformerID(ctx context.Context, performerID int) ([]*Tag, error) + FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*Tag, error) + FindByImageID(ctx context.Context, imageID int) ([]*Tag, error) + FindByGalleryID(ctx context.Context, galleryID int) ([]*Tag, error) + FindByName(ctx context.Context, name string, nocase bool) (*Tag, error) + FindByNames(ctx context.Context, names []string, nocase bool) ([]*Tag, error) + FindByParentTagID(ctx context.Context, parentID int) ([]*Tag, error) + FindByChildTagID(ctx context.Context, childID int) ([]*Tag, error) + Count(ctx context.Context) (int, error) + All(ctx context.Context) ([]*Tag, error) // TODO - this interface is temporary until the filter schema can fully // support the query needed - QueryForAutoTag(words []string) ([]*Tag, error) - Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error) - GetImage(tagID int) ([]byte, error) - GetAliases(tagID int) ([]string, error) - FindAllAncestors(tagID int, excludeIDs []int) ([]*TagPath, error) - FindAllDescendants(tagID int, excludeIDs []int) ([]*TagPath, error) + QueryForAutoTag(ctx context.Context, words []string) ([]*Tag, error) + Query(ctx context.Context, tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error) + GetImage(ctx context.Context, tagID int) ([]byte, error) + GetAliases(ctx context.Context, tagID int) ([]string, error) + FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*TagPath, error) + FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*TagPath, error) } type TagWriter interface { - Create(newTag Tag) (*Tag, error) - Update(updateTag TagPartial) (*Tag, error) - UpdateFull(updatedTag Tag) (*Tag, error) - Destroy(id int) error - UpdateImage(tagID int, image []byte) error - DestroyImage(tagID int) error - UpdateAliases(tagID int, aliases []string) error - Merge(source []int, destination int) error - UpdateParentTags(tagID int, parentIDs []int) error - UpdateChildTags(tagID int, parentIDs []int) error + Create(ctx context.Context, newTag Tag) (*Tag, error) + Update(ctx context.Context, updateTag TagPartial) (*Tag, error) + UpdateFull(ctx context.Context, updatedTag Tag) (*Tag, error) + Destroy(ctx context.Context, id int) error + UpdateImage(ctx context.Context, tagID int, image []byte) error + DestroyImage(ctx context.Context, tagID int) error + UpdateAliases(ctx context.Context, tagID int, aliases []string) error + Merge(ctx context.Context, source []int, destination int) error + UpdateParentTags(ctx context.Context, tagID int, parentIDs []int) error + UpdateChildTags(ctx context.Context, tagID int, parentIDs []int) error } type TagReaderWriter interface { diff --git a/pkg/models/transaction.go b/pkg/models/transaction.go deleted file mode 100644 index 291038b0c..000000000 --- a/pkg/models/transaction.go +++ /dev/null @@ -1,86 +0,0 @@ -package models - -import ( - "context" - - "github.com/stashapp/stash/pkg/logger" -) - -type Transaction interface { - Begin() error - Rollback() error - Commit() error - Repository() Repository -} - -type ReadTransaction interface { - Begin() error - Rollback() error - Commit() error - Repository() ReaderRepository -} - -type TransactionManager interface { - WithTxn(ctx context.Context, fn func(r Repository) error) error - WithReadTxn(ctx context.Context, fn func(r ReaderRepository) error) error -} - -func WithTxn(txn Transaction, fn func(r Repository) error) error { - err := txn.Begin() - if err != nil { - return err - } - - defer func() { - if p := recover(); p != nil { - // a panic occurred, rollback and repanic - if err := txn.Rollback(); err != nil { - logger.Warnf("error while trying to roll back transaction: %v", err) - } - panic(p) - } - - if err != nil { - // something went wrong, rollback - if err := txn.Rollback(); err != nil { - logger.Warnf("error while trying to roll back transaction: %v", err) - } - } else { - // all good, commit - err = txn.Commit() - } - }() - - err = fn(txn.Repository()) - return err -} - -func WithROTxn(txn ReadTransaction, fn func(r ReaderRepository) error) error { - err := txn.Begin() - if err != nil { - return err - } - - defer func() { - if p := recover(); p != nil { - // a panic occurred, rollback and repanic - if err := txn.Rollback(); err != nil { - logger.Warnf("error while trying to roll back RO transaction: %v", err) - } - panic(p) - } - - if err != nil { - // something went wrong, rollback - if err := txn.Rollback(); err != nil { - logger.Warnf("error while trying to roll back RO transaction: %v", err) - } - } else { - // all good, commit - err = txn.Commit() - } - }() - - err = fn(txn.Repository()) - return err -} diff --git a/pkg/movie/export.go b/pkg/movie/export.go index a70e30290..2af697a49 100644 --- a/pkg/movie/export.go +++ b/pkg/movie/export.go @@ -1,16 +1,23 @@ package movie import ( + "context" "fmt" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/json" "github.com/stashapp/stash/pkg/models/jsonschema" + "github.com/stashapp/stash/pkg/studio" "github.com/stashapp/stash/pkg/utils" ) +type ImageGetter interface { + GetFrontImage(ctx context.Context, movieID int) ([]byte, error) + GetBackImage(ctx context.Context, movieID int) ([]byte, error) +} + // ToJSON converts a Movie into its JSON equivalent. -func ToJSON(reader models.MovieReader, studioReader models.StudioReader, movie *models.Movie) (*jsonschema.Movie, error) { +func ToJSON(ctx context.Context, reader ImageGetter, studioReader studio.Finder, movie *models.Movie) (*jsonschema.Movie, error) { newMovieJSON := jsonschema.Movie{ CreatedAt: json.JSONTime{Time: movie.CreatedAt.Timestamp}, UpdatedAt: json.JSONTime{Time: movie.UpdatedAt.Timestamp}, @@ -45,7 +52,7 @@ func ToJSON(reader models.MovieReader, studioReader models.StudioReader, movie * } if movie.StudioID.Valid { - studio, err := studioReader.Find(int(movie.StudioID.Int64)) + studio, err := studioReader.Find(ctx, int(movie.StudioID.Int64)) if err != nil { return nil, fmt.Errorf("error getting movie studio: %v", err) } @@ -55,7 +62,7 @@ func ToJSON(reader models.MovieReader, studioReader models.StudioReader, movie * } } - frontImage, err := reader.GetFrontImage(movie.ID) + frontImage, err := reader.GetFrontImage(ctx, movie.ID) if err != nil { return nil, fmt.Errorf("error getting movie front image: %v", err) } @@ -64,7 +71,7 @@ func ToJSON(reader models.MovieReader, studioReader models.StudioReader, movie * newMovieJSON.FrontImage = utils.GetBase64StringFromData(frontImage) } - backImage, err := reader.GetBackImage(movie.ID) + backImage, err := reader.GetBackImage(ctx, movie.ID) if err != nil { return nil, fmt.Errorf("error getting movie back image: %v", err) } diff --git a/pkg/movie/export_test.go b/pkg/movie/export_test.go index 11be97b7b..007383902 100644 --- a/pkg/movie/export_test.go +++ b/pkg/movie/export_test.go @@ -55,7 +55,7 @@ var ( backImageBytes = []byte("backImageBytes") ) -var studio models.Studio = models.Studio{ +var movieStudio models.Studio = models.Studio{ Name: models.NullString(studioName), } @@ -189,30 +189,30 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") - mockMovieReader.On("GetFrontImage", movieID).Return(frontImageBytes, nil).Once() - mockMovieReader.On("GetFrontImage", missingStudioMovieID).Return(frontImageBytes, nil).Once() - mockMovieReader.On("GetFrontImage", emptyID).Return(nil, nil).Once().Maybe() - mockMovieReader.On("GetFrontImage", errFrontImageID).Return(nil, imageErr).Once() - mockMovieReader.On("GetFrontImage", errBackImageID).Return(frontImageBytes, nil).Once() + mockMovieReader.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once() + mockMovieReader.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once() + mockMovieReader.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe() + mockMovieReader.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once() + mockMovieReader.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once() - mockMovieReader.On("GetBackImage", movieID).Return(backImageBytes, nil).Once() - mockMovieReader.On("GetBackImage", missingStudioMovieID).Return(backImageBytes, nil).Once() - mockMovieReader.On("GetBackImage", emptyID).Return(nil, nil).Once() - mockMovieReader.On("GetBackImage", errBackImageID).Return(nil, imageErr).Once() - mockMovieReader.On("GetBackImage", errFrontImageID).Return(backImageBytes, nil).Maybe() - mockMovieReader.On("GetBackImage", errStudioMovieID).Return(backImageBytes, nil).Maybe() + mockMovieReader.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once() + mockMovieReader.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once() + mockMovieReader.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once() + mockMovieReader.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once() + mockMovieReader.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe() + mockMovieReader.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe() mockStudioReader := &mocks.StudioReaderWriter{} studioErr := errors.New("error getting studio") - mockStudioReader.On("Find", studioID).Return(&studio, nil) - mockStudioReader.On("Find", missingStudioID).Return(nil, nil) - mockStudioReader.On("Find", errStudioID).Return(nil, studioErr) + mockStudioReader.On("Find", testCtx, studioID).Return(&movieStudio, nil) + mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil) + mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr) for i, s := range scenarios { movie := s.movie - json, err := ToJSON(mockMovieReader, mockStudioReader, &movie) + json, err := ToJSON(testCtx, mockMovieReader, mockStudioReader, &movie) switch { case !s.err && err != nil: diff --git a/pkg/movie/import.go b/pkg/movie/import.go index 6afdef8b9..461df0f84 100644 --- a/pkg/movie/import.go +++ b/pkg/movie/import.go @@ -1,18 +1,26 @@ package movie import ( + "context" "database/sql" "fmt" "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/jsonschema" + "github.com/stashapp/stash/pkg/studio" "github.com/stashapp/stash/pkg/utils" ) +type NameFinderCreatorUpdater interface { + NameFinderCreator + UpdateFull(ctx context.Context, updatedMovie models.Movie) (*models.Movie, error) + UpdateImages(ctx context.Context, movieID int, frontImage []byte, backImage []byte) error +} + type Importer struct { - ReaderWriter models.MovieReaderWriter - StudioWriter models.StudioReaderWriter + ReaderWriter NameFinderCreatorUpdater + StudioWriter studio.NameFinderCreator Input jsonschema.Movie MissingRefBehaviour models.ImportMissingRefEnum @@ -21,10 +29,10 @@ type Importer struct { backImageData []byte } -func (i *Importer) PreImport() error { +func (i *Importer) PreImport(ctx context.Context) error { i.movie = i.movieJSONToMovie(i.Input) - if err := i.populateStudio(); err != nil { + if err := i.populateStudio(ctx); err != nil { return err } @@ -71,9 +79,9 @@ func (i *Importer) movieJSONToMovie(movieJSON jsonschema.Movie) models.Movie { return newMovie } -func (i *Importer) populateStudio() error { +func (i *Importer) populateStudio(ctx context.Context) error { if i.Input.Studio != "" { - studio, err := i.StudioWriter.FindByName(i.Input.Studio, false) + studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false) if err != nil { return fmt.Errorf("error finding studio by name: %v", err) } @@ -88,7 +96,7 @@ func (i *Importer) populateStudio() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - studioID, err := i.createStudio(i.Input.Studio) + studioID, err := i.createStudio(ctx, i.Input.Studio) if err != nil { return err } @@ -105,10 +113,10 @@ func (i *Importer) populateStudio() error { return nil } -func (i *Importer) createStudio(name string) (int, error) { +func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { newStudio := *models.NewStudio(name) - created, err := i.StudioWriter.Create(newStudio) + created, err := i.StudioWriter.Create(ctx, newStudio) if err != nil { return 0, err } @@ -116,9 +124,9 @@ func (i *Importer) createStudio(name string) (int, error) { return created.ID, nil } -func (i *Importer) PostImport(id int) error { +func (i *Importer) PostImport(ctx context.Context, id int) error { if len(i.frontImageData) > 0 { - if err := i.ReaderWriter.UpdateImages(id, i.frontImageData, i.backImageData); err != nil { + if err := i.ReaderWriter.UpdateImages(ctx, id, i.frontImageData, i.backImageData); err != nil { return fmt.Errorf("error setting movie images: %v", err) } } @@ -130,9 +138,9 @@ func (i *Importer) Name() string { return i.Input.Name } -func (i *Importer) FindExistingID() (*int, error) { +func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { const nocase = false - existing, err := i.ReaderWriter.FindByName(i.Name(), nocase) + existing, err := i.ReaderWriter.FindByName(ctx, i.Name(), nocase) if err != nil { return nil, err } @@ -145,8 +153,8 @@ func (i *Importer) FindExistingID() (*int, error) { return nil, nil } -func (i *Importer) Create() (*int, error) { - created, err := i.ReaderWriter.Create(i.movie) +func (i *Importer) Create(ctx context.Context) (*int, error) { + created, err := i.ReaderWriter.Create(ctx, i.movie) if err != nil { return nil, fmt.Errorf("error creating movie: %v", err) } @@ -155,10 +163,10 @@ func (i *Importer) Create() (*int, error) { return &id, nil } -func (i *Importer) Update(id int) error { +func (i *Importer) Update(ctx context.Context, id int) error { movie := i.movie movie.ID = id - _, err := i.ReaderWriter.UpdateFull(movie) + _, err := i.ReaderWriter.UpdateFull(ctx, movie) if err != nil { return fmt.Errorf("error updating existing movie: %v", err) } diff --git a/pkg/movie/import_test.go b/pkg/movie/import_test.go index 7aff71d47..26b6d9f27 100644 --- a/pkg/movie/import_test.go +++ b/pkg/movie/import_test.go @@ -1,6 +1,7 @@ package movie import ( + "context" "errors" "testing" @@ -27,6 +28,8 @@ const ( errImageID = 3 ) +var testCtx = context.Background() + func TestImporterName(t *testing.T) { i := Importer{ Input: jsonschema.Movie{ @@ -45,23 +48,23 @@ func TestImporterPreImport(t *testing.T) { }, } - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.Input.FrontImage = frontImage i.Input.BackImage = invalidImage - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) i.Input.BackImage = "" - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.Input.BackImage = backImage - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) } @@ -79,17 +82,17 @@ func TestImporterPreImportWithStudio(t *testing.T) { }, } - studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{ + studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.movie.StudioID.Int64) i.Input.Studio = existingStudioErr - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) studioReaderWriter.AssertExpectations(t) @@ -108,20 +111,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3) - studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{ + studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) + studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{ ID: existingStudioID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.movie.StudioID.Int64) @@ -141,10 +144,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once() - studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) + studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() + studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -159,13 +162,13 @@ func TestImporterPostImport(t *testing.T) { updateMovieImageErr := errors.New("UpdateImages error") - readerWriter.On("UpdateImages", movieID, frontImageBytes, backImageBytes).Return(nil).Once() - readerWriter.On("UpdateImages", errImageID, frontImageBytes, backImageBytes).Return(updateMovieImageErr).Once() + readerWriter.On("UpdateImages", testCtx, movieID, frontImageBytes, backImageBytes).Return(nil).Once() + readerWriter.On("UpdateImages", testCtx, errImageID, frontImageBytes, backImageBytes).Return(updateMovieImageErr).Once() - err := i.PostImport(movieID) + err := i.PostImport(testCtx, movieID) assert.Nil(t, err) - err = i.PostImport(errImageID) + err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -182,23 +185,23 @@ func TestImporterFindExistingID(t *testing.T) { } errFindByName := errors.New("FindByName error") - readerWriter.On("FindByName", movieName, false).Return(nil, nil).Once() - readerWriter.On("FindByName", existingMovieName, false).Return(&models.Movie{ + readerWriter.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once() + readerWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{ ID: existingMovieID, }, nil).Once() - readerWriter.On("FindByName", movieNameErr, false).Return(nil, errFindByName).Once() + readerWriter.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once() - id, err := i.FindExistingID() + id, err := i.FindExistingID(testCtx) assert.Nil(t, id) assert.Nil(t, err) i.Input.Name = existingMovieName - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Equal(t, existingMovieID, *id) assert.Nil(t, err) i.Input.Name = movieNameErr - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -222,17 +225,17 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", movie).Return(&models.Movie{ + readerWriter.On("Create", testCtx, movie).Return(&models.Movie{ ID: movieID, }, nil).Once() - readerWriter.On("Create", movieErr).Return(nil, errCreate).Once() + readerWriter.On("Create", testCtx, movieErr).Return(nil, errCreate).Once() - id, err := i.Create() + id, err := i.Create(testCtx) assert.Equal(t, movieID, *id) assert.Nil(t, err) i.movie = movieErr - id, err = i.Create() + id, err = i.Create(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -259,18 +262,18 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input movie.ID = movieID - readerWriter.On("UpdateFull", movie).Return(nil, nil).Once() + readerWriter.On("UpdateFull", testCtx, movie).Return(nil, nil).Once() - err := i.Update(movieID) + err := i.Update(testCtx, movieID) assert.Nil(t, err) i.movie = movieErr // need to set id separately movieErr.ID = errImageID - readerWriter.On("UpdateFull", movieErr).Return(nil, errUpdate).Once() + readerWriter.On("UpdateFull", testCtx, movieErr).Return(nil, errUpdate).Once() - err = i.Update(errImageID) + err = i.Update(testCtx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) diff --git a/pkg/movie/update.go b/pkg/movie/update.go new file mode 100644 index 000000000..48dc9c123 --- /dev/null +++ b/pkg/movie/update.go @@ -0,0 +1,12 @@ +package movie + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" +) + +type NameFinderCreator interface { + FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error) + Create(ctx context.Context, newMovie models.Movie) (*models.Movie, error) +} diff --git a/pkg/performer/export.go b/pkg/performer/export.go index 240d0fc28..a15df7e99 100644 --- a/pkg/performer/export.go +++ b/pkg/performer/export.go @@ -1,6 +1,7 @@ package performer import ( + "context" "fmt" "github.com/stashapp/stash/pkg/models" @@ -9,8 +10,13 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +type ImageStashIDGetter interface { + GetImage(ctx context.Context, performerID int) ([]byte, error) + GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error) +} + // ToJSON converts a Performer object into its JSON equivalent. -func ToJSON(reader models.PerformerReader, performer *models.Performer) (*jsonschema.Performer, error) { +func ToJSON(ctx context.Context, reader ImageStashIDGetter, performer *models.Performer) (*jsonschema.Performer, error) { newPerformerJSON := jsonschema.Performer{ IgnoreAutoTag: performer.IgnoreAutoTag, CreatedAt: json.JSONTime{Time: performer.CreatedAt.Timestamp}, @@ -84,7 +90,7 @@ func ToJSON(reader models.PerformerReader, performer *models.Performer) (*jsonsc newPerformerJSON.Weight = int(performer.Weight.Int64) } - image, err := reader.GetImage(performer.ID) + image, err := reader.GetImage(ctx, performer.ID) if err != nil { return nil, fmt.Errorf("error getting performers image: %v", err) } @@ -93,7 +99,7 @@ func ToJSON(reader models.PerformerReader, performer *models.Performer) (*jsonsc newPerformerJSON.Image = utils.GetBase64StringFromData(image) } - stashIDs, _ := reader.GetStashIDs(performer.ID) + stashIDs, _ := reader.GetStashIDs(ctx, performer.ID) var ret []models.StashID for _, stashID := range stashIDs { newJoin := models.StashID{ diff --git a/pkg/performer/export_test.go b/pkg/performer/export_test.go index 7cfdabb7f..e83d0e189 100644 --- a/pkg/performer/export_test.go +++ b/pkg/performer/export_test.go @@ -208,16 +208,16 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") - mockPerformerReader.On("GetImage", performerID).Return(imageBytes, nil).Once() - mockPerformerReader.On("GetImage", noImageID).Return(nil, nil).Once() - mockPerformerReader.On("GetImage", errImageID).Return(nil, imageErr).Once() + mockPerformerReader.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once() + mockPerformerReader.On("GetImage", testCtx, noImageID).Return(nil, nil).Once() + mockPerformerReader.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() - mockPerformerReader.On("GetStashIDs", performerID).Return(stashIDs, nil).Once() - mockPerformerReader.On("GetStashIDs", noImageID).Return(nil, nil).Once() + mockPerformerReader.On("GetStashIDs", testCtx, performerID).Return(stashIDs, nil).Once() + mockPerformerReader.On("GetStashIDs", testCtx, noImageID).Return(nil, nil).Once() for i, s := range scenarios { tag := s.input - json, err := ToJSON(mockPerformerReader, &tag) + json, err := ToJSON(testCtx, mockPerformerReader, &tag) switch { case !s.err && err != nil: diff --git a/pkg/performer/import.go b/pkg/performer/import.go index 9e4ec77f7..7c673fb34 100644 --- a/pkg/performer/import.go +++ b/pkg/performer/import.go @@ -1,6 +1,7 @@ package performer import ( + "context" "database/sql" "fmt" "strings" @@ -9,12 +10,21 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/jsonschema" "github.com/stashapp/stash/pkg/sliceutil/stringslice" + "github.com/stashapp/stash/pkg/tag" "github.com/stashapp/stash/pkg/utils" ) +type NameFinderCreatorUpdater interface { + NameFinderCreator + UpdateFull(ctx context.Context, updatedPerformer models.Performer) (*models.Performer, error) + UpdateTags(ctx context.Context, performerID int, tagIDs []int) error + UpdateImage(ctx context.Context, performerID int, image []byte) error + UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error +} + type Importer struct { - ReaderWriter models.PerformerReaderWriter - TagWriter models.TagReaderWriter + ReaderWriter NameFinderCreatorUpdater + TagWriter tag.NameFinderCreator Input jsonschema.Performer MissingRefBehaviour models.ImportMissingRefEnum @@ -25,10 +35,10 @@ type Importer struct { tags []*models.Tag } -func (i *Importer) PreImport() error { +func (i *Importer) PreImport(ctx context.Context) error { i.performer = performerJSONToPerformer(i.Input) - if err := i.populateTags(); err != nil { + if err := i.populateTags(ctx); err != nil { return err } @@ -43,10 +53,10 @@ func (i *Importer) PreImport() error { return nil } -func (i *Importer) populateTags() error { +func (i *Importer) populateTags(ctx context.Context) error { if len(i.Input.Tags) > 0 { - tags, err := importTags(i.TagWriter, i.Input.Tags, i.MissingRefBehaviour) + tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour) if err != nil { return err } @@ -57,8 +67,8 @@ func (i *Importer) populateTags() error { return nil } -func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) { - tags, err := tagWriter.FindByNames(names, false) +func importTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) { + tags, err := tagWriter.FindByNames(ctx, names, false) if err != nil { return nil, err } @@ -78,7 +88,7 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha } if missingRefBehaviour == models.ImportMissingRefEnumCreate { - createdTags, err := createTags(tagWriter, missingTags) + createdTags, err := createTags(ctx, tagWriter, missingTags) if err != nil { return nil, fmt.Errorf("error creating tags: %v", err) } @@ -92,12 +102,12 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha return tags, nil } -func createTags(tagWriter models.TagWriter, names []string) ([]*models.Tag, error) { +func createTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string) ([]*models.Tag, error) { var ret []*models.Tag for _, name := range names { newTag := *models.NewTag(name) - created, err := tagWriter.Create(newTag) + created, err := tagWriter.Create(ctx, newTag) if err != nil { return nil, err } @@ -108,25 +118,25 @@ func createTags(tagWriter models.TagWriter, names []string) ([]*models.Tag, erro return ret, nil } -func (i *Importer) PostImport(id int) error { +func (i *Importer) PostImport(ctx context.Context, id int) error { if len(i.tags) > 0 { var tagIDs []int for _, t := range i.tags { tagIDs = append(tagIDs, t.ID) } - if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { + if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %v", err) } } if len(i.imageData) > 0 { - if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil { + if err := i.ReaderWriter.UpdateImage(ctx, id, i.imageData); err != nil { return fmt.Errorf("error setting performer image: %v", err) } } if len(i.Input.StashIDs) > 0 { - if err := i.ReaderWriter.UpdateStashIDs(id, i.Input.StashIDs); err != nil { + if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil { return fmt.Errorf("error setting stash id: %v", err) } } @@ -138,9 +148,9 @@ func (i *Importer) Name() string { return i.Input.Name } -func (i *Importer) FindExistingID() (*int, error) { +func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { const nocase = false - existing, err := i.ReaderWriter.FindByNames([]string{i.Name()}, nocase) + existing, err := i.ReaderWriter.FindByNames(ctx, []string{i.Name()}, nocase) if err != nil { return nil, err } @@ -153,8 +163,8 @@ func (i *Importer) FindExistingID() (*int, error) { return nil, nil } -func (i *Importer) Create() (*int, error) { - created, err := i.ReaderWriter.Create(i.performer) +func (i *Importer) Create(ctx context.Context) (*int, error) { + created, err := i.ReaderWriter.Create(ctx, i.performer) if err != nil { return nil, fmt.Errorf("error creating performer: %v", err) } @@ -163,10 +173,10 @@ func (i *Importer) Create() (*int, error) { return &id, nil } -func (i *Importer) Update(id int) error { +func (i *Importer) Update(ctx context.Context, id int) error { performer := i.performer performer.ID = id - _, err := i.ReaderWriter.UpdateFull(performer) + _, err := i.ReaderWriter.UpdateFull(ctx, performer) if err != nil { return fmt.Errorf("error updating existing performer: %v", err) } diff --git a/pkg/performer/import_test.go b/pkg/performer/import_test.go index 30ddbae5e..4f80a67c0 100644 --- a/pkg/performer/import_test.go +++ b/pkg/performer/import_test.go @@ -1,6 +1,7 @@ package performer import ( + "context" "errors" "github.com/stretchr/testify/mock" @@ -29,6 +30,8 @@ const ( missingTagName = "missingTagName" ) +var testCtx = context.Background() + func TestImporterName(t *testing.T) { i := Importer{ Input: jsonschema.Performer{ @@ -47,13 +50,13 @@ func TestImporterPreImport(t *testing.T) { }, } - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.Input = *createFullJSONPerformer(performerName, image) - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) expectedPerformer := *createFullPerformer(0, performerName) @@ -74,20 +77,20 @@ func TestImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{ + tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Once() - tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() + tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) i.Input.Tags = []string{existingTagErr} - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) tagReaderWriter.AssertExpectations(t) @@ -106,20 +109,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3) - tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{ + tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) + tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{ ID: existingTagID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) @@ -139,10 +142,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once() - tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) + tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() + tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -156,13 +159,13 @@ func TestImporterPostImport(t *testing.T) { updatePerformerImageErr := errors.New("UpdateImage error") - readerWriter.On("UpdateImage", performerID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updatePerformerImageErr).Once() + readerWriter.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once() - err := i.PostImport(performerID) + err := i.PostImport(testCtx, performerID) assert.Nil(t, err) - err = i.PostImport(errImageID) + err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -179,25 +182,25 @@ func TestImporterFindExistingID(t *testing.T) { } errFindByNames := errors.New("FindByNames error") - readerWriter.On("FindByNames", []string{performerName}, false).Return(nil, nil).Once() - readerWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ + readerWriter.On("FindByNames", testCtx, []string{performerName}, false).Return(nil, nil).Once() + readerWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, }, }, nil).Once() - readerWriter.On("FindByNames", []string{performerNameErr}, false).Return(nil, errFindByNames).Once() + readerWriter.On("FindByNames", testCtx, []string{performerNameErr}, false).Return(nil, errFindByNames).Once() - id, err := i.FindExistingID() + id, err := i.FindExistingID(testCtx) assert.Nil(t, id) assert.Nil(t, err) i.Input.Name = existingPerformerName - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Equal(t, existingPerformerID, *id) assert.Nil(t, err) i.Input.Name = performerNameErr - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -218,13 +221,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) { updateErr := errors.New("UpdateTags error") - readerWriter.On("UpdateTags", performerID, []int{existingTagID}).Return(nil).Once() - readerWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + readerWriter.On("UpdateTags", testCtx, performerID, []int{existingTagID}).Return(nil).Once() + readerWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(performerID) + err := i.PostImport(testCtx, performerID) assert.Nil(t, err) - err = i.PostImport(errTagsID) + err = i.PostImport(testCtx, errTagsID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -247,17 +250,17 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", performer).Return(&models.Performer{ + readerWriter.On("Create", testCtx, performer).Return(&models.Performer{ ID: performerID, }, nil).Once() - readerWriter.On("Create", performerErr).Return(nil, errCreate).Once() + readerWriter.On("Create", testCtx, performerErr).Return(nil, errCreate).Once() - id, err := i.Create() + id, err := i.Create(testCtx) assert.Equal(t, performerID, *id) assert.Nil(t, err) i.performer = performerErr - id, err = i.Create() + id, err = i.Create(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -284,18 +287,18 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input performer.ID = performerID - readerWriter.On("UpdateFull", performer).Return(nil, nil).Once() + readerWriter.On("UpdateFull", testCtx, performer).Return(nil, nil).Once() - err := i.Update(performerID) + err := i.Update(testCtx, performerID) assert.Nil(t, err) i.performer = performerErr // need to set id separately performerErr.ID = errImageID - readerWriter.On("UpdateFull", performerErr).Return(nil, errUpdate).Once() + readerWriter.On("UpdateFull", testCtx, performerErr).Return(nil, errUpdate).Once() - err = i.Update(errImageID) + err = i.Update(testCtx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) diff --git a/pkg/performer/update.go b/pkg/performer/update.go new file mode 100644 index 000000000..5974a5eab --- /dev/null +++ b/pkg/performer/update.go @@ -0,0 +1,12 @@ +package performer + +import ( + "context" + + "github.com/stashapp/stash/pkg/models" +) + +type NameFinderCreator interface { + FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) + Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error) +} diff --git a/pkg/scene/delete.go b/pkg/scene/delete.go index 3a31d6f60..7347d68fd 100644 --- a/pkg/scene/delete.go +++ b/pkg/scene/delete.go @@ -1,6 +1,7 @@ package scene import ( + "context" "path/filepath" "github.com/stashapp/stash/pkg/file" @@ -114,19 +115,25 @@ func (d *FileDeleter) MarkMarkerFiles(scene *models.Scene, seconds int) error { return d.Files(files) } +type Destroyer interface { + Destroy(ctx context.Context, id int) error +} + +type MarkerDestroyer interface { + FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error) + Destroy(ctx context.Context, id int) error +} + // Destroy deletes a scene and its associated relationships from the // database. -func Destroy(scene *models.Scene, repo models.Repository, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error { - qb := repo.Scene() - mqb := repo.SceneMarker() - - markers, err := mqb.FindBySceneID(scene.ID) +func Destroy(ctx context.Context, scene *models.Scene, qb Destroyer, mqb MarkerDestroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error { + markers, err := mqb.FindBySceneID(ctx, scene.ID) if err != nil { return err } for _, m := range markers { - if err := DestroyMarker(scene, m, mqb, fileDeleter); err != nil { + if err := DestroyMarker(ctx, scene, m, mqb, fileDeleter); err != nil { return err } } @@ -151,7 +158,7 @@ func Destroy(scene *models.Scene, repo models.Repository, fileDeleter *FileDelet } } - if err := qb.Destroy(scene.ID); err != nil { + if err := qb.Destroy(ctx, scene.ID); err != nil { return err } @@ -161,8 +168,8 @@ func Destroy(scene *models.Scene, repo models.Repository, fileDeleter *FileDelet // DestroyMarker deletes the scene marker from the database and returns a // function that removes the generated files, to be executed after the // transaction is successfully committed. -func DestroyMarker(scene *models.Scene, sceneMarker *models.SceneMarker, qb models.SceneMarkerWriter, fileDeleter *FileDeleter) error { - if err := qb.Destroy(sceneMarker.ID); err != nil { +func DestroyMarker(ctx context.Context, scene *models.Scene, sceneMarker *models.SceneMarker, qb MarkerDestroyer, fileDeleter *FileDeleter) error { + if err := qb.Destroy(ctx, sceneMarker.ID); err != nil { return err } diff --git a/pkg/scene/export.go b/pkg/scene/export.go index c5bda2c47..57557f11a 100644 --- a/pkg/scene/export.go +++ b/pkg/scene/export.go @@ -1,6 +1,7 @@ package scene import ( + "context" "fmt" "math" "strconv" @@ -9,13 +10,38 @@ import ( "github.com/stashapp/stash/pkg/models/json" "github.com/stashapp/stash/pkg/models/jsonschema" "github.com/stashapp/stash/pkg/sliceutil/intslice" + "github.com/stashapp/stash/pkg/studio" + "github.com/stashapp/stash/pkg/tag" "github.com/stashapp/stash/pkg/utils" ) +type CoverStashIDGetter interface { + GetCover(ctx context.Context, sceneID int) ([]byte, error) + GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error) +} + +type MovieGetter interface { + GetMovies(ctx context.Context, sceneID int) ([]models.MoviesScenes, error) +} + +type MarkerTagFinder interface { + tag.Finder + TagFinder + FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*models.Tag, error) +} + +type MarkerFinder interface { + FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error) +} + +type TagFinder interface { + FindBySceneID(ctx context.Context, sceneID int) ([]*models.Tag, error) +} + // ToBasicJSON converts a scene object into its JSON object equivalent. It // does not convert the relationships to other objects, with the exception // of cover image. -func ToBasicJSON(reader models.SceneReader, scene *models.Scene) (*jsonschema.Scene, error) { +func ToBasicJSON(ctx context.Context, reader CoverStashIDGetter, scene *models.Scene) (*jsonschema.Scene, error) { newSceneJSON := jsonschema.Scene{ CreatedAt: json.JSONTime{Time: scene.CreatedAt.Timestamp}, UpdatedAt: json.JSONTime{Time: scene.UpdatedAt.Timestamp}, @@ -58,7 +84,7 @@ func ToBasicJSON(reader models.SceneReader, scene *models.Scene) (*jsonschema.Sc newSceneJSON.File = getSceneFileJSON(scene) - cover, err := reader.GetCover(scene.ID) + cover, err := reader.GetCover(ctx, scene.ID) if err != nil { return nil, fmt.Errorf("error getting scene cover: %v", err) } @@ -67,7 +93,7 @@ func ToBasicJSON(reader models.SceneReader, scene *models.Scene) (*jsonschema.Sc newSceneJSON.Cover = utils.GetBase64StringFromData(cover) } - stashIDs, _ := reader.GetStashIDs(scene.ID) + stashIDs, _ := reader.GetStashIDs(ctx, scene.ID) var ret []models.StashID for _, stashID := range stashIDs { newJoin := models.StashID{ @@ -130,9 +156,9 @@ func getSceneFileJSON(scene *models.Scene) *jsonschema.SceneFile { // GetStudioName returns the name of the provided scene's studio. It returns an // empty string if there is no studio assigned to the scene. -func GetStudioName(reader models.StudioReader, scene *models.Scene) (string, error) { +func GetStudioName(ctx context.Context, reader studio.Finder, scene *models.Scene) (string, error) { if scene.StudioID.Valid { - studio, err := reader.Find(int(scene.StudioID.Int64)) + studio, err := reader.Find(ctx, int(scene.StudioID.Int64)) if err != nil { return "", err } @@ -147,8 +173,8 @@ func GetStudioName(reader models.StudioReader, scene *models.Scene) (string, err // GetTagNames returns a slice of tag names corresponding to the provided // scene's tags. -func GetTagNames(reader models.TagReader, scene *models.Scene) ([]string, error) { - tags, err := reader.FindBySceneID(scene.ID) +func GetTagNames(ctx context.Context, reader TagFinder, scene *models.Scene) ([]string, error) { + tags, err := reader.FindBySceneID(ctx, scene.ID) if err != nil { return nil, fmt.Errorf("error getting scene tags: %v", err) } @@ -168,10 +194,10 @@ func getTagNames(tags []*models.Tag) []string { } // GetDependentTagIDs returns a slice of unique tag IDs that this scene references. -func GetDependentTagIDs(tags models.TagReader, markerReader models.SceneMarkerReader, scene *models.Scene) ([]int, error) { +func GetDependentTagIDs(ctx context.Context, tags MarkerTagFinder, markerReader MarkerFinder, scene *models.Scene) ([]int, error) { var ret []int - t, err := tags.FindBySceneID(scene.ID) + t, err := tags.FindBySceneID(ctx, scene.ID) if err != nil { return nil, err } @@ -180,14 +206,14 @@ func GetDependentTagIDs(tags models.TagReader, markerReader models.SceneMarkerRe ret = intslice.IntAppendUnique(ret, tt.ID) } - sm, err := markerReader.FindBySceneID(scene.ID) + sm, err := markerReader.FindBySceneID(ctx, scene.ID) if err != nil { return nil, err } for _, smm := range sm { ret = intslice.IntAppendUnique(ret, smm.PrimaryTagID) - smmt, err := tags.FindBySceneMarkerID(smm.ID) + smmt, err := tags.FindBySceneMarkerID(ctx, smm.ID) if err != nil { return nil, fmt.Errorf("invalid tags for scene marker: %v", err) } @@ -200,17 +226,21 @@ func GetDependentTagIDs(tags models.TagReader, markerReader models.SceneMarkerRe return ret, nil } +type MovieFinder interface { + Find(ctx context.Context, id int) (*models.Movie, error) +} + // GetSceneMoviesJSON returns a slice of SceneMovie JSON representation objects // corresponding to the provided scene's scene movie relationships. -func GetSceneMoviesJSON(movieReader models.MovieReader, sceneReader models.SceneReader, scene *models.Scene) ([]jsonschema.SceneMovie, error) { - sceneMovies, err := sceneReader.GetMovies(scene.ID) +func GetSceneMoviesJSON(ctx context.Context, movieReader MovieFinder, sceneReader MovieGetter, scene *models.Scene) ([]jsonschema.SceneMovie, error) { + sceneMovies, err := sceneReader.GetMovies(ctx, scene.ID) if err != nil { return nil, fmt.Errorf("error getting scene movies: %v", err) } var results []jsonschema.SceneMovie for _, sceneMovie := range sceneMovies { - movie, err := movieReader.Find(sceneMovie.MovieID) + movie, err := movieReader.Find(ctx, sceneMovie.MovieID) if err != nil { return nil, fmt.Errorf("error getting movie: %v", err) } @@ -228,10 +258,10 @@ func GetSceneMoviesJSON(movieReader models.MovieReader, sceneReader models.Scene } // GetDependentMovieIDs returns a slice of movie IDs that this scene references. -func GetDependentMovieIDs(sceneReader models.SceneReader, scene *models.Scene) ([]int, error) { +func GetDependentMovieIDs(ctx context.Context, sceneReader MovieGetter, scene *models.Scene) ([]int, error) { var ret []int - m, err := sceneReader.GetMovies(scene.ID) + m, err := sceneReader.GetMovies(ctx, scene.ID) if err != nil { return nil, err } @@ -245,8 +275,8 @@ func GetDependentMovieIDs(sceneReader models.SceneReader, scene *models.Scene) ( // GetSceneMarkersJSON returns a slice of SceneMarker JSON representation // objects corresponding to the provided scene's markers. -func GetSceneMarkersJSON(markerReader models.SceneMarkerReader, tagReader models.TagReader, scene *models.Scene) ([]jsonschema.SceneMarker, error) { - sceneMarkers, err := markerReader.FindBySceneID(scene.ID) +func GetSceneMarkersJSON(ctx context.Context, markerReader MarkerFinder, tagReader MarkerTagFinder, scene *models.Scene) ([]jsonschema.SceneMarker, error) { + sceneMarkers, err := markerReader.FindBySceneID(ctx, scene.ID) if err != nil { return nil, fmt.Errorf("error getting scene markers: %v", err) } @@ -254,12 +284,12 @@ func GetSceneMarkersJSON(markerReader models.SceneMarkerReader, tagReader models var results []jsonschema.SceneMarker for _, sceneMarker := range sceneMarkers { - primaryTag, err := tagReader.Find(sceneMarker.PrimaryTagID) + primaryTag, err := tagReader.Find(ctx, sceneMarker.PrimaryTagID) if err != nil { return nil, fmt.Errorf("invalid primary tag for scene marker: %v", err) } - sceneMarkerTags, err := tagReader.FindBySceneMarkerID(sceneMarker.ID) + sceneMarkerTags, err := tagReader.FindBySceneMarkerID(ctx, sceneMarker.ID) if err != nil { return nil, fmt.Errorf("invalid tags for scene marker: %v", err) } diff --git a/pkg/scene/export_test.go b/pkg/scene/export_test.go index aa8b7fb52..ae6efc725 100644 --- a/pkg/scene/export_test.go +++ b/pkg/scene/export_test.go @@ -230,16 +230,16 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") - mockSceneReader.On("GetCover", sceneID).Return(imageBytes, nil).Once() - mockSceneReader.On("GetCover", noImageID).Return(nil, nil).Once() - mockSceneReader.On("GetCover", errImageID).Return(nil, imageErr).Once() + mockSceneReader.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once() + mockSceneReader.On("GetCover", testCtx, noImageID).Return(nil, nil).Once() + mockSceneReader.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once() - mockSceneReader.On("GetStashIDs", sceneID).Return(stashIDs, nil).Once() - mockSceneReader.On("GetStashIDs", noImageID).Return(nil, nil).Once() + mockSceneReader.On("GetStashIDs", testCtx, sceneID).Return(stashIDs, nil).Once() + mockSceneReader.On("GetStashIDs", testCtx, noImageID).Return(nil, nil).Once() for i, s := range scenarios { scene := s.input - json, err := ToBasicJSON(mockSceneReader, &scene) + json, err := ToBasicJSON(testCtx, mockSceneReader, &scene) switch { case !s.err && err != nil: @@ -289,15 +289,15 @@ func TestGetStudioName(t *testing.T) { studioErr := errors.New("error getting image") - mockStudioReader.On("Find", studioID).Return(&models.Studio{ + mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ Name: models.NullString(studioName), }, nil).Once() - mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() - mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() + mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() + mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() for i, s := range getStudioScenarios { scene := s.input - json, err := GetStudioName(mockStudioReader, &scene) + json, err := GetStudioName(testCtx, mockStudioReader, &scene) switch { case !s.err && err != nil: @@ -352,13 +352,13 @@ func TestGetTagNames(t *testing.T) { tagErr := errors.New("error getting tag") - mockTagReader.On("FindBySceneID", sceneID).Return(getTags(names), nil).Once() - mockTagReader.On("FindBySceneID", noTagsID).Return(nil, nil).Once() - mockTagReader.On("FindBySceneID", errTagsID).Return(nil, tagErr).Once() + mockTagReader.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once() + mockTagReader.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once() + mockTagReader.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once() for i, s := range getTagNamesScenarios { scene := s.input - json, err := GetTagNames(mockTagReader, &scene) + json, err := GetTagNames(testCtx, mockTagReader, &scene) switch { case !s.err && err != nil: @@ -436,22 +436,22 @@ func TestGetSceneMoviesJSON(t *testing.T) { joinErr := errors.New("error getting scene movies") movieErr := errors.New("error getting movie") - mockSceneReader.On("GetMovies", sceneID).Return(validMovies, nil).Once() - mockSceneReader.On("GetMovies", noMoviesID).Return(nil, nil).Once() - mockSceneReader.On("GetMovies", errMoviesID).Return(nil, joinErr).Once() - mockSceneReader.On("GetMovies", errFindMovieID).Return(invalidMovies, nil).Once() + mockSceneReader.On("GetMovies", testCtx, sceneID).Return(validMovies, nil).Once() + mockSceneReader.On("GetMovies", testCtx, noMoviesID).Return(nil, nil).Once() + mockSceneReader.On("GetMovies", testCtx, errMoviesID).Return(nil, joinErr).Once() + mockSceneReader.On("GetMovies", testCtx, errFindMovieID).Return(invalidMovies, nil).Once() - mockMovieReader.On("Find", validMovie1).Return(&models.Movie{ + mockMovieReader.On("Find", testCtx, validMovie1).Return(&models.Movie{ Name: models.NullString(movie1Name), }, nil).Once() - mockMovieReader.On("Find", validMovie2).Return(&models.Movie{ + mockMovieReader.On("Find", testCtx, validMovie2).Return(&models.Movie{ Name: models.NullString(movie2Name), }, nil).Once() - mockMovieReader.On("Find", invalidMovie).Return(nil, movieErr).Once() + mockMovieReader.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once() for i, s := range getSceneMoviesJSONScenarios { scene := s.input - json, err := GetSceneMoviesJSON(mockMovieReader, mockSceneReader, &scene) + json, err := GetSceneMoviesJSON(testCtx, mockMovieReader, mockSceneReader, &scene) switch { case !s.err && err != nil: @@ -603,21 +603,21 @@ func TestGetSceneMarkersJSON(t *testing.T) { markersErr := errors.New("error getting scene markers") tagErr := errors.New("error getting tags") - mockMarkerReader.On("FindBySceneID", sceneID).Return(validMarkers, nil).Once() - mockMarkerReader.On("FindBySceneID", noMarkersID).Return(nil, nil).Once() - mockMarkerReader.On("FindBySceneID", errMarkersID).Return(nil, markersErr).Once() - mockMarkerReader.On("FindBySceneID", errFindPrimaryTagID).Return(invalidMarkers1, nil).Once() - mockMarkerReader.On("FindBySceneID", errFindByMarkerID).Return(invalidMarkers2, nil).Once() + mockMarkerReader.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once() + mockMarkerReader.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once() + mockMarkerReader.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once() + mockMarkerReader.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once() + mockMarkerReader.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once() - mockTagReader.On("Find", validTagID1).Return(&models.Tag{ + mockTagReader.On("Find", testCtx, validTagID1).Return(&models.Tag{ Name: validTagName1, }, nil) - mockTagReader.On("Find", validTagID2).Return(&models.Tag{ + mockTagReader.On("Find", testCtx, validTagID2).Return(&models.Tag{ Name: validTagName2, }, nil) - mockTagReader.On("Find", invalidTagID).Return(nil, tagErr) + mockTagReader.On("Find", testCtx, invalidTagID).Return(nil, tagErr) - mockTagReader.On("FindBySceneMarkerID", validMarkerID1).Return([]*models.Tag{ + mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{ { Name: validTagName1, }, @@ -625,16 +625,16 @@ func TestGetSceneMarkersJSON(t *testing.T) { Name: validTagName2, }, }, nil) - mockTagReader.On("FindBySceneMarkerID", validMarkerID2).Return([]*models.Tag{ + mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{ { Name: validTagName2, }, }, nil) - mockTagReader.On("FindBySceneMarkerID", invalidMarkerID2).Return(nil, tagErr).Once() + mockTagReader.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once() for i, s := range getSceneMarkersJSONScenarios { scene := s.input - json, err := GetSceneMarkersJSON(mockMarkerReader, mockTagReader, &scene) + json, err := GetSceneMarkersJSON(testCtx, mockMarkerReader, mockTagReader, &scene) switch { case !s.err && err != nil: diff --git a/pkg/scene/import.go b/pkg/scene/import.go index 103be88fd..d7b59cf8b 100644 --- a/pkg/scene/import.go +++ b/pkg/scene/import.go @@ -1,24 +1,37 @@ package scene import ( + "context" "database/sql" "fmt" "strconv" "strings" + "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/jsonschema" + "github.com/stashapp/stash/pkg/movie" + "github.com/stashapp/stash/pkg/performer" "github.com/stashapp/stash/pkg/sliceutil/stringslice" + "github.com/stashapp/stash/pkg/studio" + "github.com/stashapp/stash/pkg/tag" "github.com/stashapp/stash/pkg/utils" ) +type FullCreatorUpdater interface { + CreatorUpdater + Updater + UpdateGalleries(ctx context.Context, sceneID int, galleryIDs []int) error + UpdateMovies(ctx context.Context, sceneID int, movies []models.MoviesScenes) error +} + type Importer struct { - ReaderWriter models.SceneReaderWriter - StudioWriter models.StudioReaderWriter - GalleryWriter models.GalleryReaderWriter - PerformerWriter models.PerformerReaderWriter - MovieWriter models.MovieReaderWriter - TagWriter models.TagReaderWriter + ReaderWriter FullCreatorUpdater + StudioWriter studio.NameFinderCreator + GalleryWriter gallery.ChecksumsFinder + PerformerWriter performer.NameFinderCreator + MovieWriter movie.NameFinderCreator + TagWriter tag.NameFinderCreator Input jsonschema.Scene Path string MissingRefBehaviour models.ImportMissingRefEnum @@ -33,26 +46,26 @@ type Importer struct { coverImageData []byte } -func (i *Importer) PreImport() error { +func (i *Importer) PreImport(ctx context.Context) error { i.scene = i.sceneJSONToScene(i.Input) - if err := i.populateStudio(); err != nil { + if err := i.populateStudio(ctx); err != nil { return err } - if err := i.populateGalleries(); err != nil { + if err := i.populateGalleries(ctx); err != nil { return err } - if err := i.populatePerformers(); err != nil { + if err := i.populatePerformers(ctx); err != nil { return err } - if err := i.populateTags(); err != nil { + if err := i.populateTags(ctx); err != nil { return err } - if err := i.populateMovies(); err != nil { + if err := i.populateMovies(ctx); err != nil { return err } @@ -135,9 +148,9 @@ func (i *Importer) sceneJSONToScene(sceneJSON jsonschema.Scene) models.Scene { return newScene } -func (i *Importer) populateStudio() error { +func (i *Importer) populateStudio(ctx context.Context) error { if i.Input.Studio != "" { - studio, err := i.StudioWriter.FindByName(i.Input.Studio, false) + studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false) if err != nil { return fmt.Errorf("error finding studio by name: %v", err) } @@ -152,7 +165,7 @@ func (i *Importer) populateStudio() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - studioID, err := i.createStudio(i.Input.Studio) + studioID, err := i.createStudio(ctx, i.Input.Studio) if err != nil { return err } @@ -169,10 +182,10 @@ func (i *Importer) populateStudio() error { return nil } -func (i *Importer) createStudio(name string) (int, error) { +func (i *Importer) createStudio(ctx context.Context, name string) (int, error) { newStudio := *models.NewStudio(name) - created, err := i.StudioWriter.Create(newStudio) + created, err := i.StudioWriter.Create(ctx, newStudio) if err != nil { return 0, err } @@ -180,10 +193,10 @@ func (i *Importer) createStudio(name string) (int, error) { return created.ID, nil } -func (i *Importer) populateGalleries() error { +func (i *Importer) populateGalleries(ctx context.Context) error { if len(i.Input.Galleries) > 0 { checksums := i.Input.Galleries - galleries, err := i.GalleryWriter.FindByChecksums(checksums) + galleries, err := i.GalleryWriter.FindByChecksums(ctx, checksums) if err != nil { return err } @@ -211,10 +224,10 @@ func (i *Importer) populateGalleries() error { return nil } -func (i *Importer) populatePerformers() error { +func (i *Importer) populatePerformers(ctx context.Context) error { if len(i.Input.Performers) > 0 { names := i.Input.Performers - performers, err := i.PerformerWriter.FindByNames(names, false) + performers, err := i.PerformerWriter.FindByNames(ctx, names, false) if err != nil { return err } @@ -237,7 +250,7 @@ func (i *Importer) populatePerformers() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - createdPerformers, err := i.createPerformers(missingPerformers) + createdPerformers, err := i.createPerformers(ctx, missingPerformers) if err != nil { return fmt.Errorf("error creating scene performers: %v", err) } @@ -254,12 +267,12 @@ func (i *Importer) populatePerformers() error { return nil } -func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) { +func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*models.Performer, error) { var ret []*models.Performer for _, name := range names { newPerformer := *models.NewPerformer(name) - created, err := i.PerformerWriter.Create(newPerformer) + created, err := i.PerformerWriter.Create(ctx, newPerformer) if err != nil { return nil, err } @@ -270,10 +283,10 @@ func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) return ret, nil } -func (i *Importer) populateMovies() error { +func (i *Importer) populateMovies(ctx context.Context) error { if len(i.Input.Movies) > 0 { for _, inputMovie := range i.Input.Movies { - movie, err := i.MovieWriter.FindByName(inputMovie.MovieName, false) + movie, err := i.MovieWriter.FindByName(ctx, inputMovie.MovieName, false) if err != nil { return fmt.Errorf("error finding scene movie: %v", err) } @@ -284,7 +297,7 @@ func (i *Importer) populateMovies() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - movie, err = i.createMovie(inputMovie.MovieName) + movie, err = i.createMovie(ctx, inputMovie.MovieName) if err != nil { return fmt.Errorf("error creating scene movie: %v", err) } @@ -314,10 +327,10 @@ func (i *Importer) populateMovies() error { return nil } -func (i *Importer) createMovie(name string) (*models.Movie, error) { +func (i *Importer) createMovie(ctx context.Context, name string) (*models.Movie, error) { newMovie := *models.NewMovie(name) - created, err := i.MovieWriter.Create(newMovie) + created, err := i.MovieWriter.Create(ctx, newMovie) if err != nil { return nil, err } @@ -325,10 +338,10 @@ func (i *Importer) createMovie(name string) (*models.Movie, error) { return created, nil } -func (i *Importer) populateTags() error { +func (i *Importer) populateTags(ctx context.Context) error { if len(i.Input.Tags) > 0 { - tags, err := importTags(i.TagWriter, i.Input.Tags, i.MissingRefBehaviour) + tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour) if err != nil { return err } @@ -339,9 +352,9 @@ func (i *Importer) populateTags() error { return nil } -func (i *Importer) PostImport(id int) error { +func (i *Importer) PostImport(ctx context.Context, id int) error { if len(i.coverImageData) > 0 { - if err := i.ReaderWriter.UpdateCover(id, i.coverImageData); err != nil { + if err := i.ReaderWriter.UpdateCover(ctx, id, i.coverImageData); err != nil { return fmt.Errorf("error setting scene images: %v", err) } } @@ -352,7 +365,7 @@ func (i *Importer) PostImport(id int) error { galleryIDs = append(galleryIDs, gallery.ID) } - if err := i.ReaderWriter.UpdateGalleries(id, galleryIDs); err != nil { + if err := i.ReaderWriter.UpdateGalleries(ctx, id, galleryIDs); err != nil { return fmt.Errorf("failed to associate galleries: %v", err) } } @@ -363,7 +376,7 @@ func (i *Importer) PostImport(id int) error { performerIDs = append(performerIDs, performer.ID) } - if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil { + if err := i.ReaderWriter.UpdatePerformers(ctx, id, performerIDs); err != nil { return fmt.Errorf("failed to associate performers: %v", err) } } @@ -372,7 +385,7 @@ func (i *Importer) PostImport(id int) error { for index := range i.movies { i.movies[index].SceneID = id } - if err := i.ReaderWriter.UpdateMovies(id, i.movies); err != nil { + if err := i.ReaderWriter.UpdateMovies(ctx, id, i.movies); err != nil { return fmt.Errorf("failed to associate movies: %v", err) } } @@ -382,13 +395,13 @@ func (i *Importer) PostImport(id int) error { for _, t := range i.tags { tagIDs = append(tagIDs, t.ID) } - if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { + if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %v", err) } } if len(i.Input.StashIDs) > 0 { - if err := i.ReaderWriter.UpdateStashIDs(id, i.Input.StashIDs); err != nil { + if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil { return fmt.Errorf("error setting stash id: %v", err) } } @@ -400,15 +413,15 @@ func (i *Importer) Name() string { return i.Path } -func (i *Importer) FindExistingID() (*int, error) { +func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { var existing *models.Scene var err error switch i.FileNamingAlgorithm { case models.HashAlgorithmMd5: - existing, err = i.ReaderWriter.FindByChecksum(i.Input.Checksum) + existing, err = i.ReaderWriter.FindByChecksum(ctx, i.Input.Checksum) case models.HashAlgorithmOshash: - existing, err = i.ReaderWriter.FindByOSHash(i.Input.OSHash) + existing, err = i.ReaderWriter.FindByOSHash(ctx, i.Input.OSHash) default: panic("unknown file naming algorithm") } @@ -425,8 +438,8 @@ func (i *Importer) FindExistingID() (*int, error) { return nil, nil } -func (i *Importer) Create() (*int, error) { - created, err := i.ReaderWriter.Create(i.scene) +func (i *Importer) Create(ctx context.Context) (*int, error) { + created, err := i.ReaderWriter.Create(ctx, i.scene) if err != nil { return nil, fmt.Errorf("error creating scene: %v", err) } @@ -436,11 +449,11 @@ func (i *Importer) Create() (*int, error) { return &id, nil } -func (i *Importer) Update(id int) error { +func (i *Importer) Update(ctx context.Context, id int) error { scene := i.scene scene.ID = id i.ID = id - _, err := i.ReaderWriter.UpdateFull(scene) + _, err := i.ReaderWriter.UpdateFull(ctx, scene) if err != nil { return fmt.Errorf("error updating existing scene: %v", err) } @@ -448,8 +461,8 @@ func (i *Importer) Update(id int) error { return nil } -func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) { - tags, err := tagWriter.FindByNames(names, false) +func importTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) { + tags, err := tagWriter.FindByNames(ctx, names, false) if err != nil { return nil, err } @@ -469,7 +482,7 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha } if missingRefBehaviour == models.ImportMissingRefEnumCreate { - createdTags, err := createTags(tagWriter, missingTags) + createdTags, err := createTags(ctx, tagWriter, missingTags) if err != nil { return nil, fmt.Errorf("error creating tags: %v", err) } @@ -483,12 +496,12 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha return tags, nil } -func createTags(tagWriter models.TagWriter, names []string) ([]*models.Tag, error) { +func createTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string) ([]*models.Tag, error) { var ret []*models.Tag for _, name := range names { newTag := *models.NewTag(name) - created, err := tagWriter.Create(newTag) + created, err := tagWriter.Create(ctx, newTag) if err != nil { return nil, err } diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index 499f27299..75dab2200 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -1,6 +1,7 @@ package scene import ( + "context" "errors" "testing" @@ -55,6 +56,8 @@ const ( errOSHash = "errOSHash" ) +var testCtx = context.Background() + func TestImporterName(t *testing.T) { i := Importer{ Path: path, @@ -72,17 +75,18 @@ func TestImporterPreImport(t *testing.T) { }, } - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.Input.Cover = imageBase64 - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) } func TestImporterPreImportWithStudio(t *testing.T) { studioReaderWriter := &mocks.StudioReaderWriter{} + testCtx := context.Background() i := Importer{ StudioWriter: studioReaderWriter, @@ -92,17 +96,17 @@ func TestImporterPreImportWithStudio(t *testing.T) { }, } - studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{ + studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.scene.StudioID.Int64) i.Input.Studio = existingStudioErr - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) studioReaderWriter.AssertExpectations(t) @@ -120,20 +124,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3) - studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{ + studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) + studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{ ID: existingStudioID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.scene.StudioID.Int64) @@ -152,10 +156,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once() - studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) + studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() + studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -173,21 +177,21 @@ func TestImporterPreImportWithGallery(t *testing.T) { }, } - galleryReaderWriter.On("FindByChecksums", []string{existingGalleryChecksum}).Return([]*models.Gallery{ + galleryReaderWriter.On("FindByChecksums", testCtx, []string{existingGalleryChecksum}).Return([]*models.Gallery{ { ID: existingGalleryID, Checksum: existingGalleryChecksum, }, }, nil).Once() - galleryReaderWriter.On("FindByChecksums", []string{existingGalleryErr}).Return(nil, errors.New("FindByChecksums error")).Once() + galleryReaderWriter.On("FindByChecksums", testCtx, []string{existingGalleryErr}).Return(nil, errors.New("FindByChecksums error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingGalleryID, i.galleries[0].ID) i.Input.Galleries = []string{existingGalleryErr} - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) galleryReaderWriter.AssertExpectations(t) @@ -207,17 +211,17 @@ func TestImporterPreImportWithMissingGallery(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - galleryReaderWriter.On("FindByChecksums", []string{missingGalleryChecksum}).Return(nil, nil).Times(3) + galleryReaderWriter.On("FindByChecksums", testCtx, []string{missingGalleryChecksum}).Return(nil, nil).Times(3) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) galleryReaderWriter.AssertExpectations(t) @@ -237,20 +241,20 @@ func TestImporterPreImportWithPerformer(t *testing.T) { }, } - performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ + performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, Name: models.NullString(existingPerformerName), }, }, nil).Once() - performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() + performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingPerformerID, i.performers[0].ID) i.Input.Performers = []string{existingPerformerErr} - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) performerReaderWriter.AssertExpectations(t) @@ -270,20 +274,20 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(&models.Performer{ + performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{ ID: existingPerformerID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingPerformerID, i.performers[0].ID) @@ -304,15 +308,16 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) + performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } func TestImporterPreImportWithMovie(t *testing.T) { movieReaderWriter := &mocks.MovieReaderWriter{} + testCtx := context.Background() i := Importer{ MovieWriter: movieReaderWriter, @@ -328,18 +333,18 @@ func TestImporterPreImportWithMovie(t *testing.T) { }, } - movieReaderWriter.On("FindByName", existingMovieName, false).Return(&models.Movie{ + movieReaderWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{ ID: existingMovieID, Name: models.NullString(existingMovieName), }, nil).Once() - movieReaderWriter.On("FindByName", existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once() + movieReaderWriter.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingMovieID, i.movies[0].MovieID) i.Input.Movies[0].MovieName = existingMovieErr - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) movieReaderWriter.AssertExpectations(t) @@ -347,6 +352,7 @@ func TestImporterPreImportWithMovie(t *testing.T) { func TestImporterPreImportWithMissingMovie(t *testing.T) { movieReaderWriter := &mocks.MovieReaderWriter{} + testCtx := context.Background() i := Importer{ Path: path, @@ -361,20 +367,20 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - movieReaderWriter.On("FindByName", missingMovieName, false).Return(nil, nil).Times(3) - movieReaderWriter.On("Create", mock.AnythingOfType("models.Movie")).Return(&models.Movie{ + movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3) + movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Movie")).Return(&models.Movie{ ID: existingMovieID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingMovieID, i.movies[0].MovieID) @@ -397,10 +403,10 @@ func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - movieReaderWriter.On("FindByName", missingMovieName, false).Return(nil, nil).Once() - movieReaderWriter.On("Create", mock.AnythingOfType("models.Movie")).Return(nil, errors.New("Create error")) + movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once() + movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Movie")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -418,20 +424,20 @@ func TestImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{ + tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Once() - tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() + tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() - err := i.PreImport() + err := i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) i.Input.Tags = []string{existingTagErr} - err = i.PreImport() + err = i.PreImport(testCtx) assert.NotNil(t, err) tagReaderWriter.AssertExpectations(t) @@ -451,20 +457,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3) - tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{ + tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) + tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{ ID: existingTagID, }, nil) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) @@ -485,10 +491,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once() - tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) + tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() + tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) } @@ -502,13 +508,13 @@ func TestImporterPostImport(t *testing.T) { updateSceneImageErr := errors.New("UpdateCover error") - readerWriter.On("UpdateCover", sceneID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateCover", errImageID, imageBytes).Return(updateSceneImageErr).Once() + readerWriter.On("UpdateCover", testCtx, sceneID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateCover", testCtx, errImageID, imageBytes).Return(updateSceneImageErr).Once() - err := i.PostImport(sceneID) + err := i.PostImport(testCtx, sceneID) assert.Nil(t, err) - err = i.PostImport(errImageID) + err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -528,13 +534,13 @@ func TestImporterPostImportUpdateGalleries(t *testing.T) { updateErr := errors.New("UpdateGalleries error") - sceneReaderWriter.On("UpdateGalleries", sceneID, []int{existingGalleryID}).Return(nil).Once() - sceneReaderWriter.On("UpdateGalleries", errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + sceneReaderWriter.On("UpdateGalleries", testCtx, sceneID, []int{existingGalleryID}).Return(nil).Once() + sceneReaderWriter.On("UpdateGalleries", testCtx, errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(sceneID) + err := i.PostImport(testCtx, sceneID) assert.Nil(t, err) - err = i.PostImport(errGalleriesID) + err = i.PostImport(testCtx, errGalleriesID) assert.NotNil(t, err) sceneReaderWriter.AssertExpectations(t) @@ -554,13 +560,13 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { updateErr := errors.New("UpdatePerformers error") - sceneReaderWriter.On("UpdatePerformers", sceneID, []int{existingPerformerID}).Return(nil).Once() - sceneReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + sceneReaderWriter.On("UpdatePerformers", testCtx, sceneID, []int{existingPerformerID}).Return(nil).Once() + sceneReaderWriter.On("UpdatePerformers", testCtx, errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(sceneID) + err := i.PostImport(testCtx, sceneID) assert.Nil(t, err) - err = i.PostImport(errPerformersID) + err = i.PostImport(testCtx, errPerformersID) assert.NotNil(t, err) sceneReaderWriter.AssertExpectations(t) @@ -580,18 +586,18 @@ func TestImporterPostImportUpdateMovies(t *testing.T) { updateErr := errors.New("UpdateMovies error") - sceneReaderWriter.On("UpdateMovies", sceneID, []models.MoviesScenes{ + sceneReaderWriter.On("UpdateMovies", testCtx, sceneID, []models.MoviesScenes{ { MovieID: existingMovieID, SceneID: sceneID, }, }).Return(nil).Once() - sceneReaderWriter.On("UpdateMovies", errMoviesID, mock.AnythingOfType("[]models.MoviesScenes")).Return(updateErr).Once() + sceneReaderWriter.On("UpdateMovies", testCtx, errMoviesID, mock.AnythingOfType("[]models.MoviesScenes")).Return(updateErr).Once() - err := i.PostImport(sceneID) + err := i.PostImport(testCtx, sceneID) assert.Nil(t, err) - err = i.PostImport(errMoviesID) + err = i.PostImport(testCtx, errMoviesID) assert.NotNil(t, err) sceneReaderWriter.AssertExpectations(t) @@ -611,13 +617,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) { updateErr := errors.New("UpdateTags error") - sceneReaderWriter.On("UpdateTags", sceneID, []int{existingTagID}).Return(nil).Once() - sceneReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + sceneReaderWriter.On("UpdateTags", testCtx, sceneID, []int{existingTagID}).Return(nil).Once() + sceneReaderWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(sceneID) + err := i.PostImport(testCtx, sceneID) assert.Nil(t, err) - err = i.PostImport(errTagsID) + err = i.PostImport(testCtx, errTagsID) assert.NotNil(t, err) sceneReaderWriter.AssertExpectations(t) @@ -637,44 +643,44 @@ func TestImporterFindExistingID(t *testing.T) { } expectedErr := errors.New("FindBy* error") - readerWriter.On("FindByChecksum", missingChecksum).Return(nil, nil).Once() - readerWriter.On("FindByChecksum", checksum).Return(&models.Scene{ + readerWriter.On("FindByChecksum", testCtx, missingChecksum).Return(nil, nil).Once() + readerWriter.On("FindByChecksum", testCtx, checksum).Return(&models.Scene{ ID: existingSceneID, }, nil).Once() - readerWriter.On("FindByChecksum", errChecksum).Return(nil, expectedErr).Once() + readerWriter.On("FindByChecksum", testCtx, errChecksum).Return(nil, expectedErr).Once() - readerWriter.On("FindByOSHash", missingOSHash).Return(nil, nil).Once() - readerWriter.On("FindByOSHash", oshash).Return(&models.Scene{ + readerWriter.On("FindByOSHash", testCtx, missingOSHash).Return(nil, nil).Once() + readerWriter.On("FindByOSHash", testCtx, oshash).Return(&models.Scene{ ID: existingSceneID, }, nil).Once() - readerWriter.On("FindByOSHash", errOSHash).Return(nil, expectedErr).Once() + readerWriter.On("FindByOSHash", testCtx, errOSHash).Return(nil, expectedErr).Once() - id, err := i.FindExistingID() + id, err := i.FindExistingID(testCtx) assert.Nil(t, id) assert.Nil(t, err) i.Input.Checksum = checksum - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Equal(t, existingSceneID, *id) assert.Nil(t, err) i.Input.Checksum = errChecksum - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.NotNil(t, err) i.FileNamingAlgorithm = models.HashAlgorithmOshash - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.Nil(t, err) i.Input.OSHash = oshash - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Equal(t, existingSceneID, *id) assert.Nil(t, err) i.Input.OSHash = errOSHash - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -698,18 +704,18 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", scene).Return(&models.Scene{ + readerWriter.On("Create", testCtx, scene).Return(&models.Scene{ ID: sceneID, }, nil).Once() - readerWriter.On("Create", sceneErr).Return(nil, errCreate).Once() + readerWriter.On("Create", testCtx, sceneErr).Return(nil, errCreate).Once() - id, err := i.Create() + id, err := i.Create(testCtx) assert.Equal(t, sceneID, *id) assert.Nil(t, err) assert.Equal(t, sceneID, i.ID) i.scene = sceneErr - id, err = i.Create() + id, err = i.Create(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -736,9 +742,9 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input scene.ID = sceneID - readerWriter.On("UpdateFull", scene).Return(nil, nil).Once() + readerWriter.On("UpdateFull", testCtx, scene).Return(nil, nil).Once() - err := i.Update(sceneID) + err := i.Update(testCtx, sceneID) assert.Nil(t, err) assert.Equal(t, sceneID, i.ID) @@ -746,9 +752,9 @@ func TestUpdate(t *testing.T) { // need to set id separately sceneErr.ID = errImageID - readerWriter.On("UpdateFull", sceneErr).Return(nil, errUpdate).Once() + readerWriter.On("UpdateFull", testCtx, sceneErr).Return(nil, errUpdate).Once() - err = i.Update(errImageID) + err = i.Update(testCtx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) diff --git a/pkg/scene/marker_import.go b/pkg/scene/marker_import.go index 530d025ea..32f6deb65 100644 --- a/pkg/scene/marker_import.go +++ b/pkg/scene/marker_import.go @@ -1,18 +1,27 @@ package scene import ( + "context" "database/sql" "fmt" "strconv" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/jsonschema" + "github.com/stashapp/stash/pkg/tag" ) +type MarkerCreatorUpdater interface { + Create(ctx context.Context, newSceneMarker models.SceneMarker) (*models.SceneMarker, error) + Update(ctx context.Context, updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error) + FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error) + UpdateTags(ctx context.Context, markerID int, tagIDs []int) error +} + type MarkerImporter struct { SceneID int - ReaderWriter models.SceneMarkerReaderWriter - TagWriter models.TagReaderWriter + ReaderWriter MarkerCreatorUpdater + TagWriter tag.NameFinderCreator Input jsonschema.SceneMarker MissingRefBehaviour models.ImportMissingRefEnum @@ -20,7 +29,7 @@ type MarkerImporter struct { marker models.SceneMarker } -func (i *MarkerImporter) PreImport() error { +func (i *MarkerImporter) PreImport(ctx context.Context) error { seconds, _ := strconv.ParseFloat(i.Input.Seconds, 64) i.marker = models.SceneMarker{ Title: i.Input.Title, @@ -30,21 +39,21 @@ func (i *MarkerImporter) PreImport() error { UpdatedAt: models.SQLiteTimestamp{Timestamp: i.Input.UpdatedAt.GetTime()}, } - if err := i.populateTags(); err != nil { + if err := i.populateTags(ctx); err != nil { return err } return nil } -func (i *MarkerImporter) populateTags() error { +func (i *MarkerImporter) populateTags(ctx context.Context) error { // primary tag cannot be ignored mrb := i.MissingRefBehaviour if mrb == models.ImportMissingRefEnumIgnore { mrb = models.ImportMissingRefEnumFail } - primaryTag, err := importTags(i.TagWriter, []string{i.Input.PrimaryTag}, mrb) + primaryTag, err := importTags(ctx, i.TagWriter, []string{i.Input.PrimaryTag}, mrb) if err != nil { return err } @@ -52,7 +61,7 @@ func (i *MarkerImporter) populateTags() error { i.marker.PrimaryTagID = primaryTag[0].ID if len(i.Input.Tags) > 0 { - tags, err := importTags(i.TagWriter, i.Input.Tags, i.MissingRefBehaviour) + tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour) if err != nil { return err } @@ -63,13 +72,13 @@ func (i *MarkerImporter) populateTags() error { return nil } -func (i *MarkerImporter) PostImport(id int) error { +func (i *MarkerImporter) PostImport(ctx context.Context, id int) error { if len(i.tags) > 0 { var tagIDs []int for _, t := range i.tags { tagIDs = append(tagIDs, t.ID) } - if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { + if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %v", err) } } @@ -81,8 +90,8 @@ func (i *MarkerImporter) Name() string { return fmt.Sprintf("%s (%s)", i.Input.Title, i.Input.Seconds) } -func (i *MarkerImporter) FindExistingID() (*int, error) { - existingMarkers, err := i.ReaderWriter.FindBySceneID(i.SceneID) +func (i *MarkerImporter) FindExistingID(ctx context.Context) (*int, error) { + existingMarkers, err := i.ReaderWriter.FindBySceneID(ctx, i.SceneID) if err != nil { return nil, err @@ -98,8 +107,8 @@ func (i *MarkerImporter) FindExistingID() (*int, error) { return nil, nil } -func (i *MarkerImporter) Create() (*int, error) { - created, err := i.ReaderWriter.Create(i.marker) +func (i *MarkerImporter) Create(ctx context.Context) (*int, error) { + created, err := i.ReaderWriter.Create(ctx, i.marker) if err != nil { return nil, fmt.Errorf("error creating marker: %v", err) } @@ -108,10 +117,10 @@ func (i *MarkerImporter) Create() (*int, error) { return &id, nil } -func (i *MarkerImporter) Update(id int) error { +func (i *MarkerImporter) Update(ctx context.Context, id int) error { marker := i.marker marker.ID = id - _, err := i.ReaderWriter.Update(marker) + _, err := i.ReaderWriter.Update(ctx, marker) if err != nil { return fmt.Errorf("error updating existing marker: %v", err) } diff --git a/pkg/scene/marker_import_test.go b/pkg/scene/marker_import_test.go index 0aa72a08b..f34d6b266 100644 --- a/pkg/scene/marker_import_test.go +++ b/pkg/scene/marker_import_test.go @@ -1,6 +1,7 @@ package scene import ( + "context" "errors" "testing" @@ -30,6 +31,7 @@ func TestMarkerImporterName(t *testing.T) { func TestMarkerImporterPreImportWithTag(t *testing.T) { tagReaderWriter := &mocks.TagReaderWriter{} + ctx := context.Background() i := MarkerImporter{ TagWriter: tagReaderWriter, @@ -39,32 +41,32 @@ func TestMarkerImporterPreImportWithTag(t *testing.T) { }, } - tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{ + tagReaderWriter.On("FindByNames", ctx, []string{existingTagName}, false).Return([]*models.Tag{ { ID: existingTagID, Name: existingTagName, }, }, nil).Times(4) - tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Times(2) + tagReaderWriter.On("FindByNames", ctx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Times(2) - err := i.PreImport() + err := i.PreImport(ctx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.marker.PrimaryTagID) i.Input.PrimaryTag = existingTagErr - err = i.PreImport() + err = i.PreImport(ctx) assert.NotNil(t, err) i.Input.PrimaryTag = existingTagName i.Input.Tags = []string{ existingTagName, } - err = i.PreImport() + err = i.PreImport(ctx) assert.Nil(t, err) assert.Equal(t, existingTagID, i.tags[0].ID) i.Input.Tags[0] = existingTagErr - err = i.PreImport() + err = i.PreImport(ctx) assert.NotNil(t, err) tagReaderWriter.AssertExpectations(t) @@ -72,6 +74,7 @@ func TestMarkerImporterPreImportWithTag(t *testing.T) { func TestMarkerImporterPostImportUpdateTags(t *testing.T) { sceneMarkerReaderWriter := &mocks.SceneMarkerReaderWriter{} + ctx := context.Background() i := MarkerImporter{ ReaderWriter: sceneMarkerReaderWriter, @@ -84,13 +87,13 @@ func TestMarkerImporterPostImportUpdateTags(t *testing.T) { updateErr := errors.New("UpdateTags error") - sceneMarkerReaderWriter.On("UpdateTags", sceneID, []int{existingTagID}).Return(nil).Once() - sceneMarkerReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() + sceneMarkerReaderWriter.On("UpdateTags", ctx, sceneID, []int{existingTagID}).Return(nil).Once() + sceneMarkerReaderWriter.On("UpdateTags", ctx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() - err := i.PostImport(sceneID) + err := i.PostImport(ctx, sceneID) assert.Nil(t, err) - err = i.PostImport(errTagsID) + err = i.PostImport(ctx, errTagsID) assert.NotNil(t, err) sceneMarkerReaderWriter.AssertExpectations(t) @@ -98,6 +101,7 @@ func TestMarkerImporterPostImportUpdateTags(t *testing.T) { func TestMarkerImporterFindExistingID(t *testing.T) { readerWriter := &mocks.SceneMarkerReaderWriter{} + ctx := context.Background() i := MarkerImporter{ ReaderWriter: readerWriter, @@ -108,25 +112,25 @@ func TestMarkerImporterFindExistingID(t *testing.T) { } expectedErr := errors.New("FindBy* error") - readerWriter.On("FindBySceneID", sceneID).Return([]*models.SceneMarker{ + readerWriter.On("FindBySceneID", ctx, sceneID).Return([]*models.SceneMarker{ { ID: existingSceneID, Seconds: secondsFloat, }, }, nil).Times(2) - readerWriter.On("FindBySceneID", errSceneID).Return(nil, expectedErr).Once() + readerWriter.On("FindBySceneID", ctx, errSceneID).Return(nil, expectedErr).Once() - id, err := i.FindExistingID() + id, err := i.FindExistingID(ctx) assert.Equal(t, existingSceneID, *id) assert.Nil(t, err) i.marker.Seconds++ - id, err = i.FindExistingID() + id, err = i.FindExistingID(ctx) assert.Nil(t, id) assert.Nil(t, err) i.SceneID = errSceneID - id, err = i.FindExistingID() + id, err = i.FindExistingID(ctx) assert.Nil(t, id) assert.NotNil(t, err) @@ -135,6 +139,7 @@ func TestMarkerImporterFindExistingID(t *testing.T) { func TestMarkerImporterCreate(t *testing.T) { readerWriter := &mocks.SceneMarkerReaderWriter{} + ctx := context.Background() scene := models.SceneMarker{ Title: title, @@ -150,17 +155,17 @@ func TestMarkerImporterCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", scene).Return(&models.SceneMarker{ + readerWriter.On("Create", ctx, scene).Return(&models.SceneMarker{ ID: sceneID, }, nil).Once() - readerWriter.On("Create", sceneErr).Return(nil, errCreate).Once() + readerWriter.On("Create", ctx, sceneErr).Return(nil, errCreate).Once() - id, err := i.Create() + id, err := i.Create(ctx) assert.Equal(t, sceneID, *id) assert.Nil(t, err) i.marker = sceneErr - id, err = i.Create() + id, err = i.Create(ctx) assert.Nil(t, id) assert.NotNil(t, err) @@ -169,6 +174,7 @@ func TestMarkerImporterCreate(t *testing.T) { func TestMarkerImporterUpdate(t *testing.T) { readerWriter := &mocks.SceneMarkerReaderWriter{} + ctx := context.Background() scene := models.SceneMarker{ Title: title, @@ -187,18 +193,18 @@ func TestMarkerImporterUpdate(t *testing.T) { // id needs to be set for the mock input scene.ID = sceneID - readerWriter.On("Update", scene).Return(nil, nil).Once() + readerWriter.On("Update", ctx, scene).Return(nil, nil).Once() - err := i.Update(sceneID) + err := i.Update(ctx, sceneID) assert.Nil(t, err) i.marker = sceneErr // need to set id separately sceneErr.ID = errImageID - readerWriter.On("Update", sceneErr).Return(nil, errUpdate).Once() + readerWriter.On("Update", ctx, sceneErr).Return(nil, errUpdate).Once() - err = i.Update(errImageID) + err = i.Update(ctx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) diff --git a/pkg/scene/query.go b/pkg/scene/query.go index f560430c3..928270f38 100644 --- a/pkg/scene/query.go +++ b/pkg/scene/query.go @@ -11,7 +11,11 @@ import ( ) type Queryer interface { - Query(options models.SceneQueryOptions) (*models.SceneQueryResult, error) + Query(ctx context.Context, options models.SceneQueryOptions) (*models.SceneQueryResult, error) +} + +type IDFinder interface { + Find(ctx context.Context, id int) (*models.Scene, error) } // QueryOptions returns a SceneQueryOptions populated with the provided filters. @@ -26,15 +30,15 @@ func QueryOptions(sceneFilter *models.SceneFilterType, findFilter *models.FindFi } // QueryWithCount queries for scenes, returning the scene objects and the total count. -func QueryWithCount(qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, int, error) { +func QueryWithCount(ctx context.Context, qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, int, error) { // this was moved from the queryBuilder code // left here so that calling functions can reference this instead - result, err := qb.Query(QueryOptions(sceneFilter, findFilter, true)) + result, err := qb.Query(ctx, QueryOptions(sceneFilter, findFilter, true)) if err != nil { return nil, 0, err } - scenes, err := result.Resolve() + scenes, err := result.Resolve(ctx) if err != nil { return nil, 0, err } @@ -43,13 +47,13 @@ func QueryWithCount(qb Queryer, sceneFilter *models.SceneFilterType, findFilter } // Query queries for scenes using the provided filters. -func Query(qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, error) { - result, err := qb.Query(QueryOptions(sceneFilter, findFilter, false)) +func Query(ctx context.Context, qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, error) { + result, err := qb.Query(ctx, QueryOptions(sceneFilter, findFilter, false)) if err != nil { return nil, err } - scenes, err := result.Resolve() + scenes, err := result.Resolve(ctx) if err != nil { return nil, err } @@ -57,7 +61,7 @@ func Query(qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.F return scenes, nil } -func BatchProcess(ctx context.Context, reader models.SceneReader, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType, fn func(scene *models.Scene) error) error { +func BatchProcess(ctx context.Context, reader Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType, fn func(scene *models.Scene) error) error { const batchSize = 1000 if findFilter == nil { @@ -74,7 +78,7 @@ func BatchProcess(ctx context.Context, reader models.SceneReader, sceneFilter *m return nil } - scenes, err := Query(reader, sceneFilter, findFilter) + scenes, err := Query(ctx, reader, sceneFilter, findFilter) if err != nil { return fmt.Errorf("error querying for scenes: %w", err) } diff --git a/pkg/scene/scan.go b/pkg/scene/scan.go index 1f33fa9ff..e5d1ec739 100644 --- a/pkg/scene/scan.go +++ b/pkg/scene/scan.go @@ -17,11 +17,23 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) const mutexType = "scene" +type CreatorUpdater interface { + FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error) + FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error) + Create(ctx context.Context, newScene models.Scene) (*models.Scene, error) + UpdateFull(ctx context.Context, updatedScene models.Scene) (*models.Scene, error) + Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error) + + GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error) + UpdateCaptions(ctx context.Context, id int, captions []*models.SceneCaption) error +} + type videoFileCreator interface { NewVideoFile(path string) (*ffmpeg.VideoFile, error) } @@ -34,7 +46,8 @@ type Scanner struct { FileNamingAlgorithm models.HashAlgorithm CaseSensitiveFs bool - TxnManager models.TransactionManager + TxnManager txn.Manager + CreatorUpdater CreatorUpdater Paths *paths.Paths Screenshotter screenshotter VideoFileCreator videoFileCreator @@ -105,16 +118,17 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase changed = true } - if err := scanner.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { - var err error - sqb := r.Scene() + qb := scanner.CreatorUpdater - captions, er := sqb.GetCaptions(s.ID) + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { + var err error + + captions, er := qb.GetCaptions(ctx, s.ID) if er == nil { if len(captions) > 0 { clean, altered := CleanCaptions(s.Path, captions) if altered { - er = sqb.UpdateCaptions(s.ID, clean) + er = qb.UpdateCaptions(ctx, s.ID, clean) if er == nil { logger.Debugf("Captions for %s cleaned: %s -> %s", path, captions, clean) } @@ -136,20 +150,20 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done) } - if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { defer close(done) - qb := r.Scene() + qb := scanner.CreatorUpdater // ensure no clashes of hashes if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum { - dupe, _ := qb.FindByChecksum(s.Checksum.String) + dupe, _ := qb.FindByChecksum(ctx, s.Checksum.String) if dupe != nil { return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path) } } if scanned.New.OSHash != "" && scanned.Old.OSHash != scanned.New.OSHash { - dupe, _ := qb.FindByOSHash(scanned.New.OSHash) + dupe, _ := qb.FindByOSHash(ctx, scanned.New.OSHash) if dupe != nil { return fmt.Errorf("OSHash for file %s is the same as that of %s", path, dupe.Path) } @@ -158,7 +172,7 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase s.Interactive = interactive s.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()} - _, err := qb.UpdateFull(*s) + _, err := qb.UpdateFull(ctx, *s) return err }); err != nil { return err @@ -204,14 +218,14 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retS // check for scene by checksum and oshash - MD5 should be // redundant, but check both var s *models.Scene - if err := scanner.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - qb := r.Scene() + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { + qb := scanner.CreatorUpdater if checksum != "" { - s, _ = qb.FindByChecksum(checksum) + s, _ = qb.FindByChecksum(ctx, checksum) } if s == nil { - s, _ = qb.FindByOSHash(oshash) + s, _ = qb.FindByOSHash(ctx, oshash) } return nil @@ -246,8 +260,8 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retS Path: &path, Interactive: &interactive, } - if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { - _, err := r.Scene().Update(scenePartial) + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { + _, err := scanner.CreatorUpdater.Update(ctx, scenePartial) return err }); err != nil { return nil, err @@ -297,9 +311,9 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retS _ = newScene.Date.Scan(videoFile.CreationTime) } - if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { + if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error { var err error - retScene, err = r.Scene().Create(newScene) + retScene, err = scanner.CreatorUpdater.Create(ctx, newScene) return err }); err != nil { return nil, err diff --git a/pkg/scene/update.go b/pkg/scene/update.go index e1155c368..8486ac5e1 100644 --- a/pkg/scene/update.go +++ b/pkg/scene/update.go @@ -1,6 +1,7 @@ package scene import ( + "context" "database/sql" "errors" "fmt" @@ -11,6 +12,33 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +type Updater interface { + PartialUpdater + UpdatePerformers(ctx context.Context, sceneID int, performerIDs []int) error + UpdateTags(ctx context.Context, sceneID int, tagIDs []int) error + UpdateStashIDs(ctx context.Context, sceneID int, stashIDs []models.StashID) error + UpdateCover(ctx context.Context, sceneID int, cover []byte) error +} + +type PartialUpdater interface { + Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error) +} + +type PerformerUpdater interface { + GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error) + UpdatePerformers(ctx context.Context, sceneID int, performerIDs []int) error +} + +type TagUpdater interface { + GetTagIDs(ctx context.Context, sceneID int) ([]int, error) + UpdateTags(ctx context.Context, sceneID int, tagIDs []int) error +} + +type GalleryUpdater interface { + GetGalleryIDs(ctx context.Context, sceneID int) ([]int, error) + UpdateGalleries(ctx context.Context, sceneID int, galleryIDs []int) error +} + var ErrEmptyUpdater = errors.New("no fields have been set") // UpdateSet is used to update a scene and its relationships. @@ -47,7 +75,7 @@ func (u *UpdateSet) IsEmpty() bool { // Update updates a scene by updating the fields in the Partial field, then // updates non-nil relationships. Returns an error if there is no work to // be done. -func (u *UpdateSet) Update(qb models.SceneWriter, screenshotSetter ScreenshotSetter) (*models.Scene, error) { +func (u *UpdateSet) Update(ctx context.Context, qb Updater, screenshotSetter ScreenshotSetter) (*models.Scene, error) { if u.IsEmpty() { return nil, ErrEmptyUpdater } @@ -58,31 +86,31 @@ func (u *UpdateSet) Update(qb models.SceneWriter, screenshotSetter ScreenshotSet Timestamp: time.Now(), } - ret, err := qb.Update(partial) + ret, err := qb.Update(ctx, partial) if err != nil { return nil, fmt.Errorf("error updating scene: %w", err) } if u.PerformerIDs != nil { - if err := qb.UpdatePerformers(u.ID, u.PerformerIDs); err != nil { + if err := qb.UpdatePerformers(ctx, u.ID, u.PerformerIDs); err != nil { return nil, fmt.Errorf("error updating scene performers: %w", err) } } if u.TagIDs != nil { - if err := qb.UpdateTags(u.ID, u.TagIDs); err != nil { + if err := qb.UpdateTags(ctx, u.ID, u.TagIDs); err != nil { return nil, fmt.Errorf("error updating scene tags: %w", err) } } if u.StashIDs != nil { - if err := qb.UpdateStashIDs(u.ID, u.StashIDs); err != nil { + if err := qb.UpdateStashIDs(ctx, u.ID, u.StashIDs); err != nil { return nil, fmt.Errorf("error updating scene stash_ids: %w", err) } } if u.CoverImage != nil { - if err := qb.UpdateCover(u.ID, u.CoverImage); err != nil { + if err := qb.UpdateCover(ctx, u.ID, u.CoverImage); err != nil { return nil, fmt.Errorf("error updating scene cover: %w", err) } @@ -124,8 +152,8 @@ func (u UpdateSet) UpdateInput() models.SceneUpdateInput { return ret } -func UpdateFormat(qb models.SceneWriter, id int, format string) (*models.Scene, error) { - return qb.Update(models.ScenePartial{ +func UpdateFormat(ctx context.Context, qb PartialUpdater, id int, format string) (*models.Scene, error) { + return qb.Update(ctx, models.ScenePartial{ ID: id, Format: &sql.NullString{ String: format, @@ -134,8 +162,8 @@ func UpdateFormat(qb models.SceneWriter, id int, format string) (*models.Scene, }) } -func UpdateOSHash(qb models.SceneWriter, id int, oshash string) (*models.Scene, error) { - return qb.Update(models.ScenePartial{ +func UpdateOSHash(ctx context.Context, qb PartialUpdater, id int, oshash string) (*models.Scene, error) { + return qb.Update(ctx, models.ScenePartial{ ID: id, OSHash: &sql.NullString{ String: oshash, @@ -144,8 +172,8 @@ func UpdateOSHash(qb models.SceneWriter, id int, oshash string) (*models.Scene, }) } -func UpdateChecksum(qb models.SceneWriter, id int, checksum string) (*models.Scene, error) { - return qb.Update(models.ScenePartial{ +func UpdateChecksum(ctx context.Context, qb PartialUpdater, id int, checksum string) (*models.Scene, error) { + return qb.Update(ctx, models.ScenePartial{ ID: id, Checksum: &sql.NullString{ String: checksum, @@ -154,15 +182,15 @@ func UpdateChecksum(qb models.SceneWriter, id int, checksum string) (*models.Sce }) } -func UpdateFileModTime(qb models.SceneWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Scene, error) { - return qb.Update(models.ScenePartial{ +func UpdateFileModTime(ctx context.Context, qb PartialUpdater, id int, modTime models.NullSQLiteTimestamp) (*models.Scene, error) { + return qb.Update(ctx, models.ScenePartial{ ID: id, FileModTime: &modTime, }) } -func AddPerformer(qb models.SceneReaderWriter, id int, performerID int) (bool, error) { - performerIDs, err := qb.GetPerformerIDs(id) +func AddPerformer(ctx context.Context, qb PerformerUpdater, id int, performerID int) (bool, error) { + performerIDs, err := qb.GetPerformerIDs(ctx, id) if err != nil { return false, err } @@ -171,7 +199,7 @@ func AddPerformer(qb models.SceneReaderWriter, id int, performerID int) (bool, e performerIDs = intslice.IntAppendUnique(performerIDs, performerID) if len(performerIDs) != oldLen { - if err := qb.UpdatePerformers(id, performerIDs); err != nil { + if err := qb.UpdatePerformers(ctx, id, performerIDs); err != nil { return false, err } @@ -181,8 +209,8 @@ func AddPerformer(qb models.SceneReaderWriter, id int, performerID int) (bool, e return false, nil } -func AddTag(qb models.SceneReaderWriter, id int, tagID int) (bool, error) { - tagIDs, err := qb.GetTagIDs(id) +func AddTag(ctx context.Context, qb TagUpdater, id int, tagID int) (bool, error) { + tagIDs, err := qb.GetTagIDs(ctx, id) if err != nil { return false, err } @@ -191,7 +219,7 @@ func AddTag(qb models.SceneReaderWriter, id int, tagID int) (bool, error) { tagIDs = intslice.IntAppendUnique(tagIDs, tagID) if len(tagIDs) != oldLen { - if err := qb.UpdateTags(id, tagIDs); err != nil { + if err := qb.UpdateTags(ctx, id, tagIDs); err != nil { return false, err } @@ -201,8 +229,8 @@ func AddTag(qb models.SceneReaderWriter, id int, tagID int) (bool, error) { return false, nil } -func AddGallery(qb models.SceneReaderWriter, id int, galleryID int) (bool, error) { - galleryIDs, err := qb.GetGalleryIDs(id) +func AddGallery(ctx context.Context, qb GalleryUpdater, id int, galleryID int) (bool, error) { + galleryIDs, err := qb.GetGalleryIDs(ctx, id) if err != nil { return false, err } @@ -211,7 +239,7 @@ func AddGallery(qb models.SceneReaderWriter, id int, galleryID int) (bool, error galleryIDs = intslice.IntAppendUnique(galleryIDs, galleryID) if len(galleryIDs) != oldLen { - if err := qb.UpdateGalleries(id, galleryIDs); err != nil { + if err := qb.UpdateGalleries(ctx, id, galleryIDs); err != nil { return false, err } diff --git a/pkg/scene/update_test.go b/pkg/scene/update_test.go index 4619fd137..c605f98a2 100644 --- a/pkg/scene/update_test.go +++ b/pkg/scene/update_test.go @@ -1,6 +1,7 @@ package scene import ( + "context" "errors" "strconv" "testing" @@ -104,6 +105,8 @@ func TestUpdater_Update(t *testing.T) { tagID ) + ctx := context.Background() + performerIDs := []int{performerID} tagIDs := []int{tagID} stashID := "stashID" @@ -123,22 +126,22 @@ func TestUpdater_Update(t *testing.T) { updateErr := errors.New("error updating") qb := mocks.SceneReaderWriter{} - qb.On("Update", mock.MatchedBy(func(s models.ScenePartial) bool { + qb.On("Update", ctx, mock.MatchedBy(func(s models.ScenePartial) bool { return s.ID != badUpdateID })).Return(validScene, nil) - qb.On("Update", mock.MatchedBy(func(s models.ScenePartial) bool { + qb.On("Update", ctx, mock.MatchedBy(func(s models.ScenePartial) bool { return s.ID == badUpdateID })).Return(nil, updateErr) - qb.On("UpdatePerformers", sceneID, performerIDs).Return(nil).Once() - qb.On("UpdateTags", sceneID, tagIDs).Return(nil).Once() - qb.On("UpdateStashIDs", sceneID, stashIDs).Return(nil).Once() - qb.On("UpdateCover", sceneID, cover).Return(nil).Once() + qb.On("UpdatePerformers", ctx, sceneID, performerIDs).Return(nil).Once() + qb.On("UpdateTags", ctx, sceneID, tagIDs).Return(nil).Once() + qb.On("UpdateStashIDs", ctx, sceneID, stashIDs).Return(nil).Once() + qb.On("UpdateCover", ctx, sceneID, cover).Return(nil).Once() - qb.On("UpdatePerformers", badPerformersID, performerIDs).Return(updateErr).Once() - qb.On("UpdateTags", badTagsID, tagIDs).Return(updateErr).Once() - qb.On("UpdateStashIDs", badStashIDsID, stashIDs).Return(updateErr).Once() - qb.On("UpdateCover", badCoverID, cover).Return(updateErr).Once() + qb.On("UpdatePerformers", ctx, badPerformersID, performerIDs).Return(updateErr).Once() + qb.On("UpdateTags", ctx, badTagsID, tagIDs).Return(updateErr).Once() + qb.On("UpdateStashIDs", ctx, badStashIDsID, stashIDs).Return(updateErr).Once() + qb.On("UpdateCover", ctx, badCoverID, cover).Return(updateErr).Once() tests := []struct { name string @@ -232,7 +235,7 @@ func TestUpdater_Update(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.u.Update(&qb, &mockScreenshotSetter{}) + got, err := tt.u.Update(ctx, &qb, &mockScreenshotSetter{}) if (err != nil) != tt.wantErr { t.Errorf("Updater.Update() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/scraper/action.go b/pkg/scraper/action.go index e674cb2a2..0011441fb 100644 --- a/pkg/scraper/action.go +++ b/pkg/scraper/action.go @@ -33,16 +33,16 @@ type scraperActionImpl interface { scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*ScrapedGallery, error) } -func (c config) getScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, globalConfig GlobalConfig) scraperActionImpl { +func (c config) getScraper(scraper scraperTypeConfig, client *http.Client, globalConfig GlobalConfig) scraperActionImpl { switch scraper.Action { case scraperActionScript: return newScriptScraper(scraper, c, globalConfig) case scraperActionStash: - return newStashScraper(scraper, client, txnManager, c, globalConfig) + return newStashScraper(scraper, client, c, globalConfig) case scraperActionXPath: - return newXpathScraper(scraper, client, txnManager, c, globalConfig) + return newXpathScraper(scraper, client, c, globalConfig) case scraperActionJson: - return newJsonScraper(scraper, client, txnManager, c, globalConfig) + return newJsonScraper(scraper, client, c, globalConfig) } panic("unknown scraper action: " + scraper.Action) diff --git a/pkg/scraper/autotag.go b/pkg/scraper/autotag.go index 52f3ce4af..ba10ace3f 100644 --- a/pkg/scraper/autotag.go +++ b/pkg/scraper/autotag.go @@ -8,6 +8,7 @@ import ( "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/txn" ) // autoTagScraperID is the scraper ID for the built-in AutoTag scraper @@ -17,12 +18,17 @@ const ( ) type autotagScraper struct { - txnManager models.TransactionManager + // repository models.Repository + txnManager txn.Manager + performerReader match.PerformerAutoTagQueryer + studioReader match.StudioAutoTagQueryer + tagReader match.TagAutoTagQueryer + globalConfig GlobalConfig } -func autotagMatchPerformers(path string, performerReader models.PerformerReader, trimExt bool) ([]*models.ScrapedPerformer, error) { - p, err := match.PathToPerformers(path, performerReader, nil, trimExt) +func autotagMatchPerformers(ctx context.Context, path string, performerReader match.PerformerAutoTagQueryer, trimExt bool) ([]*models.ScrapedPerformer, error) { + p, err := match.PathToPerformers(ctx, path, performerReader, nil, trimExt) if err != nil { return nil, fmt.Errorf("error matching performers: %w", err) } @@ -45,8 +51,8 @@ func autotagMatchPerformers(path string, performerReader models.PerformerReader, return ret, nil } -func autotagMatchStudio(path string, studioReader models.StudioReader, trimExt bool) (*models.ScrapedStudio, error) { - studio, err := match.PathToStudio(path, studioReader, nil, trimExt) +func autotagMatchStudio(ctx context.Context, path string, studioReader match.StudioAutoTagQueryer, trimExt bool) (*models.ScrapedStudio, error) { + studio, err := match.PathToStudio(ctx, path, studioReader, nil, trimExt) if err != nil { return nil, fmt.Errorf("error matching studios: %w", err) } @@ -62,8 +68,8 @@ func autotagMatchStudio(path string, studioReader models.StudioReader, trimExt b return nil, nil } -func autotagMatchTags(path string, tagReader models.TagReader, trimExt bool) ([]*models.ScrapedTag, error) { - t, err := match.PathToTags(path, tagReader, nil, trimExt) +func autotagMatchTags(ctx context.Context, path string, tagReader match.TagAutoTagQueryer, trimExt bool) ([]*models.ScrapedTag, error) { + t, err := match.PathToTags(ctx, path, tagReader, nil, trimExt) if err != nil { return nil, fmt.Errorf("error matching tags: %w", err) } @@ -88,18 +94,18 @@ func (s autotagScraper) viaScene(ctx context.Context, _client *http.Client, scen const trimExt = false // populate performers, studio and tags based on scene path - if err := s.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := txn.WithTxn(ctx, s.txnManager, func(ctx context.Context) error { path := scene.Path - performers, err := autotagMatchPerformers(path, r.Performer(), trimExt) + performers, err := autotagMatchPerformers(ctx, path, s.performerReader, trimExt) if err != nil { return fmt.Errorf("autotag scraper viaScene: %w", err) } - studio, err := autotagMatchStudio(path, r.Studio(), trimExt) + studio, err := autotagMatchStudio(ctx, path, s.studioReader, trimExt) if err != nil { return fmt.Errorf("autotag scraper viaScene: %w", err) } - tags, err := autotagMatchTags(path, r.Tag(), trimExt) + tags, err := autotagMatchTags(ctx, path, s.tagReader, trimExt) if err != nil { return fmt.Errorf("autotag scraper viaScene: %w", err) } @@ -132,18 +138,18 @@ func (s autotagScraper) viaGallery(ctx context.Context, _client *http.Client, ga var ret *ScrapedGallery // populate performers, studio and tags based on scene path - if err := s.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { + if err := txn.WithTxn(ctx, s.txnManager, func(ctx context.Context) error { path := gallery.Path.String - performers, err := autotagMatchPerformers(path, r.Performer(), trimExt) + performers, err := autotagMatchPerformers(ctx, path, s.performerReader, trimExt) if err != nil { return fmt.Errorf("autotag scraper viaGallery: %w", err) } - studio, err := autotagMatchStudio(path, r.Studio(), trimExt) + studio, err := autotagMatchStudio(ctx, path, s.studioReader, trimExt) if err != nil { return fmt.Errorf("autotag scraper viaGallery: %w", err) } - tags, err := autotagMatchTags(path, r.Tag(), trimExt) + tags, err := autotagMatchTags(ctx, path, s.tagReader, trimExt) if err != nil { return fmt.Errorf("autotag scraper viaGallery: %w", err) } @@ -196,10 +202,13 @@ func (s autotagScraper) spec() Scraper { } } -func getAutoTagScraper(txnManager models.TransactionManager, globalConfig GlobalConfig) scraper { +func getAutoTagScraper(txnManager txn.Manager, repo Repository, globalConfig GlobalConfig) scraper { base := autotagScraper{ - txnManager: txnManager, - globalConfig: globalConfig, + txnManager: txnManager, + performerReader: repo.PerformerFinder, + studioReader: repo.StudioFinder, + tagReader: repo.TagFinder, + globalConfig: globalConfig, } return base diff --git a/pkg/scraper/cache.go b/pkg/scraper/cache.go index 5357f8b94..decf71cb2 100644 --- a/pkg/scraper/cache.go +++ b/pkg/scraper/cache.go @@ -13,7 +13,11 @@ import ( "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" + "github.com/stashapp/stash/pkg/tag" + "github.com/stashapp/stash/pkg/txn" ) const ( @@ -47,12 +51,42 @@ func isCDPPathWS(c GlobalConfig) bool { return strings.HasPrefix(c.GetScraperCDPPath(), "ws://") } +type PerformerFinder interface { + match.PerformerAutoTagQueryer + match.PerformerFinder +} + +type StudioFinder interface { + match.StudioAutoTagQueryer + match.StudioFinder +} + +type TagFinder interface { + match.TagAutoTagQueryer + tag.Queryer +} + +type GalleryFinder interface { + Find(ctx context.Context, id int) (*models.Gallery, error) +} + +type Repository struct { + SceneFinder scene.IDFinder + GalleryFinder GalleryFinder + TagFinder TagFinder + PerformerFinder PerformerFinder + MovieFinder match.MovieNamesFinder + StudioFinder StudioFinder +} + // Cache stores the database of scrapers type Cache struct { client *http.Client scrapers map[string]scraper // Scraper ID -> Scraper globalConfig GlobalConfig - txnManager models.TransactionManager + txnManager txn.Manager + + repository Repository } // newClient creates a scraper-local http client we use throughout the scraper subsystem. @@ -81,30 +115,33 @@ func newClient(gc GlobalConfig) *http.Client { // // Scraper configurations are loaded from yml files in the provided scrapers // directory and any subdirectories. -func NewCache(globalConfig GlobalConfig, txnManager models.TransactionManager) (*Cache, error) { +func NewCache(globalConfig GlobalConfig, txnManager txn.Manager, repo Repository) (*Cache, error) { // HTTP Client setup client := newClient(globalConfig) - scrapers, err := loadScrapers(globalConfig, txnManager) + ret := &Cache{ + client: client, + globalConfig: globalConfig, + txnManager: txnManager, + repository: repo, + } + + var err error + ret.scrapers, err = ret.loadScrapers() if err != nil { return nil, err } - return &Cache{ - client: client, - globalConfig: globalConfig, - scrapers: scrapers, - txnManager: txnManager, - }, nil + return ret, nil } -func loadScrapers(globalConfig GlobalConfig, txnManager models.TransactionManager) (map[string]scraper, error) { - path := globalConfig.GetScrapersPath() +func (c *Cache) loadScrapers() (map[string]scraper, error) { + path := c.globalConfig.GetScrapersPath() scrapers := make(map[string]scraper) // Add built-in scrapers - freeOnes := getFreeonesScraper(txnManager, globalConfig) - autoTag := getAutoTagScraper(txnManager, globalConfig) + freeOnes := getFreeonesScraper(c.globalConfig) + autoTag := getAutoTagScraper(c.txnManager, c.repository, c.globalConfig) scrapers[freeOnes.spec().ID] = freeOnes scrapers[autoTag.spec().ID] = autoTag @@ -113,11 +150,11 @@ func loadScrapers(globalConfig GlobalConfig, txnManager models.TransactionManage scraperFiles := []string{} err := fsutil.SymWalk(path, func(fp string, f os.FileInfo, err error) error { if filepath.Ext(fp) == ".yml" { - c, err := loadConfigFromYAMLFile(fp) + conf, err := loadConfigFromYAMLFile(fp) if err != nil { logger.Errorf("Error loading scraper %s: %v", fp, err) } else { - scraper := newGroupScraper(*c, txnManager, globalConfig) + scraper := newGroupScraper(*conf, c.globalConfig) scrapers[scraper.spec().ID] = scraper } scraperFiles = append(scraperFiles, fp) @@ -137,7 +174,7 @@ func loadScrapers(globalConfig GlobalConfig, txnManager models.TransactionManage // In the event of an error during loading, the cache will be left empty. func (c *Cache) ReloadScrapers() error { c.scrapers = nil - scrapers, err := loadScrapers(c.globalConfig, c.txnManager) + scrapers, err := c.loadScrapers() if err != nil { return err } @@ -269,7 +306,7 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape return nil, fmt.Errorf("%w: cannot use scraper %s as a scene scraper", ErrNotSupported, scraperID) } - scene, err := getScene(ctx, id, c.txnManager) + scene, err := c.getScene(ctx, id) if err != nil { return nil, fmt.Errorf("scraper %s: unable to load scene id %v: %w", scraperID, id, err) } @@ -290,7 +327,7 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape return nil, fmt.Errorf("%w: cannot use scraper %s as a gallery scraper", ErrNotSupported, scraperID) } - gallery, err := getGallery(ctx, id, c.txnManager) + gallery, err := c.getGallery(ctx, id) if err != nil { return nil, fmt.Errorf("scraper %s: unable to load gallery id %v: %w", scraperID, id, err) } @@ -309,3 +346,27 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape return c.postScrape(ctx, ret) } + +func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) { + var ret *models.Scene + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + var err error + ret, err = c.repository.SceneFinder.Find(ctx, sceneID) + return err + }); err != nil { + return nil, err + } + return ret, nil +} + +func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, error) { + var ret *models.Gallery + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + var err error + ret, err = c.repository.GalleryFinder.Find(ctx, galleryID) + return err + }); err != nil { + return nil, err + } + return ret, nil +} diff --git a/pkg/scraper/freeones.go b/pkg/scraper/freeones.go index 7b6c81649..9a8eb4859 100644 --- a/pkg/scraper/freeones.go +++ b/pkg/scraper/freeones.go @@ -4,7 +4,6 @@ import ( "strings" "github.com/stashapp/stash/pkg/logger" - "github.com/stashapp/stash/pkg/models" ) // FreeonesScraperID is the scraper ID for the built-in Freeones scraper @@ -123,7 +122,7 @@ xPathScrapers: # Last updated April 13, 2021 ` -func getFreeonesScraper(txnManager models.TransactionManager, globalConfig GlobalConfig) scraper { +func getFreeonesScraper(globalConfig GlobalConfig) scraper { yml := freeonesScraperConfig c, err := loadConfigFromYAML(FreeonesScraperID, strings.NewReader(yml)) @@ -131,5 +130,5 @@ func getFreeonesScraper(txnManager models.TransactionManager, globalConfig Globa logger.Fatalf("Error loading builtin freeones scraper: %s", err.Error()) } - return newGroupScraper(*c, txnManager, globalConfig) + return newGroupScraper(*c, globalConfig) } diff --git a/pkg/scraper/group.go b/pkg/scraper/group.go index b423883c4..bbf0a680a 100644 --- a/pkg/scraper/group.go +++ b/pkg/scraper/group.go @@ -11,14 +11,12 @@ import ( type group struct { config config - txnManager models.TransactionManager globalConf GlobalConfig } -func newGroupScraper(c config, txnManager models.TransactionManager, globalConfig GlobalConfig) scraper { +func newGroupScraper(c config, globalConfig GlobalConfig) scraper { return group{ config: c, - txnManager: txnManager, globalConf: globalConfig, } } @@ -55,7 +53,7 @@ func (g group) viaFragment(ctx context.Context, client *http.Client, input Input return nil, ErrNotSupported } - s := g.config.getScraper(*stc, client, g.txnManager, g.globalConf) + s := g.config.getScraper(*stc, client, g.globalConf) return s.scrapeByFragment(ctx, input) } @@ -64,7 +62,7 @@ func (g group) viaScene(ctx context.Context, client *http.Client, scene *models. return nil, ErrNotSupported } - s := g.config.getScraper(*g.config.SceneByFragment, client, g.txnManager, g.globalConf) + s := g.config.getScraper(*g.config.SceneByFragment, client, g.globalConf) return s.scrapeSceneByScene(ctx, scene) } @@ -73,7 +71,7 @@ func (g group) viaGallery(ctx context.Context, client *http.Client, gallery *mod return nil, ErrNotSupported } - s := g.config.getScraper(*g.config.GalleryByFragment, client, g.txnManager, g.globalConf) + s := g.config.getScraper(*g.config.GalleryByFragment, client, g.globalConf) return s.scrapeGalleryByGallery(ctx, gallery) } @@ -96,7 +94,7 @@ func (g group) viaURL(ctx context.Context, client *http.Client, url string, ty S candidates := loadUrlCandidates(g.config, ty) for _, scraper := range candidates { if scraper.matchesURL(url) { - s := g.config.getScraper(scraper.scraperTypeConfig, client, g.txnManager, g.globalConf) + s := g.config.getScraper(scraper.scraperTypeConfig, client, g.globalConf) ret, err := s.scrapeByURL(ctx, url, ty) if err != nil { return nil, err @@ -118,14 +116,14 @@ func (g group) viaName(ctx context.Context, client *http.Client, name string, ty break } - s := g.config.getScraper(*g.config.PerformerByName, client, g.txnManager, g.globalConf) + s := g.config.getScraper(*g.config.PerformerByName, client, g.globalConf) return s.scrapeByName(ctx, name, ty) case ScrapeContentTypeScene: if g.config.SceneByName == nil { break } - s := g.config.getScraper(*g.config.SceneByName, client, g.txnManager, g.globalConf) + s := g.config.getScraper(*g.config.SceneByName, client, g.globalConf) return s.scrapeByName(ctx, name, ty) } diff --git a/pkg/scraper/json.go b/pkg/scraper/json.go index dbcba38ef..1d6358a92 100644 --- a/pkg/scraper/json.go +++ b/pkg/scraper/json.go @@ -19,16 +19,14 @@ type jsonScraper struct { config config globalConfig GlobalConfig client *http.Client - txnManager models.TransactionManager } -func newJsonScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *jsonScraper { +func newJsonScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *jsonScraper { return &jsonScraper{ scraper: scraper, config: config, client: client, globalConfig: globalConfig, - txnManager: txnManager, } } diff --git a/pkg/scraper/postprocessing.go b/pkg/scraper/postprocessing.go index ded3bc816..6351ebccb 100644 --- a/pkg/scraper/postprocessing.go +++ b/pkg/scraper/postprocessing.go @@ -6,6 +6,8 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/tag" + "github.com/stashapp/stash/pkg/txn" ) // postScrape handles post-processing of scraped content. If the content @@ -45,10 +47,10 @@ func (c Cache) postScrape(ctx context.Context, content ScrapedContent) (ScrapedC } func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) { - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - tqb := r.Tag() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + tqb := c.repository.TagFinder - tags, err := postProcessTags(tqb, p.Tags) + tags, err := postProcessTags(ctx, tqb, p.Tags) if err != nil { return err } @@ -69,8 +71,8 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) { if m.Studio != nil { - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - return match.ScrapedStudio(r.Studio(), m.Studio, nil) + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + return match.ScrapedStudio(ctx, c.repository.StudioFinder, m.Studio, nil) }); err != nil { return nil, err } @@ -88,10 +90,10 @@ func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (Scra } func (c Cache) postScrapeScenePerformer(ctx context.Context, p models.ScrapedPerformer) error { - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - tqb := r.Tag() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + tqb := c.repository.TagFinder - tags, err := postProcessTags(tqb, p.Tags) + tags, err := postProcessTags(ctx, tqb, p.Tags) if err != nil { return err } @@ -106,11 +108,11 @@ func (c Cache) postScrapeScenePerformer(ctx context.Context, p models.ScrapedPer } func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (ScrapedContent, error) { - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - pqb := r.Performer() - mqb := r.Movie() - tqb := r.Tag() - sqb := r.Studio() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + pqb := c.repository.PerformerFinder + mqb := c.repository.MovieFinder + tqb := c.repository.TagFinder + sqb := c.repository.StudioFinder for _, p := range scene.Performers { if p == nil { @@ -121,26 +123,26 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped return err } - if err := match.ScrapedPerformer(pqb, p, nil); err != nil { + if err := match.ScrapedPerformer(ctx, pqb, p, nil); err != nil { return err } } for _, p := range scene.Movies { - err := match.ScrapedMovie(mqb, p) + err := match.ScrapedMovie(ctx, mqb, p) if err != nil { return err } } - tags, err := postProcessTags(tqb, scene.Tags) + tags, err := postProcessTags(ctx, tqb, scene.Tags) if err != nil { return err } scene.Tags = tags if scene.Studio != nil { - err := match.ScrapedStudio(sqb, scene.Studio, nil) + err := match.ScrapedStudio(ctx, sqb, scene.Studio, nil) if err != nil { return err } @@ -160,26 +162,26 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped } func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (ScrapedContent, error) { - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - pqb := r.Performer() - tqb := r.Tag() - sqb := r.Studio() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + pqb := c.repository.PerformerFinder + tqb := c.repository.TagFinder + sqb := c.repository.StudioFinder for _, p := range g.Performers { - err := match.ScrapedPerformer(pqb, p, nil) + err := match.ScrapedPerformer(ctx, pqb, p, nil) if err != nil { return err } } - tags, err := postProcessTags(tqb, g.Tags) + tags, err := postProcessTags(ctx, tqb, g.Tags) if err != nil { return err } g.Tags = tags if g.Studio != nil { - err := match.ScrapedStudio(sqb, g.Studio, nil) + err := match.ScrapedStudio(ctx, sqb, g.Studio, nil) if err != nil { return err } @@ -193,11 +195,11 @@ func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (Scraped return g, nil } -func postProcessTags(tqb models.TagReader, scrapedTags []*models.ScrapedTag) ([]*models.ScrapedTag, error) { +func postProcessTags(ctx context.Context, tqb tag.Queryer, scrapedTags []*models.ScrapedTag) ([]*models.ScrapedTag, error) { var ret []*models.ScrapedTag for _, t := range scrapedTags { - err := match.ScrapedTag(tqb, t) + err := match.ScrapedTag(ctx, tqb, t) if err != nil { return nil, err } diff --git a/pkg/scraper/stash.go b/pkg/scraper/stash.go index b6e0e7696..7095ab711 100644 --- a/pkg/scraper/stash.go +++ b/pkg/scraper/stash.go @@ -18,16 +18,14 @@ type stashScraper struct { config config globalConfig GlobalConfig client *http.Client - txnManager models.TransactionManager } -func newStashScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *stashScraper { +func newStashScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *stashScraper { return &stashScraper{ scraper: scraper, config: config, client: client, globalConfig: globalConfig, - txnManager: txnManager, } } @@ -308,18 +306,6 @@ func (s *stashScraper) scrapeByURL(_ context.Context, _ string, _ ScrapeContentT return nil, ErrNotSupported } -func getScene(ctx context.Context, sceneID int, txnManager models.TransactionManager) (*models.Scene, error) { - var ret *models.Scene - if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - var err error - ret, err = r.Scene().Find(sceneID) - return err - }); err != nil { - return nil, err - } - return ret, nil -} - func sceneToUpdateInput(scene *models.Scene) models.SceneUpdateInput { toStringPtr := func(s sql.NullString) *string { if s.Valid { @@ -346,18 +332,6 @@ func sceneToUpdateInput(scene *models.Scene) models.SceneUpdateInput { } } -func getGallery(ctx context.Context, galleryID int, txnManager models.TransactionManager) (*models.Gallery, error) { - var ret *models.Gallery - if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - var err error - ret, err = r.Gallery().Find(galleryID) - return err - }); err != nil { - return nil, err - } - return ret, nil -} - func galleryToUpdateInput(gallery *models.Gallery) models.GalleryUpdateInput { toStringPtr := func(s sql.NullString) *string { if s.Valid { diff --git a/pkg/scraper/stashbox/stash_box.go b/pkg/scraper/stashbox/stash_box.go index 923cd4482..7e2017aef 100644 --- a/pkg/scraper/stashbox/stash_box.go +++ b/pkg/scraper/stashbox/stash_box.go @@ -24,18 +24,52 @@ import ( "github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper/stashbox/graphql" "github.com/stashapp/stash/pkg/sliceutil/stringslice" + "github.com/stashapp/stash/pkg/studio" + "github.com/stashapp/stash/pkg/tag" + "github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/utils" ) +type SceneReader interface { + Find(ctx context.Context, id int) (*models.Scene, error) + GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error) +} + +type PerformerReader interface { + match.PerformerFinder + Find(ctx context.Context, id int) (*models.Performer, error) + FindBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) + GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error) + GetImage(ctx context.Context, performerID int) ([]byte, error) +} + +type StudioReader interface { + match.StudioFinder + studio.Finder + GetStashIDs(ctx context.Context, studioID int) ([]*models.StashID, error) +} +type TagFinder interface { + tag.Queryer + FindBySceneID(ctx context.Context, sceneID int) ([]*models.Tag, error) +} + +type Repository struct { + Scene SceneReader + Performer PerformerReader + Tag TagFinder + Studio StudioReader +} + // Client represents the client interface to a stash-box server instance. type Client struct { client *graphql.Client - txnManager models.TransactionManager + txnManager txn.Manager + repository Repository box models.StashBox } // NewClient returns a new instance of a stash-box client. -func NewClient(box models.StashBox, txnManager models.TransactionManager) *Client { +func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Client { authHeader := func(req *http.Request) { req.Header.Set("ApiKey", box.APIKey) } @@ -47,6 +81,7 @@ func NewClient(box models.StashBox, txnManager models.TransactionManager) *Clien return &Client{ client: client, txnManager: txnManager, + repository: repo, box: box, } } @@ -92,11 +127,11 @@ func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) { var fingerprints [][]*graphql.FingerprintQueryInput - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - qb := r.Scene() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + qb := c.repository.Scene for _, sceneID := range ids { - scene, err := qb.Find(sceneID) + scene, err := qb.Find(ctx, sceneID) if err != nil { return err } @@ -177,11 +212,11 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin var fingerprints []graphql.FingerprintSubmission - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - qb := r.Scene() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + qb := c.repository.Scene for _, sceneID := range ids { - scene, err := qb.Find(sceneID) + scene, err := qb.Find(ctx, sceneID) if err != nil { return err } @@ -190,7 +225,7 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin continue } - stashIDs, err := qb.GetStashIDs(sceneID) + stashIDs, err := qb.GetStashIDs(ctx, sceneID) if err != nil { return err } @@ -307,11 +342,11 @@ func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs var performers []*models.Performer - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - qb := r.Performer() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + qb := c.repository.Performer for _, performerID := range ids { - performer, err := qb.Find(performerID) + performer, err := qb.Find(ctx, performerID) if err != nil { return err } @@ -341,11 +376,11 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf var performers []*models.Performer - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - qb := r.Performer() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + qb := c.repository.Performer for _, performerID := range ids { - performer, err := qb.Find(performerID) + performer, err := qb.Find(ctx, performerID) if err != nil { return err } @@ -622,9 +657,9 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen ss.Image = getFirstImage(ctx, c.getHTTPClient(), s.Images) } - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - pqb := r.Performer() - tqb := r.Tag() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + pqb := c.repository.Performer + tqb := c.repository.Tag if s.Studio != nil { studioID := s.Studio.ID @@ -634,7 +669,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen RemoteSiteID: &studioID, } - err := match.ScrapedStudio(r.Studio(), ss.Studio, &c.box.Endpoint) + err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint) if err != nil { return err } @@ -643,7 +678,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen for _, p := range s.Performers { sp := performerFragmentToScrapedScenePerformer(p.Performer) - err := match.ScrapedPerformer(pqb, sp, &c.box.Endpoint) + err := match.ScrapedPerformer(ctx, pqb, sp, &c.box.Endpoint) if err != nil { return err } @@ -656,7 +691,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen Name: t.Name, } - err := match.ScrapedTag(tqb, st) + err := match.ScrapedTag(ctx, tqb, st) if err != nil { return err } @@ -705,12 +740,13 @@ func (c Client) GetUser(ctx context.Context) (*graphql.Me, error) { func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint string, imagePath string) (*string, error) { draft := graphql.SceneDraftInput{} var image *os.File - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - qb := r.Scene() - pqb := r.Performer() - sqb := r.Studio() + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + r := c.repository + qb := r.Scene + pqb := r.Performer + sqb := r.Studio - scene, err := qb.Find(sceneID) + scene, err := qb.Find(ctx, sceneID) if err != nil { return err } @@ -730,7 +766,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri } if scene.StudioID.Valid { - studio, err := sqb.Find(int(scene.StudioID.Int64)) + studio, err := sqb.Find(ctx, int(scene.StudioID.Int64)) if err != nil { return err } @@ -738,7 +774,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri Name: studio.Name.String, } - stashIDs, err := sqb.GetStashIDs(studio.ID) + stashIDs, err := sqb.GetStashIDs(ctx, studio.ID) if err != nil { return err } @@ -780,7 +816,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri } draft.Fingerprints = fingerprints - scenePerformers, err := pqb.FindBySceneID(sceneID) + scenePerformers, err := pqb.FindBySceneID(ctx, sceneID) if err != nil { return err } @@ -791,7 +827,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri Name: p.Name.String, } - stashIDs, err := pqb.GetStashIDs(p.ID) + stashIDs, err := pqb.GetStashIDs(ctx, p.ID) if err != nil { return err } @@ -808,7 +844,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri draft.Performers = performers var tags []*graphql.DraftEntityInput - sceneTags, err := r.Tag().FindBySceneID(scene.ID) + sceneTags, err := r.Tag.FindBySceneID(ctx, scene.ID) if err != nil { return err } @@ -825,7 +861,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri } } - stashIDs, err := qb.GetStashIDs(sceneID) + stashIDs, err := qb.GetStashIDs(ctx, sceneID) if err != nil { return err } @@ -862,9 +898,9 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri func (c Client) SubmitPerformerDraft(ctx context.Context, performer *models.Performer, endpoint string) (*string, error) { draft := graphql.PerformerDraftInput{} var image io.Reader - if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { - pqb := r.Performer() - img, _ := pqb.GetImage(performer.ID) + if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error { + pqb := c.repository.Performer + img, _ := pqb.GetImage(ctx, performer.ID) if img != nil { image = bytes.NewReader(img) } @@ -923,7 +959,7 @@ func (c Client) SubmitPerformerDraft(ctx context.Context, performer *models.Perf draft.Urls = urls } - stashIDs, err := pqb.GetStashIDs(performer.ID) + stashIDs, err := pqb.GetStashIDs(ctx, performer.ID) if err != nil { return err } diff --git a/pkg/scraper/xpath.go b/pkg/scraper/xpath.go index 092968b5e..29a4b0a19 100644 --- a/pkg/scraper/xpath.go +++ b/pkg/scraper/xpath.go @@ -23,16 +23,14 @@ type xpathScraper struct { config config globalConfig GlobalConfig client *http.Client - txnManager models.TransactionManager } -func newXpathScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *xpathScraper { +func newXpathScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *xpathScraper { return &xpathScraper{ scraper: scraper, config: config, globalConfig: globalConfig, client: client, - txnManager: txnManager, } } diff --git a/pkg/scraper/xpath_test.go b/pkg/scraper/xpath_test.go index 9aef91a23..7120f8574 100644 --- a/pkg/scraper/xpath_test.go +++ b/pkg/scraper/xpath_test.go @@ -885,7 +885,7 @@ xPathScrapers: client := &http.Client{} ctx := context.Background() - s := newGroupScraper(*c, nil, globalConfig) + s := newGroupScraper(*c, globalConfig) us, ok := s.(urlScraper) if !ok { t.Error("couldn't convert scraper into url scraper") diff --git a/pkg/database/custom_migrations.go b/pkg/sqlite/custom_migrations.go similarity index 81% rename from pkg/database/custom_migrations.go rename to pkg/sqlite/custom_migrations.go index 340ffba55..768317707 100644 --- a/pkg/database/custom_migrations.go +++ b/pkg/sqlite/custom_migrations.go @@ -1,27 +1,33 @@ -package database +package sqlite import ( + "context" "database/sql" "errors" "fmt" "strings" - "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/txn" ) -func runCustomMigrations() error { - if err := createImagesChecksumIndex(); err != nil { +func (db *Database) runCustomMigrations() error { + if err := db.createImagesChecksumIndex(); err != nil { return err } return nil } -func createImagesChecksumIndex() error { - return WithTxn(func(tx *sqlx.Tx) error { +func (db *Database) createImagesChecksumIndex() error { + return txn.WithTxn(context.Background(), db, func(ctx context.Context) error { + tx, err := getTx(ctx) + if err != nil { + return err + } + row := tx.QueryRow("SELECT 1 AS found FROM sqlite_master WHERE type = 'index' AND name = 'images_checksum_unique'") - err := row.Err() + err = row.Err() if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } diff --git a/pkg/database/database.go b/pkg/sqlite/database.go similarity index 54% rename from pkg/database/database.go rename to pkg/sqlite/database.go index 3fa260716..5897e844a 100644 --- a/pkg/database/database.go +++ b/pkg/sqlite/database.go @@ -1,4 +1,4 @@ -package database +package sqlite import ( "database/sql" @@ -20,99 +20,135 @@ import ( "github.com/stashapp/stash/pkg/logger" ) -var DB *sqlx.DB -var WriteMu sync.Mutex -var dbPath string var appSchemaVersion uint = 31 -var databaseSchemaVersion uint //go:embed migrations/*.sql var migrationsBox embed.FS var ( - // ErrMigrationNeeded indicates that a database migration is needed - // before the database can be initialized - ErrMigrationNeeded = errors.New("database migration required") - // ErrDatabaseNotInitialized indicates that the database is not // initialized, usually due to an incomplete configuration. ErrDatabaseNotInitialized = errors.New("database not initialized") ) -const sqlite3Driver = "sqlite3ex" - -// Ready returns an error if the database is not ready to begin transactions. -func Ready() error { - if DB == nil { - return ErrDatabaseNotInitialized - } - - return nil +// ErrMigrationNeeded indicates that a database migration is needed +// before the database can be initialized +type MigrationNeededError struct { + CurrentSchemaVersion uint + RequiredSchemaVersion uint } +func (e *MigrationNeededError) Error() string { + return fmt.Sprintf("database schema version %d does not match required schema version %d", e.CurrentSchemaVersion, e.RequiredSchemaVersion) +} + +type MismatchedSchemaVersionError struct { + CurrentSchemaVersion uint + RequiredSchemaVersion uint +} + +func (e *MismatchedSchemaVersionError) Error() string { + return fmt.Sprintf("schema version %d is incompatible with required schema version %d", e.CurrentSchemaVersion, e.RequiredSchemaVersion) +} + +const sqlite3Driver = "sqlite3ex" + func init() { // register custom driver with regexp function registerCustomDriver() } -// Initialize initializes the database. If the database is new, then it +type Database struct { + db *sqlx.DB + dbPath string + + schemaVersion uint + + writeMu sync.Mutex +} + +// Ready returns an error if the database is not ready to begin transactions. +func (db *Database) Ready() error { + if db.db == nil { + return ErrDatabaseNotInitialized + } + + return nil +} + +// Open initializes the database. If the database is new, then it // performs a full migration to the latest schema version. Otherwise, any // necessary migrations must be run separately using RunMigrations. // Returns true if the database is new. -func Initialize(databasePath string) error { - dbPath = databasePath +func (db *Database) Open(dbPath string) error { + db.writeMu.Lock() + defer db.writeMu.Unlock() - if err := getDatabaseSchemaVersion(); err != nil { - return fmt.Errorf("error getting database schema version: %v", err) + db.dbPath = dbPath + + databaseSchemaVersion, err := db.getDatabaseSchemaVersion() + if err != nil { + return fmt.Errorf("getting database schema version: %w", err) } + db.schemaVersion = databaseSchemaVersion + if databaseSchemaVersion == 0 { // new database, just run the migrations - if err := RunMigrations(); err != nil { + if err := db.RunMigrations(); err != nil { return fmt.Errorf("error running initial schema migrations: %v", err) } - // RunMigrations calls Initialise. Just return - return nil } else { if databaseSchemaVersion > appSchemaVersion { - panic(fmt.Sprintf("Database schema version %d is incompatible with required schema version %d", databaseSchemaVersion, appSchemaVersion)) + return &MismatchedSchemaVersionError{ + CurrentSchemaVersion: databaseSchemaVersion, + RequiredSchemaVersion: appSchemaVersion, + } } // if migration is needed, then don't open the connection - if NeedsMigration() { - logger.Warnf("Database schema version %d does not match required schema version %d.", databaseSchemaVersion, appSchemaVersion) - return nil + if db.needsMigration() { + return &MigrationNeededError{ + CurrentSchemaVersion: databaseSchemaVersion, + RequiredSchemaVersion: appSchemaVersion, + } } } - const disableForeignKeys = false - DB = open(databasePath, disableForeignKeys) + // RunMigrations may have opened a connection already + if db.db == nil { + const disableForeignKeys = false + db.db, err = db.open(disableForeignKeys) + if err != nil { + return err + } + } - if err := runCustomMigrations(); err != nil { + if err := db.runCustomMigrations(); err != nil { return err } return nil } -func Close() error { - WriteMu.Lock() - defer WriteMu.Unlock() +func (db *Database) Close() error { + db.writeMu.Lock() + defer db.writeMu.Unlock() - if DB != nil { - if err := DB.Close(); err != nil { + if db.db != nil { + if err := db.db.Close(); err != nil { return err } - DB = nil + db.db = nil } return nil } -func open(databasePath string, disableForeignKeys bool) *sqlx.DB { +func (db *Database) open(disableForeignKeys bool) (*sqlx.DB, error) { // https://github.com/mattn/go-sqlite3 - url := "file:" + databasePath + "?_journal=WAL&_sync=NORMAL" + url := "file:" + db.dbPath + "?_journal=WAL&_sync=NORMAL" if !disableForeignKeys { url += "&_fk=true" } @@ -122,14 +158,15 @@ func open(databasePath string, disableForeignKeys bool) *sqlx.DB { conn.SetMaxIdleConns(4) conn.SetConnMaxLifetime(30 * time.Second) if err != nil { - logger.Fatalf("db.Open(): %q\n", err) + return nil, fmt.Errorf("db.Open(): %w", err) } - return conn + return conn, nil } -func Reset(databasePath string) error { - err := DB.Close() +func (db *Database) Reset() error { + databasePath := db.dbPath + err := db.Close() if err != nil { return errors.New("Error closing database: " + err.Error()) @@ -151,7 +188,7 @@ func Reset(databasePath string) error { } } - if err := Initialize(databasePath); err != nil { + if err := db.Open(databasePath); err != nil { return fmt.Errorf("[reset DB] unable to initialize: %w", err) } @@ -160,18 +197,19 @@ func Reset(databasePath string) error { // Backup the database. If db is nil, then uses the existing database // connection. -func Backup(db *sqlx.DB, backupPath string) error { - if db == nil { +func (db *Database) Backup(backupPath string) error { + thisDB := db.db + if thisDB == nil { var err error - db, err = sqlx.Connect(sqlite3Driver, "file:"+dbPath+"?_fk=true") + thisDB, err = sqlx.Connect(sqlite3Driver, "file:"+db.dbPath+"?_fk=true") if err != nil { - return fmt.Errorf("open database %s failed: %v", dbPath, err) + return fmt.Errorf("open database %s failed: %v", db.dbPath, err) } - defer db.Close() + defer thisDB.Close() } logger.Infof("Backing up database into: %s", backupPath) - _, err := db.Exec(`VACUUM INTO "` + backupPath + `"`) + _, err := thisDB.Exec(`VACUUM INTO "` + backupPath + `"`) if err != nil { return fmt.Errorf("vacuum failed: %v", err) } @@ -179,40 +217,43 @@ func Backup(db *sqlx.DB, backupPath string) error { return nil } -func RestoreFromBackup(backupPath string) error { - logger.Infof("Restoring backup database %s into %s", backupPath, dbPath) - return os.Rename(backupPath, dbPath) +func (db *Database) RestoreFromBackup(backupPath string) error { + logger.Infof("Restoring backup database %s into %s", backupPath, db.dbPath) + return os.Rename(backupPath, db.dbPath) } // Migrate the database -func NeedsMigration() bool { - return databaseSchemaVersion != appSchemaVersion +func (db *Database) needsMigration() bool { + return db.schemaVersion != appSchemaVersion } -func AppSchemaVersion() uint { +func (db *Database) AppSchemaVersion() uint { return appSchemaVersion } -func DatabasePath() string { - return dbPath +func (db *Database) DatabasePath() string { + return db.dbPath } -func DatabaseBackupPath() string { - return fmt.Sprintf("%s.%d.%s", dbPath, databaseSchemaVersion, time.Now().Format("20060102_150405")) +func (db *Database) DatabaseBackupPath() string { + return fmt.Sprintf("%s.%d.%s", db.dbPath, db.schemaVersion, time.Now().Format("20060102_150405")) } -func Version() uint { - return databaseSchemaVersion +func (db *Database) Version() uint { + return db.schemaVersion } -func getMigrate() (*migrate.Migrate, error) { +func (db *Database) getMigrate() (*migrate.Migrate, error) { migrations, err := iofs.New(migrationsBox, "migrations") if err != nil { panic(err.Error()) } const disableForeignKeys = true - conn := open(dbPath, disableForeignKeys) + conn, err := db.open(disableForeignKeys) + if err != nil { + return nil, err + } driver, err := sqlite3mig.WithInstance(conn.DB, &sqlite3mig.Config{}) if err != nil { @@ -223,31 +264,31 @@ func getMigrate() (*migrate.Migrate, error) { return migrate.NewWithInstance( "iofs", migrations, - dbPath, + db.dbPath, driver, ) } -func getDatabaseSchemaVersion() error { - m, err := getMigrate() +func (db *Database) getDatabaseSchemaVersion() (uint, error) { + m, err := db.getMigrate() if err != nil { - return err - } - - databaseSchemaVersion, _, _ = m.Version() - m.Close() - return nil -} - -// Migrate the database -func RunMigrations() error { - m, err := getMigrate() - if err != nil { - panic(err.Error()) + return 0, err } defer m.Close() - databaseSchemaVersion, _, _ = m.Version() + ret, _, _ := m.Version() + return ret, nil +} + +// Migrate the database +func (db *Database) RunMigrations() error { + m, err := db.getMigrate() + if err != nil { + return err + } + defer m.Close() + + databaseSchemaVersion, _, _ := m.Version() stepNumber := appSchemaVersion - databaseSchemaVersion if stepNumber != 0 { logger.Infof("Migrating database from version %d to %d", databaseSchemaVersion, appSchemaVersion) @@ -258,14 +299,19 @@ func RunMigrations() error { } } + // update the schema version + db.schemaVersion, _, _ = m.Version() + // re-initialise the database - if err = Initialize(dbPath); err != nil { - logger.Warnf("Error re-initializing the database: %v", err) + const disableForeignKeys = false + db.db, err = db.open(disableForeignKeys) + if err != nil { + return fmt.Errorf("re-initializing the database: %w", err) } // run a vacuum on the database logger.Info("Performing vacuum on database") - _, err = DB.Exec("VACUUM") + _, err = db.db.Exec("VACUUM") if err != nil { logger.Warnf("error while performing post-migration vacuum: %v", err) } diff --git a/pkg/sqlite/filter.go b/pkg/sqlite/filter.go index 79af98efd..af8578dbe 100644 --- a/pkg/sqlite/filter.go +++ b/pkg/sqlite/filter.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "errors" "fmt" "regexp" @@ -26,13 +27,13 @@ func makeClause(sql string, args ...interface{}) sqlClause { } type criterionHandler interface { - handle(f *filterBuilder) + handle(ctx context.Context, f *filterBuilder) } -type criterionHandlerFunc func(f *filterBuilder) +type criterionHandlerFunc func(ctx context.Context, f *filterBuilder) -func (h criterionHandlerFunc) handle(f *filterBuilder) { - h(f) +func (h criterionHandlerFunc) handle(ctx context.Context, f *filterBuilder) { + h(ctx, f) } type join struct { @@ -331,8 +332,8 @@ func (f *filterBuilder) getError() error { // handleCriterion calls the handle function on the provided criterionHandler, // providing itself. -func (f *filterBuilder) handleCriterion(handler criterionHandler) { - handler.handle(f) +func (f *filterBuilder) handleCriterion(ctx context.Context, handler criterionHandler) { + handler.handle(ctx, f) } func (f *filterBuilder) setError(e error) { @@ -361,7 +362,7 @@ func (f *filterBuilder) andClauses(input []sqlClause) (string, []interface{}) { } func stringCriterionHandler(c *models.StringCriterionInput, column string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if c != nil { if modifier := c.Modifier; c.Modifier.IsValid() { switch modifier { @@ -400,7 +401,7 @@ func stringCriterionHandler(c *models.StringCriterionInput, column string) crite } func intCriterionHandler(c *models.IntCriterionInput, column string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if c != nil { clause, args := getIntCriterionWhereClause(column, *c) f.addWhere(clause, args...) @@ -409,7 +410,7 @@ func intCriterionHandler(c *models.IntCriterionInput, column string) criterionHa } func boolCriterionHandler(c *bool, column string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if c != nil { var v string if *c { @@ -441,7 +442,7 @@ type joinedMultiCriterionHandlerBuilder struct { } func (m *joinedMultiCriterionHandlerBuilder) handler(criterion *models.MultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if criterion != nil { joinAlias := m.joinAs if joinAlias == "" { @@ -511,7 +512,7 @@ type multiCriterionHandlerBuilder struct { } func (m *multiCriterionHandlerBuilder) handler(criterion *models.MultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if criterion != nil { if criterion.Modifier == models.CriterionModifierIsNull || criterion.Modifier == models.CriterionModifierNotNull { var notClause string @@ -556,7 +557,7 @@ type countCriterionHandlerBuilder struct { } func (m *countCriterionHandlerBuilder) handler(criterion *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if criterion != nil { clause, args := getCountCriterionClause(m.primaryTable, m.joinTable, m.primaryFK, *criterion) @@ -576,11 +577,11 @@ type stringListCriterionHandlerBuilder struct { } func (m *stringListCriterionHandlerBuilder) handler(criterion *models.StringCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if criterion != nil && len(criterion.Value) > 0 { m.addJoinTable(f) - stringCriterionHandler(criterion, m.joinTable+"."+m.stringColumn)(f) + stringCriterionHandler(criterion, m.joinTable+"."+m.stringColumn)(ctx, f) } } } @@ -597,7 +598,7 @@ type hierarchicalMultiCriterionHandlerBuilder struct { relationsTable string } -func getHierarchicalValues(tx dbi, values []string, table, relationsTable, parentFK string, depth *int) string { +func getHierarchicalValues(ctx context.Context, tx dbi, values []string, table, relationsTable, parentFK string, depth *int) string { var args []interface{} depthVal := 0 @@ -670,7 +671,7 @@ WHERE id in {inBinding} query := fmt.Sprintf("WITH RECURSIVE %s SELECT 'VALUES' || GROUP_CONCAT('(' || root_id || ', ' || item_id || ')') AS val FROM items", withClause) var valuesClause string - err := tx.Get(&valuesClause, query, args...) + err := tx.Get(ctx, &valuesClause, query, args...) if err != nil { logger.Error(err) // return record which never matches so we don't have to handle error here @@ -693,7 +694,7 @@ func addHierarchicalConditionClauses(f *filterBuilder, criterion *models.Hierarc } func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if criterion != nil { if criterion.Modifier == models.CriterionModifierIsNull || criterion.Modifier == models.CriterionModifierNotNull { var notClause string @@ -713,7 +714,7 @@ func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.Hie return } - valuesClause := getHierarchicalValues(m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth) + valuesClause := getHierarchicalValues(ctx, m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth) f.addLeftJoin("(SELECT column1 AS root_id, column2 AS item_id FROM ("+valuesClause+"))", m.derivedTable, fmt.Sprintf("%s.item_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK)) @@ -738,7 +739,7 @@ type joinedHierarchicalMultiCriterionHandlerBuilder struct { } func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if criterion != nil { joinAlias := m.joinAs @@ -762,7 +763,7 @@ func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(criterion *mode return } - valuesClause := getHierarchicalValues(m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth) + valuesClause := getHierarchicalValues(ctx, m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth) joinTable := utils.StrFormat(`( SELECT j.*, d.column1 AS root_id, d.column2 AS item_id FROM {joinTable} AS j diff --git a/pkg/sqlite/filter_internal_test.go b/pkg/sqlite/filter_internal_test.go index e9f173de0..f416b661c 100644 --- a/pkg/sqlite/filter_internal_test.go +++ b/pkg/sqlite/filter_internal_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "errors" "fmt" "testing" @@ -9,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" ) +var testCtx = context.Background() + func TestJoinsAddJoin(t *testing.T) { var joins joins @@ -462,7 +465,7 @@ func TestStringCriterionHandlerIncludes(t *testing.T) { const quotedValue = `"two words"` f := &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierIncludes, Value: value1, }, column)) @@ -474,7 +477,7 @@ func TestStringCriterionHandlerIncludes(t *testing.T) { assert.Equal("%words%", f.whereClauses[0].args[1]) f = &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierIncludes, Value: quotedValue, }, column)) @@ -493,7 +496,7 @@ func TestStringCriterionHandlerExcludes(t *testing.T) { const quotedValue = `"two words"` f := &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierExcludes, Value: value1, }, column)) @@ -505,7 +508,7 @@ func TestStringCriterionHandlerExcludes(t *testing.T) { assert.Equal("%words%", f.whereClauses[0].args[1]) f = &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierExcludes, Value: quotedValue, }, column)) @@ -523,7 +526,7 @@ func TestStringCriterionHandlerEquals(t *testing.T) { const value1 = "two words" f := &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierEquals, Value: value1, }, column)) @@ -541,7 +544,7 @@ func TestStringCriterionHandlerNotEquals(t *testing.T) { const value1 = "two words" f := &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierNotEquals, Value: value1, }, column)) @@ -560,7 +563,7 @@ func TestStringCriterionHandlerMatchesRegex(t *testing.T) { const invalidValue = "*two words" f := &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierMatchesRegex, Value: validValue, }, column)) @@ -572,7 +575,7 @@ func TestStringCriterionHandlerMatchesRegex(t *testing.T) { // ensure invalid regex sets error state f = &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierMatchesRegex, Value: invalidValue, }, column)) @@ -588,7 +591,7 @@ func TestStringCriterionHandlerNotMatchesRegex(t *testing.T) { const invalidValue = "*two words" f := &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierNotMatchesRegex, Value: validValue, }, column)) @@ -600,7 +603,7 @@ func TestStringCriterionHandlerNotMatchesRegex(t *testing.T) { // ensure invalid regex sets error state f = &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierNotMatchesRegex, Value: invalidValue, }, column)) @@ -614,7 +617,7 @@ func TestStringCriterionHandlerIsNull(t *testing.T) { const column = "column" f := &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierIsNull, }, column)) @@ -629,7 +632,7 @@ func TestStringCriterionHandlerNotNull(t *testing.T) { const column = "column" f := &filterBuilder{} - f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{ + f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{ Modifier: models.CriterionModifierNotNull, }, column)) diff --git a/pkg/database/functions.go b/pkg/sqlite/functions.go similarity index 96% rename from pkg/database/functions.go rename to pkg/sqlite/functions.go index 2971f1e22..29e93aa22 100644 --- a/pkg/database/functions.go +++ b/pkg/sqlite/functions.go @@ -1,4 +1,4 @@ -package database +package sqlite import ( "strconv" diff --git a/pkg/sqlite/gallery.go b/pkg/sqlite/gallery.go index b7f9276ac..bb94fa1f0 100644 --- a/pkg/sqlite/gallery.go +++ b/pkg/sqlite/gallery.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -20,62 +21,59 @@ type galleryQueryBuilder struct { repository } -func NewGalleryReaderWriter(tx dbi) *galleryQueryBuilder { - return &galleryQueryBuilder{ - repository{ - tx: tx, - tableName: galleryTable, - idColumn: idColumn, - }, - } +var GalleryReaderWriter = &galleryQueryBuilder{ + repository{ + tableName: galleryTable, + idColumn: idColumn, + }, } -func (qb *galleryQueryBuilder) Create(newObject models.Gallery) (*models.Gallery, error) { +func (qb *galleryQueryBuilder) Create(ctx context.Context, newObject models.Gallery) (*models.Gallery, error) { var ret models.Gallery - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *galleryQueryBuilder) Update(updatedObject models.Gallery) (*models.Gallery, error) { +func (qb *galleryQueryBuilder) Update(ctx context.Context, updatedObject models.Gallery) (*models.Gallery, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedObject.ID) + return qb.Find(ctx, updatedObject.ID) } -func (qb *galleryQueryBuilder) UpdatePartial(updatedObject models.GalleryPartial) (*models.Gallery, error) { +func (qb *galleryQueryBuilder) UpdatePartial(ctx context.Context, updatedObject models.GalleryPartial) (*models.Gallery, error) { const partial = true - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedObject.ID) + return qb.Find(ctx, updatedObject.ID) } -func (qb *galleryQueryBuilder) UpdateChecksum(id int, checksum string) error { - return qb.updateMap(id, map[string]interface{}{ +func (qb *galleryQueryBuilder) UpdateChecksum(ctx context.Context, id int, checksum string) error { + return qb.updateMap(ctx, id, map[string]interface{}{ "checksum": checksum, }) } -func (qb *galleryQueryBuilder) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error { - return qb.updateMap(id, map[string]interface{}{ +func (qb *galleryQueryBuilder) UpdateFileModTime(ctx context.Context, id int, modTime models.NullSQLiteTimestamp) error { + return qb.updateMap(ctx, id, map[string]interface{}{ "file_mod_time": modTime, }) } -func (qb *galleryQueryBuilder) Destroy(id int) error { - return qb.destroyExisting([]int{id}) +func (qb *galleryQueryBuilder) Destroy(ctx context.Context, id int) error { + return qb.destroyExisting(ctx, []int{id}) } -func (qb *galleryQueryBuilder) Find(id int) (*models.Gallery, error) { +func (qb *galleryQueryBuilder) Find(ctx context.Context, id int) (*models.Gallery, error) { var ret models.Gallery - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -84,10 +82,10 @@ func (qb *galleryQueryBuilder) Find(id int) (*models.Gallery, error) { return &ret, nil } -func (qb *galleryQueryBuilder) FindMany(ids []int) ([]*models.Gallery, error) { +func (qb *galleryQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Gallery, error) { var galleries []*models.Gallery for _, id := range ids { - gallery, err := qb.Find(id) + gallery, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -102,61 +100,61 @@ func (qb *galleryQueryBuilder) FindMany(ids []int) ([]*models.Gallery, error) { return galleries, nil } -func (qb *galleryQueryBuilder) FindByChecksum(checksum string) (*models.Gallery, error) { +func (qb *galleryQueryBuilder) FindByChecksum(ctx context.Context, checksum string) (*models.Gallery, error) { query := "SELECT * FROM galleries WHERE checksum = ? LIMIT 1" args := []interface{}{checksum} - return qb.queryGallery(query, args) + return qb.queryGallery(ctx, query, args) } -func (qb *galleryQueryBuilder) FindByChecksums(checksums []string) ([]*models.Gallery, error) { +func (qb *galleryQueryBuilder) FindByChecksums(ctx context.Context, checksums []string) ([]*models.Gallery, error) { query := "SELECT * FROM galleries WHERE checksum IN " + getInBinding(len(checksums)) var args []interface{} for _, checksum := range checksums { args = append(args, checksum) } - return qb.queryGalleries(query, args) + return qb.queryGalleries(ctx, query, args) } -func (qb *galleryQueryBuilder) FindByPath(path string) (*models.Gallery, error) { +func (qb *galleryQueryBuilder) FindByPath(ctx context.Context, path string) (*models.Gallery, error) { query := "SELECT * FROM galleries WHERE path = ? LIMIT 1" args := []interface{}{path} - return qb.queryGallery(query, args) + return qb.queryGallery(ctx, query, args) } -func (qb *galleryQueryBuilder) FindBySceneID(sceneID int) ([]*models.Gallery, error) { +func (qb *galleryQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Gallery, error) { query := selectAll(galleryTable) + ` LEFT JOIN scenes_galleries as scenes_join on scenes_join.gallery_id = galleries.id WHERE scenes_join.scene_id = ? GROUP BY galleries.id ` args := []interface{}{sceneID} - return qb.queryGalleries(query, args) + return qb.queryGalleries(ctx, query, args) } -func (qb *galleryQueryBuilder) FindByImageID(imageID int) ([]*models.Gallery, error) { +func (qb *galleryQueryBuilder) FindByImageID(ctx context.Context, imageID int) ([]*models.Gallery, error) { query := selectAll(galleryTable) + ` INNER JOIN galleries_images as images_join on images_join.gallery_id = galleries.id WHERE images_join.image_id = ? GROUP BY galleries.id ` args := []interface{}{imageID} - return qb.queryGalleries(query, args) + return qb.queryGalleries(ctx, query, args) } -func (qb *galleryQueryBuilder) CountByImageID(imageID int) (int, error) { +func (qb *galleryQueryBuilder) CountByImageID(ctx context.Context, imageID int) (int, error) { query := `SELECT image_id FROM galleries_images WHERE image_id = ? GROUP BY gallery_id` args := []interface{}{imageID} - return qb.runCountQuery(qb.buildCountQuery(query), args) + return qb.runCountQuery(ctx, qb.buildCountQuery(query), args) } -func (qb *galleryQueryBuilder) Count() (int, error) { - return qb.runCountQuery(qb.buildCountQuery("SELECT galleries.id FROM galleries"), nil) +func (qb *galleryQueryBuilder) Count(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT galleries.id FROM galleries"), nil) } -func (qb *galleryQueryBuilder) All() ([]*models.Gallery, error) { - return qb.queryGalleries(selectAll("galleries")+qb.getGallerySort(nil), nil) +func (qb *galleryQueryBuilder) All(ctx context.Context) ([]*models.Gallery, error) { + return qb.queryGalleries(ctx, selectAll("galleries")+qb.getGallerySort(nil), nil) } func (qb *galleryQueryBuilder) validateFilter(galleryFilter *models.GalleryFilterType) error { @@ -190,43 +188,43 @@ func (qb *galleryQueryBuilder) validateFilter(galleryFilter *models.GalleryFilte return nil } -func (qb *galleryQueryBuilder) makeFilter(galleryFilter *models.GalleryFilterType) *filterBuilder { +func (qb *galleryQueryBuilder) makeFilter(ctx context.Context, galleryFilter *models.GalleryFilterType) *filterBuilder { query := &filterBuilder{} if galleryFilter.And != nil { - query.and(qb.makeFilter(galleryFilter.And)) + query.and(qb.makeFilter(ctx, galleryFilter.And)) } if galleryFilter.Or != nil { - query.or(qb.makeFilter(galleryFilter.Or)) + query.or(qb.makeFilter(ctx, galleryFilter.Or)) } if galleryFilter.Not != nil { - query.not(qb.makeFilter(galleryFilter.Not)) + query.not(qb.makeFilter(ctx, galleryFilter.Not)) } - query.handleCriterion(stringCriterionHandler(galleryFilter.Title, "galleries.title")) - query.handleCriterion(stringCriterionHandler(galleryFilter.Details, "galleries.details")) - query.handleCriterion(stringCriterionHandler(galleryFilter.Checksum, "galleries.checksum")) - query.handleCriterion(boolCriterionHandler(galleryFilter.IsZip, "galleries.zip")) - query.handleCriterion(stringCriterionHandler(galleryFilter.Path, "galleries.path")) - query.handleCriterion(intCriterionHandler(galleryFilter.Rating, "galleries.rating")) - query.handleCriterion(stringCriterionHandler(galleryFilter.URL, "galleries.url")) - query.handleCriterion(boolCriterionHandler(galleryFilter.Organized, "galleries.organized")) - query.handleCriterion(galleryIsMissingCriterionHandler(qb, galleryFilter.IsMissing)) - query.handleCriterion(galleryTagsCriterionHandler(qb, galleryFilter.Tags)) - query.handleCriterion(galleryTagCountCriterionHandler(qb, galleryFilter.TagCount)) - query.handleCriterion(galleryPerformersCriterionHandler(qb, galleryFilter.Performers)) - query.handleCriterion(galleryPerformerCountCriterionHandler(qb, galleryFilter.PerformerCount)) - query.handleCriterion(galleryStudioCriterionHandler(qb, galleryFilter.Studios)) - query.handleCriterion(galleryPerformerTagsCriterionHandler(qb, galleryFilter.PerformerTags)) - query.handleCriterion(galleryAverageResolutionCriterionHandler(qb, galleryFilter.AverageResolution)) - query.handleCriterion(galleryImageCountCriterionHandler(qb, galleryFilter.ImageCount)) - query.handleCriterion(galleryPerformerFavoriteCriterionHandler(galleryFilter.PerformerFavorite)) - query.handleCriterion(galleryPerformerAgeCriterionHandler(galleryFilter.PerformerAge)) + query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.Title, "galleries.title")) + query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.Details, "galleries.details")) + query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.Checksum, "galleries.checksum")) + query.handleCriterion(ctx, boolCriterionHandler(galleryFilter.IsZip, "galleries.zip")) + query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.Path, "galleries.path")) + query.handleCriterion(ctx, intCriterionHandler(galleryFilter.Rating, "galleries.rating")) + query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.URL, "galleries.url")) + query.handleCriterion(ctx, boolCriterionHandler(galleryFilter.Organized, "galleries.organized")) + query.handleCriterion(ctx, galleryIsMissingCriterionHandler(qb, galleryFilter.IsMissing)) + query.handleCriterion(ctx, galleryTagsCriterionHandler(qb, galleryFilter.Tags)) + query.handleCriterion(ctx, galleryTagCountCriterionHandler(qb, galleryFilter.TagCount)) + query.handleCriterion(ctx, galleryPerformersCriterionHandler(qb, galleryFilter.Performers)) + query.handleCriterion(ctx, galleryPerformerCountCriterionHandler(qb, galleryFilter.PerformerCount)) + query.handleCriterion(ctx, galleryStudioCriterionHandler(qb, galleryFilter.Studios)) + query.handleCriterion(ctx, galleryPerformerTagsCriterionHandler(qb, galleryFilter.PerformerTags)) + query.handleCriterion(ctx, galleryAverageResolutionCriterionHandler(qb, galleryFilter.AverageResolution)) + query.handleCriterion(ctx, galleryImageCountCriterionHandler(qb, galleryFilter.ImageCount)) + query.handleCriterion(ctx, galleryPerformerFavoriteCriterionHandler(galleryFilter.PerformerFavorite)) + query.handleCriterion(ctx, galleryPerformerAgeCriterionHandler(galleryFilter.PerformerAge)) return query } -func (qb *galleryQueryBuilder) makeQuery(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) { +func (qb *galleryQueryBuilder) makeQuery(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) { if galleryFilter == nil { galleryFilter = &models.GalleryFilterType{} } @@ -245,7 +243,7 @@ func (qb *galleryQueryBuilder) makeQuery(galleryFilter *models.GalleryFilterType if err := qb.validateFilter(galleryFilter); err != nil { return nil, err } - filter := qb.makeFilter(galleryFilter) + filter := qb.makeFilter(ctx, galleryFilter) query.addFilter(filter) @@ -254,20 +252,20 @@ func (qb *galleryQueryBuilder) makeQuery(galleryFilter *models.GalleryFilterType return &query, nil } -func (qb *galleryQueryBuilder) Query(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) { - query, err := qb.makeQuery(galleryFilter, findFilter) +func (qb *galleryQueryBuilder) Query(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) { + query, err := qb.makeQuery(ctx, galleryFilter, findFilter) if err != nil { return nil, 0, err } - idsResult, countResult, err := query.executeFind() + idsResult, countResult, err := query.executeFind(ctx) if err != nil { return nil, 0, err } var galleries []*models.Gallery for _, id := range idsResult { - gallery, err := qb.Find(id) + gallery, err := qb.Find(ctx, id) if err != nil { return nil, 0, err } @@ -278,17 +276,17 @@ func (qb *galleryQueryBuilder) Query(galleryFilter *models.GalleryFilterType, fi return galleries, countResult, nil } -func (qb *galleryQueryBuilder) QueryCount(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) { - query, err := qb.makeQuery(galleryFilter, findFilter) +func (qb *galleryQueryBuilder) QueryCount(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) { + query, err := qb.makeQuery(ctx, galleryFilter, findFilter) if err != nil { return 0, err } - return query.executeCount() + return query.executeCount(ctx) } func galleryIsMissingCriterionHandler(qb *galleryQueryBuilder, isMissing *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if isMissing != nil && *isMissing != "" { switch *isMissing { case "scenes": @@ -389,7 +387,7 @@ func galleryStudioCriterionHandler(qb *galleryQueryBuilder, studios *models.Hier } func galleryPerformerTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if tags != nil { if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { var notClause string @@ -408,7 +406,7 @@ func galleryPerformerTagsCriterionHandler(qb *galleryQueryBuilder, tags *models. return } - valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) f.addWith(`performer_tags AS ( SELECT pg.gallery_id, t.column1 AS root_tag_id FROM performers_galleries pg @@ -424,7 +422,7 @@ INNER JOIN (` + valuesClause + `) t ON t.column2 = pt.tag_id } func galleryPerformerFavoriteCriterionHandler(performerfavorite *bool) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if performerfavorite != nil { f.addLeftJoin("performers_galleries", "", "galleries.id = performers_galleries.gallery_id") @@ -444,7 +442,7 @@ GROUP BY performers_galleries.gallery_id HAVING SUM(performers.favorite) = 0)`, } func galleryPerformerAgeCriterionHandler(performerAge *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if performerAge != nil { f.addInnerJoin("performers_galleries", "", "galleries.id = performers_galleries.gallery_id") f.addInnerJoin("performers", "", "performers_galleries.performer_id = performers.id") @@ -461,7 +459,7 @@ func galleryPerformerAgeCriterionHandler(performerAge *models.IntCriterionInput) } func galleryAverageResolutionCriterionHandler(qb *galleryQueryBuilder, resolution *models.ResolutionCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if resolution != nil && resolution.Value.IsValid() { qb.imagesRepository().join(f, "images_join", "galleries.id") f.addLeftJoin("images", "", "images_join.image_id = images.id") @@ -505,17 +503,17 @@ func (qb *galleryQueryBuilder) getGallerySort(findFilter *models.FindFilterType) } } -func (qb *galleryQueryBuilder) queryGallery(query string, args []interface{}) (*models.Gallery, error) { - results, err := qb.queryGalleries(query, args) +func (qb *galleryQueryBuilder) queryGallery(ctx context.Context, query string, args []interface{}) (*models.Gallery, error) { + results, err := qb.queryGalleries(ctx, query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *galleryQueryBuilder) queryGalleries(query string, args []interface{}) ([]*models.Gallery, error) { +func (qb *galleryQueryBuilder) queryGalleries(ctx context.Context, query string, args []interface{}) ([]*models.Gallery, error) { var ret models.Galleries - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } @@ -533,13 +531,13 @@ func (qb *galleryQueryBuilder) performersRepository() *joinRepository { } } -func (qb *galleryQueryBuilder) GetPerformerIDs(galleryID int) ([]int, error) { - return qb.performersRepository().getIDs(galleryID) +func (qb *galleryQueryBuilder) GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error) { + return qb.performersRepository().getIDs(ctx, galleryID) } -func (qb *galleryQueryBuilder) UpdatePerformers(galleryID int, performerIDs []int) error { +func (qb *galleryQueryBuilder) UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error { // Delete the existing joins and then create new ones - return qb.performersRepository().replace(galleryID, performerIDs) + return qb.performersRepository().replace(ctx, galleryID, performerIDs) } func (qb *galleryQueryBuilder) tagsRepository() *joinRepository { @@ -553,13 +551,13 @@ func (qb *galleryQueryBuilder) tagsRepository() *joinRepository { } } -func (qb *galleryQueryBuilder) GetTagIDs(galleryID int) ([]int, error) { - return qb.tagsRepository().getIDs(galleryID) +func (qb *galleryQueryBuilder) GetTagIDs(ctx context.Context, galleryID int) ([]int, error) { + return qb.tagsRepository().getIDs(ctx, galleryID) } -func (qb *galleryQueryBuilder) UpdateTags(galleryID int, tagIDs []int) error { +func (qb *galleryQueryBuilder) UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error { // Delete the existing joins and then create new ones - return qb.tagsRepository().replace(galleryID, tagIDs) + return qb.tagsRepository().replace(ctx, galleryID, tagIDs) } func (qb *galleryQueryBuilder) imagesRepository() *joinRepository { @@ -573,13 +571,13 @@ func (qb *galleryQueryBuilder) imagesRepository() *joinRepository { } } -func (qb *galleryQueryBuilder) GetImageIDs(galleryID int) ([]int, error) { - return qb.imagesRepository().getIDs(galleryID) +func (qb *galleryQueryBuilder) GetImageIDs(ctx context.Context, galleryID int) ([]int, error) { + return qb.imagesRepository().getIDs(ctx, galleryID) } -func (qb *galleryQueryBuilder) UpdateImages(galleryID int, imageIDs []int) error { +func (qb *galleryQueryBuilder) UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error { // Delete the existing joins and then create new ones - return qb.imagesRepository().replace(galleryID, imageIDs) + return qb.imagesRepository().replace(ctx, galleryID, imageIDs) } func (qb *galleryQueryBuilder) scenesRepository() *joinRepository { @@ -593,11 +591,11 @@ func (qb *galleryQueryBuilder) scenesRepository() *joinRepository { } } -func (qb *galleryQueryBuilder) GetSceneIDs(galleryID int) ([]int, error) { - return qb.scenesRepository().getIDs(galleryID) +func (qb *galleryQueryBuilder) GetSceneIDs(ctx context.Context, galleryID int) ([]int, error) { + return qb.scenesRepository().getIDs(ctx, galleryID) } -func (qb *galleryQueryBuilder) UpdateScenes(galleryID int, sceneIDs []int) error { +func (qb *galleryQueryBuilder) UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error { // Delete the existing joins and then create new ones - return qb.scenesRepository().replace(galleryID, sceneIDs) + return qb.scenesRepository().replace(ctx, galleryID, sceneIDs) } diff --git a/pkg/sqlite/gallery_test.go b/pkg/sqlite/gallery_test.go index f9aa9ef5e..ae2cbe21b 100644 --- a/pkg/sqlite/gallery_test.go +++ b/pkg/sqlite/gallery_test.go @@ -4,6 +4,7 @@ package sqlite_test import ( + "context" "math" "strconv" "testing" @@ -11,14 +12,15 @@ import ( "github.com/stretchr/testify/assert" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" ) func TestGalleryFind(t *testing.T) { - withTxn(func(r models.Repository) error { - gqb := r.Gallery() + withTxn(func(ctx context.Context) error { + gqb := sqlite.GalleryReaderWriter const galleryIdx = 0 - gallery, err := gqb.Find(galleryIDs[galleryIdx]) + gallery, err := gqb.Find(ctx, galleryIDs[galleryIdx]) if err != nil { t.Errorf("Error finding gallery: %s", err.Error()) @@ -26,7 +28,7 @@ func TestGalleryFind(t *testing.T) { assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String) - gallery, err = gqb.Find(0) + gallery, err = gqb.Find(ctx, 0) if err != nil { t.Errorf("Error finding gallery: %s", err.Error()) @@ -39,12 +41,12 @@ func TestGalleryFind(t *testing.T) { } func TestGalleryFindByChecksum(t *testing.T) { - withTxn(func(r models.Repository) error { - gqb := r.Gallery() + withTxn(func(ctx context.Context) error { + gqb := sqlite.GalleryReaderWriter const galleryIdx = 0 galleryChecksum := getGalleryStringValue(galleryIdx, "Checksum") - gallery, err := gqb.FindByChecksum(galleryChecksum) + gallery, err := gqb.FindByChecksum(ctx, galleryChecksum) if err != nil { t.Errorf("Error finding gallery: %s", err.Error()) @@ -53,7 +55,7 @@ func TestGalleryFindByChecksum(t *testing.T) { assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String) galleryChecksum = "not exist" - gallery, err = gqb.FindByChecksum(galleryChecksum) + gallery, err = gqb.FindByChecksum(ctx, galleryChecksum) if err != nil { t.Errorf("Error finding gallery: %s", err.Error()) @@ -66,12 +68,12 @@ func TestGalleryFindByChecksum(t *testing.T) { } func TestGalleryFindByPath(t *testing.T) { - withTxn(func(r models.Repository) error { - gqb := r.Gallery() + withTxn(func(ctx context.Context) error { + gqb := sqlite.GalleryReaderWriter const galleryIdx = 0 galleryPath := getGalleryStringValue(galleryIdx, "Path") - gallery, err := gqb.FindByPath(galleryPath) + gallery, err := gqb.FindByPath(ctx, galleryPath) if err != nil { t.Errorf("Error finding gallery: %s", err.Error()) @@ -80,7 +82,7 @@ func TestGalleryFindByPath(t *testing.T) { assert.Equal(t, galleryPath, gallery.Path.String) galleryPath = "not exist" - gallery, err = gqb.FindByPath(galleryPath) + gallery, err = gqb.FindByPath(ctx, galleryPath) if err != nil { t.Errorf("Error finding gallery: %s", err.Error()) @@ -93,11 +95,11 @@ func TestGalleryFindByPath(t *testing.T) { } func TestGalleryFindBySceneID(t *testing.T) { - withTxn(func(r models.Repository) error { - gqb := r.Gallery() + withTxn(func(ctx context.Context) error { + gqb := sqlite.GalleryReaderWriter sceneID := sceneIDs[sceneIdxWithGallery] - galleries, err := gqb.FindBySceneID(sceneID) + galleries, err := gqb.FindBySceneID(ctx, sceneID) if err != nil { t.Errorf("Error finding gallery: %s", err.Error()) @@ -105,7 +107,7 @@ func TestGalleryFindBySceneID(t *testing.T) { assert.Equal(t, getGalleryStringValue(galleryIdxWithScene, "Path"), galleries[0].Path.String) - galleries, err = gqb.FindBySceneID(0) + galleries, err = gqb.FindBySceneID(ctx, 0) if err != nil { t.Errorf("Error finding gallery: %s", err.Error()) @@ -118,24 +120,24 @@ func TestGalleryFindBySceneID(t *testing.T) { } func TestGalleryQueryQ(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { const galleryIdx = 0 q := getGalleryStringValue(galleryIdx, pathField) - sqb := r.Gallery() + sqb := sqlite.GalleryReaderWriter - galleryQueryQ(t, sqb, q, galleryIdx) + galleryQueryQ(ctx, t, sqb, q, galleryIdx) return nil }) } -func galleryQueryQ(t *testing.T, qb models.GalleryReader, q string, expectedGalleryIdx int) { +func galleryQueryQ(ctx context.Context, t *testing.T, qb models.GalleryReader, q string, expectedGalleryIdx int) { filter := models.FindFilterType{ Q: &q, } - galleries, _, err := qb.Query(nil, &filter) + galleries, _, err := qb.Query(ctx, nil, &filter) if err != nil { t.Errorf("Error querying gallery: %s", err.Error()) } @@ -146,7 +148,7 @@ func galleryQueryQ(t *testing.T, qb models.GalleryReader, q string, expectedGall // no Q should return all results filter.Q = nil - galleries, _, err = qb.Query(nil, &filter) + galleries, _, err = qb.Query(ctx, nil, &filter) if err != nil { t.Errorf("Error querying gallery: %s", err.Error()) } @@ -155,7 +157,7 @@ func galleryQueryQ(t *testing.T, qb models.GalleryReader, q string, expectedGall } func TestGalleryQueryPath(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { const galleryIdx = 1 galleryPath := getGalleryStringValue(galleryIdx, "Path") @@ -164,28 +166,28 @@ func TestGalleryQueryPath(t *testing.T) { Modifier: models.CriterionModifierEquals, } - verifyGalleriesPath(t, r.Gallery(), pathCriterion) + verifyGalleriesPath(ctx, t, sqlite.GalleryReaderWriter, pathCriterion) pathCriterion.Modifier = models.CriterionModifierNotEquals - verifyGalleriesPath(t, r.Gallery(), pathCriterion) + verifyGalleriesPath(ctx, t, sqlite.GalleryReaderWriter, pathCriterion) pathCriterion.Modifier = models.CriterionModifierMatchesRegex pathCriterion.Value = "gallery.*1_Path" - verifyGalleriesPath(t, r.Gallery(), pathCriterion) + verifyGalleriesPath(ctx, t, sqlite.GalleryReaderWriter, pathCriterion) pathCriterion.Modifier = models.CriterionModifierNotMatchesRegex - verifyGalleriesPath(t, r.Gallery(), pathCriterion) + verifyGalleriesPath(ctx, t, sqlite.GalleryReaderWriter, pathCriterion) return nil }) } -func verifyGalleriesPath(t *testing.T, sqb models.GalleryReader, pathCriterion models.StringCriterionInput) { +func verifyGalleriesPath(ctx context.Context, t *testing.T, sqb models.GalleryReader, pathCriterion models.StringCriterionInput) { galleryFilter := models.GalleryFilterType{ Path: &pathCriterion, } - galleries, _, err := sqb.Query(&galleryFilter, nil) + galleries, _, err := sqb.Query(ctx, &galleryFilter, nil) if err != nil { t.Errorf("Error querying gallery: %s", err.Error()) } @@ -215,10 +217,10 @@ func TestGalleryQueryPathOr(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 2) assert.Equal(t, gallery1Path, galleries[0].Path.String) @@ -246,10 +248,10 @@ func TestGalleryQueryPathAndRating(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 1) assert.Equal(t, galleryPath, galleries[0].Path.String) @@ -281,10 +283,10 @@ func TestGalleryQueryPathNotRating(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) for _, gallery := range galleries { verifyNullString(t, gallery.Path, pathCriterion) @@ -312,20 +314,20 @@ func TestGalleryIllegalQuery(t *testing.T) { Or: &subFilter, } - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter - _, _, err := sqb.Query(galleryFilter, nil) + _, _, err := sqb.Query(ctx, galleryFilter, nil) assert.NotNil(err) galleryFilter.Or = nil galleryFilter.Not = &subFilter - _, _, err = sqb.Query(galleryFilter, nil) + _, _, err = sqb.Query(ctx, galleryFilter, nil) assert.NotNil(err) galleryFilter.And = nil galleryFilter.Or = &subFilter - _, _, err = sqb.Query(galleryFilter, nil) + _, _, err = sqb.Query(ctx, galleryFilter, nil) assert.NotNil(err) return nil @@ -371,11 +373,11 @@ func TestGalleryQueryURL(t *testing.T) { } func verifyGalleryQuery(t *testing.T, filter models.GalleryFilterType, verifyFn func(s *models.Gallery)) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { t.Helper() - sqb := r.Gallery() + sqb := sqlite.GalleryReaderWriter - galleries := queryGallery(t, sqb, &filter, nil) + galleries := queryGallery(ctx, t, sqb, &filter, nil) // assume it should find at least one assert.Greater(t, len(galleries), 0) @@ -414,13 +416,13 @@ func TestGalleryQueryRating(t *testing.T) { } func verifyGalleriesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter galleryFilter := models.GalleryFilterType{ Rating: &ratingCriterion, } - galleries, _, err := sqb.Query(&galleryFilter, nil) + galleries, _, err := sqb.Query(ctx, &galleryFilter, nil) if err != nil { t.Errorf("Error querying gallery: %s", err.Error()) } @@ -434,8 +436,8 @@ func verifyGalleriesRating(t *testing.T, ratingCriterion models.IntCriterionInpu } func TestGalleryQueryIsMissingScene(t *testing.T) { - withTxn(func(r models.Repository) error { - qb := r.Gallery() + withTxn(func(ctx context.Context) error { + qb := sqlite.GalleryReaderWriter isMissing := "scenes" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -446,7 +448,7 @@ func TestGalleryQueryIsMissingScene(t *testing.T) { Q: &q, } - galleries, _, err := qb.Query(&galleryFilter, &findFilter) + galleries, _, err := qb.Query(ctx, &galleryFilter, &findFilter) if err != nil { t.Errorf("Error querying gallery: %s", err.Error()) } @@ -454,7 +456,7 @@ func TestGalleryQueryIsMissingScene(t *testing.T) { assert.Len(t, galleries, 0) findFilter.Q = nil - galleries, _, err = qb.Query(&galleryFilter, &findFilter) + galleries, _, err = qb.Query(ctx, &galleryFilter, &findFilter) if err != nil { t.Errorf("Error querying gallery: %s", err.Error()) } @@ -468,8 +470,8 @@ func TestGalleryQueryIsMissingScene(t *testing.T) { }) } -func queryGallery(t *testing.T, sqb models.GalleryReader, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) []*models.Gallery { - galleries, _, err := sqb.Query(galleryFilter, findFilter) +func queryGallery(ctx context.Context, t *testing.T, sqb models.GalleryReader, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) []*models.Gallery { + galleries, _, err := sqb.Query(ctx, galleryFilter, findFilter) if err != nil { t.Errorf("Error querying gallery: %s", err.Error()) } @@ -478,8 +480,8 @@ func queryGallery(t *testing.T, sqb models.GalleryReader, galleryFilter *models. } func TestGalleryQueryIsMissingStudio(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter isMissing := "studio" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -490,12 +492,12 @@ func TestGalleryQueryIsMissingStudio(t *testing.T) { Q: &q, } - galleries := queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) findFilter.Q = nil - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) // ensure non of the ids equal the one with studio for _, gallery := range galleries { @@ -507,8 +509,8 @@ func TestGalleryQueryIsMissingStudio(t *testing.T) { } func TestGalleryQueryIsMissingPerformers(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter isMissing := "performers" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -519,12 +521,12 @@ func TestGalleryQueryIsMissingPerformers(t *testing.T) { Q: &q, } - galleries := queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) findFilter.Q = nil - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.True(t, len(galleries) > 0) @@ -538,8 +540,8 @@ func TestGalleryQueryIsMissingPerformers(t *testing.T) { } func TestGalleryQueryIsMissingTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter isMissing := "tags" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -550,12 +552,12 @@ func TestGalleryQueryIsMissingTags(t *testing.T) { Q: &q, } - galleries := queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) findFilter.Q = nil - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.True(t, len(galleries) > 0) @@ -564,14 +566,14 @@ func TestGalleryQueryIsMissingTags(t *testing.T) { } func TestGalleryQueryIsMissingDate(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter isMissing := "date" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) // three in four scenes have no date assert.Len(t, galleries, int(math.Ceil(float64(totalGalleries)/4*3))) @@ -586,8 +588,8 @@ func TestGalleryQueryIsMissingDate(t *testing.T) { } func TestGalleryQueryPerformers(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter performerCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(performerIDs[performerIdxWithGallery]), @@ -600,7 +602,7 @@ func TestGalleryQueryPerformers(t *testing.T) { Performers: &performerCriterion, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 2) @@ -617,7 +619,7 @@ func TestGalleryQueryPerformers(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - galleries = queryGallery(t, sqb, &galleryFilter, nil) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdxWithTwoPerformers], galleries[0].ID) @@ -634,7 +636,7 @@ func TestGalleryQueryPerformers(t *testing.T) { Q: &q, } - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) return nil @@ -642,8 +644,8 @@ func TestGalleryQueryPerformers(t *testing.T) { } func TestGalleryQueryTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithGallery]), @@ -656,7 +658,7 @@ func TestGalleryQueryTags(t *testing.T) { Tags: &tagCriterion, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 2) // ensure ids are correct @@ -672,7 +674,7 @@ func TestGalleryQueryTags(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - galleries = queryGallery(t, sqb, &galleryFilter, nil) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdxWithTwoTags], galleries[0].ID) @@ -689,7 +691,7 @@ func TestGalleryQueryTags(t *testing.T) { Q: &q, } - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) return nil @@ -697,8 +699,8 @@ func TestGalleryQueryTags(t *testing.T) { } func TestGalleryQueryStudio(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[studioIdxWithGallery]), @@ -710,7 +712,7 @@ func TestGalleryQueryStudio(t *testing.T) { Studios: &studioCriterion, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 1) @@ -729,7 +731,7 @@ func TestGalleryQueryStudio(t *testing.T) { Q: &q, } - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) return nil @@ -737,8 +739,8 @@ func TestGalleryQueryStudio(t *testing.T) { } func TestGalleryQueryStudioDepth(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter depth := 2 studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ @@ -752,16 +754,16 @@ func TestGalleryQueryStudioDepth(t *testing.T) { Studios: &studioCriterion, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 1) depth = 1 - galleries = queryGallery(t, sqb, &galleryFilter, nil) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 0) studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])} - galleries = queryGallery(t, sqb, &galleryFilter, nil) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 1) // ensure id is correct @@ -782,15 +784,15 @@ func TestGalleryQueryStudioDepth(t *testing.T) { Q: &q, } - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) depth = 1 - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 1) studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])} - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) return nil @@ -798,8 +800,8 @@ func TestGalleryQueryStudioDepth(t *testing.T) { } func TestGalleryQueryPerformerTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), @@ -812,7 +814,7 @@ func TestGalleryQueryPerformerTags(t *testing.T) { PerformerTags: &tagCriterion, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 2) // ensure ids are correct @@ -828,7 +830,7 @@ func TestGalleryQueryPerformerTags(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - galleries = queryGallery(t, sqb, &galleryFilter, nil) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdxWithPerformerTwoTags], galleries[0].ID) @@ -845,7 +847,7 @@ func TestGalleryQueryPerformerTags(t *testing.T) { Q: &q, } - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) tagCriterion = models.HierarchicalMultiCriterionInput{ @@ -853,22 +855,22 @@ func TestGalleryQueryPerformerTags(t *testing.T) { } q = getGalleryStringValue(galleryIdx1WithImage, titleField) - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdx1WithImage], galleries[0].ID) q = getGalleryStringValue(galleryIdxWithPerformerTag, titleField) - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) tagCriterion.Modifier = models.CriterionModifierNotNull - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdxWithPerformerTag], galleries[0].ID) q = getGalleryStringValue(galleryIdx1WithImage, titleField) - galleries = queryGallery(t, sqb, &galleryFilter, &findFilter) + galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter) assert.Len(t, galleries, 0) return nil @@ -895,17 +897,17 @@ func TestGalleryQueryTagCount(t *testing.T) { } func verifyGalleriesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter galleryFilter := models.GalleryFilterType{ TagCount: &tagCountCriterion, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Greater(t, len(galleries), 0) for _, gallery := range galleries { - ids, err := sqb.GetTagIDs(gallery.ID) + ids, err := sqb.GetTagIDs(ctx, gallery.ID) if err != nil { return err } @@ -936,17 +938,17 @@ func TestGalleryQueryPerformerCount(t *testing.T) { } func verifyGalleriesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter galleryFilter := models.GalleryFilterType{ PerformerCount: &performerCountCriterion, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Greater(t, len(galleries), 0) for _, gallery := range galleries { - ids, err := sqb.GetPerformerIDs(gallery.ID) + ids, err := sqb.GetPerformerIDs(ctx, gallery.ID) if err != nil { return err } @@ -958,8 +960,8 @@ func verifyGalleriesPerformerCount(t *testing.T, performerCountCriterion models. } func TestGalleryQueryAverageResolution(t *testing.T) { - withTxn(func(r models.Repository) error { - qb := r.Gallery() + withTxn(func(ctx context.Context) error { + qb := sqlite.GalleryReaderWriter resolution := models.ResolutionEnumLow galleryFilter := models.GalleryFilterType{ AverageResolution: &models.ResolutionCriterionInput{ @@ -969,7 +971,7 @@ func TestGalleryQueryAverageResolution(t *testing.T) { } // not verifying average - just ensure we get at least one - galleries := queryGallery(t, qb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, qb, &galleryFilter, nil) assert.Greater(t, len(galleries), 0) return nil @@ -996,19 +998,19 @@ func TestGalleryQueryImageCount(t *testing.T) { } func verifyGalleriesImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Gallery() + withTxn(func(ctx context.Context) error { + sqb := sqlite.GalleryReaderWriter galleryFilter := models.GalleryFilterType{ ImageCount: &imageCountCriterion, } - galleries := queryGallery(t, sqb, &galleryFilter, nil) + galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) assert.Greater(t, len(galleries), -1) for _, gallery := range galleries { pp := 0 - result, err := r.Image().Query(models.ImageQueryOptions{ + result, err := sqlite.ImageReaderWriter.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: &models.FindFilterType{ PerPage: &pp, diff --git a/pkg/sqlite/image.go b/pkg/sqlite/image.go index d2b3adb8f..3238595d7 100644 --- a/pkg/sqlite/image.go +++ b/pkg/sqlite/image.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -29,45 +30,42 @@ type imageQueryBuilder struct { repository } -func NewImageReaderWriter(tx dbi) *imageQueryBuilder { - return &imageQueryBuilder{ - repository{ - tx: tx, - tableName: imageTable, - idColumn: idColumn, - }, - } +var ImageReaderWriter = &imageQueryBuilder{ + repository{ + tableName: imageTable, + idColumn: idColumn, + }, } -func (qb *imageQueryBuilder) Create(newObject models.Image) (*models.Image, error) { +func (qb *imageQueryBuilder) Create(ctx context.Context, newObject models.Image) (*models.Image, error) { var ret models.Image - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *imageQueryBuilder) Update(updatedObject models.ImagePartial) (*models.Image, error) { +func (qb *imageQueryBuilder) Update(ctx context.Context, updatedObject models.ImagePartial) (*models.Image, error) { const partial = true - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.find(updatedObject.ID) + return qb.find(ctx, updatedObject.ID) } -func (qb *imageQueryBuilder) UpdateFull(updatedObject models.Image) (*models.Image, error) { +func (qb *imageQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Image) (*models.Image, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.find(updatedObject.ID) + return qb.find(ctx, updatedObject.ID) } -func (qb *imageQueryBuilder) IncrementOCounter(id int) (int, error) { - _, err := qb.tx.Exec( +func (qb *imageQueryBuilder) IncrementOCounter(ctx context.Context, id int) (int, error) { + _, err := qb.tx.Exec(ctx, `UPDATE `+imageTable+` SET o_counter = o_counter + 1 WHERE `+imageTable+`.id = ?`, id, ) @@ -75,7 +73,7 @@ func (qb *imageQueryBuilder) IncrementOCounter(id int) (int, error) { return 0, err } - image, err := qb.find(id) + image, err := qb.find(ctx, id) if err != nil { return 0, err } @@ -83,8 +81,8 @@ func (qb *imageQueryBuilder) IncrementOCounter(id int) (int, error) { return image.OCounter, nil } -func (qb *imageQueryBuilder) DecrementOCounter(id int) (int, error) { - _, err := qb.tx.Exec( +func (qb *imageQueryBuilder) DecrementOCounter(ctx context.Context, id int) (int, error) { + _, err := qb.tx.Exec(ctx, `UPDATE `+imageTable+` SET o_counter = o_counter - 1 WHERE `+imageTable+`.id = ? and `+imageTable+`.o_counter > 0`, id, ) @@ -92,7 +90,7 @@ func (qb *imageQueryBuilder) DecrementOCounter(id int) (int, error) { return 0, err } - image, err := qb.find(id) + image, err := qb.find(ctx, id) if err != nil { return 0, err } @@ -100,8 +98,8 @@ func (qb *imageQueryBuilder) DecrementOCounter(id int) (int, error) { return image.OCounter, nil } -func (qb *imageQueryBuilder) ResetOCounter(id int) (int, error) { - _, err := qb.tx.Exec( +func (qb *imageQueryBuilder) ResetOCounter(ctx context.Context, id int) (int, error) { + _, err := qb.tx.Exec(ctx, `UPDATE `+imageTable+` SET o_counter = 0 WHERE `+imageTable+`.id = ?`, id, ) @@ -109,7 +107,7 @@ func (qb *imageQueryBuilder) ResetOCounter(id int) (int, error) { return 0, err } - image, err := qb.find(id) + image, err := qb.find(ctx, id) if err != nil { return 0, err } @@ -117,18 +115,18 @@ func (qb *imageQueryBuilder) ResetOCounter(id int) (int, error) { return image.OCounter, nil } -func (qb *imageQueryBuilder) Destroy(id int) error { - return qb.destroyExisting([]int{id}) +func (qb *imageQueryBuilder) Destroy(ctx context.Context, id int) error { + return qb.destroyExisting(ctx, []int{id}) } -func (qb *imageQueryBuilder) Find(id int) (*models.Image, error) { - return qb.find(id) +func (qb *imageQueryBuilder) Find(ctx context.Context, id int) (*models.Image, error) { + return qb.find(ctx, id) } -func (qb *imageQueryBuilder) FindMany(ids []int) ([]*models.Image, error) { +func (qb *imageQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Image, error) { var images []*models.Image for _, id := range ids { - image, err := qb.Find(id) + image, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -143,9 +141,9 @@ func (qb *imageQueryBuilder) FindMany(ids []int) ([]*models.Image, error) { return images, nil } -func (qb *imageQueryBuilder) find(id int) (*models.Image, error) { +func (qb *imageQueryBuilder) find(ctx context.Context, id int) (*models.Image, error) { var ret models.Image - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -154,43 +152,43 @@ func (qb *imageQueryBuilder) find(id int) (*models.Image, error) { return &ret, nil } -func (qb *imageQueryBuilder) FindByChecksum(checksum string) (*models.Image, error) { +func (qb *imageQueryBuilder) FindByChecksum(ctx context.Context, checksum string) (*models.Image, error) { query := "SELECT * FROM images WHERE checksum = ? LIMIT 1" args := []interface{}{checksum} - return qb.queryImage(query, args) + return qb.queryImage(ctx, query, args) } -func (qb *imageQueryBuilder) FindByPath(path string) (*models.Image, error) { +func (qb *imageQueryBuilder) FindByPath(ctx context.Context, path string) (*models.Image, error) { query := selectAll(imageTable) + "WHERE path = ? LIMIT 1" args := []interface{}{path} - return qb.queryImage(query, args) + return qb.queryImage(ctx, query, args) } -func (qb *imageQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Image, error) { +func (qb *imageQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Image, error) { args := []interface{}{galleryID} sort := "path" sortDir := models.SortDirectionEnumAsc - return qb.queryImages(imagesForGalleryQuery+qb.getImageSort(&models.FindFilterType{ + return qb.queryImages(ctx, imagesForGalleryQuery+qb.getImageSort(&models.FindFilterType{ Sort: &sort, Direction: &sortDir, }), args) } -func (qb *imageQueryBuilder) CountByGalleryID(galleryID int) (int, error) { +func (qb *imageQueryBuilder) CountByGalleryID(ctx context.Context, galleryID int) (int, error) { args := []interface{}{galleryID} - return qb.runCountQuery(qb.buildCountQuery(countImagesForGalleryQuery), args) + return qb.runCountQuery(ctx, qb.buildCountQuery(countImagesForGalleryQuery), args) } -func (qb *imageQueryBuilder) Count() (int, error) { - return qb.runCountQuery(qb.buildCountQuery("SELECT images.id FROM images"), nil) +func (qb *imageQueryBuilder) Count(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT images.id FROM images"), nil) } -func (qb *imageQueryBuilder) Size() (float64, error) { - return qb.runSumQuery("SELECT SUM(cast(size as double)) as sum FROM images", nil) +func (qb *imageQueryBuilder) Size(ctx context.Context) (float64, error) { + return qb.runSumQuery(ctx, "SELECT SUM(cast(size as double)) as sum FROM images", nil) } -func (qb *imageQueryBuilder) All() ([]*models.Image, error) { - return qb.queryImages(selectAll(imageTable)+qb.getImageSort(nil), nil) +func (qb *imageQueryBuilder) All(ctx context.Context) ([]*models.Image, error) { + return qb.queryImages(ctx, selectAll(imageTable)+qb.getImageSort(nil), nil) } func (qb *imageQueryBuilder) validateFilter(imageFilter *models.ImageFilterType) error { @@ -224,41 +222,41 @@ func (qb *imageQueryBuilder) validateFilter(imageFilter *models.ImageFilterType) return nil } -func (qb *imageQueryBuilder) makeFilter(imageFilter *models.ImageFilterType) *filterBuilder { +func (qb *imageQueryBuilder) makeFilter(ctx context.Context, imageFilter *models.ImageFilterType) *filterBuilder { query := &filterBuilder{} if imageFilter.And != nil { - query.and(qb.makeFilter(imageFilter.And)) + query.and(qb.makeFilter(ctx, imageFilter.And)) } if imageFilter.Or != nil { - query.or(qb.makeFilter(imageFilter.Or)) + query.or(qb.makeFilter(ctx, imageFilter.Or)) } if imageFilter.Not != nil { - query.not(qb.makeFilter(imageFilter.Not)) + query.not(qb.makeFilter(ctx, imageFilter.Not)) } - query.handleCriterion(stringCriterionHandler(imageFilter.Checksum, "images.checksum")) - query.handleCriterion(stringCriterionHandler(imageFilter.Title, "images.title")) - query.handleCriterion(stringCriterionHandler(imageFilter.Path, "images.path")) - query.handleCriterion(intCriterionHandler(imageFilter.Rating, "images.rating")) - query.handleCriterion(intCriterionHandler(imageFilter.OCounter, "images.o_counter")) - query.handleCriterion(boolCriterionHandler(imageFilter.Organized, "images.organized")) - query.handleCriterion(resolutionCriterionHandler(imageFilter.Resolution, "images.height", "images.width")) - query.handleCriterion(imageIsMissingCriterionHandler(qb, imageFilter.IsMissing)) + query.handleCriterion(ctx, stringCriterionHandler(imageFilter.Checksum, "images.checksum")) + query.handleCriterion(ctx, stringCriterionHandler(imageFilter.Title, "images.title")) + query.handleCriterion(ctx, stringCriterionHandler(imageFilter.Path, "images.path")) + query.handleCriterion(ctx, intCriterionHandler(imageFilter.Rating, "images.rating")) + query.handleCriterion(ctx, intCriterionHandler(imageFilter.OCounter, "images.o_counter")) + query.handleCriterion(ctx, boolCriterionHandler(imageFilter.Organized, "images.organized")) + query.handleCriterion(ctx, resolutionCriterionHandler(imageFilter.Resolution, "images.height", "images.width")) + query.handleCriterion(ctx, imageIsMissingCriterionHandler(qb, imageFilter.IsMissing)) - query.handleCriterion(imageTagsCriterionHandler(qb, imageFilter.Tags)) - query.handleCriterion(imageTagCountCriterionHandler(qb, imageFilter.TagCount)) - query.handleCriterion(imageGalleriesCriterionHandler(qb, imageFilter.Galleries)) - query.handleCriterion(imagePerformersCriterionHandler(qb, imageFilter.Performers)) - query.handleCriterion(imagePerformerCountCriterionHandler(qb, imageFilter.PerformerCount)) - query.handleCriterion(imageStudioCriterionHandler(qb, imageFilter.Studios)) - query.handleCriterion(imagePerformerTagsCriterionHandler(qb, imageFilter.PerformerTags)) - query.handleCriterion(imagePerformerFavoriteCriterionHandler(imageFilter.PerformerFavorite)) + query.handleCriterion(ctx, imageTagsCriterionHandler(qb, imageFilter.Tags)) + query.handleCriterion(ctx, imageTagCountCriterionHandler(qb, imageFilter.TagCount)) + query.handleCriterion(ctx, imageGalleriesCriterionHandler(qb, imageFilter.Galleries)) + query.handleCriterion(ctx, imagePerformersCriterionHandler(qb, imageFilter.Performers)) + query.handleCriterion(ctx, imagePerformerCountCriterionHandler(qb, imageFilter.PerformerCount)) + query.handleCriterion(ctx, imageStudioCriterionHandler(qb, imageFilter.Studios)) + query.handleCriterion(ctx, imagePerformerTagsCriterionHandler(qb, imageFilter.PerformerTags)) + query.handleCriterion(ctx, imagePerformerFavoriteCriterionHandler(imageFilter.PerformerFavorite)) return query } -func (qb *imageQueryBuilder) makeQuery(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) { +func (qb *imageQueryBuilder) makeQuery(ctx context.Context, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) { if imageFilter == nil { imageFilter = &models.ImageFilterType{} } @@ -277,7 +275,7 @@ func (qb *imageQueryBuilder) makeQuery(imageFilter *models.ImageFilterType, find if err := qb.validateFilter(imageFilter); err != nil { return nil, err } - filter := qb.makeFilter(imageFilter) + filter := qb.makeFilter(ctx, imageFilter) query.addFilter(filter) @@ -286,18 +284,18 @@ func (qb *imageQueryBuilder) makeQuery(imageFilter *models.ImageFilterType, find return &query, nil } -func (qb *imageQueryBuilder) Query(options models.ImageQueryOptions) (*models.ImageQueryResult, error) { - query, err := qb.makeQuery(options.ImageFilter, options.FindFilter) +func (qb *imageQueryBuilder) Query(ctx context.Context, options models.ImageQueryOptions) (*models.ImageQueryResult, error) { + query, err := qb.makeQuery(ctx, options.ImageFilter, options.FindFilter) if err != nil { return nil, err } - result, err := qb.queryGroupedFields(options, *query) + result, err := qb.queryGroupedFields(ctx, options, *query) if err != nil { return nil, fmt.Errorf("error querying aggregate fields: %w", err) } - idsResult, err := query.findIDs() + idsResult, err := query.findIDs(ctx) if err != nil { return nil, fmt.Errorf("error finding IDs: %w", err) } @@ -306,7 +304,7 @@ func (qb *imageQueryBuilder) Query(options models.ImageQueryOptions) (*models.Im return result, nil } -func (qb *imageQueryBuilder) queryGroupedFields(options models.ImageQueryOptions, query queryBuilder) (*models.ImageQueryResult, error) { +func (qb *imageQueryBuilder) queryGroupedFields(ctx context.Context, options models.ImageQueryOptions, query queryBuilder) (*models.ImageQueryResult, error) { if !options.Count && !options.Megapixels && !options.TotalSize { // nothing to do - return empty result return models.NewImageQueryResult(qb), nil @@ -336,7 +334,7 @@ func (qb *imageQueryBuilder) queryGroupedFields(options models.ImageQueryOptions Megapixels float64 Size float64 }{} - if err := qb.repository.queryStruct(aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil { + if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil { return nil, err } @@ -347,17 +345,17 @@ func (qb *imageQueryBuilder) queryGroupedFields(options models.ImageQueryOptions return ret, nil } -func (qb *imageQueryBuilder) QueryCount(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) { - query, err := qb.makeQuery(imageFilter, findFilter) +func (qb *imageQueryBuilder) QueryCount(ctx context.Context, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) { + query, err := qb.makeQuery(ctx, imageFilter, findFilter) if err != nil { return 0, err } - return query.executeCount() + return query.executeCount(ctx) } func imageIsMissingCriterionHandler(qb *imageQueryBuilder, isMissing *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if isMissing != nil && *isMissing != "" { switch *isMissing { case "studio": @@ -453,7 +451,7 @@ func imagePerformerCountCriterionHandler(qb *imageQueryBuilder, performerCount * } func imagePerformerFavoriteCriterionHandler(performerfavorite *bool) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if performerfavorite != nil { f.addLeftJoin("performers_images", "", "images.id = performers_images.image_id") @@ -487,7 +485,7 @@ func imageStudioCriterionHandler(qb *imageQueryBuilder, studios *models.Hierarch } func imagePerformerTagsCriterionHandler(qb *imageQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if tags != nil { if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { var notClause string @@ -506,7 +504,7 @@ func imagePerformerTagsCriterionHandler(qb *imageQueryBuilder, tags *models.Hier return } - valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) f.addWith(`performer_tags AS ( SELECT pi.image_id, t.column1 AS root_tag_id FROM performers_images pi @@ -538,17 +536,17 @@ func (qb *imageQueryBuilder) getImageSort(findFilter *models.FindFilterType) str } } -func (qb *imageQueryBuilder) queryImage(query string, args []interface{}) (*models.Image, error) { - results, err := qb.queryImages(query, args) +func (qb *imageQueryBuilder) queryImage(ctx context.Context, query string, args []interface{}) (*models.Image, error) { + results, err := qb.queryImages(ctx, query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *imageQueryBuilder) queryImages(query string, args []interface{}) ([]*models.Image, error) { +func (qb *imageQueryBuilder) queryImages(ctx context.Context, query string, args []interface{}) ([]*models.Image, error) { var ret models.Images - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } @@ -566,13 +564,13 @@ func (qb *imageQueryBuilder) galleriesRepository() *joinRepository { } } -func (qb *imageQueryBuilder) GetGalleryIDs(imageID int) ([]int, error) { - return qb.galleriesRepository().getIDs(imageID) +func (qb *imageQueryBuilder) GetGalleryIDs(ctx context.Context, imageID int) ([]int, error) { + return qb.galleriesRepository().getIDs(ctx, imageID) } -func (qb *imageQueryBuilder) UpdateGalleries(imageID int, galleryIDs []int) error { +func (qb *imageQueryBuilder) UpdateGalleries(ctx context.Context, imageID int, galleryIDs []int) error { // Delete the existing joins and then create new ones - return qb.galleriesRepository().replace(imageID, galleryIDs) + return qb.galleriesRepository().replace(ctx, imageID, galleryIDs) } func (qb *imageQueryBuilder) performersRepository() *joinRepository { @@ -586,13 +584,13 @@ func (qb *imageQueryBuilder) performersRepository() *joinRepository { } } -func (qb *imageQueryBuilder) GetPerformerIDs(imageID int) ([]int, error) { - return qb.performersRepository().getIDs(imageID) +func (qb *imageQueryBuilder) GetPerformerIDs(ctx context.Context, imageID int) ([]int, error) { + return qb.performersRepository().getIDs(ctx, imageID) } -func (qb *imageQueryBuilder) UpdatePerformers(imageID int, performerIDs []int) error { +func (qb *imageQueryBuilder) UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error { // Delete the existing joins and then create new ones - return qb.performersRepository().replace(imageID, performerIDs) + return qb.performersRepository().replace(ctx, imageID, performerIDs) } func (qb *imageQueryBuilder) tagsRepository() *joinRepository { @@ -606,11 +604,11 @@ func (qb *imageQueryBuilder) tagsRepository() *joinRepository { } } -func (qb *imageQueryBuilder) GetTagIDs(imageID int) ([]int, error) { - return qb.tagsRepository().getIDs(imageID) +func (qb *imageQueryBuilder) GetTagIDs(ctx context.Context, imageID int) ([]int, error) { + return qb.tagsRepository().getIDs(ctx, imageID) } -func (qb *imageQueryBuilder) UpdateTags(imageID int, tagIDs []int) error { +func (qb *imageQueryBuilder) UpdateTags(ctx context.Context, imageID int, tagIDs []int) error { // Delete the existing joins and then create new ones - return qb.tagsRepository().replace(imageID, tagIDs) + return qb.tagsRepository().replace(ctx, imageID, tagIDs) } diff --git a/pkg/sqlite/image_test.go b/pkg/sqlite/image_test.go index 552db2cdf..3c131ed56 100644 --- a/pkg/sqlite/image_test.go +++ b/pkg/sqlite/image_test.go @@ -4,6 +4,7 @@ package sqlite_test import ( + "context" "database/sql" "strconv" "testing" @@ -11,16 +12,17 @@ import ( "github.com/stretchr/testify/assert" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" ) func TestImageFind(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { // assume that the first image is imageWithGalleryPath - sqb := r.Image() + sqb := sqlite.ImageReaderWriter const imageIdx = 0 imageID := imageIDs[imageIdx] - image, err := sqb.Find(imageID) + image, err := sqb.Find(ctx, imageID) if err != nil { t.Errorf("Error finding image: %s", err.Error()) @@ -29,7 +31,7 @@ func TestImageFind(t *testing.T) { assert.Equal(t, getImageStringValue(imageIdx, "Path"), image.Path) imageID = 0 - image, err = sqb.Find(imageID) + image, err = sqb.Find(ctx, imageID) if err != nil { t.Errorf("Error finding image: %s", err.Error()) @@ -42,12 +44,12 @@ func TestImageFind(t *testing.T) { } func TestImageFindByPath(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter const imageIdx = 1 imagePath := getImageStringValue(imageIdx, "Path") - image, err := sqb.FindByPath(imagePath) + image, err := sqb.FindByPath(ctx, imagePath) if err != nil { t.Errorf("Error finding image: %s", err.Error()) @@ -57,7 +59,7 @@ func TestImageFindByPath(t *testing.T) { assert.Equal(t, imagePath, image.Path) imagePath = "not exist" - image, err = sqb.FindByPath(imagePath) + image, err = sqb.FindByPath(ctx, imagePath) if err != nil { t.Errorf("Error finding image: %s", err.Error()) @@ -70,10 +72,10 @@ func TestImageFindByPath(t *testing.T) { } func TestImageFindByGalleryID(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter - images, err := sqb.FindByGalleryID(galleryIDs[galleryIdxWithTwoImages]) + images, err := sqb.FindByGalleryID(ctx, galleryIDs[galleryIdxWithTwoImages]) if err != nil { t.Errorf("Error finding images: %s", err.Error()) @@ -83,7 +85,7 @@ func TestImageFindByGalleryID(t *testing.T) { assert.Equal(t, imageIDs[imageIdx1WithGallery], images[0].ID) assert.Equal(t, imageIDs[imageIdx2WithGallery], images[1].ID) - images, err = sqb.FindByGalleryID(galleryIDs[galleryIdxWithScene]) + images, err = sqb.FindByGalleryID(ctx, galleryIDs[galleryIdxWithScene]) if err != nil { t.Errorf("Error finding images: %s", err.Error()) @@ -96,21 +98,21 @@ func TestImageFindByGalleryID(t *testing.T) { } func TestImageQueryQ(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { const imageIdx = 2 q := getImageStringValue(imageIdx, titleField) - sqb := r.Image() + sqb := sqlite.ImageReaderWriter - imageQueryQ(t, sqb, q, imageIdx) + imageQueryQ(ctx, t, sqb, q, imageIdx) return nil }) } -func queryImagesWithCount(sqb models.ImageReader, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, int, error) { - result, err := sqb.Query(models.ImageQueryOptions{ +func queryImagesWithCount(ctx context.Context, sqb models.ImageReader, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, int, error) { + result, err := sqb.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: findFilter, Count: true, @@ -121,7 +123,7 @@ func queryImagesWithCount(sqb models.ImageReader, imageFilter *models.ImageFilte return nil, 0, err } - images, err := result.Resolve() + images, err := result.Resolve(ctx) if err != nil { return nil, 0, err } @@ -129,17 +131,17 @@ func queryImagesWithCount(sqb models.ImageReader, imageFilter *models.ImageFilte return images, result.Count, nil } -func imageQueryQ(t *testing.T, sqb models.ImageReader, q string, expectedImageIdx int) { +func imageQueryQ(ctx context.Context, t *testing.T, sqb models.ImageReader, q string, expectedImageIdx int) { filter := models.FindFilterType{ Q: &q, } - images := queryImages(t, sqb, nil, &filter) + images := queryImages(ctx, t, sqb, nil, &filter) assert.Len(t, images, 1) image := images[0] assert.Equal(t, imageIDs[expectedImageIdx], image.ID) - count, err := sqb.QueryCount(nil, &filter) + count, err := sqb.QueryCount(ctx, nil, &filter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -147,7 +149,7 @@ func imageQueryQ(t *testing.T, sqb models.ImageReader, q string, expectedImageId // no Q should return all results filter.Q = nil - images = queryImages(t, sqb, nil, &filter) + images = queryImages(ctx, t, sqb, nil, &filter) assert.Len(t, images, totalImages) } @@ -175,13 +177,13 @@ func TestImageQueryPath(t *testing.T) { } func verifyImagePath(t *testing.T, pathCriterion models.StringCriterionInput, expected int) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter imageFilter := models.ImageFilterType{ Path: &pathCriterion, } - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Equal(t, expected, len(images), "number of returned images") @@ -213,10 +215,10 @@ func TestImageQueryPathOr(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 2) assert.Equal(t, image1Path, images[0].Path) @@ -244,10 +246,10 @@ func TestImageQueryPathAndRating(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 1) assert.Equal(t, imagePath, images[0].Path) @@ -279,10 +281,10 @@ func TestImageQueryPathNotRating(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) for _, image := range images { verifyString(t, image.Path, pathCriterion) @@ -310,20 +312,20 @@ func TestImageIllegalQuery(t *testing.T) { Or: &subFilter, } - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter - _, _, err := queryImagesWithCount(sqb, imageFilter, nil) + _, _, err := queryImagesWithCount(ctx, sqb, imageFilter, nil) assert.NotNil(err) imageFilter.Or = nil imageFilter.Not = &subFilter - _, _, err = queryImagesWithCount(sqb, imageFilter, nil) + _, _, err = queryImagesWithCount(ctx, sqb, imageFilter, nil) assert.NotNil(err) imageFilter.And = nil imageFilter.Or = &subFilter - _, _, err = queryImagesWithCount(sqb, imageFilter, nil) + _, _, err = queryImagesWithCount(ctx, sqb, imageFilter, nil) assert.NotNil(err) return nil @@ -356,13 +358,13 @@ func TestImageQueryRating(t *testing.T) { } func verifyImagesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter imageFilter := models.ImageFilterType{ Rating: &ratingCriterion, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, nil) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -395,13 +397,13 @@ func TestImageQueryOCounter(t *testing.T) { } func verifyImagesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter imageFilter := models.ImageFilterType{ OCounter: &oCounterCriterion, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, nil) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -424,8 +426,8 @@ func TestImageQueryResolution(t *testing.T) { } func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter imageFilter := models.ImageFilterType{ Resolution: &models.ResolutionCriterionInput{ Value: resolution, @@ -433,7 +435,7 @@ func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) { }, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, nil) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -465,8 +467,8 @@ func verifyImageResolution(t *testing.T, height sql.NullInt64, resolution models } func TestImageQueryIsMissingGalleries(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter isMissing := "galleries" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -477,7 +479,7 @@ func TestImageQueryIsMissingGalleries(t *testing.T) { Q: &q, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -485,7 +487,7 @@ func TestImageQueryIsMissingGalleries(t *testing.T) { assert.Len(t, images, 0) findFilter.Q = nil - images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -502,8 +504,8 @@ func TestImageQueryIsMissingGalleries(t *testing.T) { } func TestImageQueryIsMissingStudio(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter isMissing := "studio" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -514,7 +516,7 @@ func TestImageQueryIsMissingStudio(t *testing.T) { Q: &q, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -522,7 +524,7 @@ func TestImageQueryIsMissingStudio(t *testing.T) { assert.Len(t, images, 0) findFilter.Q = nil - images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -537,8 +539,8 @@ func TestImageQueryIsMissingStudio(t *testing.T) { } func TestImageQueryIsMissingPerformers(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter isMissing := "performers" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -549,7 +551,7 @@ func TestImageQueryIsMissingPerformers(t *testing.T) { Q: &q, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -557,7 +559,7 @@ func TestImageQueryIsMissingPerformers(t *testing.T) { assert.Len(t, images, 0) findFilter.Q = nil - images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -574,8 +576,8 @@ func TestImageQueryIsMissingPerformers(t *testing.T) { } func TestImageQueryIsMissingTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter isMissing := "tags" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -586,7 +588,7 @@ func TestImageQueryIsMissingTags(t *testing.T) { Q: &q, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -594,7 +596,7 @@ func TestImageQueryIsMissingTags(t *testing.T) { assert.Len(t, images, 0) findFilter.Q = nil - images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -606,14 +608,14 @@ func TestImageQueryIsMissingTags(t *testing.T) { } func TestImageQueryIsMissingRating(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter isMissing := "rating" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, nil) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -630,8 +632,8 @@ func TestImageQueryIsMissingRating(t *testing.T) { } func TestImageQueryGallery(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter galleryCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(galleryIDs[galleryIdxWithImage]), @@ -643,7 +645,7 @@ func TestImageQueryGallery(t *testing.T) { Galleries: &galleryCriterion, } - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 1) // ensure ids are correct @@ -659,7 +661,7 @@ func TestImageQueryGallery(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - images = queryImages(t, sqb, &imageFilter, nil) + images = queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithTwoGalleries], images[0].ID) @@ -676,11 +678,11 @@ func TestImageQueryGallery(t *testing.T) { Q: &q, } - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) q = getImageStringValue(imageIdxWithPerformer, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 1) return nil @@ -688,8 +690,8 @@ func TestImageQueryGallery(t *testing.T) { } func TestImageQueryPerformers(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter performerCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(performerIDs[performerIdxWithImage]), @@ -702,7 +704,7 @@ func TestImageQueryPerformers(t *testing.T) { Performers: &performerCriterion, } - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 2) // ensure ids are correct @@ -718,7 +720,7 @@ func TestImageQueryPerformers(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - images = queryImages(t, sqb, &imageFilter, nil) + images = queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithTwoPerformers], images[0].ID) @@ -734,7 +736,7 @@ func TestImageQueryPerformers(t *testing.T) { Q: &q, } - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) performerCriterion = models.MultiCriterionInput{ @@ -742,22 +744,22 @@ func TestImageQueryPerformers(t *testing.T) { } q = getImageStringValue(imageIdxWithGallery, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithGallery], images[0].ID) q = getImageStringValue(imageIdxWithPerformerTag, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) performerCriterion.Modifier = models.CriterionModifierNotNull - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithPerformerTag], images[0].ID) q = getImageStringValue(imageIdxWithGallery, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) return nil @@ -765,8 +767,8 @@ func TestImageQueryPerformers(t *testing.T) { } func TestImageQueryTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithImage]), @@ -779,7 +781,7 @@ func TestImageQueryTags(t *testing.T) { Tags: &tagCriterion, } - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 2) // ensure ids are correct @@ -795,7 +797,7 @@ func TestImageQueryTags(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - images = queryImages(t, sqb, &imageFilter, nil) + images = queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithTwoTags], images[0].ID) @@ -811,7 +813,7 @@ func TestImageQueryTags(t *testing.T) { Q: &q, } - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) tagCriterion = models.HierarchicalMultiCriterionInput{ @@ -819,22 +821,22 @@ func TestImageQueryTags(t *testing.T) { } q = getImageStringValue(imageIdxWithGallery, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithGallery], images[0].ID) q = getImageStringValue(imageIdxWithTag, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) tagCriterion.Modifier = models.CriterionModifierNotNull - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithTag], images[0].ID) q = getImageStringValue(imageIdxWithGallery, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) return nil @@ -842,8 +844,8 @@ func TestImageQueryTags(t *testing.T) { } func TestImageQueryStudio(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[studioIdxWithImage]), @@ -855,7 +857,7 @@ func TestImageQueryStudio(t *testing.T) { Studios: &studioCriterion, } - images, _, err := queryImagesWithCount(sqb, &imageFilter, nil) + images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -877,7 +879,7 @@ func TestImageQueryStudio(t *testing.T) { Q: &q, } - images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter) + images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -888,8 +890,8 @@ func TestImageQueryStudio(t *testing.T) { } func TestImageQueryStudioDepth(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter depth := 2 studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ @@ -903,16 +905,16 @@ func TestImageQueryStudioDepth(t *testing.T) { Studios: &studioCriterion, } - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 1) depth = 1 - images = queryImages(t, sqb, &imageFilter, nil) + images = queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 0) studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])} - images = queryImages(t, sqb, &imageFilter, nil) + images = queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 1) // ensure id is correct @@ -933,23 +935,23 @@ func TestImageQueryStudioDepth(t *testing.T) { Q: &q, } - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) depth = 1 - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 1) studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])} - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) return nil }) } -func queryImages(t *testing.T, sqb models.ImageReader, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) []*models.Image { - images, _, err := queryImagesWithCount(sqb, imageFilter, findFilter) +func queryImages(ctx context.Context, t *testing.T, sqb models.ImageReader, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) []*models.Image { + images, _, err := queryImagesWithCount(ctx, sqb, imageFilter, findFilter) if err != nil { t.Errorf("Error querying images: %s", err.Error()) } @@ -958,8 +960,8 @@ func queryImages(t *testing.T, sqb models.ImageReader, imageFilter *models.Image } func TestImageQueryPerformerTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), @@ -972,7 +974,7 @@ func TestImageQueryPerformerTags(t *testing.T) { PerformerTags: &tagCriterion, } - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 2) // ensure ids are correct @@ -988,7 +990,7 @@ func TestImageQueryPerformerTags(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - images = queryImages(t, sqb, &imageFilter, nil) + images = queryImages(ctx, t, sqb, &imageFilter, nil) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithPerformerTwoTags], images[0].ID) @@ -1005,7 +1007,7 @@ func TestImageQueryPerformerTags(t *testing.T) { Q: &q, } - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) tagCriterion = models.HierarchicalMultiCriterionInput{ @@ -1013,22 +1015,22 @@ func TestImageQueryPerformerTags(t *testing.T) { } q = getImageStringValue(imageIdxWithGallery, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithGallery], images[0].ID) q = getImageStringValue(imageIdxWithPerformerTag, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) tagCriterion.Modifier = models.CriterionModifierNotNull - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithPerformerTag], images[0].ID) q = getImageStringValue(imageIdxWithGallery, titleField) - images = queryImages(t, sqb, &imageFilter, &findFilter) + images = queryImages(ctx, t, sqb, &imageFilter, &findFilter) assert.Len(t, images, 0) return nil @@ -1055,17 +1057,17 @@ func TestImageQueryTagCount(t *testing.T) { } func verifyImagesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter imageFilter := models.ImageFilterType{ TagCount: &tagCountCriterion, } - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Greater(t, len(images), 0) for _, image := range images { - ids, err := sqb.GetTagIDs(image.ID) + ids, err := sqb.GetTagIDs(ctx, image.ID) if err != nil { return err } @@ -1096,17 +1098,17 @@ func TestImageQueryPerformerCount(t *testing.T) { } func verifyImagesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Image() + withTxn(func(ctx context.Context) error { + sqb := sqlite.ImageReaderWriter imageFilter := models.ImageFilterType{ PerformerCount: &performerCountCriterion, } - images := queryImages(t, sqb, &imageFilter, nil) + images := queryImages(ctx, t, sqb, &imageFilter, nil) assert.Greater(t, len(images), 0) for _, image := range images { - ids, err := sqb.GetPerformerIDs(image.ID) + ids, err := sqb.GetPerformerIDs(ctx, image.ID) if err != nil { return err } @@ -1118,7 +1120,7 @@ func verifyImagesPerformerCount(t *testing.T, performerCountCriterion models.Int } func TestImageQuerySorting(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { sort := titleField direction := models.SortDirectionEnumAsc findFilter := models.FindFilterType{ @@ -1126,8 +1128,8 @@ func TestImageQuerySorting(t *testing.T) { Direction: &direction, } - sqb := r.Image() - images, _, err := queryImagesWithCount(sqb, nil, &findFilter) + sqb := sqlite.ImageReaderWriter + images, _, err := queryImagesWithCount(ctx, sqb, nil, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -1142,7 +1144,7 @@ func TestImageQuerySorting(t *testing.T) { // sort in descending order direction = models.SortDirectionEnumDesc - images, _, err = queryImagesWithCount(sqb, nil, &findFilter) + images, _, err = queryImagesWithCount(ctx, sqb, nil, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -1157,14 +1159,14 @@ func TestImageQuerySorting(t *testing.T) { } func TestImageQueryPagination(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { perPage := 1 findFilter := models.FindFilterType{ PerPage: &perPage, } - sqb := r.Image() - images, _, err := queryImagesWithCount(sqb, nil, &findFilter) + sqb := sqlite.ImageReaderWriter + images, _, err := queryImagesWithCount(ctx, sqb, nil, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -1175,7 +1177,7 @@ func TestImageQueryPagination(t *testing.T) { page := 2 findFilter.Page = &page - images, _, err = queryImagesWithCount(sqb, nil, &findFilter) + images, _, err = queryImagesWithCount(ctx, sqb, nil, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } @@ -1187,7 +1189,7 @@ func TestImageQueryPagination(t *testing.T) { perPage = 2 page = 1 - images, _, err = queryImagesWithCount(sqb, nil, &findFilter) + images, _, err = queryImagesWithCount(ctx, sqb, nil, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) } diff --git a/pkg/database/migrations/10_image_tables.up.sql b/pkg/sqlite/migrations/10_image_tables.up.sql similarity index 100% rename from pkg/database/migrations/10_image_tables.up.sql rename to pkg/sqlite/migrations/10_image_tables.up.sql diff --git a/pkg/database/migrations/11_tag_image.up.sql b/pkg/sqlite/migrations/11_tag_image.up.sql similarity index 100% rename from pkg/database/migrations/11_tag_image.up.sql rename to pkg/sqlite/migrations/11_tag_image.up.sql diff --git a/pkg/database/migrations/12_oshash.up.sql b/pkg/sqlite/migrations/12_oshash.up.sql similarity index 100% rename from pkg/database/migrations/12_oshash.up.sql rename to pkg/sqlite/migrations/12_oshash.up.sql diff --git a/pkg/database/migrations/13_images.up.sql b/pkg/sqlite/migrations/13_images.up.sql similarity index 100% rename from pkg/database/migrations/13_images.up.sql rename to pkg/sqlite/migrations/13_images.up.sql diff --git a/pkg/database/migrations/14_stash_box_ids.up.sql b/pkg/sqlite/migrations/14_stash_box_ids.up.sql similarity index 100% rename from pkg/database/migrations/14_stash_box_ids.up.sql rename to pkg/sqlite/migrations/14_stash_box_ids.up.sql diff --git a/pkg/database/migrations/15_file_mod_time.up.sql b/pkg/sqlite/migrations/15_file_mod_time.up.sql similarity index 100% rename from pkg/database/migrations/15_file_mod_time.up.sql rename to pkg/sqlite/migrations/15_file_mod_time.up.sql diff --git a/pkg/database/migrations/16_organized_flag.up.sql b/pkg/sqlite/migrations/16_organized_flag.up.sql similarity index 100% rename from pkg/database/migrations/16_organized_flag.up.sql rename to pkg/sqlite/migrations/16_organized_flag.up.sql diff --git a/pkg/database/migrations/17_reset_scene_size.up.sql b/pkg/sqlite/migrations/17_reset_scene_size.up.sql similarity index 100% rename from pkg/database/migrations/17_reset_scene_size.up.sql rename to pkg/sqlite/migrations/17_reset_scene_size.up.sql diff --git a/pkg/database/migrations/18_scene_galleries.up.sql b/pkg/sqlite/migrations/18_scene_galleries.up.sql similarity index 100% rename from pkg/database/migrations/18_scene_galleries.up.sql rename to pkg/sqlite/migrations/18_scene_galleries.up.sql diff --git a/pkg/database/migrations/19_performer_tags.up.sql b/pkg/sqlite/migrations/19_performer_tags.up.sql similarity index 100% rename from pkg/database/migrations/19_performer_tags.up.sql rename to pkg/sqlite/migrations/19_performer_tags.up.sql diff --git a/pkg/database/migrations/1_initial.down.sql b/pkg/sqlite/migrations/1_initial.down.sql similarity index 100% rename from pkg/database/migrations/1_initial.down.sql rename to pkg/sqlite/migrations/1_initial.down.sql diff --git a/pkg/database/migrations/1_initial.up.sql b/pkg/sqlite/migrations/1_initial.up.sql similarity index 100% rename from pkg/database/migrations/1_initial.up.sql rename to pkg/sqlite/migrations/1_initial.up.sql diff --git a/pkg/database/migrations/20_phash.up.sql b/pkg/sqlite/migrations/20_phash.up.sql similarity index 100% rename from pkg/database/migrations/20_phash.up.sql rename to pkg/sqlite/migrations/20_phash.up.sql diff --git a/pkg/database/migrations/21_performers_studios_details.up.sql b/pkg/sqlite/migrations/21_performers_studios_details.up.sql similarity index 100% rename from pkg/database/migrations/21_performers_studios_details.up.sql rename to pkg/sqlite/migrations/21_performers_studios_details.up.sql diff --git a/pkg/database/migrations/22_performers_studios_rating.up.sql b/pkg/sqlite/migrations/22_performers_studios_rating.up.sql similarity index 100% rename from pkg/database/migrations/22_performers_studios_rating.up.sql rename to pkg/sqlite/migrations/22_performers_studios_rating.up.sql diff --git a/pkg/database/migrations/23_scenes_interactive.up.sql b/pkg/sqlite/migrations/23_scenes_interactive.up.sql similarity index 100% rename from pkg/database/migrations/23_scenes_interactive.up.sql rename to pkg/sqlite/migrations/23_scenes_interactive.up.sql diff --git a/pkg/database/migrations/24_tag_aliases.up.sql b/pkg/sqlite/migrations/24_tag_aliases.up.sql similarity index 100% rename from pkg/database/migrations/24_tag_aliases.up.sql rename to pkg/sqlite/migrations/24_tag_aliases.up.sql diff --git a/pkg/database/migrations/25_saved_filters.up.sql b/pkg/sqlite/migrations/25_saved_filters.up.sql similarity index 100% rename from pkg/database/migrations/25_saved_filters.up.sql rename to pkg/sqlite/migrations/25_saved_filters.up.sql diff --git a/pkg/database/migrations/26_tag_hierarchy.up.sql b/pkg/sqlite/migrations/26_tag_hierarchy.up.sql similarity index 100% rename from pkg/database/migrations/26_tag_hierarchy.up.sql rename to pkg/sqlite/migrations/26_tag_hierarchy.up.sql diff --git a/pkg/database/migrations/27_studio_aliases.up.sql b/pkg/sqlite/migrations/27_studio_aliases.up.sql similarity index 100% rename from pkg/database/migrations/27_studio_aliases.up.sql rename to pkg/sqlite/migrations/27_studio_aliases.up.sql diff --git a/pkg/database/migrations/28_images_indexes.up.sql b/pkg/sqlite/migrations/28_images_indexes.up.sql similarity index 100% rename from pkg/database/migrations/28_images_indexes.up.sql rename to pkg/sqlite/migrations/28_images_indexes.up.sql diff --git a/pkg/database/migrations/29_interactive_speed.up.sql b/pkg/sqlite/migrations/29_interactive_speed.up.sql similarity index 100% rename from pkg/database/migrations/29_interactive_speed.up.sql rename to pkg/sqlite/migrations/29_interactive_speed.up.sql diff --git a/pkg/database/migrations/2_cover_image.up.sql b/pkg/sqlite/migrations/2_cover_image.up.sql similarity index 100% rename from pkg/database/migrations/2_cover_image.up.sql rename to pkg/sqlite/migrations/2_cover_image.up.sql diff --git a/pkg/database/migrations/30_ignore_autotag.up..sql b/pkg/sqlite/migrations/30_ignore_autotag.up..sql similarity index 100% rename from pkg/database/migrations/30_ignore_autotag.up..sql rename to pkg/sqlite/migrations/30_ignore_autotag.up..sql diff --git a/pkg/database/migrations/31_scenes_captions.up.sql b/pkg/sqlite/migrations/31_scenes_captions.up.sql similarity index 100% rename from pkg/database/migrations/31_scenes_captions.up.sql rename to pkg/sqlite/migrations/31_scenes_captions.up.sql diff --git a/pkg/database/migrations/3_o_counter.up.sql b/pkg/sqlite/migrations/3_o_counter.up.sql similarity index 100% rename from pkg/database/migrations/3_o_counter.up.sql rename to pkg/sqlite/migrations/3_o_counter.up.sql diff --git a/pkg/database/migrations/4_movie.up.sql b/pkg/sqlite/migrations/4_movie.up.sql similarity index 100% rename from pkg/database/migrations/4_movie.up.sql rename to pkg/sqlite/migrations/4_movie.up.sql diff --git a/pkg/database/migrations/5_performer_gender.down.sql b/pkg/sqlite/migrations/5_performer_gender.down.sql similarity index 100% rename from pkg/database/migrations/5_performer_gender.down.sql rename to pkg/sqlite/migrations/5_performer_gender.down.sql diff --git a/pkg/database/migrations/5_performer_gender.up.sql b/pkg/sqlite/migrations/5_performer_gender.up.sql similarity index 100% rename from pkg/database/migrations/5_performer_gender.up.sql rename to pkg/sqlite/migrations/5_performer_gender.up.sql diff --git a/pkg/database/migrations/6_scenes_format.up.sql b/pkg/sqlite/migrations/6_scenes_format.up.sql similarity index 100% rename from pkg/database/migrations/6_scenes_format.up.sql rename to pkg/sqlite/migrations/6_scenes_format.up.sql diff --git a/pkg/database/migrations/7_performer_optimization.up.sql b/pkg/sqlite/migrations/7_performer_optimization.up.sql similarity index 100% rename from pkg/database/migrations/7_performer_optimization.up.sql rename to pkg/sqlite/migrations/7_performer_optimization.up.sql diff --git a/pkg/database/migrations/8_movie_fix.up.sql b/pkg/sqlite/migrations/8_movie_fix.up.sql similarity index 100% rename from pkg/database/migrations/8_movie_fix.up.sql rename to pkg/sqlite/migrations/8_movie_fix.up.sql diff --git a/pkg/database/migrations/9_studios_parent_studio.up.sql b/pkg/sqlite/migrations/9_studios_parent_studio.up.sql similarity index 100% rename from pkg/database/migrations/9_studios_parent_studio.up.sql rename to pkg/sqlite/migrations/9_studios_parent_studio.up.sql diff --git a/pkg/sqlite/movies.go b/pkg/sqlite/movies.go index eac02ae54..c52556e15 100644 --- a/pkg/sqlite/movies.go +++ b/pkg/sqlite/movies.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -15,50 +16,47 @@ type movieQueryBuilder struct { repository } -func NewMovieReaderWriter(tx dbi) *movieQueryBuilder { - return &movieQueryBuilder{ - repository{ - tx: tx, - tableName: movieTable, - idColumn: idColumn, - }, - } +var MovieReaderWriter = &movieQueryBuilder{ + repository{ + tableName: movieTable, + idColumn: idColumn, + }, } -func (qb *movieQueryBuilder) Create(newObject models.Movie) (*models.Movie, error) { +func (qb *movieQueryBuilder) Create(ctx context.Context, newObject models.Movie) (*models.Movie, error) { var ret models.Movie - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *movieQueryBuilder) Update(updatedObject models.MoviePartial) (*models.Movie, error) { +func (qb *movieQueryBuilder) Update(ctx context.Context, updatedObject models.MoviePartial) (*models.Movie, error) { const partial = true - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedObject.ID) + return qb.Find(ctx, updatedObject.ID) } -func (qb *movieQueryBuilder) UpdateFull(updatedObject models.Movie) (*models.Movie, error) { +func (qb *movieQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Movie) (*models.Movie, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedObject.ID) + return qb.Find(ctx, updatedObject.ID) } -func (qb *movieQueryBuilder) Destroy(id int) error { - return qb.destroyExisting([]int{id}) +func (qb *movieQueryBuilder) Destroy(ctx context.Context, id int) error { + return qb.destroyExisting(ctx, []int{id}) } -func (qb *movieQueryBuilder) Find(id int) (*models.Movie, error) { +func (qb *movieQueryBuilder) Find(ctx context.Context, id int) (*models.Movie, error) { var ret models.Movie - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -67,10 +65,10 @@ func (qb *movieQueryBuilder) Find(id int) (*models.Movie, error) { return &ret, nil } -func (qb *movieQueryBuilder) FindMany(ids []int) ([]*models.Movie, error) { +func (qb *movieQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Movie, error) { var movies []*models.Movie for _, id := range ids { - movie, err := qb.Find(id) + movie, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -85,17 +83,17 @@ func (qb *movieQueryBuilder) FindMany(ids []int) ([]*models.Movie, error) { return movies, nil } -func (qb *movieQueryBuilder) FindByName(name string, nocase bool) (*models.Movie, error) { +func (qb *movieQueryBuilder) FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error) { query := "SELECT * FROM movies WHERE name = ?" if nocase { query += " COLLATE NOCASE" } query += " LIMIT 1" args := []interface{}{name} - return qb.queryMovie(query, args) + return qb.queryMovie(ctx, query, args) } -func (qb *movieQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Movie, error) { +func (qb *movieQueryBuilder) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Movie, error) { query := "SELECT * FROM movies WHERE name" if nocase { query += " COLLATE NOCASE" @@ -105,34 +103,34 @@ func (qb *movieQueryBuilder) FindByNames(names []string, nocase bool) ([]*models for _, name := range names { args = append(args, name) } - return qb.queryMovies(query, args) + return qb.queryMovies(ctx, query, args) } -func (qb *movieQueryBuilder) Count() (int, error) { - return qb.runCountQuery(qb.buildCountQuery("SELECT movies.id FROM movies"), nil) +func (qb *movieQueryBuilder) Count(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT movies.id FROM movies"), nil) } -func (qb *movieQueryBuilder) All() ([]*models.Movie, error) { - return qb.queryMovies(selectAll("movies")+qb.getMovieSort(nil), nil) +func (qb *movieQueryBuilder) All(ctx context.Context) ([]*models.Movie, error) { + return qb.queryMovies(ctx, selectAll("movies")+qb.getMovieSort(nil), nil) } -func (qb *movieQueryBuilder) makeFilter(movieFilter *models.MovieFilterType) *filterBuilder { +func (qb *movieQueryBuilder) makeFilter(ctx context.Context, movieFilter *models.MovieFilterType) *filterBuilder { query := &filterBuilder{} - query.handleCriterion(stringCriterionHandler(movieFilter.Name, "movies.name")) - query.handleCriterion(stringCriterionHandler(movieFilter.Director, "movies.director")) - query.handleCriterion(stringCriterionHandler(movieFilter.Synopsis, "movies.synopsis")) - query.handleCriterion(intCriterionHandler(movieFilter.Rating, "movies.rating")) - query.handleCriterion(durationCriterionHandler(movieFilter.Duration, "movies.duration")) - query.handleCriterion(movieIsMissingCriterionHandler(qb, movieFilter.IsMissing)) - query.handleCriterion(stringCriterionHandler(movieFilter.URL, "movies.url")) - query.handleCriterion(movieStudioCriterionHandler(qb, movieFilter.Studios)) - query.handleCriterion(moviePerformersCriterionHandler(qb, movieFilter.Performers)) + query.handleCriterion(ctx, stringCriterionHandler(movieFilter.Name, "movies.name")) + query.handleCriterion(ctx, stringCriterionHandler(movieFilter.Director, "movies.director")) + query.handleCriterion(ctx, stringCriterionHandler(movieFilter.Synopsis, "movies.synopsis")) + query.handleCriterion(ctx, intCriterionHandler(movieFilter.Rating, "movies.rating")) + query.handleCriterion(ctx, durationCriterionHandler(movieFilter.Duration, "movies.duration")) + query.handleCriterion(ctx, movieIsMissingCriterionHandler(qb, movieFilter.IsMissing)) + query.handleCriterion(ctx, stringCriterionHandler(movieFilter.URL, "movies.url")) + query.handleCriterion(ctx, movieStudioCriterionHandler(qb, movieFilter.Studios)) + query.handleCriterion(ctx, moviePerformersCriterionHandler(qb, movieFilter.Performers)) return query } -func (qb *movieQueryBuilder) Query(movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) { +func (qb *movieQueryBuilder) Query(ctx context.Context, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) { if findFilter == nil { findFilter = &models.FindFilterType{} } @@ -148,19 +146,19 @@ func (qb *movieQueryBuilder) Query(movieFilter *models.MovieFilterType, findFilt query.parseQueryString(searchColumns, *q) } - filter := qb.makeFilter(movieFilter) + filter := qb.makeFilter(ctx, movieFilter) query.addFilter(filter) query.sortAndPagination = qb.getMovieSort(findFilter) + getPagination(findFilter) - idsResult, countResult, err := query.executeFind() + idsResult, countResult, err := query.executeFind(ctx) if err != nil { return nil, 0, err } var movies []*models.Movie for _, id := range idsResult { - movie, err := qb.Find(id) + movie, err := qb.Find(ctx, id) if err != nil { return nil, 0, err } @@ -172,7 +170,7 @@ func (qb *movieQueryBuilder) Query(movieFilter *models.MovieFilterType, findFilt } func movieIsMissingCriterionHandler(qb *movieQueryBuilder, isMissing *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if isMissing != nil && *isMissing != "" { switch *isMissing { case "front_image": @@ -206,7 +204,7 @@ func movieStudioCriterionHandler(qb *movieQueryBuilder, studios *models.Hierarch } func moviePerformersCriterionHandler(qb *movieQueryBuilder, performers *models.MultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if performers != nil { if performers.Modifier == models.CriterionModifierIsNull || performers.Modifier == models.CriterionModifierNotNull { var notClause string @@ -273,30 +271,30 @@ func (qb *movieQueryBuilder) getMovieSort(findFilter *models.FindFilterType) str } } -func (qb *movieQueryBuilder) queryMovie(query string, args []interface{}) (*models.Movie, error) { - results, err := qb.queryMovies(query, args) +func (qb *movieQueryBuilder) queryMovie(ctx context.Context, query string, args []interface{}) (*models.Movie, error) { + results, err := qb.queryMovies(ctx, query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *movieQueryBuilder) queryMovies(query string, args []interface{}) ([]*models.Movie, error) { +func (qb *movieQueryBuilder) queryMovies(ctx context.Context, query string, args []interface{}) ([]*models.Movie, error) { var ret models.Movies - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } return []*models.Movie(ret), nil } -func (qb *movieQueryBuilder) UpdateImages(movieID int, frontImage []byte, backImage []byte) error { +func (qb *movieQueryBuilder) UpdateImages(ctx context.Context, movieID int, frontImage []byte, backImage []byte) error { // Delete the existing cover and then create new - if err := qb.DestroyImages(movieID); err != nil { + if err := qb.DestroyImages(ctx, movieID); err != nil { return err } - _, err := qb.tx.Exec( + _, err := qb.tx.Exec(ctx, `INSERT INTO movies_images (movie_id, front_image, back_image) VALUES (?, ?, ?)`, movieID, frontImage, @@ -306,26 +304,26 @@ func (qb *movieQueryBuilder) UpdateImages(movieID int, frontImage []byte, backIm return err } -func (qb *movieQueryBuilder) DestroyImages(movieID int) error { +func (qb *movieQueryBuilder) DestroyImages(ctx context.Context, movieID int) error { // Delete the existing joins - _, err := qb.tx.Exec("DELETE FROM movies_images WHERE movie_id = ?", movieID) + _, err := qb.tx.Exec(ctx, "DELETE FROM movies_images WHERE movie_id = ?", movieID) if err != nil { return err } return err } -func (qb *movieQueryBuilder) GetFrontImage(movieID int) ([]byte, error) { +func (qb *movieQueryBuilder) GetFrontImage(ctx context.Context, movieID int) ([]byte, error) { query := `SELECT front_image from movies_images WHERE movie_id = ?` - return getImage(qb.tx, query, movieID) + return getImage(ctx, qb.tx, query, movieID) } -func (qb *movieQueryBuilder) GetBackImage(movieID int) ([]byte, error) { +func (qb *movieQueryBuilder) GetBackImage(ctx context.Context, movieID int) ([]byte, error) { query := `SELECT back_image from movies_images WHERE movie_id = ?` - return getImage(qb.tx, query, movieID) + return getImage(ctx, qb.tx, query, movieID) } -func (qb *movieQueryBuilder) FindByPerformerID(performerID int) ([]*models.Movie, error) { +func (qb *movieQueryBuilder) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Movie, error) { query := `SELECT DISTINCT movies.* FROM movies INNER JOIN movies_scenes ON movies.id = movies_scenes.movie_id @@ -333,33 +331,33 @@ INNER JOIN performers_scenes ON performers_scenes.scene_id = movies_scenes.scene WHERE performers_scenes.performer_id = ? ` args := []interface{}{performerID} - return qb.queryMovies(query, args) + return qb.queryMovies(ctx, query, args) } -func (qb *movieQueryBuilder) CountByPerformerID(performerID int) (int, error) { +func (qb *movieQueryBuilder) CountByPerformerID(ctx context.Context, performerID int) (int, error) { query := `SELECT COUNT(DISTINCT movies_scenes.movie_id) AS count FROM movies_scenes INNER JOIN performers_scenes ON performers_scenes.scene_id = movies_scenes.scene_id WHERE performers_scenes.performer_id = ? ` args := []interface{}{performerID} - return qb.runCountQuery(query, args) + return qb.runCountQuery(ctx, query, args) } -func (qb *movieQueryBuilder) FindByStudioID(studioID int) ([]*models.Movie, error) { +func (qb *movieQueryBuilder) FindByStudioID(ctx context.Context, studioID int) ([]*models.Movie, error) { query := `SELECT movies.* FROM movies WHERE movies.studio_id = ? ` args := []interface{}{studioID} - return qb.queryMovies(query, args) + return qb.queryMovies(ctx, query, args) } -func (qb *movieQueryBuilder) CountByStudioID(studioID int) (int, error) { +func (qb *movieQueryBuilder) CountByStudioID(ctx context.Context, studioID int) (int, error) { query := `SELECT COUNT(1) AS count FROM movies WHERE movies.studio_id = ? ` args := []interface{}{studioID} - return qb.runCountQuery(query, args) + return qb.runCountQuery(ctx, query, args) } diff --git a/pkg/sqlite/movies_test.go b/pkg/sqlite/movies_test.go index 75c6cc5bf..eff0cf50b 100644 --- a/pkg/sqlite/movies_test.go +++ b/pkg/sqlite/movies_test.go @@ -4,6 +4,7 @@ package sqlite_test import ( + "context" "database/sql" "fmt" "strconv" @@ -14,15 +15,16 @@ import ( "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" ) func TestMovieFindByName(t *testing.T) { - withTxn(func(r models.Repository) error { - mqb := r.Movie() + withTxn(func(ctx context.Context) error { + mqb := sqlite.MovieReaderWriter name := movieNames[movieIdxWithScene] // find a movie by name - movie, err := mqb.FindByName(name, false) + movie, err := mqb.FindByName(ctx, name, false) if err != nil { t.Errorf("Error finding movies: %s", err.Error()) @@ -32,7 +34,7 @@ func TestMovieFindByName(t *testing.T) { name = movieNames[movieIdxWithDupName] // find a movie by name nocase - movie, err = mqb.FindByName(name, true) + movie, err = mqb.FindByName(ctx, name, true) if err != nil { t.Errorf("Error finding movies: %s", err.Error()) @@ -48,21 +50,21 @@ func TestMovieFindByName(t *testing.T) { } func TestMovieFindByNames(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { var names []string - mqb := r.Movie() + mqb := sqlite.MovieReaderWriter names = append(names, movieNames[movieIdxWithScene]) // find movies by names - movies, err := mqb.FindByNames(names, false) + movies, err := mqb.FindByNames(ctx, names, false) if err != nil { t.Errorf("Error finding movies: %s", err.Error()) } assert.Len(t, movies, 1) assert.Equal(t, movieNames[movieIdxWithScene], movies[0].Name.String) - movies, err = mqb.FindByNames(names, true) // find movies by names nocase + movies, err = mqb.FindByNames(ctx, names, true) // find movies by names nocase if err != nil { t.Errorf("Error finding movies: %s", err.Error()) } @@ -75,8 +77,8 @@ func TestMovieFindByNames(t *testing.T) { } func TestMovieQueryStudio(t *testing.T) { - withTxn(func(r models.Repository) error { - mqb := r.Movie() + withTxn(func(ctx context.Context) error { + mqb := sqlite.MovieReaderWriter studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[studioIdxWithMovie]), @@ -88,7 +90,7 @@ func TestMovieQueryStudio(t *testing.T) { Studios: &studioCriterion, } - movies, _, err := mqb.Query(&movieFilter, nil) + movies, _, err := mqb.Query(ctx, &movieFilter, nil) if err != nil { t.Errorf("Error querying movie: %s", err.Error()) } @@ -110,7 +112,7 @@ func TestMovieQueryStudio(t *testing.T) { Q: &q, } - movies, _, err = mqb.Query(&movieFilter, &findFilter) + movies, _, err = mqb.Query(ctx, &movieFilter, &findFilter) if err != nil { t.Errorf("Error querying movie: %s", err.Error()) } @@ -159,11 +161,11 @@ func TestMovieQueryURL(t *testing.T) { } func verifyMovieQuery(t *testing.T, filter models.MovieFilterType, verifyFn func(s *models.Movie)) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { t.Helper() - sqb := r.Movie() + sqb := sqlite.MovieReaderWriter - movies := queryMovie(t, sqb, &filter, nil) + movies := queryMovie(ctx, t, sqb, &filter, nil) // assume it should find at least one assert.Greater(t, len(movies), 0) @@ -176,8 +178,8 @@ func verifyMovieQuery(t *testing.T, filter models.MovieFilterType, verifyFn func }) } -func queryMovie(t *testing.T, sqb models.MovieReader, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) []*models.Movie { - movies, _, err := sqb.Query(movieFilter, findFilter) +func queryMovie(ctx context.Context, t *testing.T, sqb models.MovieReader, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) []*models.Movie { + movies, _, err := sqb.Query(ctx, movieFilter, findFilter) if err != nil { t.Errorf("Error querying movie: %s", err.Error()) } @@ -193,9 +195,9 @@ func TestMovieQuerySorting(t *testing.T) { Direction: &direction, } - withTxn(func(r models.Repository) error { - sqb := r.Movie() - movies := queryMovie(t, sqb, nil, &findFilter) + withTxn(func(ctx context.Context) error { + sqb := sqlite.MovieReaderWriter + movies := queryMovie(ctx, t, sqb, nil, &findFilter) // scenes should be in same order as indexes firstMovie := movies[0] @@ -205,7 +207,7 @@ func TestMovieQuerySorting(t *testing.T) { // sort in descending order direction = models.SortDirectionEnumAsc - movies = queryMovie(t, sqb, nil, &findFilter) + movies = queryMovie(ctx, t, sqb, nil, &findFilter) lastMovie := movies[len(movies)-1] assert.Equal(t, movieIDs[movieIdxWithScene], lastMovie.ID) @@ -215,8 +217,8 @@ func TestMovieQuerySorting(t *testing.T) { } func TestMovieUpdateMovieImages(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - mqb := r.Movie() + if err := withTxn(func(ctx context.Context) error { + mqb := sqlite.MovieReaderWriter // create movie to test against const name = "TestMovieUpdateMovieImages" @@ -224,26 +226,26 @@ func TestMovieUpdateMovieImages(t *testing.T) { Name: sql.NullString{String: name, Valid: true}, Checksum: md5.FromString(name), } - created, err := mqb.Create(movie) + created, err := mqb.Create(ctx, movie) if err != nil { return fmt.Errorf("Error creating movie: %s", err.Error()) } frontImage := []byte("frontImage") backImage := []byte("backImage") - err = mqb.UpdateImages(created.ID, frontImage, backImage) + err = mqb.UpdateImages(ctx, created.ID, frontImage, backImage) if err != nil { return fmt.Errorf("Error updating movie images: %s", err.Error()) } // ensure images are set - storedFront, err := mqb.GetFrontImage(created.ID) + storedFront, err := mqb.GetFrontImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting front image: %s", err.Error()) } assert.Equal(t, storedFront, frontImage) - storedBack, err := mqb.GetBackImage(created.ID) + storedBack, err := mqb.GetBackImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting back image: %s", err.Error()) } @@ -251,26 +253,26 @@ func TestMovieUpdateMovieImages(t *testing.T) { // set front image only newImage := []byte("newImage") - err = mqb.UpdateImages(created.ID, newImage, nil) + err = mqb.UpdateImages(ctx, created.ID, newImage, nil) if err != nil { return fmt.Errorf("Error updating movie images: %s", err.Error()) } - storedFront, err = mqb.GetFrontImage(created.ID) + storedFront, err = mqb.GetFrontImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting front image: %s", err.Error()) } assert.Equal(t, storedFront, newImage) // back image should be nil - storedBack, err = mqb.GetBackImage(created.ID) + storedBack, err = mqb.GetBackImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting back image: %s", err.Error()) } assert.Nil(t, nil) // set back image only - err = mqb.UpdateImages(created.ID, nil, newImage) + err = mqb.UpdateImages(ctx, created.ID, nil, newImage) if err == nil { return fmt.Errorf("Expected error setting nil front image") } @@ -282,8 +284,8 @@ func TestMovieUpdateMovieImages(t *testing.T) { } func TestMovieDestroyMovieImages(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - mqb := r.Movie() + if err := withTxn(func(ctx context.Context) error { + mqb := sqlite.MovieReaderWriter // create movie to test against const name = "TestMovieDestroyMovieImages" @@ -291,32 +293,32 @@ func TestMovieDestroyMovieImages(t *testing.T) { Name: sql.NullString{String: name, Valid: true}, Checksum: md5.FromString(name), } - created, err := mqb.Create(movie) + created, err := mqb.Create(ctx, movie) if err != nil { return fmt.Errorf("Error creating movie: %s", err.Error()) } frontImage := []byte("frontImage") backImage := []byte("backImage") - err = mqb.UpdateImages(created.ID, frontImage, backImage) + err = mqb.UpdateImages(ctx, created.ID, frontImage, backImage) if err != nil { return fmt.Errorf("Error updating movie images: %s", err.Error()) } - err = mqb.DestroyImages(created.ID) + err = mqb.DestroyImages(ctx, created.ID) if err != nil { return fmt.Errorf("Error destroying movie images: %s", err.Error()) } // front image should be nil - storedFront, err := mqb.GetFrontImage(created.ID) + storedFront, err := mqb.GetFrontImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting front image: %s", err.Error()) } assert.Nil(t, storedFront) // back image should be nil - storedBack, err := mqb.GetBackImage(created.ID) + storedBack, err := mqb.GetBackImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting back image: %s", err.Error()) } diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index 142be42ff..d81170f0c 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -25,66 +26,63 @@ type performerQueryBuilder struct { repository } -func NewPerformerReaderWriter(tx dbi) *performerQueryBuilder { - return &performerQueryBuilder{ - repository{ - tx: tx, - tableName: performerTable, - idColumn: idColumn, - }, - } +var PerformerReaderWriter = &performerQueryBuilder{ + repository{ + tableName: performerTable, + idColumn: idColumn, + }, } -func (qb *performerQueryBuilder) Create(newObject models.Performer) (*models.Performer, error) { +func (qb *performerQueryBuilder) Create(ctx context.Context, newObject models.Performer) (*models.Performer, error) { var ret models.Performer - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *performerQueryBuilder) Update(updatedObject models.PerformerPartial) (*models.Performer, error) { +func (qb *performerQueryBuilder) Update(ctx context.Context, updatedObject models.PerformerPartial) (*models.Performer, error) { const partial = true - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } var ret models.Performer - if err := qb.get(updatedObject.ID, &ret); err != nil { + if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *performerQueryBuilder) UpdateFull(updatedObject models.Performer) (*models.Performer, error) { +func (qb *performerQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Performer) (*models.Performer, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } var ret models.Performer - if err := qb.get(updatedObject.ID, &ret); err != nil { + if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *performerQueryBuilder) Destroy(id int) error { +func (qb *performerQueryBuilder) Destroy(ctx context.Context, id int) error { // TODO - add on delete cascade to performers_scenes - _, err := qb.tx.Exec("DELETE FROM performers_scenes WHERE performer_id = ?", id) + _, err := qb.tx.Exec(ctx, "DELETE FROM performers_scenes WHERE performer_id = ?", id) if err != nil { return err } - return qb.destroyExisting([]int{id}) + return qb.destroyExisting(ctx, []int{id}) } -func (qb *performerQueryBuilder) Find(id int) (*models.Performer, error) { +func (qb *performerQueryBuilder) Find(ctx context.Context, id int) (*models.Performer, error) { var ret models.Performer - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -93,10 +91,10 @@ func (qb *performerQueryBuilder) Find(id int) (*models.Performer, error) { return &ret, nil } -func (qb *performerQueryBuilder) FindMany(ids []int) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Performer, error) { var performers []*models.Performer for _, id := range ids { - performer, err := qb.Find(id) + performer, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -111,44 +109,44 @@ func (qb *performerQueryBuilder) FindMany(ids []int) ([]*models.Performer, error return performers, nil } -func (qb *performerQueryBuilder) FindBySceneID(sceneID int) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id WHERE scenes_join.scene_id = ? ` args := []interface{}{sceneID} - return qb.queryPerformers(query, args) + return qb.queryPerformers(ctx, query, args) } -func (qb *performerQueryBuilder) FindByImageID(imageID int) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) FindByImageID(ctx context.Context, imageID int) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performers_images as images_join on images_join.performer_id = performers.id WHERE images_join.image_id = ? ` args := []interface{}{imageID} - return qb.queryPerformers(query, args) + return qb.queryPerformers(ctx, query, args) } -func (qb *performerQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performers_galleries as galleries_join on galleries_join.performer_id = performers.id WHERE galleries_join.gallery_id = ? ` args := []interface{}{galleryID} - return qb.queryPerformers(query, args) + return qb.queryPerformers(ctx, query, args) } -func (qb *performerQueryBuilder) FindNamesBySceneID(sceneID int) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) FindNamesBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) { query := ` SELECT performers.name FROM performers LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id WHERE scenes_join.scene_id = ? ` args := []interface{}{sceneID} - return qb.queryPerformers(query, args) + return qb.queryPerformers(ctx, query, args) } -func (qb *performerQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) { query := "SELECT * FROM performers WHERE name" if nocase { query += " COLLATE NOCASE" @@ -159,23 +157,23 @@ func (qb *performerQueryBuilder) FindByNames(names []string, nocase bool) ([]*mo for _, name := range names { args = append(args, name) } - return qb.queryPerformers(query, args) + return qb.queryPerformers(ctx, query, args) } -func (qb *performerQueryBuilder) CountByTagID(tagID int) (int, error) { +func (qb *performerQueryBuilder) CountByTagID(ctx context.Context, tagID int) (int, error) { args := []interface{}{tagID} - return qb.runCountQuery(qb.buildCountQuery(countPerformersForTagQuery), args) + return qb.runCountQuery(ctx, qb.buildCountQuery(countPerformersForTagQuery), args) } -func (qb *performerQueryBuilder) Count() (int, error) { - return qb.runCountQuery(qb.buildCountQuery("SELECT performers.id FROM performers"), nil) +func (qb *performerQueryBuilder) Count(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT performers.id FROM performers"), nil) } -func (qb *performerQueryBuilder) All() ([]*models.Performer, error) { - return qb.queryPerformers(selectAll("performers")+qb.getPerformerSort(nil), nil) +func (qb *performerQueryBuilder) All(ctx context.Context) ([]*models.Performer, error) { + return qb.queryPerformers(ctx, selectAll("performers")+qb.getPerformerSort(nil), nil) } -func (qb *performerQueryBuilder) QueryForAutoTag(words []string) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Performer, error) { // TODO - Query needs to be changed to support queries of this type, and // this method should be removed query := selectAll(performerTable) @@ -196,7 +194,7 @@ func (qb *performerQueryBuilder) QueryForAutoTag(words []string) ([]*models.Perf "ignore_auto_tag = 0", whereOr, }, " AND ") - return qb.queryPerformers(query+" WHERE "+where, args) + return qb.queryPerformers(ctx, query+" WHERE "+where, args) } func (qb *performerQueryBuilder) validateFilter(filter *models.PerformerFilterType) error { @@ -230,74 +228,74 @@ func (qb *performerQueryBuilder) validateFilter(filter *models.PerformerFilterTy return nil } -func (qb *performerQueryBuilder) makeFilter(filter *models.PerformerFilterType) *filterBuilder { +func (qb *performerQueryBuilder) makeFilter(ctx context.Context, filter *models.PerformerFilterType) *filterBuilder { query := &filterBuilder{} if filter.And != nil { - query.and(qb.makeFilter(filter.And)) + query.and(qb.makeFilter(ctx, filter.And)) } if filter.Or != nil { - query.or(qb.makeFilter(filter.Or)) + query.or(qb.makeFilter(ctx, filter.Or)) } if filter.Not != nil { - query.not(qb.makeFilter(filter.Not)) + query.not(qb.makeFilter(ctx, filter.Not)) } const tableName = performerTable - query.handleCriterion(stringCriterionHandler(filter.Name, tableName+".name")) - query.handleCriterion(stringCriterionHandler(filter.Details, tableName+".details")) + query.handleCriterion(ctx, stringCriterionHandler(filter.Name, tableName+".name")) + query.handleCriterion(ctx, stringCriterionHandler(filter.Details, tableName+".details")) - query.handleCriterion(boolCriterionHandler(filter.FilterFavorites, tableName+".favorite")) - query.handleCriterion(boolCriterionHandler(filter.IgnoreAutoTag, tableName+".ignore_auto_tag")) + query.handleCriterion(ctx, boolCriterionHandler(filter.FilterFavorites, tableName+".favorite")) + query.handleCriterion(ctx, boolCriterionHandler(filter.IgnoreAutoTag, tableName+".ignore_auto_tag")) - query.handleCriterion(yearFilterCriterionHandler(filter.BirthYear, tableName+".birthdate")) - query.handleCriterion(yearFilterCriterionHandler(filter.DeathYear, tableName+".death_date")) + query.handleCriterion(ctx, yearFilterCriterionHandler(filter.BirthYear, tableName+".birthdate")) + query.handleCriterion(ctx, yearFilterCriterionHandler(filter.DeathYear, tableName+".death_date")) - query.handleCriterion(performerAgeFilterCriterionHandler(filter.Age)) + query.handleCriterion(ctx, performerAgeFilterCriterionHandler(filter.Age)) - query.handleCriterion(criterionHandlerFunc(func(f *filterBuilder) { + query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if gender := filter.Gender; gender != nil { f.addWhere(tableName+".gender = ?", gender.Value.String()) } })) - query.handleCriterion(performerIsMissingCriterionHandler(qb, filter.IsMissing)) - query.handleCriterion(stringCriterionHandler(filter.Ethnicity, tableName+".ethnicity")) - query.handleCriterion(stringCriterionHandler(filter.Country, tableName+".country")) - query.handleCriterion(stringCriterionHandler(filter.EyeColor, tableName+".eye_color")) - query.handleCriterion(stringCriterionHandler(filter.Height, tableName+".height")) - query.handleCriterion(stringCriterionHandler(filter.Measurements, tableName+".measurements")) - query.handleCriterion(stringCriterionHandler(filter.FakeTits, tableName+".fake_tits")) - query.handleCriterion(stringCriterionHandler(filter.CareerLength, tableName+".career_length")) - query.handleCriterion(stringCriterionHandler(filter.Tattoos, tableName+".tattoos")) - query.handleCriterion(stringCriterionHandler(filter.Piercings, tableName+".piercings")) - query.handleCriterion(intCriterionHandler(filter.Rating, tableName+".rating")) - query.handleCriterion(stringCriterionHandler(filter.HairColor, tableName+".hair_color")) - query.handleCriterion(stringCriterionHandler(filter.URL, tableName+".url")) - query.handleCriterion(intCriterionHandler(filter.Weight, tableName+".weight")) - query.handleCriterion(criterionHandlerFunc(func(f *filterBuilder) { + query.handleCriterion(ctx, performerIsMissingCriterionHandler(qb, filter.IsMissing)) + query.handleCriterion(ctx, stringCriterionHandler(filter.Ethnicity, tableName+".ethnicity")) + query.handleCriterion(ctx, stringCriterionHandler(filter.Country, tableName+".country")) + query.handleCriterion(ctx, stringCriterionHandler(filter.EyeColor, tableName+".eye_color")) + query.handleCriterion(ctx, stringCriterionHandler(filter.Height, tableName+".height")) + query.handleCriterion(ctx, stringCriterionHandler(filter.Measurements, tableName+".measurements")) + query.handleCriterion(ctx, stringCriterionHandler(filter.FakeTits, tableName+".fake_tits")) + query.handleCriterion(ctx, stringCriterionHandler(filter.CareerLength, tableName+".career_length")) + query.handleCriterion(ctx, stringCriterionHandler(filter.Tattoos, tableName+".tattoos")) + query.handleCriterion(ctx, stringCriterionHandler(filter.Piercings, tableName+".piercings")) + query.handleCriterion(ctx, intCriterionHandler(filter.Rating, tableName+".rating")) + query.handleCriterion(ctx, stringCriterionHandler(filter.HairColor, tableName+".hair_color")) + query.handleCriterion(ctx, stringCriterionHandler(filter.URL, tableName+".url")) + query.handleCriterion(ctx, intCriterionHandler(filter.Weight, tableName+".weight")) + query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if filter.StashID != nil { qb.stashIDRepository().join(f, "performer_stash_ids", "performers.id") - stringCriterionHandler(filter.StashID, "performer_stash_ids.stash_id")(f) + stringCriterionHandler(filter.StashID, "performer_stash_ids.stash_id")(ctx, f) } })) // TODO - need better handling of aliases - query.handleCriterion(stringCriterionHandler(filter.Aliases, tableName+".aliases")) + query.handleCriterion(ctx, stringCriterionHandler(filter.Aliases, tableName+".aliases")) - query.handleCriterion(performerTagsCriterionHandler(qb, filter.Tags)) + query.handleCriterion(ctx, performerTagsCriterionHandler(qb, filter.Tags)) - query.handleCriterion(performerStudiosCriterionHandler(qb, filter.Studios)) + query.handleCriterion(ctx, performerStudiosCriterionHandler(qb, filter.Studios)) - query.handleCriterion(performerTagCountCriterionHandler(qb, filter.TagCount)) - query.handleCriterion(performerSceneCountCriterionHandler(qb, filter.SceneCount)) - query.handleCriterion(performerImageCountCriterionHandler(qb, filter.ImageCount)) - query.handleCriterion(performerGalleryCountCriterionHandler(qb, filter.GalleryCount)) + query.handleCriterion(ctx, performerTagCountCriterionHandler(qb, filter.TagCount)) + query.handleCriterion(ctx, performerSceneCountCriterionHandler(qb, filter.SceneCount)) + query.handleCriterion(ctx, performerImageCountCriterionHandler(qb, filter.ImageCount)) + query.handleCriterion(ctx, performerGalleryCountCriterionHandler(qb, filter.GalleryCount)) return query } -func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { +func (qb *performerQueryBuilder) Query(ctx context.Context, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { if performerFilter == nil { performerFilter = &models.PerformerFilterType{} } @@ -316,19 +314,19 @@ func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterTy if err := qb.validateFilter(performerFilter); err != nil { return nil, 0, err } - filter := qb.makeFilter(performerFilter) + filter := qb.makeFilter(ctx, performerFilter) query.addFilter(filter) query.sortAndPagination = qb.getPerformerSort(findFilter) + getPagination(findFilter) - idsResult, countResult, err := query.executeFind() + idsResult, countResult, err := query.executeFind(ctx) if err != nil { return nil, 0, err } var performers []*models.Performer for _, id := range idsResult { - performer, err := qb.Find(id) + performer, err := qb.Find(ctx, id) if err != nil { return nil, 0, err } @@ -339,7 +337,7 @@ func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterTy } func performerIsMissingCriterionHandler(qb *performerQueryBuilder, isMissing *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if isMissing != nil && *isMissing != "" { switch *isMissing { case "scenes": // Deprecated: use `scene_count == 0` filter instead @@ -359,7 +357,7 @@ func performerIsMissingCriterionHandler(qb *performerQueryBuilder, isMissing *st } func yearFilterCriterionHandler(year *models.IntCriterionInput, col string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if year != nil && year.Modifier.IsValid() { clause, args := getIntCriterionWhereClause("cast(strftime('%Y', "+col+") as int)", *year) f.addWhere(clause, args...) @@ -368,7 +366,7 @@ func yearFilterCriterionHandler(year *models.IntCriterionInput, col string) crit } func performerAgeFilterCriterionHandler(age *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if age != nil && age.Modifier.IsValid() { clause, args := getIntCriterionWhereClause( "cast(IFNULL(strftime('%Y.%m%d', performers.death_date), strftime('%Y.%m%d', 'now')) - strftime('%Y.%m%d', performers.birthdate) as int)", @@ -437,7 +435,7 @@ func performerGalleryCountCriterionHandler(qb *performerQueryBuilder, count *mod } func performerStudiosCriterionHandler(qb *performerQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if studios != nil { formatMaps := []utils.StrFormatMap{ { @@ -493,7 +491,7 @@ func performerStudiosCriterionHandler(qb *performerQueryBuilder, studios *models } const derivedPerformerStudioTable = "performer_studio" - valuesClause := getHierarchicalValues(qb.tx, studios.Value, studioTable, "", "parent_id", studios.Depth) + valuesClause := getHierarchicalValues(ctx, qb.tx, studios.Value, studioTable, "", "parent_id", studios.Depth) f.addWith("studio(root_id, item_id) AS (" + valuesClause + ")") templStr := `SELECT performer_id FROM {primaryTable} @@ -540,9 +538,9 @@ func (qb *performerQueryBuilder) getPerformerSort(findFilter *models.FindFilterT return getSort(sort, direction, "performers") } -func (qb *performerQueryBuilder) queryPerformers(query string, args []interface{}) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) queryPerformers(ctx context.Context, query string, args []interface{}) ([]*models.Performer, error) { var ret models.Performers - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } @@ -560,13 +558,13 @@ func (qb *performerQueryBuilder) tagsRepository() *joinRepository { } } -func (qb *performerQueryBuilder) GetTagIDs(id int) ([]int, error) { - return qb.tagsRepository().getIDs(id) +func (qb *performerQueryBuilder) GetTagIDs(ctx context.Context, id int) ([]int, error) { + return qb.tagsRepository().getIDs(ctx, id) } -func (qb *performerQueryBuilder) UpdateTags(id int, tagIDs []int) error { +func (qb *performerQueryBuilder) UpdateTags(ctx context.Context, id int, tagIDs []int) error { // Delete the existing joins and then create new ones - return qb.tagsRepository().replace(id, tagIDs) + return qb.tagsRepository().replace(ctx, id, tagIDs) } func (qb *performerQueryBuilder) imageRepository() *imageRepository { @@ -580,16 +578,16 @@ func (qb *performerQueryBuilder) imageRepository() *imageRepository { } } -func (qb *performerQueryBuilder) GetImage(performerID int) ([]byte, error) { - return qb.imageRepository().get(performerID) +func (qb *performerQueryBuilder) GetImage(ctx context.Context, performerID int) ([]byte, error) { + return qb.imageRepository().get(ctx, performerID) } -func (qb *performerQueryBuilder) UpdateImage(performerID int, image []byte) error { - return qb.imageRepository().replace(performerID, image) +func (qb *performerQueryBuilder) UpdateImage(ctx context.Context, performerID int, image []byte) error { + return qb.imageRepository().replace(ctx, performerID, image) } -func (qb *performerQueryBuilder) DestroyImage(performerID int) error { - return qb.imageRepository().destroy([]int{performerID}) +func (qb *performerQueryBuilder) DestroyImage(ctx context.Context, performerID int) error { + return qb.imageRepository().destroy(ctx, []int{performerID}) } func (qb *performerQueryBuilder) stashIDRepository() *stashIDRepository { @@ -602,25 +600,25 @@ func (qb *performerQueryBuilder) stashIDRepository() *stashIDRepository { } } -func (qb *performerQueryBuilder) GetStashIDs(performerID int) ([]*models.StashID, error) { - return qb.stashIDRepository().get(performerID) +func (qb *performerQueryBuilder) GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error) { + return qb.stashIDRepository().get(ctx, performerID) } -func (qb *performerQueryBuilder) UpdateStashIDs(performerID int, stashIDs []models.StashID) error { - return qb.stashIDRepository().replace(performerID, stashIDs) +func (qb *performerQueryBuilder) UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error { + return qb.stashIDRepository().replace(ctx, performerID, stashIDs) } -func (qb *performerQueryBuilder) FindByStashID(stashID models.StashID) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performer_stash_ids on performer_stash_ids.performer_id = performers.id WHERE performer_stash_ids.stash_id = ? AND performer_stash_ids.endpoint = ? ` args := []interface{}{stashID.StashID, stashID.Endpoint} - return qb.queryPerformers(query, args) + return qb.queryPerformers(ctx, query, args) } -func (qb *performerQueryBuilder) FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) { +func (qb *performerQueryBuilder) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performer_stash_ids on performer_stash_ids.performer_id = performers.id ` @@ -637,5 +635,5 @@ func (qb *performerQueryBuilder) FindByStashIDStatus(hasStashID bool, stashboxEn } args := []interface{}{stashboxEndpoint} - return qb.queryPerformers(query, args) + return qb.queryPerformers(ctx, query, args) } diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index a6839f573..7be6eb4fd 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -4,6 +4,7 @@ package sqlite_test import ( + "context" "database/sql" "fmt" "math" @@ -16,14 +17,15 @@ import ( "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" ) func TestPerformerFindBySceneID(t *testing.T) { - withTxn(func(r models.Repository) error { - pqb := r.Performer() + withTxn(func(ctx context.Context) error { + pqb := sqlite.PerformerReaderWriter sceneID := sceneIDs[sceneIdxWithPerformer] - performers, err := pqb.FindBySceneID(sceneID) + performers, err := pqb.FindBySceneID(ctx, sceneID) if err != nil { t.Errorf("Error finding performer: %s", err.Error()) @@ -34,7 +36,7 @@ func TestPerformerFindBySceneID(t *testing.T) { assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name.String) - performers, err = pqb.FindBySceneID(0) + performers, err = pqb.FindBySceneID(ctx, 0) if err != nil { t.Errorf("Error finding performer: %s", err.Error()) @@ -55,21 +57,21 @@ func TestPerformerFindByNames(t *testing.T) { return ret } - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { var names []string - pqb := r.Performer() + pqb := sqlite.PerformerReaderWriter names = append(names, performerNames[performerIdxWithScene]) // find performers by names - performers, err := pqb.FindByNames(names, false) + performers, err := pqb.FindByNames(ctx, names, false) if err != nil { t.Errorf("Error finding performers: %s", err.Error()) } assert.Len(t, performers, 1) assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String) - performers, err = pqb.FindByNames(names, true) // find performers by names nocase + performers, err = pqb.FindByNames(ctx, names, true) // find performers by names nocase if err != nil { t.Errorf("Error finding performers: %s", err.Error()) } @@ -79,14 +81,14 @@ func TestPerformerFindByNames(t *testing.T) { names = append(names, performerNames[performerIdx1WithScene]) // find performers by names ( 2 names ) - performers, err = pqb.FindByNames(names, false) + performers, err = pqb.FindByNames(ctx, names, false) if err != nil { t.Errorf("Error finding performers: %s", err.Error()) } retNames := getNames(performers) assert.Equal(t, names, retNames) - performers, err = pqb.FindByNames(names, true) // find performers by names ( 2 names nocase) + performers, err = pqb.FindByNames(ctx, names, true) // find performers by names ( 2 names nocase) if err != nil { t.Errorf("Error finding performers: %s", err.Error()) } @@ -122,10 +124,10 @@ func TestPerformerQueryEthnicityOr(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Len(t, performers, 2) assert.Equal(t, performer1Eth, performers[0].Ethnicity.String) @@ -153,10 +155,10 @@ func TestPerformerQueryEthnicityAndRating(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Len(t, performers, 1) assert.Equal(t, performerEth, performers[0].Ethnicity.String) @@ -188,10 +190,10 @@ func TestPerformerQueryEthnicityNotRating(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) for _, performer := range performers { verifyString(t, performer.Ethnicity.String, ethCriterion) @@ -219,20 +221,20 @@ func TestPerformerIllegalQuery(t *testing.T) { Or: &subFilter, } - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter - _, _, err := sqb.Query(performerFilter, nil) + _, _, err := sqb.Query(ctx, performerFilter, nil) assert.NotNil(err) performerFilter.Or = nil performerFilter.Not = &subFilter - _, _, err = sqb.Query(performerFilter, nil) + _, _, err = sqb.Query(ctx, performerFilter, nil) assert.NotNil(err) performerFilter.And = nil performerFilter.Or = &subFilter - _, _, err = sqb.Query(performerFilter, nil) + _, _, err = sqb.Query(ctx, performerFilter, nil) assert.NotNil(err) return nil @@ -240,15 +242,15 @@ func TestPerformerIllegalQuery(t *testing.T) { } func TestPerformerQueryIgnoreAutoTag(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { ignoreAutoTag := true performerFilter := models.PerformerFilterType{ IgnoreAutoTag: &ignoreAutoTag, } - sqb := r.Performer() + sqb := sqlite.PerformerReaderWriter - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Len(t, performers, int(math.Ceil(float64(totalPerformers)/5))) for _, p := range performers { @@ -260,12 +262,12 @@ func TestPerformerQueryIgnoreAutoTag(t *testing.T) { } func TestPerformerQueryForAutoTag(t *testing.T) { - withTxn(func(r models.Repository) error { - tqb := r.Performer() + withTxn(func(ctx context.Context) error { + tqb := sqlite.PerformerReaderWriter name := performerNames[performerIdx1WithScene] // find a performer by name - performers, err := tqb.QueryForAutoTag([]string{name}) + performers, err := tqb.QueryForAutoTag(ctx, []string{name}) if err != nil { t.Errorf("Error finding performers: %s", err.Error()) @@ -280,8 +282,8 @@ func TestPerformerQueryForAutoTag(t *testing.T) { } func TestPerformerUpdatePerformerImage(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Performer() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.PerformerReaderWriter // create performer to test against const name = "TestPerformerUpdatePerformerImage" @@ -290,26 +292,26 @@ func TestPerformerUpdatePerformerImage(t *testing.T) { Checksum: md5.FromString(name), Favorite: sql.NullBool{Bool: false, Valid: true}, } - created, err := qb.Create(performer) + created, err := qb.Create(ctx, performer) if err != nil { return fmt.Errorf("Error creating performer: %s", err.Error()) } image := []byte("image") - err = qb.UpdateImage(created.ID, image) + err = qb.UpdateImage(ctx, created.ID, image) if err != nil { return fmt.Errorf("Error updating performer image: %s", err.Error()) } // ensure image set - storedImage, err := qb.GetImage(created.ID) + storedImage, err := qb.GetImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } assert.Equal(t, storedImage, image) // set nil image - err = qb.UpdateImage(created.ID, nil) + err = qb.UpdateImage(ctx, created.ID, nil) if err == nil { return fmt.Errorf("Expected error setting nil image") } @@ -321,8 +323,8 @@ func TestPerformerUpdatePerformerImage(t *testing.T) { } func TestPerformerDestroyPerformerImage(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Performer() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.PerformerReaderWriter // create performer to test against const name = "TestPerformerDestroyPerformerImage" @@ -331,24 +333,24 @@ func TestPerformerDestroyPerformerImage(t *testing.T) { Checksum: md5.FromString(name), Favorite: sql.NullBool{Bool: false, Valid: true}, } - created, err := qb.Create(performer) + created, err := qb.Create(ctx, performer) if err != nil { return fmt.Errorf("Error creating performer: %s", err.Error()) } image := []byte("image") - err = qb.UpdateImage(created.ID, image) + err = qb.UpdateImage(ctx, created.ID, image) if err != nil { return fmt.Errorf("Error updating performer image: %s", err.Error()) } - err = qb.DestroyImage(created.ID) + err = qb.DestroyImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error destroying performer image: %s", err.Error()) } // image should be nil - storedImage, err := qb.GetImage(created.ID) + storedImage, err := qb.GetImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } @@ -380,13 +382,13 @@ func TestPerformerQueryAge(t *testing.T) { } func verifyPerformerAge(t *testing.T, ageCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Performer() + withTxn(func(ctx context.Context) error { + qb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ Age: &ageCriterion, } - performers, _, err := qb.Query(&performerFilter, nil) + performers, _, err := qb.Query(ctx, &performerFilter, nil) if err != nil { t.Errorf("Error querying performer: %s", err.Error()) } @@ -433,13 +435,13 @@ func TestPerformerQueryCareerLength(t *testing.T) { } func verifyPerformerCareerLength(t *testing.T, criterion models.StringCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Performer() + withTxn(func(ctx context.Context) error { + qb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ CareerLength: &criterion, } - performers, _, err := qb.Query(&performerFilter, nil) + performers, _, err := qb.Query(ctx, &performerFilter, nil) if err != nil { t.Errorf("Error querying performer: %s", err.Error()) } @@ -492,11 +494,11 @@ func TestPerformerQueryURL(t *testing.T) { } func verifyPerformerQuery(t *testing.T, filter models.PerformerFilterType, verifyFn func(s *models.Performer)) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { t.Helper() - sqb := r.Performer() + sqb := sqlite.PerformerReaderWriter - performers := queryPerformers(t, sqb, &filter, nil) + performers := queryPerformers(ctx, t, sqb, &filter, nil) // assume it should find at least one assert.Greater(t, len(performers), 0) @@ -509,8 +511,8 @@ func verifyPerformerQuery(t *testing.T, filter models.PerformerFilterType, verif }) } -func queryPerformers(t *testing.T, qb models.PerformerReader, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) []*models.Performer { - performers, _, err := qb.Query(performerFilter, findFilter) +func queryPerformers(ctx context.Context, t *testing.T, qb models.PerformerReader, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) []*models.Performer { + performers, _, err := qb.Query(ctx, performerFilter, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -519,8 +521,8 @@ func queryPerformers(t *testing.T, qb models.PerformerReader, performerFilter *m } func TestPerformerQueryTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), @@ -534,7 +536,7 @@ func TestPerformerQueryTags(t *testing.T) { } // ensure ids are correct - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Len(t, performers, 2) for _, performer := range performers { assert.True(t, performer.ID == performerIDs[performerIdxWithTag] || performer.ID == performerIDs[performerIdxWithTwoTags]) @@ -548,7 +550,7 @@ func TestPerformerQueryTags(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - performers = queryPerformers(t, sqb, &performerFilter, nil) + performers = queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Len(t, performers, 1) assert.Equal(t, sceneIDs[performerIdxWithTwoTags], performers[0].ID) @@ -565,7 +567,7 @@ func TestPerformerQueryTags(t *testing.T) { Q: &q, } - performers = queryPerformers(t, sqb, &performerFilter, &findFilter) + performers = queryPerformers(ctx, t, sqb, &performerFilter, &findFilter) assert.Len(t, performers, 0) return nil @@ -592,17 +594,17 @@ func TestPerformerQueryTagCount(t *testing.T) { } func verifyPerformersTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ TagCount: &tagCountCriterion, } - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Greater(t, len(performers), 0) for _, performer := range performers { - ids, err := sqb.GetTagIDs(performer.ID) + ids, err := sqb.GetTagIDs(ctx, performer.ID) if err != nil { return err } @@ -633,17 +635,17 @@ func TestPerformerQuerySceneCount(t *testing.T) { } func verifyPerformersSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ SceneCount: &sceneCountCriterion, } - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Greater(t, len(performers), 0) for _, performer := range performers { - ids, err := r.Scene().FindByPerformerID(performer.ID) + ids, err := sqlite.SceneReaderWriter.FindByPerformerID(ctx, performer.ID) if err != nil { return err } @@ -674,19 +676,19 @@ func TestPerformerQueryImageCount(t *testing.T) { } func verifyPerformersImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ ImageCount: &imageCountCriterion, } - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Greater(t, len(performers), 0) for _, performer := range performers { pp := 0 - result, err := r.Image().Query(models.ImageQueryOptions{ + result, err := sqlite.ImageReaderWriter.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: &models.FindFilterType{ PerPage: &pp, @@ -730,19 +732,19 @@ func TestPerformerQueryGalleryCount(t *testing.T) { } func verifyPerformersGalleryCount(t *testing.T, galleryCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ GalleryCount: &galleryCountCriterion, } - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Greater(t, len(performers), 0) for _, performer := range performers { pp := 0 - _, count, err := r.Gallery().Query(&models.GalleryFilterType{ + _, count, err := sqlite.GalleryReaderWriter.Query(ctx, &models.GalleryFilterType{ Performers: &models.MultiCriterionInput{ Value: []string{strconv.Itoa(performer.ID)}, Modifier: models.CriterionModifierIncludes, @@ -761,7 +763,7 @@ func verifyPerformersGalleryCount(t *testing.T, galleryCountCriterion models.Int } func TestPerformerQueryStudio(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { testCases := []struct { studioIndex int performerIndex int @@ -771,7 +773,7 @@ func TestPerformerQueryStudio(t *testing.T) { {studioIndex: studioIdxWithGalleryPerformer, performerIndex: performerIdxWithGalleryStudio}, } - sqb := r.Performer() + sqb := sqlite.PerformerReaderWriter for _, tc := range testCases { studioCriterion := models.HierarchicalMultiCriterionInput{ @@ -785,7 +787,7 @@ func TestPerformerQueryStudio(t *testing.T) { Studios: &studioCriterion, } - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.Len(t, performers, 1) @@ -804,7 +806,7 @@ func TestPerformerQueryStudio(t *testing.T) { Q: &q, } - performers = queryPerformers(t, sqb, &performerFilter, &findFilter) + performers = queryPerformers(ctx, t, sqb, &performerFilter, &findFilter) assert.Len(t, performers, 0) } @@ -819,21 +821,21 @@ func TestPerformerQueryStudio(t *testing.T) { Q: &q, } - performers := queryPerformers(t, sqb, performerFilter, findFilter) + performers := queryPerformers(ctx, t, sqb, performerFilter, findFilter) assert.Len(t, performers, 1) assert.Equal(t, imageIDs[performerIdx1WithImage], performers[0].ID) q = getPerformerStringValue(performerIdxWithSceneStudio, "Name") - performers = queryPerformers(t, sqb, performerFilter, findFilter) + performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter) assert.Len(t, performers, 0) performerFilter.Studios.Modifier = models.CriterionModifierNotNull - performers = queryPerformers(t, sqb, performerFilter, findFilter) + performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter) assert.Len(t, performers, 1) assert.Equal(t, imageIDs[performerIdxWithSceneStudio], performers[0].ID) q = getPerformerStringValue(performerIdx1WithImage, "Name") - performers = queryPerformers(t, sqb, performerFilter, findFilter) + performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter) assert.Len(t, performers, 0) return nil @@ -841,8 +843,8 @@ func TestPerformerQueryStudio(t *testing.T) { } func TestPerformerStashIDs(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Performer() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.PerformerReaderWriter // create performer to test against const name = "TestStashIDs" @@ -851,12 +853,12 @@ func TestPerformerStashIDs(t *testing.T) { Checksum: md5.FromString(name), Favorite: sql.NullBool{Bool: false, Valid: true}, } - created, err := qb.Create(performer) + created, err := qb.Create(ctx, performer) if err != nil { return fmt.Errorf("Error creating performer: %s", err.Error()) } - testStashIDReaderWriter(t, qb, created.ID) + testStashIDReaderWriter(ctx, t, qb, created.ID) return nil }); err != nil { t.Error(err.Error()) @@ -888,13 +890,13 @@ func TestPerformerQueryRating(t *testing.T) { } func verifyPerformersRating(t *testing.T, ratingCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ Rating: &ratingCriterion, } - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) for _, performer := range performers { verifyInt64(t, performer.Rating, ratingCriterion) @@ -905,14 +907,14 @@ func verifyPerformersRating(t *testing.T, ratingCriterion models.IntCriterionInp } func TestPerformerQueryIsMissingRating(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Performer() + withTxn(func(ctx context.Context) error { + sqb := sqlite.PerformerReaderWriter isMissing := "rating" performerFilter := models.PerformerFilterType{ IsMissing: &isMissing, } - performers := queryPerformers(t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) assert.True(t, len(performers) > 0) @@ -925,14 +927,14 @@ func TestPerformerQueryIsMissingRating(t *testing.T) { } func TestPerformerQueryIsMissingImage(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { isMissing := "image" performerFilter := &models.PerformerFilterType{ IsMissing: &isMissing, } // ensure query does not error - performers, _, err := r.Performer().Query(performerFilter, nil) + performers, _, err := sqlite.PerformerReaderWriter.Query(ctx, performerFilter, nil) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -940,7 +942,7 @@ func TestPerformerQueryIsMissingImage(t *testing.T) { assert.True(t, len(performers) > 0) for _, performer := range performers { - img, err := r.Performer().GetImage(performer.ID) + img, err := sqlite.PerformerReaderWriter.GetImage(ctx, performer.ID) if err != nil { t.Errorf("error getting performer image: %s", err.Error()) } @@ -959,9 +961,9 @@ func TestPerformerQuerySortScenesCount(t *testing.T) { Direction: &direction, } - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { // just ensure it queries without error - performers, _, err := r.Performer().Query(nil, findFilter) + performers, _, err := sqlite.PerformerReaderWriter.Query(ctx, nil, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -976,7 +978,7 @@ func TestPerformerQuerySortScenesCount(t *testing.T) { // sort in ascending order direction = models.SortDirectionEnumAsc - performers, _, err = r.Performer().Query(nil, findFilter) + performers, _, err = sqlite.PerformerReaderWriter.Query(ctx, nil, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go index 27ce213b5..25cb2b2b1 100644 --- a/pkg/sqlite/query.go +++ b/pkg/sqlite/query.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "fmt" "strings" @@ -54,24 +55,24 @@ func (qb queryBuilder) toSQL(includeSortPagination bool) string { return body } -func (qb queryBuilder) findIDs() ([]int, error) { +func (qb queryBuilder) findIDs(ctx context.Context) ([]int, error) { const includeSortPagination = true sql := qb.toSQL(includeSortPagination) logger.Tracef("SQL: %s, args: %v", sql, qb.args) - return qb.repository.runIdsQuery(sql, qb.args) + return qb.repository.runIdsQuery(ctx, sql, qb.args) } -func (qb queryBuilder) executeFind() ([]int, int, error) { +func (qb queryBuilder) executeFind(ctx context.Context) ([]int, int, error) { if qb.err != nil { return nil, 0, qb.err } body := qb.body() - return qb.repository.executeFindQuery(body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith) + return qb.repository.executeFindQuery(ctx, body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith) } -func (qb queryBuilder) executeCount() (int, error) { +func (qb queryBuilder) executeCount(ctx context.Context) (int, error) { if qb.err != nil { return 0, qb.err } @@ -89,7 +90,7 @@ func (qb queryBuilder) executeCount() (int, error) { body = qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses) countQuery := withClause + qb.repository.buildCountQuery(body) - return qb.repository.runCountQuery(countQuery, qb.args) + return qb.repository.runCountQuery(ctx, countQuery, qb.args) } func (qb *queryBuilder) addWhere(clauses ...string) { diff --git a/pkg/database/regex.go b/pkg/sqlite/regex.go similarity index 98% rename from pkg/database/regex.go rename to pkg/sqlite/regex.go index dc7b5feb5..bbf713ae3 100644 --- a/pkg/database/regex.go +++ b/pkg/sqlite/regex.go @@ -1,4 +1,4 @@ -package database +package sqlite import ( "regexp" diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go index f195d9c7e..5bc1fe782 100644 --- a/pkg/sqlite/repository.go +++ b/pkg/sqlite/repository.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -26,23 +27,23 @@ type repository struct { idColumn string } -func (r *repository) get(id int, dest interface{}) error { +func (r *repository) getByID(ctx context.Context, id int, dest interface{}) error { stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ? LIMIT 1", r.tableName, r.idColumn) - return r.tx.Get(dest, stmt, id) + return r.tx.Get(ctx, dest, stmt, id) } -func (r *repository) getAll(id int, f func(rows *sqlx.Rows) error) error { +func (r *repository) getAll(ctx context.Context, id int, f func(rows *sqlx.Rows) error) error { stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", r.tableName, r.idColumn) - return r.queryFunc(stmt, []interface{}{id}, false, f) + return r.queryFunc(ctx, stmt, []interface{}{id}, false, f) } -func (r *repository) insert(obj interface{}) (sql.Result, error) { +func (r *repository) insert(ctx context.Context, obj interface{}) (sql.Result, error) { stmt := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", r.tableName, listKeys(obj, false), listKeys(obj, true)) - return r.tx.NamedExec(stmt, obj) + return r.tx.NamedExec(ctx, stmt, obj) } -func (r *repository) insertObject(obj interface{}, out interface{}) error { - result, err := r.insert(obj) +func (r *repository) insertObject(ctx context.Context, obj interface{}, out interface{}) error { + result, err := r.insert(ctx, obj) if err != nil { return err } @@ -50,11 +51,11 @@ func (r *repository) insertObject(obj interface{}, out interface{}) error { if err != nil { return err } - return r.get(int(id), out) + return r.getByID(ctx, int(id), out) } -func (r *repository) update(id int, obj interface{}, partial bool) error { - exists, err := r.exists(id) +func (r *repository) update(ctx context.Context, id int, obj interface{}, partial bool) error { + exists, err := r.exists(ctx, id) if err != nil { return err } @@ -64,13 +65,13 @@ func (r *repository) update(id int, obj interface{}, partial bool) error { } stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSet(obj, partial), r.tableName, r.idColumn) - _, err = r.tx.NamedExec(stmt, obj) + _, err = r.tx.NamedExec(ctx, stmt, obj) return err } -func (r *repository) updateMap(id int, m map[string]interface{}) error { - exists, err := r.exists(id) +func (r *repository) updateMap(ctx context.Context, id int, m map[string]interface{}) error { + exists, err := r.exists(ctx, id) if err != nil { return err } @@ -80,14 +81,14 @@ func (r *repository) updateMap(id int, m map[string]interface{}) error { } stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSetMap(m), r.tableName, r.idColumn) - _, err = r.tx.NamedExec(stmt, m) + _, err = r.tx.NamedExec(ctx, stmt, m) return err } -func (r *repository) destroyExisting(ids []int) error { +func (r *repository) destroyExisting(ctx context.Context, ids []int) error { for _, id := range ids { - exists, err := r.exists(id) + exists, err := r.exists(ctx, id) if err != nil { return err } @@ -97,13 +98,13 @@ func (r *repository) destroyExisting(ids []int) error { } } - return r.destroy(ids) + return r.destroy(ctx, ids) } -func (r *repository) destroy(ids []int) error { +func (r *repository) destroy(ctx context.Context, ids []int) error { for _, id := range ids { stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", r.tableName, r.idColumn) - if _, err := r.tx.Exec(stmt, id); err != nil { + if _, err := r.tx.Exec(ctx, stmt, id); err != nil { return err } } @@ -111,11 +112,11 @@ func (r *repository) destroy(ids []int) error { return nil } -func (r *repository) exists(id int) (bool, error) { +func (r *repository) exists(ctx context.Context, id int) (bool, error) { stmt := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ? LIMIT 1", r.idColumn, r.tableName, r.idColumn) stmt = r.buildCountQuery(stmt) - c, err := r.runCountQuery(stmt, []interface{}{id}) + c, err := r.runCountQuery(ctx, stmt, []interface{}{id}) if err != nil { return false, err } @@ -127,25 +128,25 @@ func (r *repository) buildCountQuery(query string) string { return "SELECT COUNT(*) as count FROM (" + query + ") as temp" } -func (r *repository) runCountQuery(query string, args []interface{}) (int, error) { +func (r *repository) runCountQuery(ctx context.Context, query string, args []interface{}) (int, error) { result := struct { Int int `db:"count"` }{0} // Perform query and fetch result - if err := r.tx.Get(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) { + if err := r.tx.Get(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) { return 0, err } return result.Int, nil } -func (r *repository) runIdsQuery(query string, args []interface{}) ([]int, error) { +func (r *repository) runIdsQuery(ctx context.Context, query string, args []interface{}) ([]int, error) { var result []struct { Int int `db:"id"` } - if err := r.tx.Select(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) { + if err := r.tx.Select(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) { return []int{}, err } @@ -156,24 +157,24 @@ func (r *repository) runIdsQuery(query string, args []interface{}) ([]int, error return vsm, nil } -func (r *repository) runSumQuery(query string, args []interface{}) (float64, error) { +func (r *repository) runSumQuery(ctx context.Context, query string, args []interface{}) (float64, error) { // Perform query and fetch result result := struct { Float64 float64 `db:"sum"` }{0} // Perform query and fetch result - if err := r.tx.Get(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) { + if err := r.tx.Get(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) { return 0, err } return result.Float64, nil } -func (r *repository) queryFunc(query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error { +func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error { logger.Tracef("SQL: %s, args: %v", query, args) - rows, err := r.tx.Queryx(query, args...) + rows, err := r.tx.Queryx(ctx, query, args...) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err @@ -196,8 +197,8 @@ func (r *repository) queryFunc(query string, args []interface{}, single bool, f return nil } -func (r *repository) query(query string, args []interface{}, out objectList) error { - return r.queryFunc(query, args, false, func(rows *sqlx.Rows) error { +func (r *repository) query(ctx context.Context, query string, args []interface{}, out objectList) error { + return r.queryFunc(ctx, query, args, false, func(rows *sqlx.Rows) error { object := out.New() if err := rows.StructScan(object); err != nil { return err @@ -207,8 +208,8 @@ func (r *repository) query(query string, args []interface{}, out objectList) err }) } -func (r *repository) queryStruct(query string, args []interface{}, out interface{}) error { - return r.queryFunc(query, args, true, func(rows *sqlx.Rows) error { +func (r *repository) queryStruct(ctx context.Context, query string, args []interface{}, out interface{}) error { + return r.queryFunc(ctx, query, args, true, func(rows *sqlx.Rows) error { if err := rows.StructScan(out); err != nil { return err } @@ -216,8 +217,8 @@ func (r *repository) queryStruct(query string, args []interface{}, out interface }) } -func (r *repository) querySimple(query string, args []interface{}, out interface{}) error { - rows, err := r.tx.Queryx(query, args...) +func (r *repository) querySimple(ctx context.Context, query string, args []interface{}, out interface{}) error { + rows, err := r.tx.Queryx(ctx, query, args...) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err @@ -249,7 +250,7 @@ func (r *repository) buildQueryBody(body string, whereClauses []string, havingCl return body } -func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) { +func (r *repository) executeFindQuery(ctx context.Context, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) { body = r.buildQueryBody(body, whereClauses, havingClauses) withClause := "" @@ -272,8 +273,8 @@ func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPa var idsResult []int var idsErr error - countResult, countErr = r.runCountQuery(countQuery, args) - idsResult, idsErr = r.runIdsQuery(idsQuery, args) + countResult, countErr = r.runCountQuery(ctx, countQuery, args) + idsResult, idsErr = r.runIdsQuery(ctx, idsQuery, args) if countErr != nil { return nil, 0, fmt.Errorf("error executing count query with SQL: %s, args: %v, error: %s", countQuery, args, countErr.Error()) @@ -318,23 +319,23 @@ type joinRepository struct { fkColumn string } -func (r *joinRepository) getIDs(id int) ([]int, error) { +func (r *joinRepository) getIDs(ctx context.Context, id int) ([]int, error) { query := fmt.Sprintf(`SELECT %s as id from %s WHERE %s = ?`, r.fkColumn, r.tableName, r.idColumn) - return r.runIdsQuery(query, []interface{}{id}) + return r.runIdsQuery(ctx, query, []interface{}{id}) } -func (r *joinRepository) insert(id, foreignID int) (sql.Result, error) { +func (r *joinRepository) insert(ctx context.Context, id, foreignID int) (sql.Result, error) { stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.fkColumn) - return r.tx.Exec(stmt, id, foreignID) + return r.tx.Exec(ctx, stmt, id, foreignID) } -func (r *joinRepository) replace(id int, foreignIDs []int) error { - if err := r.destroy([]int{id}); err != nil { +func (r *joinRepository) replace(ctx context.Context, id int, foreignIDs []int) error { + if err := r.destroy(ctx, []int{id}); err != nil { return err } for _, fk := range foreignIDs { - if _, err := r.insert(id, fk); err != nil { + if _, err := r.insert(ctx, id, fk); err != nil { return err } } @@ -347,20 +348,20 @@ type imageRepository struct { imageColumn string } -func (r *imageRepository) get(id int) ([]byte, error) { +func (r *imageRepository) get(ctx context.Context, id int) ([]byte, error) { query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.imageColumn, r.tableName, r.idColumn) var ret []byte - err := r.querySimple(query, []interface{}{id}, &ret) + err := r.querySimple(ctx, query, []interface{}{id}, &ret) return ret, err } -func (r *imageRepository) replace(id int, image []byte) error { - if err := r.destroy([]int{id}); err != nil { +func (r *imageRepository) replace(ctx context.Context, id int, image []byte) error { + if err := r.destroy(ctx, []int{id}); err != nil { return err } stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.imageColumn) - _, err := r.tx.Exec(stmt, id, image) + _, err := r.tx.Exec(ctx, stmt, id, image) return err } @@ -369,10 +370,10 @@ type captionRepository struct { repository } -func (r *captionRepository) get(id int) ([]*models.SceneCaption, error) { +func (r *captionRepository) get(ctx context.Context, id int) ([]*models.SceneCaption, error) { query := fmt.Sprintf("SELECT %s, %s, %s from %s WHERE %s = ?", sceneCaptionCodeColumn, sceneCaptionFilenameColumn, sceneCaptionTypeColumn, r.tableName, r.idColumn) var ret []*models.SceneCaption - err := r.queryFunc(query, []interface{}{id}, false, func(rows *sqlx.Rows) error { + err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error { var captionCode string var captionFilename string var captionType string @@ -392,18 +393,18 @@ func (r *captionRepository) get(id int) ([]*models.SceneCaption, error) { return ret, err } -func (r *captionRepository) insert(id int, caption *models.SceneCaption) (sql.Result, error) { +func (r *captionRepository) insert(ctx context.Context, id int, caption *models.SceneCaption) (sql.Result, error) { stmt := fmt.Sprintf("INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?)", r.tableName, r.idColumn, sceneCaptionCodeColumn, sceneCaptionFilenameColumn, sceneCaptionTypeColumn) - return r.tx.Exec(stmt, id, caption.LanguageCode, caption.Filename, caption.CaptionType) + return r.tx.Exec(ctx, stmt, id, caption.LanguageCode, caption.Filename, caption.CaptionType) } -func (r *captionRepository) replace(id int, captions []*models.SceneCaption) error { - if err := r.destroy([]int{id}); err != nil { +func (r *captionRepository) replace(ctx context.Context, id int, captions []*models.SceneCaption) error { + if err := r.destroy(ctx, []int{id}); err != nil { return err } for _, caption := range captions { - if _, err := r.insert(id, caption); err != nil { + if _, err := r.insert(ctx, id, caption); err != nil { return err } } @@ -416,10 +417,10 @@ type stringRepository struct { stringColumn string } -func (r *stringRepository) get(id int) ([]string, error) { +func (r *stringRepository) get(ctx context.Context, id int) ([]string, error) { query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.stringColumn, r.tableName, r.idColumn) var ret []string - err := r.queryFunc(query, []interface{}{id}, false, func(rows *sqlx.Rows) error { + err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error { var out string if err := rows.Scan(&out); err != nil { return err @@ -431,18 +432,18 @@ func (r *stringRepository) get(id int) ([]string, error) { return ret, err } -func (r *stringRepository) insert(id int, s string) (sql.Result, error) { +func (r *stringRepository) insert(ctx context.Context, id int, s string) (sql.Result, error) { stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.stringColumn) - return r.tx.Exec(stmt, id, s) + return r.tx.Exec(ctx, stmt, id, s) } -func (r *stringRepository) replace(id int, newStrings []string) error { - if err := r.destroy([]int{id}); err != nil { +func (r *stringRepository) replace(ctx context.Context, id int, newStrings []string) error { + if err := r.destroy(ctx, []int{id}); err != nil { return err } for _, s := range newStrings { - if _, err := r.insert(id, s); err != nil { + if _, err := r.insert(ctx, id, s); err != nil { return err } } @@ -464,21 +465,21 @@ func (s *stashIDs) New() interface{} { return &models.StashID{} } -func (r *stashIDRepository) get(id int) ([]*models.StashID, error) { +func (r *stashIDRepository) get(ctx context.Context, id int) ([]*models.StashID, error) { query := fmt.Sprintf("SELECT stash_id, endpoint from %s WHERE %s = ?", r.tableName, r.idColumn) var ret stashIDs - err := r.query(query, []interface{}{id}, &ret) + err := r.query(ctx, query, []interface{}{id}, &ret) return []*models.StashID(ret), err } -func (r *stashIDRepository) replace(id int, newIDs []models.StashID) error { - if err := r.destroy([]int{id}); err != nil { +func (r *stashIDRepository) replace(ctx context.Context, id int, newIDs []models.StashID) error { + if err := r.destroy(ctx, []int{id}); err != nil { return err } query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn) for _, stashID := range newIDs { - _, err := r.tx.Exec(query, id, stashID.Endpoint, stashID.StashID) + _, err := r.tx.Exec(ctx, query, id, stashID.Endpoint, stashID.StashID) if err != nil { return err } diff --git a/pkg/sqlite/saved_filter.go b/pkg/sqlite/saved_filter.go index 6c507bee3..54fddcbc8 100644 --- a/pkg/sqlite/saved_filter.go +++ b/pkg/sqlite/saved_filter.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -15,42 +16,39 @@ type savedFilterQueryBuilder struct { repository } -func NewSavedFilterReaderWriter(tx dbi) *savedFilterQueryBuilder { - return &savedFilterQueryBuilder{ - repository{ - tx: tx, - tableName: savedFilterTable, - idColumn: idColumn, - }, - } +var SavedFilterReaderWriter = &savedFilterQueryBuilder{ + repository{ + tableName: savedFilterTable, + idColumn: idColumn, + }, } -func (qb *savedFilterQueryBuilder) Create(newObject models.SavedFilter) (*models.SavedFilter, error) { +func (qb *savedFilterQueryBuilder) Create(ctx context.Context, newObject models.SavedFilter) (*models.SavedFilter, error) { var ret models.SavedFilter - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *savedFilterQueryBuilder) Update(updatedObject models.SavedFilter) (*models.SavedFilter, error) { +func (qb *savedFilterQueryBuilder) Update(ctx context.Context, updatedObject models.SavedFilter) (*models.SavedFilter, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } var ret models.SavedFilter - if err := qb.get(updatedObject.ID, &ret); err != nil { + if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *savedFilterQueryBuilder) SetDefault(obj models.SavedFilter) (*models.SavedFilter, error) { +func (qb *savedFilterQueryBuilder) SetDefault(ctx context.Context, obj models.SavedFilter) (*models.SavedFilter, error) { // find the existing default - existing, err := qb.FindDefault(obj.Mode) + existing, err := qb.FindDefault(ctx, obj.Mode) if err != nil { return nil, err @@ -60,19 +58,19 @@ func (qb *savedFilterQueryBuilder) SetDefault(obj models.SavedFilter) (*models.S if existing != nil { obj.ID = existing.ID - return qb.Update(obj) + return qb.Update(ctx, obj) } - return qb.Create(obj) + return qb.Create(ctx, obj) } -func (qb *savedFilterQueryBuilder) Destroy(id int) error { - return qb.destroyExisting([]int{id}) +func (qb *savedFilterQueryBuilder) Destroy(ctx context.Context, id int) error { + return qb.destroyExisting(ctx, []int{id}) } -func (qb *savedFilterQueryBuilder) Find(id int) (*models.SavedFilter, error) { +func (qb *savedFilterQueryBuilder) Find(ctx context.Context, id int) (*models.SavedFilter, error) { var ret models.SavedFilter - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -81,10 +79,10 @@ func (qb *savedFilterQueryBuilder) Find(id int) (*models.SavedFilter, error) { return &ret, nil } -func (qb *savedFilterQueryBuilder) FindMany(ids []int, ignoreNotFound bool) ([]*models.SavedFilter, error) { +func (qb *savedFilterQueryBuilder) FindMany(ctx context.Context, ids []int, ignoreNotFound bool) ([]*models.SavedFilter, error) { var filters []*models.SavedFilter for _, id := range ids { - filter, err := qb.Find(id) + filter, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -99,24 +97,24 @@ func (qb *savedFilterQueryBuilder) FindMany(ids []int, ignoreNotFound bool) ([]* return filters, nil } -func (qb *savedFilterQueryBuilder) FindByMode(mode models.FilterMode) ([]*models.SavedFilter, error) { +func (qb *savedFilterQueryBuilder) FindByMode(ctx context.Context, mode models.FilterMode) ([]*models.SavedFilter, error) { // exclude empty-named filters - these are the internal default filters query := fmt.Sprintf(`SELECT * FROM %s WHERE mode = ? AND name != ?`, savedFilterTable) var ret models.SavedFilters - if err := qb.query(query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil { + if err := qb.query(ctx, query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil { return nil, err } return []*models.SavedFilter(ret), nil } -func (qb *savedFilterQueryBuilder) FindDefault(mode models.FilterMode) (*models.SavedFilter, error) { +func (qb *savedFilterQueryBuilder) FindDefault(ctx context.Context, mode models.FilterMode) (*models.SavedFilter, error) { query := fmt.Sprintf(`SELECT * FROM %s WHERE mode = ? AND name = ?`, savedFilterTable) var ret models.SavedFilters - if err := qb.query(query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil { + if err := qb.query(ctx, query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil { return nil, err } @@ -127,9 +125,9 @@ func (qb *savedFilterQueryBuilder) FindDefault(mode models.FilterMode) (*models. return nil, nil } -func (qb *savedFilterQueryBuilder) All() ([]*models.SavedFilter, error) { +func (qb *savedFilterQueryBuilder) All(ctx context.Context) ([]*models.SavedFilter, error) { var ret models.SavedFilters - if err := qb.query(selectAll(savedFilterTable), nil, &ret); err != nil { + if err := qb.query(ctx, selectAll(savedFilterTable), nil, &ret); err != nil { return nil, err } diff --git a/pkg/sqlite/saved_filter_test.go b/pkg/sqlite/saved_filter_test.go index 5ec049290..c22b374fb 100644 --- a/pkg/sqlite/saved_filter_test.go +++ b/pkg/sqlite/saved_filter_test.go @@ -4,15 +4,17 @@ package sqlite_test import ( + "context" "testing" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" "github.com/stretchr/testify/assert" ) func TestSavedFilterFind(t *testing.T) { - withTxn(func(r models.Repository) error { - savedFilter, err := r.SavedFilter().Find(savedFilterIDs[savedFilterIdxImage]) + withTxn(func(ctx context.Context) error { + savedFilter, err := sqlite.SavedFilterReaderWriter.Find(ctx, savedFilterIDs[savedFilterIdxImage]) if err != nil { t.Errorf("Error finding saved filter: %s", err.Error()) @@ -25,8 +27,8 @@ func TestSavedFilterFind(t *testing.T) { } func TestSavedFilterFindByMode(t *testing.T) { - withTxn(func(r models.Repository) error { - savedFilters, err := r.SavedFilter().FindByMode(models.FilterModeScenes) + withTxn(func(ctx context.Context) error { + savedFilters, err := sqlite.SavedFilterReaderWriter.FindByMode(ctx, models.FilterModeScenes) if err != nil { t.Errorf("Error finding saved filters: %s", err.Error()) @@ -45,8 +47,8 @@ func TestSavedFilterDestroy(t *testing.T) { var id int // create the saved filter to destroy - withTxn(func(r models.Repository) error { - created, err := r.SavedFilter().Create(models.SavedFilter{ + withTxn(func(ctx context.Context) error { + created, err := sqlite.SavedFilterReaderWriter.Create(ctx, models.SavedFilter{ Name: filterName, Mode: models.FilterModeScenes, Filter: testFilter, @@ -59,15 +61,15 @@ func TestSavedFilterDestroy(t *testing.T) { return err }) - withTxn(func(r models.Repository) error { - qb := r.SavedFilter() + withTxn(func(ctx context.Context) error { + qb := sqlite.SavedFilterReaderWriter - return qb.Destroy(id) + return qb.Destroy(ctx, id) }) // now try to find it - withTxn(func(r models.Repository) error { - found, err := r.SavedFilter().Find(id) + withTxn(func(ctx context.Context) error { + found, err := sqlite.SavedFilterReaderWriter.Find(ctx, id) if err == nil { assert.Nil(t, found) } @@ -77,8 +79,8 @@ func TestSavedFilterDestroy(t *testing.T) { } func TestSavedFilterFindDefault(t *testing.T) { - withTxn(func(r models.Repository) error { - def, err := r.SavedFilter().FindDefault(models.FilterModeScenes) + withTxn(func(ctx context.Context) error { + def, err := sqlite.SavedFilterReaderWriter.FindDefault(ctx, models.FilterModeScenes) if err == nil { assert.Equal(t, savedFilterIDs[savedFilterIdxDefaultScene], def.ID) } @@ -90,8 +92,8 @@ func TestSavedFilterFindDefault(t *testing.T) { func TestSavedFilterSetDefault(t *testing.T) { const newFilter = "foo" - withTxn(func(r models.Repository) error { - _, err := r.SavedFilter().SetDefault(models.SavedFilter{ + withTxn(func(ctx context.Context) error { + _, err := sqlite.SavedFilterReaderWriter.SetDefault(ctx, models.SavedFilter{ Mode: models.FilterModeMovies, Filter: newFilter, }) @@ -100,8 +102,8 @@ func TestSavedFilterSetDefault(t *testing.T) { }) var defID int - withTxn(func(r models.Repository) error { - def, err := r.SavedFilter().FindDefault(models.FilterModeMovies) + withTxn(func(ctx context.Context) error { + def, err := sqlite.SavedFilterReaderWriter.FindDefault(ctx, models.FilterModeMovies) if err == nil { defID = def.ID assert.Equal(t, newFilter, def.Filter) @@ -111,8 +113,8 @@ func TestSavedFilterSetDefault(t *testing.T) { }) // destroy it again - withTxn(func(r models.Repository) error { - return r.SavedFilter().Destroy(defID) + withTxn(func(ctx context.Context) error { + return sqlite.SavedFilterReaderWriter.Destroy(ctx, defID) }) } diff --git a/pkg/sqlite/scene.go b/pkg/sqlite/scene.go index cb5085dfd..921e2f4c3 100644 --- a/pkg/sqlite/scene.go +++ b/pkg/sqlite/scene.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -89,45 +90,42 @@ type sceneQueryBuilder struct { repository } -func NewSceneReaderWriter(tx dbi) *sceneQueryBuilder { - return &sceneQueryBuilder{ - repository{ - tx: tx, - tableName: sceneTable, - idColumn: idColumn, - }, - } +var SceneReaderWriter = &sceneQueryBuilder{ + repository{ + tableName: sceneTable, + idColumn: idColumn, + }, } -func (qb *sceneQueryBuilder) Create(newObject models.Scene) (*models.Scene, error) { +func (qb *sceneQueryBuilder) Create(ctx context.Context, newObject models.Scene) (*models.Scene, error) { var ret models.Scene - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *sceneQueryBuilder) Update(updatedObject models.ScenePartial) (*models.Scene, error) { +func (qb *sceneQueryBuilder) Update(ctx context.Context, updatedObject models.ScenePartial) (*models.Scene, error) { const partial = true - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.find(updatedObject.ID) + return qb.find(ctx, updatedObject.ID) } -func (qb *sceneQueryBuilder) UpdateFull(updatedObject models.Scene) (*models.Scene, error) { +func (qb *sceneQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Scene) (*models.Scene, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.find(updatedObject.ID) + return qb.find(ctx, updatedObject.ID) } -func (qb *sceneQueryBuilder) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error { - return qb.updateMap(id, map[string]interface{}{ +func (qb *sceneQueryBuilder) UpdateFileModTime(ctx context.Context, id int, modTime models.NullSQLiteTimestamp) error { + return qb.updateMap(ctx, id, map[string]interface{}{ "file_mod_time": modTime, }) } @@ -142,17 +140,17 @@ func (qb *sceneQueryBuilder) captionRepository() *captionRepository { } } -func (qb *sceneQueryBuilder) GetCaptions(sceneID int) ([]*models.SceneCaption, error) { - return qb.captionRepository().get(sceneID) +func (qb *sceneQueryBuilder) GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error) { + return qb.captionRepository().get(ctx, sceneID) } -func (qb *sceneQueryBuilder) UpdateCaptions(sceneID int, captions []*models.SceneCaption) error { - return qb.captionRepository().replace(sceneID, captions) +func (qb *sceneQueryBuilder) UpdateCaptions(ctx context.Context, sceneID int, captions []*models.SceneCaption) error { + return qb.captionRepository().replace(ctx, sceneID, captions) } -func (qb *sceneQueryBuilder) IncrementOCounter(id int) (int, error) { - _, err := qb.tx.Exec( +func (qb *sceneQueryBuilder) IncrementOCounter(ctx context.Context, id int) (int, error) { + _, err := qb.tx.Exec(ctx, `UPDATE scenes SET o_counter = o_counter + 1 WHERE scenes.id = ?`, id, ) @@ -160,7 +158,7 @@ func (qb *sceneQueryBuilder) IncrementOCounter(id int) (int, error) { return 0, err } - scene, err := qb.find(id) + scene, err := qb.find(ctx, id) if err != nil { return 0, err } @@ -168,8 +166,8 @@ func (qb *sceneQueryBuilder) IncrementOCounter(id int) (int, error) { return scene.OCounter, nil } -func (qb *sceneQueryBuilder) DecrementOCounter(id int) (int, error) { - _, err := qb.tx.Exec( +func (qb *sceneQueryBuilder) DecrementOCounter(ctx context.Context, id int) (int, error) { + _, err := qb.tx.Exec(ctx, `UPDATE scenes SET o_counter = o_counter - 1 WHERE scenes.id = ? and scenes.o_counter > 0`, id, ) @@ -177,7 +175,7 @@ func (qb *sceneQueryBuilder) DecrementOCounter(id int) (int, error) { return 0, err } - scene, err := qb.find(id) + scene, err := qb.find(ctx, id) if err != nil { return 0, err } @@ -185,8 +183,8 @@ func (qb *sceneQueryBuilder) DecrementOCounter(id int) (int, error) { return scene.OCounter, nil } -func (qb *sceneQueryBuilder) ResetOCounter(id int) (int, error) { - _, err := qb.tx.Exec( +func (qb *sceneQueryBuilder) ResetOCounter(ctx context.Context, id int) (int, error) { + _, err := qb.tx.Exec(ctx, `UPDATE scenes SET o_counter = 0 WHERE scenes.id = ?`, id, ) @@ -194,7 +192,7 @@ func (qb *sceneQueryBuilder) ResetOCounter(id int) (int, error) { return 0, err } - scene, err := qb.find(id) + scene, err := qb.find(ctx, id) if err != nil { return 0, err } @@ -202,27 +200,27 @@ func (qb *sceneQueryBuilder) ResetOCounter(id int) (int, error) { return scene.OCounter, nil } -func (qb *sceneQueryBuilder) Destroy(id int) error { +func (qb *sceneQueryBuilder) Destroy(ctx context.Context, id int) error { // delete all related table rows // TODO - this should be handled by a delete cascade - if err := qb.performersRepository().destroy([]int{id}); err != nil { + if err := qb.performersRepository().destroy(ctx, []int{id}); err != nil { return err } // scene markers should be handled prior to calling destroy // galleries should be handled prior to calling destroy - return qb.destroyExisting([]int{id}) + return qb.destroyExisting(ctx, []int{id}) } -func (qb *sceneQueryBuilder) Find(id int) (*models.Scene, error) { - return qb.find(id) +func (qb *sceneQueryBuilder) Find(ctx context.Context, id int) (*models.Scene, error) { + return qb.find(ctx, id) } -func (qb *sceneQueryBuilder) FindMany(ids []int) ([]*models.Scene, error) { +func (qb *sceneQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Scene, error) { var scenes []*models.Scene for _, id := range ids { - scene, err := qb.Find(id) + scene, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -237,9 +235,9 @@ func (qb *sceneQueryBuilder) FindMany(ids []int) ([]*models.Scene, error) { return scenes, nil } -func (qb *sceneQueryBuilder) find(id int) (*models.Scene, error) { +func (qb *sceneQueryBuilder) find(ctx context.Context, id int) (*models.Scene, error) { var ret models.Scene - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -248,92 +246,92 @@ func (qb *sceneQueryBuilder) find(id int) (*models.Scene, error) { return &ret, nil } -func (qb *sceneQueryBuilder) FindByChecksum(checksum string) (*models.Scene, error) { +func (qb *sceneQueryBuilder) FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error) { query := "SELECT * FROM scenes WHERE checksum = ? LIMIT 1" args := []interface{}{checksum} - return qb.queryScene(query, args) + return qb.queryScene(ctx, query, args) } -func (qb *sceneQueryBuilder) FindByOSHash(oshash string) (*models.Scene, error) { +func (qb *sceneQueryBuilder) FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error) { query := "SELECT * FROM scenes WHERE oshash = ? LIMIT 1" args := []interface{}{oshash} - return qb.queryScene(query, args) + return qb.queryScene(ctx, query, args) } -func (qb *sceneQueryBuilder) FindByPath(path string) (*models.Scene, error) { +func (qb *sceneQueryBuilder) FindByPath(ctx context.Context, path string) (*models.Scene, error) { query := selectAll(sceneTable) + "WHERE path = ? LIMIT 1" args := []interface{}{path} - return qb.queryScene(query, args) + return qb.queryScene(ctx, query, args) } -func (qb *sceneQueryBuilder) FindByPerformerID(performerID int) ([]*models.Scene, error) { +func (qb *sceneQueryBuilder) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Scene, error) { args := []interface{}{performerID} - return qb.queryScenes(scenesForPerformerQuery, args) + return qb.queryScenes(ctx, scenesForPerformerQuery, args) } -func (qb *sceneQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Scene, error) { +func (qb *sceneQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Scene, error) { args := []interface{}{galleryID} - return qb.queryScenes(scenesForGalleryQuery, args) + return qb.queryScenes(ctx, scenesForGalleryQuery, args) } -func (qb *sceneQueryBuilder) CountByPerformerID(performerID int) (int, error) { +func (qb *sceneQueryBuilder) CountByPerformerID(ctx context.Context, performerID int) (int, error) { args := []interface{}{performerID} - return qb.runCountQuery(qb.buildCountQuery(countScenesForPerformerQuery), args) + return qb.runCountQuery(ctx, qb.buildCountQuery(countScenesForPerformerQuery), args) } -func (qb *sceneQueryBuilder) FindByMovieID(movieID int) ([]*models.Scene, error) { +func (qb *sceneQueryBuilder) FindByMovieID(ctx context.Context, movieID int) ([]*models.Scene, error) { args := []interface{}{movieID} - return qb.queryScenes(scenesForMovieQuery, args) + return qb.queryScenes(ctx, scenesForMovieQuery, args) } -func (qb *sceneQueryBuilder) CountByMovieID(movieID int) (int, error) { +func (qb *sceneQueryBuilder) CountByMovieID(ctx context.Context, movieID int) (int, error) { args := []interface{}{movieID} - return qb.runCountQuery(qb.buildCountQuery(scenesForMovieQuery), args) + return qb.runCountQuery(ctx, qb.buildCountQuery(scenesForMovieQuery), args) } -func (qb *sceneQueryBuilder) Count() (int, error) { - return qb.runCountQuery(qb.buildCountQuery("SELECT scenes.id FROM scenes"), nil) +func (qb *sceneQueryBuilder) Count(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT scenes.id FROM scenes"), nil) } -func (qb *sceneQueryBuilder) Size() (float64, error) { - return qb.runSumQuery("SELECT SUM(cast(size as double)) as sum FROM scenes", nil) +func (qb *sceneQueryBuilder) Size(ctx context.Context) (float64, error) { + return qb.runSumQuery(ctx, "SELECT SUM(cast(size as double)) as sum FROM scenes", nil) } -func (qb *sceneQueryBuilder) Duration() (float64, error) { - return qb.runSumQuery("SELECT SUM(cast(duration as double)) as sum FROM scenes", nil) +func (qb *sceneQueryBuilder) Duration(ctx context.Context) (float64, error) { + return qb.runSumQuery(ctx, "SELECT SUM(cast(duration as double)) as sum FROM scenes", nil) } -func (qb *sceneQueryBuilder) CountByStudioID(studioID int) (int, error) { +func (qb *sceneQueryBuilder) CountByStudioID(ctx context.Context, studioID int) (int, error) { args := []interface{}{studioID} - return qb.runCountQuery(qb.buildCountQuery(scenesForStudioQuery), args) + return qb.runCountQuery(ctx, qb.buildCountQuery(scenesForStudioQuery), args) } -func (qb *sceneQueryBuilder) CountByTagID(tagID int) (int, error) { +func (qb *sceneQueryBuilder) CountByTagID(ctx context.Context, tagID int) (int, error) { args := []interface{}{tagID} - return qb.runCountQuery(qb.buildCountQuery(countScenesForTagQuery), args) + return qb.runCountQuery(ctx, qb.buildCountQuery(countScenesForTagQuery), args) } // CountMissingChecksum returns the number of scenes missing a checksum value. -func (qb *sceneQueryBuilder) CountMissingChecksum() (int, error) { - return qb.runCountQuery(qb.buildCountQuery(countScenesForMissingChecksumQuery), []interface{}{}) +func (qb *sceneQueryBuilder) CountMissingChecksum(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery(countScenesForMissingChecksumQuery), []interface{}{}) } // CountMissingOSHash returns the number of scenes missing an oshash value. -func (qb *sceneQueryBuilder) CountMissingOSHash() (int, error) { - return qb.runCountQuery(qb.buildCountQuery(countScenesForMissingOSHashQuery), []interface{}{}) +func (qb *sceneQueryBuilder) CountMissingOSHash(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery(countScenesForMissingOSHashQuery), []interface{}{}) } -func (qb *sceneQueryBuilder) Wall(q *string) ([]*models.Scene, error) { +func (qb *sceneQueryBuilder) Wall(ctx context.Context, q *string) ([]*models.Scene, error) { s := "" if q != nil { s = *q } query := selectAll(sceneTable) + "WHERE scenes.details LIKE '%" + s + "%' ORDER BY RANDOM() LIMIT 80" - return qb.queryScenes(query, nil) + return qb.queryScenes(ctx, query, nil) } -func (qb *sceneQueryBuilder) All() ([]*models.Scene, error) { - return qb.queryScenes(selectAll(sceneTable)+qb.getDefaultSceneSort(), nil) +func (qb *sceneQueryBuilder) All(ctx context.Context) ([]*models.Scene, error) { + return qb.queryScenes(ctx, selectAll(sceneTable)+qb.getDefaultSceneSort(), nil) } func illegalFilterCombination(type1, type2 string) error { @@ -371,61 +369,61 @@ func (qb *sceneQueryBuilder) validateFilter(sceneFilter *models.SceneFilterType) return nil } -func (qb *sceneQueryBuilder) makeFilter(sceneFilter *models.SceneFilterType) *filterBuilder { +func (qb *sceneQueryBuilder) makeFilter(ctx context.Context, sceneFilter *models.SceneFilterType) *filterBuilder { query := &filterBuilder{} if sceneFilter.And != nil { - query.and(qb.makeFilter(sceneFilter.And)) + query.and(qb.makeFilter(ctx, sceneFilter.And)) } if sceneFilter.Or != nil { - query.or(qb.makeFilter(sceneFilter.Or)) + query.or(qb.makeFilter(ctx, sceneFilter.Or)) } if sceneFilter.Not != nil { - query.not(qb.makeFilter(sceneFilter.Not)) + query.not(qb.makeFilter(ctx, sceneFilter.Not)) } - query.handleCriterion(stringCriterionHandler(sceneFilter.Path, "scenes.path")) - query.handleCriterion(stringCriterionHandler(sceneFilter.Title, "scenes.title")) - query.handleCriterion(stringCriterionHandler(sceneFilter.Details, "scenes.details")) - query.handleCriterion(stringCriterionHandler(sceneFilter.Oshash, "scenes.oshash")) - query.handleCriterion(stringCriterionHandler(sceneFilter.Checksum, "scenes.checksum")) - query.handleCriterion(phashCriterionHandler(sceneFilter.Phash)) - query.handleCriterion(intCriterionHandler(sceneFilter.Rating, "scenes.rating")) - query.handleCriterion(intCriterionHandler(sceneFilter.OCounter, "scenes.o_counter")) - query.handleCriterion(boolCriterionHandler(sceneFilter.Organized, "scenes.organized")) - query.handleCriterion(durationCriterionHandler(sceneFilter.Duration, "scenes.duration")) - query.handleCriterion(resolutionCriterionHandler(sceneFilter.Resolution, "scenes.height", "scenes.width")) - query.handleCriterion(hasMarkersCriterionHandler(sceneFilter.HasMarkers)) - query.handleCriterion(sceneIsMissingCriterionHandler(qb, sceneFilter.IsMissing)) - query.handleCriterion(stringCriterionHandler(sceneFilter.URL, "scenes.url")) + query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Path, "scenes.path")) + query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Title, "scenes.title")) + query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Details, "scenes.details")) + query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Oshash, "scenes.oshash")) + query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Checksum, "scenes.checksum")) + query.handleCriterion(ctx, phashCriterionHandler(sceneFilter.Phash)) + query.handleCriterion(ctx, intCriterionHandler(sceneFilter.Rating, "scenes.rating")) + query.handleCriterion(ctx, intCriterionHandler(sceneFilter.OCounter, "scenes.o_counter")) + query.handleCriterion(ctx, boolCriterionHandler(sceneFilter.Organized, "scenes.organized")) + query.handleCriterion(ctx, durationCriterionHandler(sceneFilter.Duration, "scenes.duration")) + query.handleCriterion(ctx, resolutionCriterionHandler(sceneFilter.Resolution, "scenes.height", "scenes.width")) + query.handleCriterion(ctx, hasMarkersCriterionHandler(sceneFilter.HasMarkers)) + query.handleCriterion(ctx, sceneIsMissingCriterionHandler(qb, sceneFilter.IsMissing)) + query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.URL, "scenes.url")) - query.handleCriterion(criterionHandlerFunc(func(f *filterBuilder) { + query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if sceneFilter.StashID != nil { qb.stashIDRepository().join(f, "scene_stash_ids", "scenes.id") - stringCriterionHandler(sceneFilter.StashID, "scene_stash_ids.stash_id")(f) + stringCriterionHandler(sceneFilter.StashID, "scene_stash_ids.stash_id")(ctx, f) } })) - query.handleCriterion(boolCriterionHandler(sceneFilter.Interactive, "scenes.interactive")) - query.handleCriterion(intCriterionHandler(sceneFilter.InteractiveSpeed, "scenes.interactive_speed")) + query.handleCriterion(ctx, boolCriterionHandler(sceneFilter.Interactive, "scenes.interactive")) + query.handleCriterion(ctx, intCriterionHandler(sceneFilter.InteractiveSpeed, "scenes.interactive_speed")) - query.handleCriterion(sceneCaptionCriterionHandler(qb, sceneFilter.Captions)) + query.handleCriterion(ctx, sceneCaptionCriterionHandler(qb, sceneFilter.Captions)) - query.handleCriterion(sceneTagsCriterionHandler(qb, sceneFilter.Tags)) - query.handleCriterion(sceneTagCountCriterionHandler(qb, sceneFilter.TagCount)) - query.handleCriterion(scenePerformersCriterionHandler(qb, sceneFilter.Performers)) - query.handleCriterion(scenePerformerCountCriterionHandler(qb, sceneFilter.PerformerCount)) - query.handleCriterion(sceneStudioCriterionHandler(qb, sceneFilter.Studios)) - query.handleCriterion(sceneMoviesCriterionHandler(qb, sceneFilter.Movies)) - query.handleCriterion(scenePerformerTagsCriterionHandler(qb, sceneFilter.PerformerTags)) - query.handleCriterion(scenePerformerFavoriteCriterionHandler(sceneFilter.PerformerFavorite)) - query.handleCriterion(scenePerformerAgeCriterionHandler(sceneFilter.PerformerAge)) - query.handleCriterion(scenePhashDuplicatedCriterionHandler(sceneFilter.Duplicated)) + query.handleCriterion(ctx, sceneTagsCriterionHandler(qb, sceneFilter.Tags)) + query.handleCriterion(ctx, sceneTagCountCriterionHandler(qb, sceneFilter.TagCount)) + query.handleCriterion(ctx, scenePerformersCriterionHandler(qb, sceneFilter.Performers)) + query.handleCriterion(ctx, scenePerformerCountCriterionHandler(qb, sceneFilter.PerformerCount)) + query.handleCriterion(ctx, sceneStudioCriterionHandler(qb, sceneFilter.Studios)) + query.handleCriterion(ctx, sceneMoviesCriterionHandler(qb, sceneFilter.Movies)) + query.handleCriterion(ctx, scenePerformerTagsCriterionHandler(qb, sceneFilter.PerformerTags)) + query.handleCriterion(ctx, scenePerformerFavoriteCriterionHandler(sceneFilter.PerformerFavorite)) + query.handleCriterion(ctx, scenePerformerAgeCriterionHandler(sceneFilter.PerformerAge)) + query.handleCriterion(ctx, scenePhashDuplicatedCriterionHandler(sceneFilter.Duplicated)) return query } -func (qb *sceneQueryBuilder) Query(options models.SceneQueryOptions) (*models.SceneQueryResult, error) { +func (qb *sceneQueryBuilder) Query(ctx context.Context, options models.SceneQueryOptions) (*models.SceneQueryResult, error) { sceneFilter := options.SceneFilter findFilter := options.FindFilter @@ -448,19 +446,19 @@ func (qb *sceneQueryBuilder) Query(options models.SceneQueryOptions) (*models.Sc if err := qb.validateFilter(sceneFilter); err != nil { return nil, err } - filter := qb.makeFilter(sceneFilter) + filter := qb.makeFilter(ctx, sceneFilter) query.addFilter(filter) qb.setSceneSort(&query, findFilter) query.sortAndPagination += getPagination(findFilter) - result, err := qb.queryGroupedFields(options, query) + result, err := qb.queryGroupedFields(ctx, options, query) if err != nil { return nil, fmt.Errorf("error querying aggregate fields: %w", err) } - idsResult, err := query.findIDs() + idsResult, err := query.findIDs(ctx) if err != nil { return nil, fmt.Errorf("error finding IDs: %w", err) } @@ -469,7 +467,7 @@ func (qb *sceneQueryBuilder) Query(options models.SceneQueryOptions) (*models.Sc return result, nil } -func (qb *sceneQueryBuilder) queryGroupedFields(options models.SceneQueryOptions, query queryBuilder) (*models.SceneQueryResult, error) { +func (qb *sceneQueryBuilder) queryGroupedFields(ctx context.Context, options models.SceneQueryOptions, query queryBuilder) (*models.SceneQueryResult, error) { if !options.Count && !options.TotalDuration && !options.TotalSize { // nothing to do - return empty result return models.NewSceneQueryResult(qb), nil @@ -499,7 +497,7 @@ func (qb *sceneQueryBuilder) queryGroupedFields(options models.SceneQueryOptions Duration float64 Size float64 }{} - if err := qb.repository.queryStruct(aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil { + if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil { return nil, err } @@ -511,7 +509,7 @@ func (qb *sceneQueryBuilder) queryGroupedFields(options models.SceneQueryOptions } func phashCriterionHandler(phashFilter *models.StringCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if phashFilter != nil { // convert value to int from hex // ignore errors @@ -534,7 +532,7 @@ func phashCriterionHandler(phashFilter *models.StringCriterionInput) criterionHa } func scenePhashDuplicatedCriterionHandler(duplicatedFilter *models.PHashDuplicationCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { // TODO: Wishlist item: Implement Distance matching if duplicatedFilter != nil { var v string @@ -549,7 +547,7 @@ func scenePhashDuplicatedCriterionHandler(duplicatedFilter *models.PHashDuplicat } func durationCriterionHandler(durationFilter *models.IntCriterionInput, column string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if durationFilter != nil { clause, args := getIntCriterionWhereClause("cast("+column+" as int)", *durationFilter) f.addWhere(clause, args...) @@ -558,7 +556,7 @@ func durationCriterionHandler(durationFilter *models.IntCriterionInput, column s } func resolutionCriterionHandler(resolution *models.ResolutionCriterionInput, heightColumn string, widthColumn string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if resolution != nil && resolution.Value.IsValid() { min := resolution.Value.GetMinResolution() max := resolution.Value.GetMaxResolution() @@ -580,7 +578,7 @@ func resolutionCriterionHandler(resolution *models.ResolutionCriterionInput, hei } func hasMarkersCriterionHandler(hasMarkers *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if hasMarkers != nil { f.addLeftJoin("scene_markers", "", "scene_markers.scene_id = scenes.id") if *hasMarkers == "true" { @@ -593,7 +591,7 @@ func hasMarkersCriterionHandler(hasMarkers *string) criterionHandlerFunc { } func sceneIsMissingCriterionHandler(qb *sceneQueryBuilder, isMissing *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if isMissing != nil && *isMissing != "" { switch *isMissing { case "galleries": @@ -699,7 +697,7 @@ func scenePerformerCountCriterionHandler(qb *sceneQueryBuilder, performerCount * } func scenePerformerFavoriteCriterionHandler(performerfavorite *bool) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if performerfavorite != nil { f.addLeftJoin("performers_scenes", "", "scenes.id = performers_scenes.scene_id") @@ -719,7 +717,7 @@ GROUP BY performers_scenes.scene_id HAVING SUM(performers.favorite) = 0)`, "nofa } func scenePerformerAgeCriterionHandler(performerAge *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if performerAge != nil { f.addInnerJoin("performers_scenes", "", "scenes.id = performers_scenes.scene_id") f.addInnerJoin("performers", "", "performers_scenes.performer_id = performers.id") @@ -759,7 +757,7 @@ func sceneMoviesCriterionHandler(qb *sceneQueryBuilder, movies *models.MultiCrit } func scenePerformerTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if tags != nil { if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { var notClause string @@ -778,7 +776,7 @@ func scenePerformerTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.Hier return } - valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) f.addWith(`performer_tags AS ( SELECT ps.scene_id, t.column1 AS root_tag_id FROM performers_scenes ps @@ -816,17 +814,17 @@ func (qb *sceneQueryBuilder) setSceneSort(query *queryBuilder, findFilter *model } } -func (qb *sceneQueryBuilder) queryScene(query string, args []interface{}) (*models.Scene, error) { - results, err := qb.queryScenes(query, args) +func (qb *sceneQueryBuilder) queryScene(ctx context.Context, query string, args []interface{}) (*models.Scene, error) { + results, err := qb.queryScenes(ctx, query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *sceneQueryBuilder) queryScenes(query string, args []interface{}) ([]*models.Scene, error) { +func (qb *sceneQueryBuilder) queryScenes(ctx context.Context, query string, args []interface{}) ([]*models.Scene, error) { var ret models.Scenes - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } @@ -844,16 +842,16 @@ func (qb *sceneQueryBuilder) imageRepository() *imageRepository { } } -func (qb *sceneQueryBuilder) GetCover(sceneID int) ([]byte, error) { - return qb.imageRepository().get(sceneID) +func (qb *sceneQueryBuilder) GetCover(ctx context.Context, sceneID int) ([]byte, error) { + return qb.imageRepository().get(ctx, sceneID) } -func (qb *sceneQueryBuilder) UpdateCover(sceneID int, image []byte) error { - return qb.imageRepository().replace(sceneID, image) +func (qb *sceneQueryBuilder) UpdateCover(ctx context.Context, sceneID int, image []byte) error { + return qb.imageRepository().replace(ctx, sceneID, image) } -func (qb *sceneQueryBuilder) DestroyCover(sceneID int) error { - return qb.imageRepository().destroy([]int{sceneID}) +func (qb *sceneQueryBuilder) DestroyCover(ctx context.Context, sceneID int) error { + return qb.imageRepository().destroy(ctx, []int{sceneID}) } func (qb *sceneQueryBuilder) moviesRepository() *repository { @@ -864,8 +862,8 @@ func (qb *sceneQueryBuilder) moviesRepository() *repository { } } -func (qb *sceneQueryBuilder) GetMovies(id int) (ret []models.MoviesScenes, err error) { - if err := qb.moviesRepository().getAll(id, func(rows *sqlx.Rows) error { +func (qb *sceneQueryBuilder) GetMovies(ctx context.Context, id int) (ret []models.MoviesScenes, err error) { + if err := qb.moviesRepository().getAll(ctx, id, func(rows *sqlx.Rows) error { var ms models.MoviesScenes if err := rows.StructScan(&ms); err != nil { return err @@ -880,16 +878,16 @@ func (qb *sceneQueryBuilder) GetMovies(id int) (ret []models.MoviesScenes, err e return ret, nil } -func (qb *sceneQueryBuilder) UpdateMovies(sceneID int, movies []models.MoviesScenes) error { +func (qb *sceneQueryBuilder) UpdateMovies(ctx context.Context, sceneID int, movies []models.MoviesScenes) error { // destroy existing joins r := qb.moviesRepository() - if err := r.destroy([]int{sceneID}); err != nil { + if err := r.destroy(ctx, []int{sceneID}); err != nil { return err } for _, m := range movies { m.SceneID = sceneID - if _, err := r.insert(m); err != nil { + if _, err := r.insert(ctx, m); err != nil { return err } } @@ -908,13 +906,13 @@ func (qb *sceneQueryBuilder) performersRepository() *joinRepository { } } -func (qb *sceneQueryBuilder) GetPerformerIDs(id int) ([]int, error) { - return qb.performersRepository().getIDs(id) +func (qb *sceneQueryBuilder) GetPerformerIDs(ctx context.Context, id int) ([]int, error) { + return qb.performersRepository().getIDs(ctx, id) } -func (qb *sceneQueryBuilder) UpdatePerformers(id int, performerIDs []int) error { +func (qb *sceneQueryBuilder) UpdatePerformers(ctx context.Context, id int, performerIDs []int) error { // Delete the existing joins and then create new ones - return qb.performersRepository().replace(id, performerIDs) + return qb.performersRepository().replace(ctx, id, performerIDs) } func (qb *sceneQueryBuilder) tagsRepository() *joinRepository { @@ -928,13 +926,13 @@ func (qb *sceneQueryBuilder) tagsRepository() *joinRepository { } } -func (qb *sceneQueryBuilder) GetTagIDs(id int) ([]int, error) { - return qb.tagsRepository().getIDs(id) +func (qb *sceneQueryBuilder) GetTagIDs(ctx context.Context, id int) ([]int, error) { + return qb.tagsRepository().getIDs(ctx, id) } -func (qb *sceneQueryBuilder) UpdateTags(id int, tagIDs []int) error { +func (qb *sceneQueryBuilder) UpdateTags(ctx context.Context, id int, tagIDs []int) error { // Delete the existing joins and then create new ones - return qb.tagsRepository().replace(id, tagIDs) + return qb.tagsRepository().replace(ctx, id, tagIDs) } func (qb *sceneQueryBuilder) galleriesRepository() *joinRepository { @@ -948,13 +946,13 @@ func (qb *sceneQueryBuilder) galleriesRepository() *joinRepository { } } -func (qb *sceneQueryBuilder) GetGalleryIDs(id int) ([]int, error) { - return qb.galleriesRepository().getIDs(id) +func (qb *sceneQueryBuilder) GetGalleryIDs(ctx context.Context, id int) ([]int, error) { + return qb.galleriesRepository().getIDs(ctx, id) } -func (qb *sceneQueryBuilder) UpdateGalleries(id int, galleryIDs []int) error { +func (qb *sceneQueryBuilder) UpdateGalleries(ctx context.Context, id int, galleryIDs []int) error { // Delete the existing joins and then create new ones - return qb.galleriesRepository().replace(id, galleryIDs) + return qb.galleriesRepository().replace(ctx, id, galleryIDs) } func (qb *sceneQueryBuilder) stashIDRepository() *stashIDRepository { @@ -967,19 +965,19 @@ func (qb *sceneQueryBuilder) stashIDRepository() *stashIDRepository { } } -func (qb *sceneQueryBuilder) GetStashIDs(sceneID int) ([]*models.StashID, error) { - return qb.stashIDRepository().get(sceneID) +func (qb *sceneQueryBuilder) GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error) { + return qb.stashIDRepository().get(ctx, sceneID) } -func (qb *sceneQueryBuilder) UpdateStashIDs(sceneID int, stashIDs []models.StashID) error { - return qb.stashIDRepository().replace(sceneID, stashIDs) +func (qb *sceneQueryBuilder) UpdateStashIDs(ctx context.Context, sceneID int, stashIDs []models.StashID) error { + return qb.stashIDRepository().replace(ctx, sceneID, stashIDs) } -func (qb *sceneQueryBuilder) FindDuplicates(distance int) ([][]*models.Scene, error) { +func (qb *sceneQueryBuilder) FindDuplicates(ctx context.Context, distance int) ([][]*models.Scene, error) { var dupeIds [][]int if distance == 0 { var ids []string - if err := qb.tx.Select(&ids, findExactDuplicateQuery); err != nil { + if err := qb.tx.Select(ctx, &ids, findExactDuplicateQuery); err != nil { return nil, err } @@ -996,7 +994,7 @@ func (qb *sceneQueryBuilder) FindDuplicates(distance int) ([][]*models.Scene, er } else { var hashes []*utils.Phash - if err := qb.queryFunc(findAllPhashesQuery, nil, false, func(rows *sqlx.Rows) error { + if err := qb.queryFunc(ctx, findAllPhashesQuery, nil, false, func(rows *sqlx.Rows) error { phash := utils.Phash{ Bucket: -1, } @@ -1015,7 +1013,7 @@ func (qb *sceneQueryBuilder) FindDuplicates(distance int) ([][]*models.Scene, er var duplicates [][]*models.Scene for _, sceneIds := range dupeIds { - if scenes, err := qb.FindMany(sceneIds); err == nil { + if scenes, err := qb.FindMany(ctx, sceneIds); err == nil { duplicates = append(duplicates, scenes) } } diff --git a/pkg/sqlite/scene_marker.go b/pkg/sqlite/scene_marker.go index c79c1dc16..c8091d2a5 100644 --- a/pkg/sqlite/scene_marker.go +++ b/pkg/sqlite/scene_marker.go @@ -1,11 +1,11 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/models" ) @@ -22,57 +22,54 @@ type sceneMarkerQueryBuilder struct { repository } -func NewSceneMarkerReaderWriter(tx dbi) *sceneMarkerQueryBuilder { - return &sceneMarkerQueryBuilder{ - repository{ - tx: tx, - tableName: sceneMarkerTable, - idColumn: idColumn, - }, - } +var SceneMarkerReaderWriter = &sceneMarkerQueryBuilder{ + repository{ + tableName: sceneMarkerTable, + idColumn: idColumn, + }, } -func (qb *sceneMarkerQueryBuilder) Create(newObject models.SceneMarker) (*models.SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) Create(ctx context.Context, newObject models.SceneMarker) (*models.SceneMarker, error) { var ret models.SceneMarker - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *sceneMarkerQueryBuilder) Update(updatedObject models.SceneMarker) (*models.SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) Update(ctx context.Context, updatedObject models.SceneMarker) (*models.SceneMarker, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } var ret models.SceneMarker - if err := qb.get(updatedObject.ID, &ret); err != nil { + if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *sceneMarkerQueryBuilder) Destroy(id int) error { - return qb.destroyExisting([]int{id}) +func (qb *sceneMarkerQueryBuilder) Destroy(ctx context.Context, id int) error { + return qb.destroyExisting(ctx, []int{id}) } -func (qb *sceneMarkerQueryBuilder) Find(id int) (*models.SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) Find(ctx context.Context, id int) (*models.SceneMarker, error) { query := "SELECT * FROM scene_markers WHERE id = ? LIMIT 1" args := []interface{}{id} - results, err := qb.querySceneMarkers(query, args) + results, err := qb.querySceneMarkers(ctx, query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *sceneMarkerQueryBuilder) FindMany(ids []int) ([]*models.SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.SceneMarker, error) { var markers []*models.SceneMarker for _, id := range ids { - marker, err := qb.Find(id) + marker, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -87,7 +84,7 @@ func (qb *sceneMarkerQueryBuilder) FindMany(ids []int) ([]*models.SceneMarker, e return markers, nil } -func (qb *sceneMarkerQueryBuilder) FindBySceneID(sceneID int) ([]*models.SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error) { query := ` SELECT scene_markers.* FROM scene_markers WHERE scene_markers.scene_id = ? @@ -95,15 +92,15 @@ func (qb *sceneMarkerQueryBuilder) FindBySceneID(sceneID int) ([]*models.SceneMa ORDER BY scene_markers.seconds ASC ` args := []interface{}{sceneID} - return qb.querySceneMarkers(query, args) + return qb.querySceneMarkers(ctx, query, args) } -func (qb *sceneMarkerQueryBuilder) CountByTagID(tagID int) (int, error) { +func (qb *sceneMarkerQueryBuilder) CountByTagID(ctx context.Context, tagID int) (int, error) { args := []interface{}{tagID, tagID} - return qb.runCountQuery(qb.buildCountQuery(countSceneMarkersForTagQuery), args) + return qb.runCountQuery(ctx, qb.buildCountQuery(countSceneMarkersForTagQuery), args) } -func (qb *sceneMarkerQueryBuilder) GetMarkerStrings(q *string, sort *string) ([]*models.MarkerStringsResultType, error) { +func (qb *sceneMarkerQueryBuilder) GetMarkerStrings(ctx context.Context, q *string, sort *string) ([]*models.MarkerStringsResultType, error) { query := "SELECT count(*) as `count`, scene_markers.id as id, scene_markers.title as title FROM scene_markers" if q != nil { query += " WHERE title LIKE '%" + *q + "%'" @@ -115,30 +112,30 @@ func (qb *sceneMarkerQueryBuilder) GetMarkerStrings(q *string, sort *string) ([] query += " ORDER BY title ASC" } var args []interface{} - return qb.queryMarkerStringsResultType(query, args) + return qb.queryMarkerStringsResultType(ctx, query, args) } -func (qb *sceneMarkerQueryBuilder) Wall(q *string) ([]*models.SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) Wall(ctx context.Context, q *string) ([]*models.SceneMarker, error) { s := "" if q != nil { s = *q } query := "SELECT scene_markers.* FROM scene_markers WHERE scene_markers.title LIKE '%" + s + "%' ORDER BY RANDOM() LIMIT 80" - return qb.querySceneMarkers(query, nil) + return qb.querySceneMarkers(ctx, query, nil) } -func (qb *sceneMarkerQueryBuilder) makeFilter(sceneMarkerFilter *models.SceneMarkerFilterType) *filterBuilder { +func (qb *sceneMarkerQueryBuilder) makeFilter(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType) *filterBuilder { query := &filterBuilder{} - query.handleCriterion(sceneMarkerTagIDCriterionHandler(qb, sceneMarkerFilter.TagID)) - query.handleCriterion(sceneMarkerTagsCriterionHandler(qb, sceneMarkerFilter.Tags)) - query.handleCriterion(sceneMarkerSceneTagsCriterionHandler(qb, sceneMarkerFilter.SceneTags)) - query.handleCriterion(sceneMarkerPerformersCriterionHandler(qb, sceneMarkerFilter.Performers)) + query.handleCriterion(ctx, sceneMarkerTagIDCriterionHandler(qb, sceneMarkerFilter.TagID)) + query.handleCriterion(ctx, sceneMarkerTagsCriterionHandler(qb, sceneMarkerFilter.Tags)) + query.handleCriterion(ctx, sceneMarkerSceneTagsCriterionHandler(qb, sceneMarkerFilter.SceneTags)) + query.handleCriterion(ctx, sceneMarkerPerformersCriterionHandler(qb, sceneMarkerFilter.Performers)) return query } -func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) { +func (qb *sceneMarkerQueryBuilder) Query(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) { if sceneMarkerFilter == nil { sceneMarkerFilter = &models.SceneMarkerFilterType{} } @@ -154,19 +151,19 @@ func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFi query.parseQueryString(searchColumns, *q) } - filter := qb.makeFilter(sceneMarkerFilter) + filter := qb.makeFilter(ctx, sceneMarkerFilter) query.addFilter(filter) query.sortAndPagination = qb.getSceneMarkerSort(&query, findFilter) + getPagination(findFilter) - idsResult, countResult, err := query.executeFind() + idsResult, countResult, err := query.executeFind(ctx) if err != nil { return nil, 0, err } var sceneMarkers []*models.SceneMarker for _, id := range idsResult { - sceneMarker, err := qb.Find(id) + sceneMarker, err := qb.Find(ctx, id) if err != nil { return nil, 0, err } @@ -178,7 +175,7 @@ func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFi } func sceneMarkerTagIDCriterionHandler(qb *sceneMarkerQueryBuilder, tagID *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if tagID != nil { f.addLeftJoin("scene_markers_tags", "", "scene_markers_tags.scene_marker_id = scene_markers.id") @@ -188,7 +185,7 @@ func sceneMarkerTagIDCriterionHandler(qb *sceneMarkerQueryBuilder, tagID *string } func sceneMarkerTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if tags != nil { if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { var notClause string @@ -205,7 +202,7 @@ func sceneMarkerTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.H if len(tags.Value) == 0 { return } - valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) f.addWith(`marker_tags AS ( SELECT mt.scene_marker_id, t.column1 AS root_tag_id FROM scene_markers_tags mt @@ -223,7 +220,7 @@ INNER JOIN (` + valuesClause + `) t ON t.column2 = m.primary_tag_id } func sceneMarkerSceneTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if tags != nil { if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { var notClause string @@ -241,7 +238,7 @@ func sceneMarkerSceneTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *mod return } - valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) f.addWith(`scene_tags AS ( SELECT st.scene_id, t.column1 AS root_tag_id FROM scenes_tags st @@ -269,10 +266,10 @@ func sceneMarkerPerformersCriterionHandler(qb *sceneMarkerQueryBuilder, performe } handler := h.handler(performers) - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { // Make sure scenes is included, otherwise excludes filter fails f.addLeftJoin(sceneTable, "", "scenes.id = scene_markers.scene_id") - handler(f) + handler(ctx, f) } } @@ -289,17 +286,17 @@ func (qb *sceneMarkerQueryBuilder) getSceneMarkerSort(query *queryBuilder, findF return getSort(sort, direction, tableName) } -func (qb *sceneMarkerQueryBuilder) querySceneMarkers(query string, args []interface{}) ([]*models.SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) querySceneMarkers(ctx context.Context, query string, args []interface{}) ([]*models.SceneMarker, error) { var ret models.SceneMarkers - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } return []*models.SceneMarker(ret), nil } -func (qb *sceneMarkerQueryBuilder) queryMarkerStringsResultType(query string, args []interface{}) ([]*models.MarkerStringsResultType, error) { - rows, err := database.DB.Queryx(query, args...) +func (qb *sceneMarkerQueryBuilder) queryMarkerStringsResultType(ctx context.Context, query string, args []interface{}) ([]*models.MarkerStringsResultType, error) { + rows, err := qb.tx.Queryx(ctx, query, args...) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, err } @@ -332,11 +329,11 @@ func (qb *sceneMarkerQueryBuilder) tagsRepository() *joinRepository { } } -func (qb *sceneMarkerQueryBuilder) GetTagIDs(id int) ([]int, error) { - return qb.tagsRepository().getIDs(id) +func (qb *sceneMarkerQueryBuilder) GetTagIDs(ctx context.Context, id int) ([]int, error) { + return qb.tagsRepository().getIDs(ctx, id) } -func (qb *sceneMarkerQueryBuilder) UpdateTags(id int, tagIDs []int) error { +func (qb *sceneMarkerQueryBuilder) UpdateTags(ctx context.Context, id int, tagIDs []int) error { // Delete the existing joins and then create new ones - return qb.tagsRepository().replace(id, tagIDs) + return qb.tagsRepository().replace(ctx, id, tagIDs) } diff --git a/pkg/sqlite/scene_marker_test.go b/pkg/sqlite/scene_marker_test.go index 2fa0d7501..c0d29162e 100644 --- a/pkg/sqlite/scene_marker_test.go +++ b/pkg/sqlite/scene_marker_test.go @@ -4,18 +4,20 @@ package sqlite_test import ( + "context" "testing" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" "github.com/stretchr/testify/assert" ) func TestMarkerFindBySceneID(t *testing.T) { - withTxn(func(r models.Repository) error { - mqb := r.SceneMarker() + withTxn(func(ctx context.Context) error { + mqb := sqlite.SceneMarkerReaderWriter sceneID := sceneIDs[sceneIdxWithMarkers] - markers, err := mqb.FindBySceneID(sceneID) + markers, err := mqb.FindBySceneID(ctx, sceneID) if err != nil { t.Errorf("Error finding markers: %s", err.Error()) @@ -26,7 +28,7 @@ func TestMarkerFindBySceneID(t *testing.T) { assert.Equal(t, sceneIDs[sceneIdxWithMarkers], int(marker.SceneID.Int64)) } - markers, err = mqb.FindBySceneID(0) + markers, err = mqb.FindBySceneID(ctx, 0) if err != nil { t.Errorf("Error finding marker: %s", err.Error()) @@ -39,10 +41,10 @@ func TestMarkerFindBySceneID(t *testing.T) { } func TestMarkerCountByTagID(t *testing.T) { - withTxn(func(r models.Repository) error { - mqb := r.SceneMarker() + withTxn(func(ctx context.Context) error { + mqb := sqlite.SceneMarkerReaderWriter - markerCount, err := mqb.CountByTagID(tagIDs[tagIdxWithPrimaryMarkers]) + markerCount, err := mqb.CountByTagID(ctx, tagIDs[tagIdxWithPrimaryMarkers]) if err != nil { t.Errorf("error calling CountByTagID: %s", err.Error()) @@ -50,7 +52,7 @@ func TestMarkerCountByTagID(t *testing.T) { assert.Equal(t, 3, markerCount) - markerCount, err = mqb.CountByTagID(tagIDs[tagIdxWithMarkers]) + markerCount, err = mqb.CountByTagID(ctx, tagIDs[tagIdxWithMarkers]) if err != nil { t.Errorf("error calling CountByTagID: %s", err.Error()) @@ -58,7 +60,7 @@ func TestMarkerCountByTagID(t *testing.T) { assert.Equal(t, 1, markerCount) - markerCount, err = mqb.CountByTagID(0) + markerCount, err = mqb.CountByTagID(ctx, 0) if err != nil { t.Errorf("error calling CountByTagID: %s", err.Error()) @@ -71,9 +73,9 @@ func TestMarkerCountByTagID(t *testing.T) { } func TestMarkerQuerySortBySceneUpdated(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { sort := "scenes_updated_at" - _, _, err := r.SceneMarker().Query(nil, &models.FindFilterType{ + _, _, err := sqlite.SceneMarkerReaderWriter.Query(ctx, nil, &models.FindFilterType{ Sort: &sort, }) @@ -92,9 +94,9 @@ func TestMarkerQueryTags(t *testing.T) { findFilter *models.FindFilterType } - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { testTags := func(m *models.SceneMarker, markerFilter *models.SceneMarkerFilterType) { - tagIDs, err := r.SceneMarker().GetTagIDs(m.ID) + tagIDs, err := sqlite.SceneMarkerReaderWriter.GetTagIDs(ctx, m.ID) if err != nil { t.Errorf("error getting marker tag ids: %v", err) } @@ -129,7 +131,7 @@ func TestMarkerQueryTags(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - markers := queryMarkers(t, r.SceneMarker(), tc.markerFilter, tc.findFilter) + markers := queryMarkers(ctx, t, sqlite.SceneMarkerReaderWriter, tc.markerFilter, tc.findFilter) assert.Greater(t, len(markers), 0) for _, m := range markers { testTags(m, tc.markerFilter) @@ -148,9 +150,9 @@ func TestMarkerQuerySceneTags(t *testing.T) { findFilter *models.FindFilterType } - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { testTags := func(m *models.SceneMarker, markerFilter *models.SceneMarkerFilterType) { - tagIDs, err := r.Scene().GetTagIDs(int(m.SceneID.Int64)) + tagIDs, err := sqlite.SceneReaderWriter.GetTagIDs(ctx, int(m.SceneID.Int64)) if err != nil { t.Errorf("error getting marker tag ids: %v", err) } @@ -185,7 +187,7 @@ func TestMarkerQuerySceneTags(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - markers := queryMarkers(t, r.SceneMarker(), tc.markerFilter, tc.findFilter) + markers := queryMarkers(ctx, t, sqlite.SceneMarkerReaderWriter, tc.markerFilter, tc.findFilter) assert.Greater(t, len(markers), 0) for _, m := range markers { testTags(m, tc.markerFilter) @@ -197,9 +199,9 @@ func TestMarkerQuerySceneTags(t *testing.T) { }) } -func queryMarkers(t *testing.T, sqb models.SceneMarkerReader, markerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) []*models.SceneMarker { +func queryMarkers(ctx context.Context, t *testing.T, sqb models.SceneMarkerReader, markerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) []*models.SceneMarker { t.Helper() - result, _, err := sqb.Query(markerFilter, findFilter) + result, _, err := sqb.Query(ctx, markerFilter, findFilter) if err != nil { t.Errorf("Error querying markers: %v", err) } diff --git a/pkg/sqlite/scene_test.go b/pkg/sqlite/scene_test.go index dc70d7637..da88a0bdc 100644 --- a/pkg/sqlite/scene_test.go +++ b/pkg/sqlite/scene_test.go @@ -4,6 +4,7 @@ package sqlite_test import ( + "context" "database/sql" "fmt" "math" @@ -15,16 +16,17 @@ import ( "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" ) func TestSceneFind(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { // assume that the first scene is sceneWithGalleryPath - sqb := r.Scene() + sqb := sqlite.SceneReaderWriter const sceneIdx = 0 sceneID := sceneIDs[sceneIdx] - scene, err := sqb.Find(sceneID) + scene, err := sqb.Find(ctx, sceneID) if err != nil { t.Errorf("Error finding scene: %s", err.Error()) @@ -33,7 +35,7 @@ func TestSceneFind(t *testing.T) { assert.Equal(t, getSceneStringValue(sceneIdx, "Path"), scene.Path) sceneID = 0 - scene, err = sqb.Find(sceneID) + scene, err = sqb.Find(ctx, sceneID) if err != nil { t.Errorf("Error finding scene: %s", err.Error()) @@ -46,12 +48,12 @@ func TestSceneFind(t *testing.T) { } func TestSceneFindByPath(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter const sceneIdx = 1 scenePath := getSceneStringValue(sceneIdx, "Path") - scene, err := sqb.FindByPath(scenePath) + scene, err := sqb.FindByPath(ctx, scenePath) if err != nil { t.Errorf("Error finding scene: %s", err.Error()) @@ -61,7 +63,7 @@ func TestSceneFindByPath(t *testing.T) { assert.Equal(t, scenePath, scene.Path) scenePath = "not exist" - scene, err = sqb.FindByPath(scenePath) + scene, err = sqb.FindByPath(ctx, scenePath) if err != nil { t.Errorf("Error finding scene: %s", err.Error()) @@ -74,9 +76,9 @@ func TestSceneFindByPath(t *testing.T) { } func TestSceneCountByPerformerID(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() - count, err := sqb.CountByPerformerID(performerIDs[performerIdxWithScene]) + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter + count, err := sqb.CountByPerformerID(ctx, performerIDs[performerIdxWithScene]) if err != nil { t.Errorf("Error counting scenes: %s", err.Error()) @@ -84,7 +86,7 @@ func TestSceneCountByPerformerID(t *testing.T) { assert.Equal(t, 1, count) - count, err = sqb.CountByPerformerID(0) + count, err = sqb.CountByPerformerID(ctx, 0) if err != nil { t.Errorf("Error counting scenes: %s", err.Error()) @@ -97,12 +99,12 @@ func TestSceneCountByPerformerID(t *testing.T) { } func TestSceneWall(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter const sceneIdx = 2 wallQuery := getSceneStringValue(sceneIdx, "Details") - scenes, err := sqb.Wall(&wallQuery) + scenes, err := sqb.Wall(ctx, &wallQuery) if err != nil { t.Errorf("Error finding scenes: %s", err.Error()) @@ -114,7 +116,7 @@ func TestSceneWall(t *testing.T) { assert.Equal(t, getSceneStringValue(sceneIdx, "Path"), scene.Path) wallQuery = "not exist" - scenes, err = sqb.Wall(&wallQuery) + scenes, err = sqb.Wall(ctx, &wallQuery) if err != nil { t.Errorf("Error finding scene: %s", err.Error()) @@ -131,18 +133,18 @@ func TestSceneQueryQ(t *testing.T) { q := getSceneStringValue(sceneIdx, titleField) - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - sceneQueryQ(t, sqb, q, sceneIdx) + sceneQueryQ(ctx, t, sqb, q, sceneIdx) return nil }) } -func queryScene(t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) []*models.Scene { +func queryScene(ctx context.Context, t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) []*models.Scene { t.Helper() - result, err := sqb.Query(models.SceneQueryOptions{ + result, err := sqb.Query(ctx, models.SceneQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: findFilter, }, @@ -152,7 +154,7 @@ func queryScene(t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneF t.Errorf("Error querying scene: %v", err) } - scenes, err := result.Resolve() + scenes, err := result.Resolve(ctx) if err != nil { t.Errorf("Error resolving scenes: %v", err) } @@ -160,11 +162,11 @@ func queryScene(t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneF return scenes } -func sceneQueryQ(t *testing.T, sqb models.SceneReader, q string, expectedSceneIdx int) { +func sceneQueryQ(ctx context.Context, t *testing.T, sqb models.SceneReader, q string, expectedSceneIdx int) { filter := models.FindFilterType{ Q: &q, } - scenes := queryScene(t, sqb, nil, &filter) + scenes := queryScene(ctx, t, sqb, nil, &filter) assert.Len(t, scenes, 1) scene := scenes[0] @@ -172,7 +174,7 @@ func sceneQueryQ(t *testing.T, sqb models.SceneReader, q string, expectedSceneId // no Q should return all results filter.Q = nil - scenes = queryScene(t, sqb, nil, &filter) + scenes = queryScene(ctx, t, sqb, nil, &filter) assert.Len(t, scenes, totalScenes) } @@ -257,10 +259,10 @@ func TestSceneQueryPathOr(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 2) assert.Equal(t, scene1Path, scenes[0].Path) @@ -288,10 +290,10 @@ func TestSceneQueryPathAndRating(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 1) assert.Equal(t, scenePath, scenes[0].Path) @@ -323,10 +325,10 @@ func TestSceneQueryPathNotRating(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) for _, scene := range scenes { verifyString(t, scene.Path, pathCriterion) @@ -354,24 +356,24 @@ func TestSceneIllegalQuery(t *testing.T) { Or: &subFilter, } - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter queryOptions := models.SceneQueryOptions{ SceneFilter: sceneFilter, } - _, err := sqb.Query(queryOptions) + _, err := sqb.Query(ctx, queryOptions) assert.NotNil(err) sceneFilter.Or = nil sceneFilter.Not = &subFilter - _, err = sqb.Query(queryOptions) + _, err = sqb.Query(ctx, queryOptions) assert.NotNil(err) sceneFilter.And = nil sceneFilter.Or = &subFilter - _, err = sqb.Query(queryOptions) + _, err = sqb.Query(ctx, queryOptions) assert.NotNil(err) return nil @@ -379,11 +381,11 @@ func TestSceneIllegalQuery(t *testing.T) { } func verifySceneQuery(t *testing.T, filter models.SceneFilterType, verifyFn func(s *models.Scene)) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { t.Helper() - sqb := r.Scene() + sqb := sqlite.SceneReaderWriter - scenes := queryScene(t, sqb, &filter, nil) + scenes := queryScene(ctx, t, sqb, &filter, nil) // assume it should find at least one assert.Greater(t, len(scenes), 0) @@ -397,13 +399,13 @@ func verifySceneQuery(t *testing.T, filter models.SceneFilterType, verifyFn func } func verifyScenesPath(t *testing.T, pathCriterion models.StringCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter sceneFilter := models.SceneFilterType{ Path: &pathCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) for _, scene := range scenes { verifyString(t, scene.Path, pathCriterion) @@ -489,13 +491,13 @@ func TestSceneQueryRating(t *testing.T) { } func verifyScenesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter sceneFilter := models.SceneFilterType{ Rating: &ratingCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) for _, scene := range scenes { verifyInt64(t, scene.Rating, ratingCriterion) @@ -548,13 +550,13 @@ func TestSceneQueryOCounter(t *testing.T) { } func verifyScenesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter sceneFilter := models.SceneFilterType{ OCounter: &oCounterCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) for _, scene := range scenes { verifyInt(t, scene.OCounter, oCounterCriterion) @@ -607,13 +609,13 @@ func TestSceneQueryDuration(t *testing.T) { } func verifyScenesDuration(t *testing.T, durationCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter sceneFilter := models.SceneFilterType{ Duration: &durationCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) for _, scene := range scenes { if durationCriterion.Modifier == models.CriterionModifierEquals { @@ -661,8 +663,8 @@ func TestSceneQueryResolution(t *testing.T) { } func verifyScenesResolution(t *testing.T, resolution models.ResolutionEnum) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter sceneFilter := models.SceneFilterType{ Resolution: &models.ResolutionCriterionInput{ Value: resolution, @@ -670,7 +672,7 @@ func verifyScenesResolution(t *testing.T, resolution models.ResolutionEnum) { }, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) for _, scene := range scenes { verifySceneResolution(t, scene.Height, resolution) @@ -706,20 +708,20 @@ func TestAllResolutionsHaveResolutionRange(t *testing.T) { } func TestSceneQueryResolutionModifiers(t *testing.T) { - if err := withRollbackTxn(func(r models.Repository) error { - qb := r.Scene() - sceneNoResolution, _ := createScene(qb, 0, 0) - firstScene540P, _ := createScene(qb, 960, 540) - secondScene540P, _ := createScene(qb, 1280, 719) - firstScene720P, _ := createScene(qb, 1280, 720) - secondScene720P, _ := createScene(qb, 1280, 721) - thirdScene720P, _ := createScene(qb, 1920, 1079) - scene1080P, _ := createScene(qb, 1920, 1080) + if err := withRollbackTxn(func(ctx context.Context) error { + qb := sqlite.SceneReaderWriter + sceneNoResolution, _ := createScene(ctx, qb, 0, 0) + firstScene540P, _ := createScene(ctx, qb, 960, 540) + secondScene540P, _ := createScene(ctx, qb, 1280, 719) + firstScene720P, _ := createScene(ctx, qb, 1280, 720) + secondScene720P, _ := createScene(ctx, qb, 1280, 721) + thirdScene720P, _ := createScene(ctx, qb, 1920, 1079) + scene1080P, _ := createScene(ctx, qb, 1920, 1080) - scenesEqualTo720P := queryScenes(t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierEquals) - scenesNotEqualTo720P := queryScenes(t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierNotEquals) - scenesGreaterThan720P := queryScenes(t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierGreaterThan) - scenesLessThan720P := queryScenes(t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierLessThan) + scenesEqualTo720P := queryScenes(ctx, t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierEquals) + scenesNotEqualTo720P := queryScenes(ctx, t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierNotEquals) + scenesGreaterThan720P := queryScenes(ctx, t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierGreaterThan) + scenesLessThan720P := queryScenes(ctx, t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierLessThan) assert.Subset(t, scenesEqualTo720P, []*models.Scene{firstScene720P, secondScene720P, thirdScene720P}) assert.NotSubset(t, scenesEqualTo720P, []*models.Scene{sceneNoResolution, firstScene540P, secondScene540P, scene1080P}) @@ -739,7 +741,7 @@ func TestSceneQueryResolutionModifiers(t *testing.T) { } } -func queryScenes(t *testing.T, queryBuilder models.SceneReaderWriter, resolution models.ResolutionEnum, modifier models.CriterionModifier) []*models.Scene { +func queryScenes(ctx context.Context, t *testing.T, queryBuilder models.SceneReaderWriter, resolution models.ResolutionEnum, modifier models.CriterionModifier) []*models.Scene { sceneFilter := models.SceneFilterType{ Resolution: &models.ResolutionCriterionInput{ Value: resolution, @@ -747,10 +749,10 @@ func queryScenes(t *testing.T, queryBuilder models.SceneReaderWriter, resolution }, } - return queryScene(t, queryBuilder, &sceneFilter, nil) + return queryScene(ctx, t, queryBuilder, &sceneFilter, nil) } -func createScene(queryBuilder models.SceneReaderWriter, width int64, height int64) (*models.Scene, error) { +func createScene(ctx context.Context, queryBuilder models.SceneReaderWriter, width int64, height int64) (*models.Scene, error) { name := fmt.Sprintf("TestSceneQueryResolutionModifiers %d %d", width, height) scene := models.Scene{ Path: name, @@ -765,12 +767,12 @@ func createScene(queryBuilder models.SceneReaderWriter, width int64, height int6 Checksum: sql.NullString{String: md5.FromString(name), Valid: true}, } - return queryBuilder.Create(scene) + return queryBuilder.Create(ctx, scene) } func TestSceneQueryHasMarkers(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter hasMarkers := "true" sceneFilter := models.SceneFilterType{ HasMarkers: &hasMarkers, @@ -781,17 +783,17 @@ func TestSceneQueryHasMarkers(t *testing.T) { Q: &q, } - scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithMarkers], scenes[0].ID) hasMarkers = "false" - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) findFilter.Q = nil - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.NotEqual(t, 0, len(scenes)) @@ -805,8 +807,8 @@ func TestSceneQueryHasMarkers(t *testing.T) { } func TestSceneQueryIsMissingGallery(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter isMissing := "galleries" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -817,12 +819,12 @@ func TestSceneQueryIsMissingGallery(t *testing.T) { Q: &q, } - scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) findFilter.Q = nil - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) // ensure non of the ids equal the one with gallery for _, scene := range scenes { @@ -834,8 +836,8 @@ func TestSceneQueryIsMissingGallery(t *testing.T) { } func TestSceneQueryIsMissingStudio(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter isMissing := "studio" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -846,12 +848,12 @@ func TestSceneQueryIsMissingStudio(t *testing.T) { Q: &q, } - scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) findFilter.Q = nil - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) // ensure non of the ids equal the one with studio for _, scene := range scenes { @@ -863,8 +865,8 @@ func TestSceneQueryIsMissingStudio(t *testing.T) { } func TestSceneQueryIsMissingMovies(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter isMissing := "movie" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -875,12 +877,12 @@ func TestSceneQueryIsMissingMovies(t *testing.T) { Q: &q, } - scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) findFilter.Q = nil - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) // ensure non of the ids equal the one with movies for _, scene := range scenes { @@ -892,8 +894,8 @@ func TestSceneQueryIsMissingMovies(t *testing.T) { } func TestSceneQueryIsMissingPerformers(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter isMissing := "performers" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -904,12 +906,12 @@ func TestSceneQueryIsMissingPerformers(t *testing.T) { Q: &q, } - scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) findFilter.Q = nil - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.True(t, len(scenes) > 0) @@ -923,14 +925,14 @@ func TestSceneQueryIsMissingPerformers(t *testing.T) { } func TestSceneQueryIsMissingDate(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter isMissing := "date" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) // three in four scenes have no date assert.Len(t, scenes, int(math.Ceil(float64(totalScenes)/4*3))) @@ -945,8 +947,8 @@ func TestSceneQueryIsMissingDate(t *testing.T) { } func TestSceneQueryIsMissingTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter isMissing := "tags" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -957,12 +959,12 @@ func TestSceneQueryIsMissingTags(t *testing.T) { Q: &q, } - scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) findFilter.Q = nil - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.True(t, len(scenes) > 0) @@ -971,14 +973,14 @@ func TestSceneQueryIsMissingTags(t *testing.T) { } func TestSceneQueryIsMissingRating(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter isMissing := "rating" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.True(t, len(scenes) > 0) @@ -992,8 +994,8 @@ func TestSceneQueryIsMissingRating(t *testing.T) { } func TestSceneQueryPerformers(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter performerCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(performerIDs[performerIdxWithScene]), @@ -1006,7 +1008,7 @@ func TestSceneQueryPerformers(t *testing.T) { Performers: &performerCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 2) @@ -1023,7 +1025,7 @@ func TestSceneQueryPerformers(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - scenes = queryScene(t, sqb, &sceneFilter, nil) + scenes = queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithTwoPerformers], scenes[0].ID) @@ -1040,7 +1042,7 @@ func TestSceneQueryPerformers(t *testing.T) { Q: &q, } - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) return nil @@ -1048,8 +1050,8 @@ func TestSceneQueryPerformers(t *testing.T) { } func TestSceneQueryTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithScene]), @@ -1062,7 +1064,7 @@ func TestSceneQueryTags(t *testing.T) { Tags: &tagCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 2) // ensure ids are correct @@ -1078,7 +1080,7 @@ func TestSceneQueryTags(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - scenes = queryScene(t, sqb, &sceneFilter, nil) + scenes = queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithTwoTags], scenes[0].ID) @@ -1095,7 +1097,7 @@ func TestSceneQueryTags(t *testing.T) { Q: &q, } - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) return nil @@ -1103,8 +1105,8 @@ func TestSceneQueryTags(t *testing.T) { } func TestSceneQueryPerformerTags(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), @@ -1117,7 +1119,7 @@ func TestSceneQueryPerformerTags(t *testing.T) { PerformerTags: &tagCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 2) // ensure ids are correct @@ -1133,7 +1135,7 @@ func TestSceneQueryPerformerTags(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - scenes = queryScene(t, sqb, &sceneFilter, nil) + scenes = queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithPerformerTwoTags], scenes[0].ID) @@ -1150,7 +1152,7 @@ func TestSceneQueryPerformerTags(t *testing.T) { Q: &q, } - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) tagCriterion = models.HierarchicalMultiCriterionInput{ @@ -1158,22 +1160,22 @@ func TestSceneQueryPerformerTags(t *testing.T) { } q = getSceneStringValue(sceneIdx1WithPerformer, titleField) - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdx1WithPerformer], scenes[0].ID) q = getSceneStringValue(sceneIdxWithPerformerTag, titleField) - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) tagCriterion.Modifier = models.CriterionModifierNotNull - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithPerformerTag], scenes[0].ID) q = getSceneStringValue(sceneIdx1WithPerformer, titleField) - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) return nil @@ -1181,8 +1183,8 @@ func TestSceneQueryPerformerTags(t *testing.T) { } func TestSceneQueryStudio(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[studioIdxWithScene]), @@ -1194,7 +1196,7 @@ func TestSceneQueryStudio(t *testing.T) { Studios: &studioCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 1) @@ -1213,7 +1215,7 @@ func TestSceneQueryStudio(t *testing.T) { Q: &q, } - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) return nil @@ -1221,8 +1223,8 @@ func TestSceneQueryStudio(t *testing.T) { } func TestSceneQueryStudioDepth(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter depth := 2 studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ @@ -1236,16 +1238,16 @@ func TestSceneQueryStudioDepth(t *testing.T) { Studios: &studioCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 1) depth = 1 - scenes = queryScene(t, sqb, &sceneFilter, nil) + scenes = queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 0) studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])} - scenes = queryScene(t, sqb, &sceneFilter, nil) + scenes = queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 1) // ensure id is correct @@ -1265,15 +1267,15 @@ func TestSceneQueryStudioDepth(t *testing.T) { Q: &q, } - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) depth = 1 - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 1) studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])} - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) return nil @@ -1281,8 +1283,8 @@ func TestSceneQueryStudioDepth(t *testing.T) { } func TestSceneQueryMovies(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter movieCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(movieIDs[movieIdxWithScene]), @@ -1294,7 +1296,7 @@ func TestSceneQueryMovies(t *testing.T) { Movies: &movieCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Len(t, scenes, 1) @@ -1313,7 +1315,7 @@ func TestSceneQueryMovies(t *testing.T) { Q: &q, } - scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter) assert.Len(t, scenes, 0) return nil @@ -1328,9 +1330,9 @@ func TestSceneQuerySorting(t *testing.T) { Direction: &direction, } - withTxn(func(r models.Repository) error { - sqb := r.Scene() - scenes := queryScene(t, sqb, nil, &findFilter) + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter + scenes := queryScene(ctx, t, sqb, nil, &findFilter) // scenes should be in same order as indexes firstScene := scenes[0] @@ -1342,7 +1344,7 @@ func TestSceneQuerySorting(t *testing.T) { // sort in descending order direction = models.SortDirectionEnumDesc - scenes = queryScene(t, sqb, nil, &findFilter) + scenes = queryScene(ctx, t, sqb, nil, &findFilter) firstScene = scenes[0] lastScene = scenes[len(scenes)-1] @@ -1359,9 +1361,9 @@ func TestSceneQueryPagination(t *testing.T) { PerPage: &perPage, } - withTxn(func(r models.Repository) error { - sqb := r.Scene() - scenes := queryScene(t, sqb, nil, &findFilter) + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter + scenes := queryScene(ctx, t, sqb, nil, &findFilter) assert.Len(t, scenes, 1) @@ -1369,7 +1371,7 @@ func TestSceneQueryPagination(t *testing.T) { page := 2 findFilter.Page = &page - scenes = queryScene(t, sqb, nil, &findFilter) + scenes = queryScene(ctx, t, sqb, nil, &findFilter) assert.Len(t, scenes, 1) secondID := scenes[0].ID @@ -1378,7 +1380,7 @@ func TestSceneQueryPagination(t *testing.T) { perPage = 2 page = 1 - scenes = queryScene(t, sqb, nil, &findFilter) + scenes = queryScene(ctx, t, sqb, nil, &findFilter) assert.Len(t, scenes, 2) assert.Equal(t, firstID, scenes[0].ID) assert.Equal(t, secondID, scenes[1].ID) @@ -1407,17 +1409,17 @@ func TestSceneQueryTagCount(t *testing.T) { } func verifyScenesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter sceneFilter := models.SceneFilterType{ TagCount: &tagCountCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Greater(t, len(scenes), 0) for _, scene := range scenes { - ids, err := sqb.GetTagIDs(scene.ID) + ids, err := sqb.GetTagIDs(ctx, scene.ID) if err != nil { return err } @@ -1448,17 +1450,17 @@ func TestSceneQueryPerformerCount(t *testing.T) { } func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter sceneFilter := models.SceneFilterType{ PerformerCount: &performerCountCriterion, } - scenes := queryScene(t, sqb, &sceneFilter, nil) + scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) assert.Greater(t, len(scenes), 0) for _, scene := range scenes { - ids, err := sqb.GetPerformerIDs(scene.ID) + ids, err := sqb.GetPerformerIDs(ctx, scene.ID) if err != nil { return err } @@ -1470,10 +1472,10 @@ func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.Int } func TestSceneCountByTagID(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - sceneCount, err := sqb.CountByTagID(tagIDs[tagIdxWithScene]) + sceneCount, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithScene]) if err != nil { t.Errorf("error calling CountByTagID: %s", err.Error()) @@ -1481,7 +1483,7 @@ func TestSceneCountByTagID(t *testing.T) { assert.Equal(t, 1, sceneCount) - sceneCount, err = sqb.CountByTagID(0) + sceneCount, err = sqb.CountByTagID(ctx, 0) if err != nil { t.Errorf("error calling CountByTagID: %s", err.Error()) @@ -1494,10 +1496,10 @@ func TestSceneCountByTagID(t *testing.T) { } func TestSceneCountByMovieID(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - sceneCount, err := sqb.CountByMovieID(movieIDs[movieIdxWithScene]) + sceneCount, err := sqb.CountByMovieID(ctx, movieIDs[movieIdxWithScene]) if err != nil { t.Errorf("error calling CountByMovieID: %s", err.Error()) @@ -1505,7 +1507,7 @@ func TestSceneCountByMovieID(t *testing.T) { assert.Equal(t, 1, sceneCount) - sceneCount, err = sqb.CountByMovieID(0) + sceneCount, err = sqb.CountByMovieID(ctx, 0) if err != nil { t.Errorf("error calling CountByMovieID: %s", err.Error()) @@ -1518,10 +1520,10 @@ func TestSceneCountByMovieID(t *testing.T) { } func TestSceneCountByStudioID(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - sceneCount, err := sqb.CountByStudioID(studioIDs[studioIdxWithScene]) + sceneCount, err := sqb.CountByStudioID(ctx, studioIDs[studioIdxWithScene]) if err != nil { t.Errorf("error calling CountByStudioID: %s", err.Error()) @@ -1529,7 +1531,7 @@ func TestSceneCountByStudioID(t *testing.T) { assert.Equal(t, 1, sceneCount) - sceneCount, err = sqb.CountByStudioID(0) + sceneCount, err = sqb.CountByStudioID(ctx, 0) if err != nil { t.Errorf("error calling CountByStudioID: %s", err.Error()) @@ -1542,10 +1544,10 @@ func TestSceneCountByStudioID(t *testing.T) { } func TestFindByMovieID(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - scenes, err := sqb.FindByMovieID(movieIDs[movieIdxWithScene]) + scenes, err := sqb.FindByMovieID(ctx, movieIDs[movieIdxWithScene]) if err != nil { t.Errorf("error calling FindByMovieID: %s", err.Error()) @@ -1554,7 +1556,7 @@ func TestFindByMovieID(t *testing.T) { assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithMovie], scenes[0].ID) - scenes, err = sqb.FindByMovieID(0) + scenes, err = sqb.FindByMovieID(ctx, 0) if err != nil { t.Errorf("error calling FindByMovieID: %s", err.Error()) @@ -1567,10 +1569,10 @@ func TestFindByMovieID(t *testing.T) { } func TestFindByPerformerID(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Scene() + withTxn(func(ctx context.Context) error { + sqb := sqlite.SceneReaderWriter - scenes, err := sqb.FindByPerformerID(performerIDs[performerIdxWithScene]) + scenes, err := sqb.FindByPerformerID(ctx, performerIDs[performerIdxWithScene]) if err != nil { t.Errorf("error calling FindByPerformerID: %s", err.Error()) @@ -1579,7 +1581,7 @@ func TestFindByPerformerID(t *testing.T) { assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithPerformer], scenes[0].ID) - scenes, err = sqb.FindByPerformerID(0) + scenes, err = sqb.FindByPerformerID(ctx, 0) if err != nil { t.Errorf("error calling FindByPerformerID: %s", err.Error()) @@ -1592,8 +1594,8 @@ func TestFindByPerformerID(t *testing.T) { } func TestSceneUpdateSceneCover(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Scene() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.SceneReaderWriter // create performer to test against const name = "TestSceneUpdateSceneCover" @@ -1601,26 +1603,26 @@ func TestSceneUpdateSceneCover(t *testing.T) { Path: name, Checksum: sql.NullString{String: md5.FromString(name), Valid: true}, } - created, err := qb.Create(scene) + created, err := qb.Create(ctx, scene) if err != nil { return fmt.Errorf("Error creating scene: %s", err.Error()) } image := []byte("image") - err = qb.UpdateCover(created.ID, image) + err = qb.UpdateCover(ctx, created.ID, image) if err != nil { return fmt.Errorf("Error updating scene cover: %s", err.Error()) } // ensure image set - storedImage, err := qb.GetCover(created.ID) + storedImage, err := qb.GetCover(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } assert.Equal(t, storedImage, image) // set nil image - err = qb.UpdateCover(created.ID, nil) + err = qb.UpdateCover(ctx, created.ID, nil) if err == nil { return fmt.Errorf("Expected error setting nil image") } @@ -1632,8 +1634,8 @@ func TestSceneUpdateSceneCover(t *testing.T) { } func TestSceneDestroySceneCover(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Scene() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.SceneReaderWriter // create performer to test against const name = "TestSceneDestroySceneCover" @@ -1641,24 +1643,24 @@ func TestSceneDestroySceneCover(t *testing.T) { Path: name, Checksum: sql.NullString{String: md5.FromString(name), Valid: true}, } - created, err := qb.Create(scene) + created, err := qb.Create(ctx, scene) if err != nil { return fmt.Errorf("Error creating scene: %s", err.Error()) } image := []byte("image") - err = qb.UpdateCover(created.ID, image) + err = qb.UpdateCover(ctx, created.ID, image) if err != nil { return fmt.Errorf("Error updating scene image: %s", err.Error()) } - err = qb.DestroyCover(created.ID) + err = qb.DestroyCover(ctx, created.ID) if err != nil { return fmt.Errorf("Error destroying scene cover: %s", err.Error()) } // image should be nil - storedImage, err := qb.GetCover(created.ID) + storedImage, err := qb.GetCover(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } @@ -1671,8 +1673,8 @@ func TestSceneDestroySceneCover(t *testing.T) { } func TestSceneStashIDs(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Scene() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.SceneReaderWriter // create scene to test against const name = "TestSceneStashIDs" @@ -1680,12 +1682,12 @@ func TestSceneStashIDs(t *testing.T) { Path: name, Checksum: sql.NullString{String: md5.FromString(name), Valid: true}, } - created, err := qb.Create(scene) + created, err := qb.Create(ctx, scene) if err != nil { return fmt.Errorf("Error creating scene: %s", err.Error()) } - testStashIDReaderWriter(t, qb, created.ID) + testStashIDReaderWriter(ctx, t, qb, created.ID) return nil }); err != nil { t.Error(err.Error()) @@ -1693,8 +1695,8 @@ func TestSceneStashIDs(t *testing.T) { } func TestSceneQueryQTrim(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Scene() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.SceneReaderWriter expectedID := sceneIDs[sceneIdxWithSpacedName] @@ -1717,7 +1719,7 @@ func TestSceneQueryQTrim(t *testing.T) { f := models.FindFilterType{ Q: &tst.query, } - scenes := queryScene(t, qb, nil, &f) + scenes := queryScene(ctx, t, qb, nil, &f) assert.Len(t, scenes, tst.count) if len(scenes) > 0 { @@ -1726,7 +1728,7 @@ func TestSceneQueryQTrim(t *testing.T) { } findFilter := models.FindFilterType{} - scenes := queryScene(t, qb, nil, &findFilter) + scenes := queryScene(ctx, t, qb, nil, &findFilter) assert.NotEqual(t, 0, len(scenes)) return nil diff --git a/pkg/sqlite/scraped_item.go b/pkg/sqlite/scraped_item.go index 1eafc98a5..1b8216dab 100644 --- a/pkg/sqlite/scraped_item.go +++ b/pkg/sqlite/scraped_item.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" @@ -13,41 +14,38 @@ type scrapedItemQueryBuilder struct { repository } -func NewScrapedItemReaderWriter(tx dbi) *scrapedItemQueryBuilder { - return &scrapedItemQueryBuilder{ - repository{ - tx: tx, - tableName: scrapedItemTable, - idColumn: idColumn, - }, - } +var ScrapedItemReaderWriter = &scrapedItemQueryBuilder{ + repository{ + tableName: scrapedItemTable, + idColumn: idColumn, + }, } -func (qb *scrapedItemQueryBuilder) Create(newObject models.ScrapedItem) (*models.ScrapedItem, error) { +func (qb *scrapedItemQueryBuilder) Create(ctx context.Context, newObject models.ScrapedItem) (*models.ScrapedItem, error) { var ret models.ScrapedItem - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *scrapedItemQueryBuilder) Update(updatedObject models.ScrapedItem) (*models.ScrapedItem, error) { +func (qb *scrapedItemQueryBuilder) Update(ctx context.Context, updatedObject models.ScrapedItem) (*models.ScrapedItem, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.find(updatedObject.ID) + return qb.find(ctx, updatedObject.ID) } -func (qb *scrapedItemQueryBuilder) Find(id int) (*models.ScrapedItem, error) { - return qb.find(id) +func (qb *scrapedItemQueryBuilder) Find(ctx context.Context, id int) (*models.ScrapedItem, error) { + return qb.find(ctx, id) } -func (qb *scrapedItemQueryBuilder) find(id int) (*models.ScrapedItem, error) { +func (qb *scrapedItemQueryBuilder) find(ctx context.Context, id int) (*models.ScrapedItem, error) { var ret models.ScrapedItem - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -56,8 +54,8 @@ func (qb *scrapedItemQueryBuilder) find(id int) (*models.ScrapedItem, error) { return &ret, nil } -func (qb *scrapedItemQueryBuilder) All() ([]*models.ScrapedItem, error) { - return qb.queryScrapedItems(selectAll("scraped_items")+qb.getScrapedItemsSort(nil), nil) +func (qb *scrapedItemQueryBuilder) All(ctx context.Context) ([]*models.ScrapedItem, error) { + return qb.queryScrapedItems(ctx, selectAll("scraped_items")+qb.getScrapedItemsSort(nil), nil) } func (qb *scrapedItemQueryBuilder) getScrapedItemsSort(findFilter *models.FindFilterType) string { @@ -73,9 +71,9 @@ func (qb *scrapedItemQueryBuilder) getScrapedItemsSort(findFilter *models.FindFi return getSort(sort, direction, "scraped_items") } -func (qb *scrapedItemQueryBuilder) queryScrapedItems(query string, args []interface{}) ([]*models.ScrapedItem, error) { +func (qb *scrapedItemQueryBuilder) queryScrapedItems(ctx context.Context, query string, args []interface{}) ([]*models.ScrapedItem, error) { var ret models.ScrapedItems - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index e07d8aebe..5cb123118 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -13,13 +13,13 @@ import ( "testing" "time" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stashapp/stash/pkg/sqlite" + "github.com/stashapp/stash/pkg/txn" ) const ( @@ -386,20 +386,21 @@ var ( } ) +var db *sqlite.Database + func TestMain(m *testing.M) { ret := runTests(m) os.Exit(ret) } -func withTxn(f func(r models.Repository) error) error { - t := sqlite.NewTransactionManager() - return t.WithTxn(context.TODO(), f) +func withTxn(f func(ctx context.Context) error) error { + return txn.WithTxn(context.Background(), db, f) } -func withRollbackTxn(f func(r models.Repository) error) error { +func withRollbackTxn(f func(ctx context.Context) error) error { var ret error - withTxn(func(repo models.Repository) error { - ret = f(repo) + withTxn(func(ctx context.Context) error { + ret = f(ctx) return errors.New("fake error for rollback") }) @@ -407,7 +408,7 @@ func withRollbackTxn(f func(r models.Repository) error) error { } func testTeardown(databaseFile string) { - err := database.DB.Close() + err := db.Close() if err != nil { panic(err) @@ -428,7 +429,9 @@ func runTests(m *testing.M) int { f.Close() databaseFile := f.Name() - if err := database.Initialize(databaseFile); err != nil { + db = &sqlite.Database{} + + if err := db.Open(databaseFile); err != nil { panic(fmt.Sprintf("Could not initialize database: %s", err.Error())) } @@ -445,109 +448,109 @@ func runTests(m *testing.M) int { } func populateDB() error { - if err := withTxn(func(r models.Repository) error { - if err := createScenes(r.Scene(), totalScenes); err != nil { + if err := withTxn(func(ctx context.Context) error { + if err := createScenes(ctx, sqlite.SceneReaderWriter, totalScenes); err != nil { return fmt.Errorf("error creating scenes: %s", err.Error()) } - if err := createImages(r.Image(), totalImages); err != nil { + if err := createImages(ctx, sqlite.ImageReaderWriter, totalImages); err != nil { return fmt.Errorf("error creating images: %s", err.Error()) } - if err := createGalleries(r.Gallery(), totalGalleries); err != nil { + if err := createGalleries(ctx, sqlite.GalleryReaderWriter, totalGalleries); err != nil { return fmt.Errorf("error creating galleries: %s", err.Error()) } - if err := createMovies(r.Movie(), moviesNameCase, moviesNameNoCase); err != nil { + if err := createMovies(ctx, sqlite.MovieReaderWriter, moviesNameCase, moviesNameNoCase); err != nil { return fmt.Errorf("error creating movies: %s", err.Error()) } - if err := createPerformers(r.Performer(), performersNameCase, performersNameNoCase); err != nil { + if err := createPerformers(ctx, sqlite.PerformerReaderWriter, performersNameCase, performersNameNoCase); err != nil { return fmt.Errorf("error creating performers: %s", err.Error()) } - if err := createTags(r.Tag(), tagsNameCase, tagsNameNoCase); err != nil { + if err := createTags(ctx, sqlite.TagReaderWriter, tagsNameCase, tagsNameNoCase); err != nil { return fmt.Errorf("error creating tags: %s", err.Error()) } - if err := addTagImage(r.Tag(), tagIdxWithCoverImage); err != nil { + if err := addTagImage(ctx, sqlite.TagReaderWriter, tagIdxWithCoverImage); err != nil { return fmt.Errorf("error adding tag image: %s", err.Error()) } - if err := createStudios(r.Studio(), studiosNameCase, studiosNameNoCase); err != nil { + if err := createStudios(ctx, sqlite.StudioReaderWriter, studiosNameCase, studiosNameNoCase); err != nil { return fmt.Errorf("error creating studios: %s", err.Error()) } - if err := createSavedFilters(r.SavedFilter(), totalSavedFilters); err != nil { + if err := createSavedFilters(ctx, sqlite.SavedFilterReaderWriter, totalSavedFilters); err != nil { return fmt.Errorf("error creating saved filters: %s", err.Error()) } - if err := linkPerformerTags(r.Performer()); err != nil { + if err := linkPerformerTags(ctx, sqlite.PerformerReaderWriter); err != nil { return fmt.Errorf("error linking performer tags: %s", err.Error()) } - if err := linkSceneGalleries(r.Scene()); err != nil { + if err := linkSceneGalleries(ctx, sqlite.SceneReaderWriter); err != nil { return fmt.Errorf("error linking scenes to galleries: %s", err.Error()) } - if err := linkSceneMovies(r.Scene()); err != nil { + if err := linkSceneMovies(ctx, sqlite.SceneReaderWriter); err != nil { return fmt.Errorf("error linking scenes to movies: %s", err.Error()) } - if err := linkScenePerformers(r.Scene()); err != nil { + if err := linkScenePerformers(ctx, sqlite.SceneReaderWriter); err != nil { return fmt.Errorf("error linking scene performers: %s", err.Error()) } - if err := linkSceneTags(r.Scene()); err != nil { + if err := linkSceneTags(ctx, sqlite.SceneReaderWriter); err != nil { return fmt.Errorf("error linking scene tags: %s", err.Error()) } - if err := linkSceneStudios(r.Scene()); err != nil { + if err := linkSceneStudios(ctx, sqlite.SceneReaderWriter); err != nil { return fmt.Errorf("error linking scene studios: %s", err.Error()) } - if err := linkImageGalleries(r.Gallery()); err != nil { + if err := linkImageGalleries(ctx, sqlite.GalleryReaderWriter); err != nil { return fmt.Errorf("error linking gallery images: %s", err.Error()) } - if err := linkImagePerformers(r.Image()); err != nil { + if err := linkImagePerformers(ctx, sqlite.ImageReaderWriter); err != nil { return fmt.Errorf("error linking image performers: %s", err.Error()) } - if err := linkImageTags(r.Image()); err != nil { + if err := linkImageTags(ctx, sqlite.ImageReaderWriter); err != nil { return fmt.Errorf("error linking image tags: %s", err.Error()) } - if err := linkImageStudios(r.Image()); err != nil { + if err := linkImageStudios(ctx, sqlite.ImageReaderWriter); err != nil { return fmt.Errorf("error linking image studio: %s", err.Error()) } - if err := linkMovieStudios(r.Movie()); err != nil { + if err := linkMovieStudios(ctx, sqlite.MovieReaderWriter); err != nil { return fmt.Errorf("error linking movie studios: %s", err.Error()) } - if err := linkStudiosParent(r.Studio()); err != nil { + if err := linkStudiosParent(ctx, sqlite.StudioReaderWriter); err != nil { return fmt.Errorf("error linking studios parent: %s", err.Error()) } - if err := linkGalleryPerformers(r.Gallery()); err != nil { + if err := linkGalleryPerformers(ctx, sqlite.GalleryReaderWriter); err != nil { return fmt.Errorf("error linking gallery performers: %s", err.Error()) } - if err := linkGalleryTags(r.Gallery()); err != nil { + if err := linkGalleryTags(ctx, sqlite.GalleryReaderWriter); err != nil { return fmt.Errorf("error linking gallery tags: %s", err.Error()) } - if err := linkGalleryStudios(r.Gallery()); err != nil { + if err := linkGalleryStudios(ctx, sqlite.GalleryReaderWriter); err != nil { return fmt.Errorf("error linking gallery studios: %s", err.Error()) } - if err := linkTagsParent(r.Tag()); err != nil { + if err := linkTagsParent(ctx, sqlite.TagReaderWriter); err != nil { return fmt.Errorf("error linking tags parent: %s", err.Error()) } for _, ms := range markerSpecs { - if err := createMarker(r.SceneMarker(), ms); err != nil { + if err := createMarker(ctx, sqlite.SceneMarkerReaderWriter, ms); err != nil { return fmt.Errorf("error creating scene marker: %s", err.Error()) } } @@ -642,7 +645,7 @@ func getObjectDate(index int) models.SQLiteDate { } } -func createScenes(sqb models.SceneReaderWriter, n int) error { +func createScenes(ctx context.Context, sqb models.SceneReaderWriter, n int) error { for i := 0; i < n; i++ { scene := models.Scene{ Path: getSceneStringValue(i, pathField), @@ -658,7 +661,7 @@ func createScenes(sqb models.SceneReaderWriter, n int) error { Date: getObjectDate(i), } - created, err := sqb.Create(scene) + created, err := sqb.Create(ctx, scene) if err != nil { return fmt.Errorf("Error creating scene %v+: %s", scene, err.Error()) @@ -683,7 +686,7 @@ func getImagePath(index int) string { return getImageStringValue(index, pathField) } -func createImages(qb models.ImageReaderWriter, n int) error { +func createImages(ctx context.Context, qb models.ImageReaderWriter, n int) error { for i := 0; i < n; i++ { image := models.Image{ Path: getImagePath(i), @@ -695,7 +698,7 @@ func createImages(qb models.ImageReaderWriter, n int) error { Width: getWidth(i), } - created, err := qb.Create(image) + created, err := qb.Create(ctx, image) if err != nil { return fmt.Errorf("Error creating image %v+: %s", image, err.Error()) @@ -715,7 +718,7 @@ func getGalleryNullStringValue(index int, field string) sql.NullString { return getPrefixedNullStringValue("gallery", index, field) } -func createGalleries(gqb models.GalleryReaderWriter, n int) error { +func createGalleries(ctx context.Context, gqb models.GalleryReaderWriter, n int) error { for i := 0; i < n; i++ { gallery := models.Gallery{ Path: models.NullString(getGalleryStringValue(i, pathField)), @@ -726,7 +729,7 @@ func createGalleries(gqb models.GalleryReaderWriter, n int) error { Date: getObjectDate(i), } - created, err := gqb.Create(gallery) + created, err := gqb.Create(ctx, gallery) if err != nil { return fmt.Errorf("Error creating gallery %v+: %s", gallery, err.Error()) @@ -747,7 +750,7 @@ func getMovieNullStringValue(index int, field string) sql.NullString { } // createMoviees creates n movies with plain Name and o movies with camel cased NaMe included -func createMovies(mqb models.MovieReaderWriter, n int, o int) error { +func createMovies(ctx context.Context, mqb models.MovieReaderWriter, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" @@ -768,7 +771,7 @@ func createMovies(mqb models.MovieReaderWriter, n int, o int) error { Checksum: md5.FromString(name), } - created, err := mqb.Create(movie) + created, err := mqb.Create(ctx, movie) if err != nil { return fmt.Errorf("Error creating movie [%d] %v+: %s", i, movie, err.Error()) @@ -828,7 +831,7 @@ func getIgnoreAutoTag(index int) bool { } // createPerformers creates n performers with plain Name and o performers with camel cased NaMe included -func createPerformers(pqb models.PerformerReaderWriter, n int, o int) error { +func createPerformers(ctx context.Context, pqb models.PerformerReaderWriter, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" @@ -864,7 +867,7 @@ func createPerformers(pqb models.PerformerReaderWriter, n int, o int) error { performer.CareerLength = models.NullString(*careerLength) } - created, err := pqb.Create(performer) + created, err := pqb.Create(ctx, performer) if err != nil { return fmt.Errorf("Error creating performer %v+: %s", performer, err.Error()) @@ -942,7 +945,7 @@ func getTagChildCount(id int) int { } //createTags creates n tags with plain Name and o tags with camel cased NaMe included -func createTags(tqb models.TagReaderWriter, n int, o int) error { +func createTags(ctx context.Context, tqb models.TagReaderWriter, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" @@ -962,7 +965,7 @@ func createTags(tqb models.TagReaderWriter, n int, o int) error { IgnoreAutoTag: getIgnoreAutoTag(i), } - created, err := tqb.Create(tag) + created, err := tqb.Create(ctx, tag) if err != nil { return fmt.Errorf("Error creating tag %v+: %s", tag, err.Error()) @@ -970,7 +973,7 @@ func createTags(tqb models.TagReaderWriter, n int, o int) error { // add alias alias := getTagStringValue(i, "Alias") - if err := tqb.UpdateAliases(created.ID, []string{alias}); err != nil { + if err := tqb.UpdateAliases(ctx, created.ID, []string{alias}); err != nil { return fmt.Errorf("error setting tag alias: %s", err.Error()) } @@ -989,7 +992,7 @@ func getStudioNullStringValue(index int, field string) sql.NullString { return getPrefixedNullStringValue("studio", index, field) } -func createStudio(sqb models.StudioReaderWriter, name string, parentID *int64) (*models.Studio, error) { +func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name string, parentID *int64) (*models.Studio, error) { studio := models.Studio{ Name: sql.NullString{String: name, Valid: true}, Checksum: md5.FromString(name), @@ -999,11 +1002,11 @@ func createStudio(sqb models.StudioReaderWriter, name string, parentID *int64) ( studio.ParentID = sql.NullInt64{Int64: *parentID, Valid: true} } - return createStudioFromModel(sqb, studio) + return createStudioFromModel(ctx, sqb, studio) } -func createStudioFromModel(sqb models.StudioReaderWriter, studio models.Studio) (*models.Studio, error) { - created, err := sqb.Create(studio) +func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, studio models.Studio) (*models.Studio, error) { + created, err := sqb.Create(ctx, studio) if err != nil { return nil, fmt.Errorf("Error creating studio %v+: %s", studio, err.Error()) @@ -1013,7 +1016,7 @@ func createStudioFromModel(sqb models.StudioReaderWriter, studio models.Studio) } // createStudios creates n studios with plain Name and o studios with camel cased NaMe included -func createStudios(sqb models.StudioReaderWriter, n int, o int) error { +func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" @@ -1034,7 +1037,7 @@ func createStudios(sqb models.StudioReaderWriter, n int, o int) error { URL: getStudioNullStringValue(index, urlField), IgnoreAutoTag: getIgnoreAutoTag(i), } - created, err := createStudioFromModel(sqb, studio) + created, err := createStudioFromModel(ctx, sqb, studio) if err != nil { return err @@ -1042,7 +1045,7 @@ func createStudios(sqb models.StudioReaderWriter, n int, o int) error { // add alias alias := getStudioStringValue(i, "Alias") - if err := sqb.UpdateAliases(created.ID, []string{alias}); err != nil { + if err := sqb.UpdateAliases(ctx, created.ID, []string{alias}); err != nil { return fmt.Errorf("error setting studio alias: %s", err.Error()) } @@ -1053,13 +1056,13 @@ func createStudios(sqb models.StudioReaderWriter, n int, o int) error { return nil } -func createMarker(mqb models.SceneMarkerReaderWriter, markerSpec markerSpec) error { +func createMarker(ctx context.Context, mqb models.SceneMarkerReaderWriter, markerSpec markerSpec) error { marker := models.SceneMarker{ SceneID: sql.NullInt64{Int64: int64(sceneIDs[markerSpec.sceneIdx]), Valid: true}, PrimaryTagID: tagIDs[markerSpec.primaryTagIdx], } - created, err := mqb.Create(marker) + created, err := mqb.Create(ctx, marker) if err != nil { return fmt.Errorf("error creating marker %v+: %w", marker, err) @@ -1074,7 +1077,7 @@ func createMarker(mqb models.SceneMarkerReaderWriter, markerSpec markerSpec) err newTagIDs = append(newTagIDs, tagIDs[tagIdx]) } - if err := mqb.UpdateTags(created.ID, newTagIDs); err != nil { + if err := mqb.UpdateTags(ctx, created.ID, newTagIDs); err != nil { return fmt.Errorf("error creating marker/tag join: %w", err) } } @@ -1107,7 +1110,7 @@ func getSavedFilterName(index int) string { return getPrefixedStringValue("savedFilter", index, "Name") } -func createSavedFilters(qb models.SavedFilterReaderWriter, n int) error { +func createSavedFilters(ctx context.Context, qb models.SavedFilterReaderWriter, n int) error { for i := 0; i < n; i++ { savedFilter := models.SavedFilter{ Mode: getSavedFilterMode(i), @@ -1115,7 +1118,7 @@ func createSavedFilters(qb models.SavedFilterReaderWriter, n int) error { Filter: getPrefixedStringValue("savedFilter", i, "Filter"), } - created, err := qb.Create(savedFilter) + created, err := qb.Create(ctx, savedFilter) if err != nil { return fmt.Errorf("Error creating saved filter %v+: %s", savedFilter, err.Error()) @@ -1137,25 +1140,25 @@ func doLinks(links [][2]int, fn func(idx1, idx2 int) error) error { return nil } -func linkPerformerTags(qb models.PerformerReaderWriter) error { +func linkPerformerTags(ctx context.Context, qb models.PerformerReaderWriter) error { return doLinks(performerTagLinks, func(performerIndex, tagIndex int) error { performerID := performerIDs[performerIndex] tagID := tagIDs[tagIndex] - tagIDs, err := qb.GetTagIDs(performerID) + tagIDs, err := qb.GetTagIDs(ctx, performerID) if err != nil { return err } tagIDs = intslice.IntAppendUnique(tagIDs, tagID) - return qb.UpdateTags(performerID, tagIDs) + return qb.UpdateTags(ctx, performerID, tagIDs) }) } -func linkSceneMovies(qb models.SceneReaderWriter) error { +func linkSceneMovies(ctx context.Context, qb models.SceneReaderWriter) error { return doLinks(sceneMovieLinks, func(sceneIndex, movieIndex int) error { sceneID := sceneIDs[sceneIndex] - movies, err := qb.GetMovies(sceneID) + movies, err := qb.GetMovies(ctx, sceneID) if err != nil { return err } @@ -1164,157 +1167,157 @@ func linkSceneMovies(qb models.SceneReaderWriter) error { MovieID: movieIDs[movieIndex], SceneID: sceneID, }) - return qb.UpdateMovies(sceneID, movies) + return qb.UpdateMovies(ctx, sceneID, movies) }) } -func linkScenePerformers(qb models.SceneReaderWriter) error { +func linkScenePerformers(ctx context.Context, qb models.SceneReaderWriter) error { return doLinks(scenePerformerLinks, func(sceneIndex, performerIndex int) error { - _, err := scene.AddPerformer(qb, sceneIDs[sceneIndex], performerIDs[performerIndex]) + _, err := scene.AddPerformer(ctx, qb, sceneIDs[sceneIndex], performerIDs[performerIndex]) return err }) } -func linkSceneGalleries(qb models.SceneReaderWriter) error { +func linkSceneGalleries(ctx context.Context, qb models.SceneReaderWriter) error { return doLinks(sceneGalleryLinks, func(sceneIndex, galleryIndex int) error { - _, err := scene.AddGallery(qb, sceneIDs[sceneIndex], galleryIDs[galleryIndex]) + _, err := scene.AddGallery(ctx, qb, sceneIDs[sceneIndex], galleryIDs[galleryIndex]) return err }) } -func linkSceneTags(qb models.SceneReaderWriter) error { +func linkSceneTags(ctx context.Context, qb models.SceneReaderWriter) error { return doLinks(sceneTagLinks, func(sceneIndex, tagIndex int) error { - _, err := scene.AddTag(qb, sceneIDs[sceneIndex], tagIDs[tagIndex]) + _, err := scene.AddTag(ctx, qb, sceneIDs[sceneIndex], tagIDs[tagIndex]) return err }) } -func linkSceneStudios(sqb models.SceneWriter) error { +func linkSceneStudios(ctx context.Context, sqb models.SceneWriter) error { return doLinks(sceneStudioLinks, func(sceneIndex, studioIndex int) error { scene := models.ScenePartial{ ID: sceneIDs[sceneIndex], StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true}, } - _, err := sqb.Update(scene) + _, err := sqb.Update(ctx, scene) return err }) } -func linkImageGalleries(gqb models.GalleryReaderWriter) error { +func linkImageGalleries(ctx context.Context, gqb models.GalleryReaderWriter) error { return doLinks(imageGalleryLinks, func(imageIndex, galleryIndex int) error { - return gallery.AddImage(gqb, galleryIDs[galleryIndex], imageIDs[imageIndex]) + return gallery.AddImage(ctx, gqb, galleryIDs[galleryIndex], imageIDs[imageIndex]) }) } -func linkImageTags(iqb models.ImageReaderWriter) error { +func linkImageTags(ctx context.Context, iqb models.ImageReaderWriter) error { return doLinks(imageTagLinks, func(imageIndex, tagIndex int) error { imageID := imageIDs[imageIndex] - tags, err := iqb.GetTagIDs(imageID) + tags, err := iqb.GetTagIDs(ctx, imageID) if err != nil { return err } tags = append(tags, tagIDs[tagIndex]) - return iqb.UpdateTags(imageID, tags) + return iqb.UpdateTags(ctx, imageID, tags) }) } -func linkImageStudios(qb models.ImageWriter) error { +func linkImageStudios(ctx context.Context, qb models.ImageWriter) error { return doLinks(imageStudioLinks, func(imageIndex, studioIndex int) error { image := models.ImagePartial{ ID: imageIDs[imageIndex], StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true}, } - _, err := qb.Update(image) + _, err := qb.Update(ctx, image) return err }) } -func linkImagePerformers(qb models.ImageReaderWriter) error { +func linkImagePerformers(ctx context.Context, qb models.ImageReaderWriter) error { return doLinks(imagePerformerLinks, func(imageIndex, performerIndex int) error { imageID := imageIDs[imageIndex] - performers, err := qb.GetPerformerIDs(imageID) + performers, err := qb.GetPerformerIDs(ctx, imageID) if err != nil { return err } performers = append(performers, performerIDs[performerIndex]) - return qb.UpdatePerformers(imageID, performers) + return qb.UpdatePerformers(ctx, imageID, performers) }) } -func linkGalleryPerformers(qb models.GalleryReaderWriter) error { +func linkGalleryPerformers(ctx context.Context, qb models.GalleryReaderWriter) error { return doLinks(galleryPerformerLinks, func(galleryIndex, performerIndex int) error { galleryID := galleryIDs[galleryIndex] - performers, err := qb.GetPerformerIDs(galleryID) + performers, err := qb.GetPerformerIDs(ctx, galleryID) if err != nil { return err } performers = append(performers, performerIDs[performerIndex]) - return qb.UpdatePerformers(galleryID, performers) + return qb.UpdatePerformers(ctx, galleryID, performers) }) } -func linkGalleryStudios(qb models.GalleryReaderWriter) error { +func linkGalleryStudios(ctx context.Context, qb models.GalleryReaderWriter) error { return doLinks(galleryStudioLinks, func(galleryIndex, studioIndex int) error { gallery := models.GalleryPartial{ ID: galleryIDs[galleryIndex], StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true}, } - _, err := qb.UpdatePartial(gallery) + _, err := qb.UpdatePartial(ctx, gallery) return err }) } -func linkGalleryTags(qb models.GalleryReaderWriter) error { +func linkGalleryTags(ctx context.Context, qb models.GalleryReaderWriter) error { return doLinks(galleryTagLinks, func(galleryIndex, tagIndex int) error { galleryID := galleryIDs[galleryIndex] - tags, err := qb.GetTagIDs(galleryID) + tags, err := qb.GetTagIDs(ctx, galleryID) if err != nil { return err } tags = append(tags, tagIDs[tagIndex]) - return qb.UpdateTags(galleryID, tags) + return qb.UpdateTags(ctx, galleryID, tags) }) } -func linkMovieStudios(mqb models.MovieWriter) error { +func linkMovieStudios(ctx context.Context, mqb models.MovieWriter) error { return doLinks(movieStudioLinks, func(movieIndex, studioIndex int) error { movie := models.MoviePartial{ ID: movieIDs[movieIndex], StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true}, } - _, err := mqb.Update(movie) + _, err := mqb.Update(ctx, movie) return err }) } -func linkStudiosParent(qb models.StudioWriter) error { +func linkStudiosParent(ctx context.Context, qb models.StudioWriter) error { return doLinks(studioParentLinks, func(parentIndex, childIndex int) error { studio := models.StudioPartial{ ID: studioIDs[childIndex], ParentID: &sql.NullInt64{Int64: int64(studioIDs[parentIndex]), Valid: true}, } - _, err := qb.Update(studio) + _, err := qb.Update(ctx, studio) return err }) } -func linkTagsParent(qb models.TagReaderWriter) error { +func linkTagsParent(ctx context.Context, qb models.TagReaderWriter) error { return doLinks(tagParentLinks, func(parentIndex, childIndex int) error { tagID := tagIDs[childIndex] - parentTags, err := qb.FindByChildTagID(tagID) + parentTags, err := qb.FindByChildTagID(ctx, tagID) if err != nil { return err } @@ -1326,10 +1329,10 @@ func linkTagsParent(qb models.TagReaderWriter) error { parentIDs = append(parentIDs, tagIDs[parentIndex]) - return qb.UpdateParentTags(tagID, parentIDs) + return qb.UpdateParentTags(ctx, tagID, parentIDs) }) } -func addTagImage(qb models.TagWriter, tagIndex int) error { - return qb.UpdateImage(tagIDs[tagIndex], models.DefaultTagImage) +func addTagImage(ctx context.Context, qb models.TagWriter, tagIndex int) error { + return qb.UpdateImage(ctx, tagIDs[tagIndex], models.DefaultTagImage) } diff --git a/pkg/sqlite/sql.go b/pkg/sqlite/sql.go index a83612b0b..353e1a130 100644 --- a/pkg/sqlite/sql.go +++ b/pkg/sqlite/sql.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -225,8 +226,8 @@ func getCountCriterionClause(primaryTable, joinTable, primaryFK string, criterio return getIntCriterionWhereClause(lhs, criterion) } -func getImage(tx dbi, query string, args ...interface{}) ([]byte, error) { - rows, err := tx.Queryx(query, args...) +func getImage(ctx context.Context, tx dbi, query string, args ...interface{}) ([]byte, error) { + rows, err := tx.Queryx(ctx, query, args...) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, err diff --git a/pkg/sqlite/stash_id_test.go b/pkg/sqlite/stash_id_test.go index 0f57bef19..44d08efbc 100644 --- a/pkg/sqlite/stash_id_test.go +++ b/pkg/sqlite/stash_id_test.go @@ -4,6 +4,7 @@ package sqlite_test import ( + "context" "testing" "github.com/stashapp/stash/pkg/models" @@ -11,16 +12,16 @@ import ( ) type stashIDReaderWriter interface { - GetStashIDs(performerID int) ([]*models.StashID, error) - UpdateStashIDs(performerID int, stashIDs []models.StashID) error + GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error) + UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error } -func testStashIDReaderWriter(t *testing.T, r stashIDReaderWriter, id int) { +func testStashIDReaderWriter(ctx context.Context, t *testing.T, r stashIDReaderWriter, id int) { // ensure no stash IDs to begin with - testNoStashIDs(t, r, id) + testNoStashIDs(ctx, t, r, id) // ensure GetStashIDs with non-existing also returns none - testNoStashIDs(t, r, -1) + testNoStashIDs(ctx, t, r, -1) // add stash ids const stashIDStr = "stashID" @@ -31,28 +32,28 @@ func testStashIDReaderWriter(t *testing.T, r stashIDReaderWriter, id int) { } // update stash ids and ensure was updated - if err := r.UpdateStashIDs(id, []models.StashID{stashID}); err != nil { + if err := r.UpdateStashIDs(ctx, id, []models.StashID{stashID}); err != nil { t.Error(err.Error()) } - testStashIDs(t, r, id, []*models.StashID{&stashID}) + testStashIDs(ctx, t, r, id, []*models.StashID{&stashID}) // update non-existing id - should return error - if err := r.UpdateStashIDs(-1, []models.StashID{stashID}); err == nil { + if err := r.UpdateStashIDs(ctx, -1, []models.StashID{stashID}); err == nil { t.Error("expected error when updating non-existing id") } // remove stash ids and ensure was updated - if err := r.UpdateStashIDs(id, []models.StashID{}); err != nil { + if err := r.UpdateStashIDs(ctx, id, []models.StashID{}); err != nil { t.Error(err.Error()) } - testNoStashIDs(t, r, id) + testNoStashIDs(ctx, t, r, id) } -func testNoStashIDs(t *testing.T, r stashIDReaderWriter, id int) { +func testNoStashIDs(ctx context.Context, t *testing.T, r stashIDReaderWriter, id int) { t.Helper() - stashIDs, err := r.GetStashIDs(id) + stashIDs, err := r.GetStashIDs(ctx, id) if err != nil { t.Error(err.Error()) return @@ -61,9 +62,9 @@ func testNoStashIDs(t *testing.T, r stashIDReaderWriter, id int) { assert.Len(t, stashIDs, 0) } -func testStashIDs(t *testing.T, r stashIDReaderWriter, id int, expected []*models.StashID) { +func testStashIDs(ctx context.Context, t *testing.T, r stashIDReaderWriter, id int, expected []*models.StashID) { t.Helper() - stashIDs, err := r.GetStashIDs(id) + stashIDs, err := r.GetStashIDs(ctx, id) if err != nil { t.Error(err.Error()) return diff --git a/pkg/sqlite/studio.go b/pkg/sqlite/studio.go index cc810fe0c..966257a1c 100644 --- a/pkg/sqlite/studio.go +++ b/pkg/sqlite/studio.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -18,57 +19,54 @@ type studioQueryBuilder struct { repository } -func NewStudioReaderWriter(tx dbi) *studioQueryBuilder { - return &studioQueryBuilder{ - repository{ - tx: tx, - tableName: studioTable, - idColumn: idColumn, - }, - } +var StudioReaderWriter = &studioQueryBuilder{ + repository{ + tableName: studioTable, + idColumn: idColumn, + }, } -func (qb *studioQueryBuilder) Create(newObject models.Studio) (*models.Studio, error) { +func (qb *studioQueryBuilder) Create(ctx context.Context, newObject models.Studio) (*models.Studio, error) { var ret models.Studio - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *studioQueryBuilder) Update(updatedObject models.StudioPartial) (*models.Studio, error) { +func (qb *studioQueryBuilder) Update(ctx context.Context, updatedObject models.StudioPartial) (*models.Studio, error) { const partial = true - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedObject.ID) + return qb.Find(ctx, updatedObject.ID) } -func (qb *studioQueryBuilder) UpdateFull(updatedObject models.Studio) (*models.Studio, error) { +func (qb *studioQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Studio) (*models.Studio, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedObject.ID) + return qb.Find(ctx, updatedObject.ID) } -func (qb *studioQueryBuilder) Destroy(id int) error { +func (qb *studioQueryBuilder) Destroy(ctx context.Context, id int) error { // TODO - set null on foreign key in scraped items // remove studio from scraped items - _, err := qb.tx.Exec("UPDATE scraped_items SET studio_id = null WHERE studio_id = ?", id) + _, err := qb.tx.Exec(ctx, "UPDATE scraped_items SET studio_id = null WHERE studio_id = ?", id) if err != nil { return err } - return qb.destroyExisting([]int{id}) + return qb.destroyExisting(ctx, []int{id}) } -func (qb *studioQueryBuilder) Find(id int) (*models.Studio, error) { +func (qb *studioQueryBuilder) Find(ctx context.Context, id int) (*models.Studio, error) { var ret models.Studio - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -77,10 +75,10 @@ func (qb *studioQueryBuilder) Find(id int) (*models.Studio, error) { return &ret, nil } -func (qb *studioQueryBuilder) FindMany(ids []int) ([]*models.Studio, error) { +func (qb *studioQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Studio, error) { var studios []*models.Studio for _, id := range ids { - studio, err := qb.Find(id) + studio, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -95,47 +93,47 @@ func (qb *studioQueryBuilder) FindMany(ids []int) ([]*models.Studio, error) { return studios, nil } -func (qb *studioQueryBuilder) FindChildren(id int) ([]*models.Studio, error) { +func (qb *studioQueryBuilder) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) { query := "SELECT studios.* FROM studios WHERE studios.parent_id = ?" args := []interface{}{id} - return qb.queryStudios(query, args) + return qb.queryStudios(ctx, query, args) } -func (qb *studioQueryBuilder) FindBySceneID(sceneID int) (*models.Studio, error) { +func (qb *studioQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) (*models.Studio, error) { query := "SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1" args := []interface{}{sceneID} - return qb.queryStudio(query, args) + return qb.queryStudio(ctx, query, args) } -func (qb *studioQueryBuilder) FindByName(name string, nocase bool) (*models.Studio, error) { +func (qb *studioQueryBuilder) FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error) { query := "SELECT * FROM studios WHERE name = ?" if nocase { query += " COLLATE NOCASE" } query += " LIMIT 1" args := []interface{}{name} - return qb.queryStudio(query, args) + return qb.queryStudio(ctx, query, args) } -func (qb *studioQueryBuilder) FindByStashID(stashID models.StashID) ([]*models.Studio, error) { +func (qb *studioQueryBuilder) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) { query := selectAll("studios") + ` LEFT JOIN studio_stash_ids on studio_stash_ids.studio_id = studios.id WHERE studio_stash_ids.stash_id = ? AND studio_stash_ids.endpoint = ? ` args := []interface{}{stashID.StashID, stashID.Endpoint} - return qb.queryStudios(query, args) + return qb.queryStudios(ctx, query, args) } -func (qb *studioQueryBuilder) Count() (int, error) { - return qb.runCountQuery(qb.buildCountQuery("SELECT studios.id FROM studios"), nil) +func (qb *studioQueryBuilder) Count(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT studios.id FROM studios"), nil) } -func (qb *studioQueryBuilder) All() ([]*models.Studio, error) { - return qb.queryStudios(selectAll("studios")+qb.getStudioSort(nil), nil) +func (qb *studioQueryBuilder) All(ctx context.Context) ([]*models.Studio, error) { + return qb.queryStudios(ctx, selectAll("studios")+qb.getStudioSort(nil), nil) } -func (qb *studioQueryBuilder) QueryForAutoTag(words []string) ([]*models.Studio, error) { +func (qb *studioQueryBuilder) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) { // TODO - Query needs to be changed to support queries of this type, and // this method should be removed query := selectAll(studioTable) @@ -159,7 +157,7 @@ func (qb *studioQueryBuilder) QueryForAutoTag(words []string) ([]*models.Studio, "studios.ignore_auto_tag = 0", whereOr, }, " AND ") - return qb.queryStudios(query+" WHERE "+where, args) + return qb.queryStudios(ctx, query+" WHERE "+where, args) } func (qb *studioQueryBuilder) validateFilter(filter *models.StudioFilterType) error { @@ -193,43 +191,43 @@ func (qb *studioQueryBuilder) validateFilter(filter *models.StudioFilterType) er return nil } -func (qb *studioQueryBuilder) makeFilter(studioFilter *models.StudioFilterType) *filterBuilder { +func (qb *studioQueryBuilder) makeFilter(ctx context.Context, studioFilter *models.StudioFilterType) *filterBuilder { query := &filterBuilder{} if studioFilter.And != nil { - query.and(qb.makeFilter(studioFilter.And)) + query.and(qb.makeFilter(ctx, studioFilter.And)) } if studioFilter.Or != nil { - query.or(qb.makeFilter(studioFilter.Or)) + query.or(qb.makeFilter(ctx, studioFilter.Or)) } if studioFilter.Not != nil { - query.not(qb.makeFilter(studioFilter.Not)) + query.not(qb.makeFilter(ctx, studioFilter.Not)) } - query.handleCriterion(stringCriterionHandler(studioFilter.Name, studioTable+".name")) - query.handleCriterion(stringCriterionHandler(studioFilter.Details, studioTable+".details")) - query.handleCriterion(stringCriterionHandler(studioFilter.URL, studioTable+".url")) - query.handleCriterion(intCriterionHandler(studioFilter.Rating, studioTable+".rating")) - query.handleCriterion(boolCriterionHandler(studioFilter.IgnoreAutoTag, studioTable+".ignore_auto_tag")) + query.handleCriterion(ctx, stringCriterionHandler(studioFilter.Name, studioTable+".name")) + query.handleCriterion(ctx, stringCriterionHandler(studioFilter.Details, studioTable+".details")) + query.handleCriterion(ctx, stringCriterionHandler(studioFilter.URL, studioTable+".url")) + query.handleCriterion(ctx, intCriterionHandler(studioFilter.Rating, studioTable+".rating")) + query.handleCriterion(ctx, boolCriterionHandler(studioFilter.IgnoreAutoTag, studioTable+".ignore_auto_tag")) - query.handleCriterion(criterionHandlerFunc(func(f *filterBuilder) { + query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if studioFilter.StashID != nil { qb.stashIDRepository().join(f, "studio_stash_ids", "studios.id") - stringCriterionHandler(studioFilter.StashID, "studio_stash_ids.stash_id")(f) + stringCriterionHandler(studioFilter.StashID, "studio_stash_ids.stash_id")(ctx, f) } })) - query.handleCriterion(studioIsMissingCriterionHandler(qb, studioFilter.IsMissing)) - query.handleCriterion(studioSceneCountCriterionHandler(qb, studioFilter.SceneCount)) - query.handleCriterion(studioImageCountCriterionHandler(qb, studioFilter.ImageCount)) - query.handleCriterion(studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount)) - query.handleCriterion(studioParentCriterionHandler(qb, studioFilter.Parents)) - query.handleCriterion(studioAliasCriterionHandler(qb, studioFilter.Aliases)) + query.handleCriterion(ctx, studioIsMissingCriterionHandler(qb, studioFilter.IsMissing)) + query.handleCriterion(ctx, studioSceneCountCriterionHandler(qb, studioFilter.SceneCount)) + query.handleCriterion(ctx, studioImageCountCriterionHandler(qb, studioFilter.ImageCount)) + query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount)) + query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents)) + query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases)) return query } -func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) { +func (qb *studioQueryBuilder) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) { if studioFilter == nil { studioFilter = &models.StudioFilterType{} } @@ -250,19 +248,19 @@ func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findF if err := qb.validateFilter(studioFilter); err != nil { return nil, 0, err } - filter := qb.makeFilter(studioFilter) + filter := qb.makeFilter(ctx, studioFilter) query.addFilter(filter) query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter) - idsResult, countResult, err := query.executeFind() + idsResult, countResult, err := query.executeFind(ctx) if err != nil { return nil, 0, err } var studios []*models.Studio for _, id := range idsResult { - studio, err := qb.Find(id) + studio, err := qb.Find(ctx, id) if err != nil { return nil, 0, err } @@ -274,7 +272,7 @@ func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findF } func studioIsMissingCriterionHandler(qb *studioQueryBuilder, isMissing *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if isMissing != nil && *isMissing != "" { switch *isMissing { case "image": @@ -291,7 +289,7 @@ func studioIsMissingCriterionHandler(qb *studioQueryBuilder, isMissing *string) } func studioSceneCountCriterionHandler(qb *studioQueryBuilder, sceneCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if sceneCount != nil { f.addLeftJoin("scenes", "", "scenes.studio_id = studios.id") clause, args := getIntCriterionWhereClause("count(distinct scenes.id)", *sceneCount) @@ -302,7 +300,7 @@ func studioSceneCountCriterionHandler(qb *studioQueryBuilder, sceneCount *models } func studioImageCountCriterionHandler(qb *studioQueryBuilder, imageCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if imageCount != nil { f.addLeftJoin("images", "", "images.studio_id = studios.id") clause, args := getIntCriterionWhereClause("count(distinct images.id)", *imageCount) @@ -313,7 +311,7 @@ func studioImageCountCriterionHandler(qb *studioQueryBuilder, imageCount *models } func studioGalleryCountCriterionHandler(qb *studioQueryBuilder, galleryCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if galleryCount != nil { f.addLeftJoin("galleries", "", "galleries.studio_id = studios.id") clause, args := getIntCriterionWhereClause("count(distinct galleries.id)", *galleryCount) @@ -373,17 +371,17 @@ func (qb *studioQueryBuilder) getStudioSort(findFilter *models.FindFilterType) s } } -func (qb *studioQueryBuilder) queryStudio(query string, args []interface{}) (*models.Studio, error) { - results, err := qb.queryStudios(query, args) +func (qb *studioQueryBuilder) queryStudio(ctx context.Context, query string, args []interface{}) (*models.Studio, error) { + results, err := qb.queryStudios(ctx, query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *studioQueryBuilder) queryStudios(query string, args []interface{}) ([]*models.Studio, error) { +func (qb *studioQueryBuilder) queryStudios(ctx context.Context, query string, args []interface{}) ([]*models.Studio, error) { var ret models.Studios - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } @@ -401,20 +399,20 @@ func (qb *studioQueryBuilder) imageRepository() *imageRepository { } } -func (qb *studioQueryBuilder) GetImage(studioID int) ([]byte, error) { - return qb.imageRepository().get(studioID) +func (qb *studioQueryBuilder) GetImage(ctx context.Context, studioID int) ([]byte, error) { + return qb.imageRepository().get(ctx, studioID) } -func (qb *studioQueryBuilder) HasImage(studioID int) (bool, error) { - return qb.imageRepository().exists(studioID) +func (qb *studioQueryBuilder) HasImage(ctx context.Context, studioID int) (bool, error) { + return qb.imageRepository().exists(ctx, studioID) } -func (qb *studioQueryBuilder) UpdateImage(studioID int, image []byte) error { - return qb.imageRepository().replace(studioID, image) +func (qb *studioQueryBuilder) UpdateImage(ctx context.Context, studioID int, image []byte) error { + return qb.imageRepository().replace(ctx, studioID, image) } -func (qb *studioQueryBuilder) DestroyImage(studioID int) error { - return qb.imageRepository().destroy([]int{studioID}) +func (qb *studioQueryBuilder) DestroyImage(ctx context.Context, studioID int) error { + return qb.imageRepository().destroy(ctx, []int{studioID}) } func (qb *studioQueryBuilder) stashIDRepository() *stashIDRepository { @@ -427,12 +425,12 @@ func (qb *studioQueryBuilder) stashIDRepository() *stashIDRepository { } } -func (qb *studioQueryBuilder) GetStashIDs(studioID int) ([]*models.StashID, error) { - return qb.stashIDRepository().get(studioID) +func (qb *studioQueryBuilder) GetStashIDs(ctx context.Context, studioID int) ([]*models.StashID, error) { + return qb.stashIDRepository().get(ctx, studioID) } -func (qb *studioQueryBuilder) UpdateStashIDs(studioID int, stashIDs []models.StashID) error { - return qb.stashIDRepository().replace(studioID, stashIDs) +func (qb *studioQueryBuilder) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error { + return qb.stashIDRepository().replace(ctx, studioID, stashIDs) } func (qb *studioQueryBuilder) aliasRepository() *stringRepository { @@ -446,10 +444,10 @@ func (qb *studioQueryBuilder) aliasRepository() *stringRepository { } } -func (qb *studioQueryBuilder) GetAliases(studioID int) ([]string, error) { - return qb.aliasRepository().get(studioID) +func (qb *studioQueryBuilder) GetAliases(ctx context.Context, studioID int) ([]string, error) { + return qb.aliasRepository().get(ctx, studioID) } -func (qb *studioQueryBuilder) UpdateAliases(studioID int, aliases []string) error { - return qb.aliasRepository().replace(studioID, aliases) +func (qb *studioQueryBuilder) UpdateAliases(ctx context.Context, studioID int, aliases []string) error { + return qb.aliasRepository().replace(ctx, studioID, aliases) } diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go index 08e6a30da..28b162328 100644 --- a/pkg/sqlite/studio_test.go +++ b/pkg/sqlite/studio_test.go @@ -4,6 +4,7 @@ package sqlite_test import ( + "context" "database/sql" "errors" "fmt" @@ -13,16 +14,17 @@ import ( "testing" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" "github.com/stretchr/testify/assert" ) func TestStudioFindByName(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter name := studioNames[studioIdxWithScene] // find a studio by name - studio, err := sqb.FindByName(name, false) + studio, err := sqb.FindByName(ctx, name, false) if err != nil { t.Errorf("Error finding studios: %s", err.Error()) @@ -32,7 +34,7 @@ func TestStudioFindByName(t *testing.T) { name = studioNames[studioIdxWithDupName] // find a studio by name nocase - studio, err = sqb.FindByName(name, true) + studio, err = sqb.FindByName(ctx, name, true) if err != nil { t.Errorf("Error finding studios: %s", err.Error()) @@ -67,10 +69,10 @@ func TestStudioQueryNameOr(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter - studios := queryStudio(t, sqb, &studioFilter, nil) + studios := queryStudio(ctx, t, sqb, &studioFilter, nil) assert.Len(t, studios, 2) assert.Equal(t, studio1Name, studios[0].Name.String) @@ -98,10 +100,10 @@ func TestStudioQueryNameAndUrl(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter - studios := queryStudio(t, sqb, &studioFilter, nil) + studios := queryStudio(ctx, t, sqb, &studioFilter, nil) assert.Len(t, studios, 1) assert.Equal(t, studioName, studios[0].Name.String) @@ -133,10 +135,10 @@ func TestStudioQueryNameNotUrl(t *testing.T) { }, } - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter - studios := queryStudio(t, sqb, &studioFilter, nil) + studios := queryStudio(ctx, t, sqb, &studioFilter, nil) for _, studio := range studios { verifyString(t, studio.Name.String, nameCriterion) @@ -164,20 +166,20 @@ func TestStudioIllegalQuery(t *testing.T) { Or: &subFilter, } - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter - _, _, err := sqb.Query(studioFilter, nil) + _, _, err := sqb.Query(ctx, studioFilter, nil) assert.NotNil(err) studioFilter.Or = nil studioFilter.Not = &subFilter - _, _, err = sqb.Query(studioFilter, nil) + _, _, err = sqb.Query(ctx, studioFilter, nil) assert.NotNil(err) studioFilter.And = nil studioFilter.Or = &subFilter - _, _, err = sqb.Query(studioFilter, nil) + _, _, err = sqb.Query(ctx, studioFilter, nil) assert.NotNil(err) return nil @@ -185,15 +187,15 @@ func TestStudioIllegalQuery(t *testing.T) { } func TestStudioQueryIgnoreAutoTag(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { ignoreAutoTag := true studioFilter := models.StudioFilterType{ IgnoreAutoTag: &ignoreAutoTag, } - sqb := r.Studio() + sqb := sqlite.StudioReaderWriter - studios := queryStudio(t, sqb, &studioFilter, nil) + studios := queryStudio(ctx, t, sqb, &studioFilter, nil) assert.Len(t, studios, int(math.Ceil(float64(totalStudios)/5))) for _, s := range studios { @@ -205,12 +207,12 @@ func TestStudioQueryIgnoreAutoTag(t *testing.T) { } func TestStudioQueryForAutoTag(t *testing.T) { - withTxn(func(r models.Repository) error { - tqb := r.Studio() + withTxn(func(ctx context.Context) error { + tqb := sqlite.StudioReaderWriter name := studioNames[studioIdxWithMovie] // find a studio by name - studios, err := tqb.QueryForAutoTag([]string{name}) + studios, err := tqb.QueryForAutoTag(ctx, []string{name}) if err != nil { t.Errorf("Error finding studios: %s", err.Error()) @@ -221,7 +223,7 @@ func TestStudioQueryForAutoTag(t *testing.T) { // find by alias name = getStudioStringValue(studioIdxWithMovie, "Alias") - studios, err = tqb.QueryForAutoTag([]string{name}) + studios, err = tqb.QueryForAutoTag(ctx, []string{name}) if err != nil { t.Errorf("Error finding studios: %s", err.Error()) @@ -235,8 +237,8 @@ func TestStudioQueryForAutoTag(t *testing.T) { } func TestStudioQueryParent(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter studioCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[studioIdxWithChildStudio]), @@ -248,7 +250,7 @@ func TestStudioQueryParent(t *testing.T) { Parents: &studioCriterion, } - studios, _, err := sqb.Query(&studioFilter, nil) + studios, _, err := sqb.Query(ctx, &studioFilter, nil) if err != nil { t.Errorf("Error querying studio: %s", err.Error()) } @@ -270,7 +272,7 @@ func TestStudioQueryParent(t *testing.T) { Q: &q, } - studios, _, err = sqb.Query(&studioFilter, &findFilter) + studios, _, err = sqb.Query(ctx, &studioFilter, &findFilter) if err != nil { t.Errorf("Error querying studio: %s", err.Error()) } @@ -285,28 +287,28 @@ func TestStudioDestroyParent(t *testing.T) { const childName = "child" // create parent and child studios - if err := withTxn(func(r models.Repository) error { - createdParent, err := createStudio(r.Studio(), parentName, nil) + if err := withTxn(func(ctx context.Context) error { + createdParent, err := createStudio(ctx, sqlite.StudioReaderWriter, parentName, nil) if err != nil { return fmt.Errorf("Error creating parent studio: %s", err.Error()) } parentID := int64(createdParent.ID) - createdChild, err := createStudio(r.Studio(), childName, &parentID) + createdChild, err := createStudio(ctx, sqlite.StudioReaderWriter, childName, &parentID) if err != nil { return fmt.Errorf("Error creating child studio: %s", err.Error()) } - sqb := r.Studio() + sqb := sqlite.StudioReaderWriter // destroy the parent - err = sqb.Destroy(createdParent.ID) + err = sqb.Destroy(ctx, createdParent.ID) if err != nil { return fmt.Errorf("Error destroying parent studio: %s", err.Error()) } // destroy the child - err = sqb.Destroy(createdChild.ID) + err = sqb.Destroy(ctx, createdChild.ID) if err != nil { return fmt.Errorf("Error destroying child studio: %s", err.Error()) } @@ -318,10 +320,10 @@ func TestStudioDestroyParent(t *testing.T) { } func TestStudioFindChildren(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter - studios, err := sqb.FindChildren(studioIDs[studioIdxWithChildStudio]) + studios, err := sqb.FindChildren(ctx, studioIDs[studioIdxWithChildStudio]) if err != nil { t.Errorf("error calling FindChildren: %s", err.Error()) @@ -330,7 +332,7 @@ func TestStudioFindChildren(t *testing.T) { assert.Len(t, studios, 1) assert.Equal(t, studioIDs[studioIdxWithParentStudio], studios[0].ID) - studios, err = sqb.FindChildren(0) + studios, err = sqb.FindChildren(ctx, 0) if err != nil { t.Errorf("error calling FindChildren: %s", err.Error()) @@ -347,19 +349,19 @@ func TestStudioUpdateClearParent(t *testing.T) { const childName = "clearParent_child" // create parent and child studios - if err := withTxn(func(r models.Repository) error { - createdParent, err := createStudio(r.Studio(), parentName, nil) + if err := withTxn(func(ctx context.Context) error { + createdParent, err := createStudio(ctx, sqlite.StudioReaderWriter, parentName, nil) if err != nil { return fmt.Errorf("Error creating parent studio: %s", err.Error()) } parentID := int64(createdParent.ID) - createdChild, err := createStudio(r.Studio(), childName, &parentID) + createdChild, err := createStudio(ctx, sqlite.StudioReaderWriter, childName, &parentID) if err != nil { return fmt.Errorf("Error creating child studio: %s", err.Error()) } - sqb := r.Studio() + sqb := sqlite.StudioReaderWriter // clear the parent id from the child updatePartial := models.StudioPartial{ @@ -367,7 +369,7 @@ func TestStudioUpdateClearParent(t *testing.T) { ParentID: &sql.NullInt64{Valid: false}, } - updatedStudio, err := sqb.Update(updatePartial) + updatedStudio, err := sqb.Update(ctx, updatePartial) if err != nil { return fmt.Errorf("Error updated studio: %s", err.Error()) @@ -384,31 +386,31 @@ func TestStudioUpdateClearParent(t *testing.T) { } func TestStudioUpdateStudioImage(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Studio() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.StudioReaderWriter // create performer to test against const name = "TestStudioUpdateStudioImage" - created, err := createStudio(r.Studio(), name, nil) + created, err := createStudio(ctx, sqlite.StudioReaderWriter, name, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } image := []byte("image") - err = qb.UpdateImage(created.ID, image) + err = qb.UpdateImage(ctx, created.ID, image) if err != nil { return fmt.Errorf("Error updating studio image: %s", err.Error()) } // ensure image set - storedImage, err := qb.GetImage(created.ID) + storedImage, err := qb.GetImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } assert.Equal(t, storedImage, image) // set nil image - err = qb.UpdateImage(created.ID, nil) + err = qb.UpdateImage(ctx, created.ID, nil) if err == nil { return fmt.Errorf("Expected error setting nil image") } @@ -420,29 +422,29 @@ func TestStudioUpdateStudioImage(t *testing.T) { } func TestStudioDestroyStudioImage(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Studio() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.StudioReaderWriter // create performer to test against const name = "TestStudioDestroyStudioImage" - created, err := createStudio(r.Studio(), name, nil) + created, err := createStudio(ctx, sqlite.StudioReaderWriter, name, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } image := []byte("image") - err = qb.UpdateImage(created.ID, image) + err = qb.UpdateImage(ctx, created.ID, image) if err != nil { return fmt.Errorf("Error updating studio image: %s", err.Error()) } - err = qb.DestroyImage(created.ID) + err = qb.DestroyImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error destroying studio image: %s", err.Error()) } // image should be nil - storedImage, err := qb.GetImage(created.ID) + storedImage, err := qb.GetImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } @@ -474,17 +476,17 @@ func TestStudioQuerySceneCount(t *testing.T) { } func verifyStudiosSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter studioFilter := models.StudioFilterType{ SceneCount: &sceneCountCriterion, } - studios := queryStudio(t, sqb, &studioFilter, nil) + studios := queryStudio(ctx, t, sqb, &studioFilter, nil) assert.Greater(t, len(studios), 0) for _, studio := range studios { - sceneCount, err := r.Scene().CountByStudioID(studio.ID) + sceneCount, err := sqlite.SceneReaderWriter.CountByStudioID(ctx, studio.ID) if err != nil { return err } @@ -515,19 +517,19 @@ func TestStudioQueryImageCount(t *testing.T) { } func verifyStudiosImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter studioFilter := models.StudioFilterType{ ImageCount: &imageCountCriterion, } - studios := queryStudio(t, sqb, &studioFilter, nil) + studios := queryStudio(ctx, t, sqb, &studioFilter, nil) assert.Greater(t, len(studios), 0) for _, studio := range studios { pp := 0 - result, err := r.Image().Query(models.ImageQueryOptions{ + result, err := sqlite.ImageReaderWriter.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: &models.FindFilterType{ PerPage: &pp, @@ -571,19 +573,19 @@ func TestStudioQueryGalleryCount(t *testing.T) { } func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter studioFilter := models.StudioFilterType{ GalleryCount: &galleryCountCriterion, } - studios := queryStudio(t, sqb, &studioFilter, nil) + studios := queryStudio(ctx, t, sqb, &studioFilter, nil) assert.Greater(t, len(studios), 0) for _, studio := range studios { pp := 0 - _, count, err := r.Gallery().Query(&models.GalleryFilterType{ + _, count, err := sqlite.GalleryReaderWriter.Query(ctx, &models.GalleryFilterType{ Studios: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(studio.ID)}, Modifier: models.CriterionModifierIncludes, @@ -602,17 +604,17 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri } func TestStudioStashIDs(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Studio() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.StudioReaderWriter // create studio to test against const name = "TestStudioStashIDs" - created, err := createStudio(r.Studio(), name, nil) + created, err := createStudio(ctx, sqlite.StudioReaderWriter, name, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } - testStashIDReaderWriter(t, qb, created.ID) + testStashIDReaderWriter(ctx, t, qb, created.ID) return nil }); err != nil { t.Error(err.Error()) @@ -632,7 +634,7 @@ func TestStudioQueryURL(t *testing.T) { URL: &urlCriterion, } - verifyFn := func(g *models.Studio, r models.Repository) { + verifyFn := func(ctx context.Context, g *models.Studio) { t.Helper() verifyNullString(t, g.URL, urlCriterion) } @@ -682,18 +684,18 @@ func TestStudioQueryRating(t *testing.T) { verifyStudiosRating(t, ratingCriterion) } -func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn func(s *models.Studio, r models.Repository)) { - withTxn(func(r models.Repository) error { +func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn func(ctx context.Context, s *models.Studio)) { + withTxn(func(ctx context.Context) error { t.Helper() - sqb := r.Studio() + sqb := sqlite.StudioReaderWriter - studios := queryStudio(t, sqb, &filter, nil) + studios := queryStudio(ctx, t, sqb, &filter, nil) // assume it should find at least one assert.Greater(t, len(studios), 0) for _, studio := range studios { - verifyFn(studio, r) + verifyFn(ctx, studio) } return nil @@ -701,13 +703,13 @@ func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn fu } func verifyStudiosRating(t *testing.T, ratingCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter studioFilter := models.StudioFilterType{ Rating: &ratingCriterion, } - studios, _, err := sqb.Query(&studioFilter, nil) + studios, _, err := sqb.Query(ctx, &studioFilter, nil) if err != nil { t.Errorf("Error querying studio: %s", err.Error()) @@ -722,14 +724,14 @@ func verifyStudiosRating(t *testing.T, ratingCriterion models.IntCriterionInput) } func TestStudioQueryIsMissingRating(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter isMissing := "rating" studioFilter := models.StudioFilterType{ IsMissing: &isMissing, } - studios, _, err := sqb.Query(&studioFilter, nil) + studios, _, err := sqb.Query(ctx, &studioFilter, nil) if err != nil { t.Errorf("Error querying studio: %s", err.Error()) @@ -745,8 +747,8 @@ func TestStudioQueryIsMissingRating(t *testing.T) { }) } -func queryStudio(t *testing.T, sqb models.StudioReader, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) []*models.Studio { - studios, _, err := sqb.Query(studioFilter, findFilter) +func queryStudio(ctx context.Context, t *testing.T, sqb models.StudioReader, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) []*models.Studio { + studios, _, err := sqb.Query(ctx, studioFilter, findFilter) if err != nil { t.Errorf("Error querying studio: %s", err.Error()) } @@ -767,7 +769,7 @@ func TestStudioQueryName(t *testing.T) { Name: nameCriterion, } - verifyFn := func(studio *models.Studio, r models.Repository) { + verifyFn := func(ctx context.Context, studio *models.Studio) { verifyNullString(t, studio.Name, *nameCriterion) } @@ -797,8 +799,8 @@ func TestStudioQueryAlias(t *testing.T) { Aliases: aliasCriterion, } - verifyFn := func(studio *models.Studio, r models.Repository) { - aliases, err := r.Studio().GetAliases(studio.ID) + verifyFn := func(ctx context.Context, studio *models.Studio) { + aliases, err := sqlite.StudioReaderWriter.GetAliases(ctx, studio.ID) if err != nil { t.Errorf("Error querying studios: %s", err.Error()) } @@ -825,24 +827,24 @@ func TestStudioQueryAlias(t *testing.T) { } func TestStudioUpdateAlias(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Studio() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.StudioReaderWriter // create studio to test against const name = "TestStudioUpdateAlias" - created, err := createStudio(qb, name, nil) + created, err := createStudio(ctx, qb, name, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } aliases := []string{"alias1", "alias2"} - err = qb.UpdateAliases(created.ID, aliases) + err = qb.UpdateAliases(ctx, created.ID, aliases) if err != nil { return fmt.Errorf("Error updating studio aliases: %s", err.Error()) } // ensure aliases set - storedAliases, err := qb.GetAliases(created.ID) + storedAliases, err := qb.GetAliases(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting aliases: %s", err.Error()) } @@ -922,11 +924,11 @@ func TestStudioQueryFast(t *testing.T) { } - withTxn(func(r models.Repository) error { - sqb := r.Studio() + withTxn(func(ctx context.Context) error { + sqb := sqlite.StudioReaderWriter for _, f := range filters { for _, ff := range findFilters { - _, _, err := sqb.Query(&f, &ff) + _, _, err := sqb.Query(ctx, &f, &ff) if err != nil { t.Errorf("Error querying studio: %s", err.Error()) } diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index 9513a269b..02e853fdc 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "fmt" @@ -18,53 +19,50 @@ type tagQueryBuilder struct { repository } -func NewTagReaderWriter(tx dbi) *tagQueryBuilder { - return &tagQueryBuilder{ - repository{ - tx: tx, - tableName: tagTable, - idColumn: idColumn, - }, - } +var TagReaderWriter = &tagQueryBuilder{ + repository{ + tableName: tagTable, + idColumn: idColumn, + }, } -func (qb *tagQueryBuilder) Create(newObject models.Tag) (*models.Tag, error) { +func (qb *tagQueryBuilder) Create(ctx context.Context, newObject models.Tag) (*models.Tag, error) { var ret models.Tag - if err := qb.insertObject(newObject, &ret); err != nil { + if err := qb.insertObject(ctx, newObject, &ret); err != nil { return nil, err } return &ret, nil } -func (qb *tagQueryBuilder) Update(updatedObject models.TagPartial) (*models.Tag, error) { +func (qb *tagQueryBuilder) Update(ctx context.Context, updatedObject models.TagPartial) (*models.Tag, error) { const partial = true - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedObject.ID) + return qb.Find(ctx, updatedObject.ID) } -func (qb *tagQueryBuilder) UpdateFull(updatedObject models.Tag) (*models.Tag, error) { +func (qb *tagQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Tag) (*models.Tag, error) { const partial = false - if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedObject.ID) + return qb.Find(ctx, updatedObject.ID) } -func (qb *tagQueryBuilder) Destroy(id int) error { +func (qb *tagQueryBuilder) Destroy(ctx context.Context, id int) error { // TODO - add delete cascade to foreign key // delete tag from scenes and markers first - _, err := qb.tx.Exec("DELETE FROM scenes_tags WHERE tag_id = ?", id) + _, err := qb.tx.Exec(ctx, "DELETE FROM scenes_tags WHERE tag_id = ?", id) if err != nil { return err } // TODO - add delete cascade to foreign key - _, err = qb.tx.Exec("DELETE FROM scene_markers_tags WHERE tag_id = ?", id) + _, err = qb.tx.Exec(ctx, "DELETE FROM scene_markers_tags WHERE tag_id = ?", id) if err != nil { return err } @@ -72,7 +70,7 @@ func (qb *tagQueryBuilder) Destroy(id int) error { // cannot unset primary_tag_id in scene_markers because it is not nullable countQuery := "SELECT COUNT(*) as count FROM scene_markers where primary_tag_id = ?" args := []interface{}{id} - primaryMarkers, err := qb.runCountQuery(countQuery, args) + primaryMarkers, err := qb.runCountQuery(ctx, countQuery, args) if err != nil { return err } @@ -81,12 +79,12 @@ func (qb *tagQueryBuilder) Destroy(id int) error { return errors.New("cannot delete tag used as a primary tag in scene markers") } - return qb.destroyExisting([]int{id}) + return qb.destroyExisting(ctx, []int{id}) } -func (qb *tagQueryBuilder) Find(id int) (*models.Tag, error) { +func (qb *tagQueryBuilder) Find(ctx context.Context, id int) (*models.Tag, error) { var ret models.Tag - if err := qb.get(id, &ret); err != nil { + if err := qb.getByID(ctx, id, &ret); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -95,10 +93,10 @@ func (qb *tagQueryBuilder) Find(id int) (*models.Tag, error) { return &ret, nil } -func (qb *tagQueryBuilder) FindMany(ids []int) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Tag, error) { var tags []*models.Tag for _, id := range ids { - tag, err := qb.Find(id) + tag, err := qb.Find(ctx, id) if err != nil { return nil, err } @@ -113,7 +111,7 @@ func (qb *tagQueryBuilder) FindMany(ids []int) ([]*models.Tag, error) { return tags, nil } -func (qb *tagQueryBuilder) FindBySceneID(sceneID int) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN scenes_tags as scenes_join on scenes_join.tag_id = tags.id @@ -122,10 +120,10 @@ func (qb *tagQueryBuilder) FindBySceneID(sceneID int) ([]*models.Tag, error) { ` query += qb.getDefaultTagSort() args := []interface{}{sceneID} - return qb.queryTags(query, args) + return qb.queryTags(ctx, query, args) } -func (qb *tagQueryBuilder) FindByPerformerID(performerID int) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN performers_tags as performers_join on performers_join.tag_id = tags.id @@ -134,10 +132,10 @@ func (qb *tagQueryBuilder) FindByPerformerID(performerID int) ([]*models.Tag, er ` query += qb.getDefaultTagSort() args := []interface{}{performerID} - return qb.queryTags(query, args) + return qb.queryTags(ctx, query, args) } -func (qb *tagQueryBuilder) FindByImageID(imageID int) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindByImageID(ctx context.Context, imageID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN images_tags as images_join on images_join.tag_id = tags.id @@ -146,10 +144,10 @@ func (qb *tagQueryBuilder) FindByImageID(imageID int) ([]*models.Tag, error) { ` query += qb.getDefaultTagSort() args := []interface{}{imageID} - return qb.queryTags(query, args) + return qb.queryTags(ctx, query, args) } -func (qb *tagQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN galleries_tags as galleries_join on galleries_join.tag_id = tags.id @@ -158,10 +156,10 @@ func (qb *tagQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Tag, error) ` query += qb.getDefaultTagSort() args := []interface{}{galleryID} - return qb.queryTags(query, args) + return qb.queryTags(ctx, query, args) } -func (qb *tagQueryBuilder) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN scene_markers_tags as scene_markers_join on scene_markers_join.tag_id = tags.id @@ -170,20 +168,20 @@ func (qb *tagQueryBuilder) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag ` query += qb.getDefaultTagSort() args := []interface{}{sceneMarkerID} - return qb.queryTags(query, args) + return qb.queryTags(ctx, query, args) } -func (qb *tagQueryBuilder) FindByName(name string, nocase bool) (*models.Tag, error) { +func (qb *tagQueryBuilder) FindByName(ctx context.Context, name string, nocase bool) (*models.Tag, error) { query := "SELECT * FROM tags WHERE name = ?" if nocase { query += " COLLATE NOCASE" } query += " LIMIT 1" args := []interface{}{name} - return qb.queryTag(query, args) + return qb.queryTag(ctx, query, args) } -func (qb *tagQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Tag, error) { query := "SELECT * FROM tags WHERE name" if nocase { query += " COLLATE NOCASE" @@ -193,10 +191,10 @@ func (qb *tagQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.T for _, name := range names { args = append(args, name) } - return qb.queryTags(query, args) + return qb.queryTags(ctx, query, args) } -func (qb *tagQueryBuilder) FindByParentTagID(parentID int) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags INNER JOIN tags_relations ON tags_relations.child_id = tags.id @@ -204,10 +202,10 @@ func (qb *tagQueryBuilder) FindByParentTagID(parentID int) ([]*models.Tag, error ` query += qb.getDefaultTagSort() args := []interface{}{parentID} - return qb.queryTags(query, args) + return qb.queryTags(ctx, query, args) } -func (qb *tagQueryBuilder) FindByChildTagID(parentID int) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) FindByChildTagID(ctx context.Context, parentID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags INNER JOIN tags_relations ON tags_relations.parent_id = tags.id @@ -215,18 +213,18 @@ func (qb *tagQueryBuilder) FindByChildTagID(parentID int) ([]*models.Tag, error) ` query += qb.getDefaultTagSort() args := []interface{}{parentID} - return qb.queryTags(query, args) + return qb.queryTags(ctx, query, args) } -func (qb *tagQueryBuilder) Count() (int, error) { - return qb.runCountQuery(qb.buildCountQuery("SELECT tags.id FROM tags"), nil) +func (qb *tagQueryBuilder) Count(ctx context.Context) (int, error) { + return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT tags.id FROM tags"), nil) } -func (qb *tagQueryBuilder) All() ([]*models.Tag, error) { - return qb.queryTags(selectAll("tags")+qb.getDefaultTagSort(), nil) +func (qb *tagQueryBuilder) All(ctx context.Context) ([]*models.Tag, error) { + return qb.queryTags(ctx, selectAll("tags")+qb.getDefaultTagSort(), nil) } -func (qb *tagQueryBuilder) QueryForAutoTag(words []string) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Tag, error) { // TODO - Query needs to be changed to support queries of this type, and // this method should be removed query := selectAll(tagTable) @@ -250,7 +248,7 @@ func (qb *tagQueryBuilder) QueryForAutoTag(words []string) ([]*models.Tag, error "tags.ignore_auto_tag = 0", whereOr, }, " AND ") - return qb.queryTags(query+" WHERE "+where, args) + return qb.queryTags(ctx, query+" WHERE "+where, args) } func (qb *tagQueryBuilder) validateFilter(tagFilter *models.TagFilterType) error { @@ -284,38 +282,38 @@ func (qb *tagQueryBuilder) validateFilter(tagFilter *models.TagFilterType) error return nil } -func (qb *tagQueryBuilder) makeFilter(tagFilter *models.TagFilterType) *filterBuilder { +func (qb *tagQueryBuilder) makeFilter(ctx context.Context, tagFilter *models.TagFilterType) *filterBuilder { query := &filterBuilder{} if tagFilter.And != nil { - query.and(qb.makeFilter(tagFilter.And)) + query.and(qb.makeFilter(ctx, tagFilter.And)) } if tagFilter.Or != nil { - query.or(qb.makeFilter(tagFilter.Or)) + query.or(qb.makeFilter(ctx, tagFilter.Or)) } if tagFilter.Not != nil { - query.not(qb.makeFilter(tagFilter.Not)) + query.not(qb.makeFilter(ctx, tagFilter.Not)) } - query.handleCriterion(stringCriterionHandler(tagFilter.Name, tagTable+".name")) - query.handleCriterion(tagAliasCriterionHandler(qb, tagFilter.Aliases)) - query.handleCriterion(boolCriterionHandler(tagFilter.IgnoreAutoTag, tagTable+".ignore_auto_tag")) + query.handleCriterion(ctx, stringCriterionHandler(tagFilter.Name, tagTable+".name")) + query.handleCriterion(ctx, tagAliasCriterionHandler(qb, tagFilter.Aliases)) + query.handleCriterion(ctx, boolCriterionHandler(tagFilter.IgnoreAutoTag, tagTable+".ignore_auto_tag")) - query.handleCriterion(tagIsMissingCriterionHandler(qb, tagFilter.IsMissing)) - query.handleCriterion(tagSceneCountCriterionHandler(qb, tagFilter.SceneCount)) - query.handleCriterion(tagImageCountCriterionHandler(qb, tagFilter.ImageCount)) - query.handleCriterion(tagGalleryCountCriterionHandler(qb, tagFilter.GalleryCount)) - query.handleCriterion(tagPerformerCountCriterionHandler(qb, tagFilter.PerformerCount)) - query.handleCriterion(tagMarkerCountCriterionHandler(qb, tagFilter.MarkerCount)) - query.handleCriterion(tagParentsCriterionHandler(qb, tagFilter.Parents)) - query.handleCriterion(tagChildrenCriterionHandler(qb, tagFilter.Children)) - query.handleCriterion(tagParentCountCriterionHandler(qb, tagFilter.ParentCount)) - query.handleCriterion(tagChildCountCriterionHandler(qb, tagFilter.ChildCount)) + query.handleCriterion(ctx, tagIsMissingCriterionHandler(qb, tagFilter.IsMissing)) + query.handleCriterion(ctx, tagSceneCountCriterionHandler(qb, tagFilter.SceneCount)) + query.handleCriterion(ctx, tagImageCountCriterionHandler(qb, tagFilter.ImageCount)) + query.handleCriterion(ctx, tagGalleryCountCriterionHandler(qb, tagFilter.GalleryCount)) + query.handleCriterion(ctx, tagPerformerCountCriterionHandler(qb, tagFilter.PerformerCount)) + query.handleCriterion(ctx, tagMarkerCountCriterionHandler(qb, tagFilter.MarkerCount)) + query.handleCriterion(ctx, tagParentsCriterionHandler(qb, tagFilter.Parents)) + query.handleCriterion(ctx, tagChildrenCriterionHandler(qb, tagFilter.Children)) + query.handleCriterion(ctx, tagParentCountCriterionHandler(qb, tagFilter.ParentCount)) + query.handleCriterion(ctx, tagChildCountCriterionHandler(qb, tagFilter.ChildCount)) return query } -func (qb *tagQueryBuilder) Query(tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) { +func (qb *tagQueryBuilder) Query(ctx context.Context, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) { if tagFilter == nil { tagFilter = &models.TagFilterType{} } @@ -335,19 +333,19 @@ func (qb *tagQueryBuilder) Query(tagFilter *models.TagFilterType, findFilter *mo if err := qb.validateFilter(tagFilter); err != nil { return nil, 0, err } - filter := qb.makeFilter(tagFilter) + filter := qb.makeFilter(ctx, tagFilter) query.addFilter(filter) query.sortAndPagination = qb.getTagSort(&query, findFilter) + getPagination(findFilter) - idsResult, countResult, err := query.executeFind() + idsResult, countResult, err := query.executeFind(ctx) if err != nil { return nil, 0, err } var tags []*models.Tag for _, id := range idsResult { - tag, err := qb.Find(id) + tag, err := qb.Find(ctx, id) if err != nil { return nil, 0, err } @@ -370,7 +368,7 @@ func tagAliasCriterionHandler(qb *tagQueryBuilder, alias *models.StringCriterion } func tagIsMissingCriterionHandler(qb *tagQueryBuilder, isMissing *string) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if isMissing != nil && *isMissing != "" { switch *isMissing { case "image": @@ -384,7 +382,7 @@ func tagIsMissingCriterionHandler(qb *tagQueryBuilder, isMissing *string) criter } func tagSceneCountCriterionHandler(qb *tagQueryBuilder, sceneCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if sceneCount != nil { f.addLeftJoin("scenes_tags", "", "scenes_tags.tag_id = tags.id") clause, args := getIntCriterionWhereClause("count(distinct scenes_tags.scene_id)", *sceneCount) @@ -395,7 +393,7 @@ func tagSceneCountCriterionHandler(qb *tagQueryBuilder, sceneCount *models.IntCr } func tagImageCountCriterionHandler(qb *tagQueryBuilder, imageCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if imageCount != nil { f.addLeftJoin("images_tags", "", "images_tags.tag_id = tags.id") clause, args := getIntCriterionWhereClause("count(distinct images_tags.image_id)", *imageCount) @@ -406,7 +404,7 @@ func tagImageCountCriterionHandler(qb *tagQueryBuilder, imageCount *models.IntCr } func tagGalleryCountCriterionHandler(qb *tagQueryBuilder, galleryCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if galleryCount != nil { f.addLeftJoin("galleries_tags", "", "galleries_tags.tag_id = tags.id") clause, args := getIntCriterionWhereClause("count(distinct galleries_tags.gallery_id)", *galleryCount) @@ -417,7 +415,7 @@ func tagGalleryCountCriterionHandler(qb *tagQueryBuilder, galleryCount *models.I } func tagPerformerCountCriterionHandler(qb *tagQueryBuilder, performerCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if performerCount != nil { f.addLeftJoin("performers_tags", "", "performers_tags.tag_id = tags.id") clause, args := getIntCriterionWhereClause("count(distinct performers_tags.performer_id)", *performerCount) @@ -428,7 +426,7 @@ func tagPerformerCountCriterionHandler(qb *tagQueryBuilder, performerCount *mode } func tagMarkerCountCriterionHandler(qb *tagQueryBuilder, markerCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if markerCount != nil { f.addLeftJoin("scene_markers_tags", "", "scene_markers_tags.tag_id = tags.id") f.addLeftJoin("scene_markers", "", "scene_markers_tags.scene_marker_id = scene_markers.id OR scene_markers.primary_tag_id = tags.id") @@ -440,7 +438,7 @@ func tagMarkerCountCriterionHandler(qb *tagQueryBuilder, markerCount *models.Int } func tagParentsCriterionHandler(qb *tagQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if tags != nil { if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { var notClause string @@ -489,7 +487,7 @@ func tagParentsCriterionHandler(qb *tagQueryBuilder, tags *models.HierarchicalMu } func tagChildrenCriterionHandler(qb *tagQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if tags != nil { if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull { var notClause string @@ -538,7 +536,7 @@ func tagChildrenCriterionHandler(qb *tagQueryBuilder, tags *models.HierarchicalM } func tagParentCountCriterionHandler(qb *tagQueryBuilder, parentCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if parentCount != nil { f.addLeftJoin("tags_relations", "parents_count", "parents_count.child_id = tags.id") clause, args := getIntCriterionWhereClause("count(distinct parents_count.parent_id)", *parentCount) @@ -549,7 +547,7 @@ func tagParentCountCriterionHandler(qb *tagQueryBuilder, parentCount *models.Int } func tagChildCountCriterionHandler(qb *tagQueryBuilder, childCount *models.IntCriterionInput) criterionHandlerFunc { - return func(f *filterBuilder) { + return func(ctx context.Context, f *filterBuilder) { if childCount != nil { f.addLeftJoin("tags_relations", "children_count", "children_count.parent_id = tags.id") clause, args := getIntCriterionWhereClause("count(distinct children_count.child_id)", *childCount) @@ -592,17 +590,17 @@ func (qb *tagQueryBuilder) getTagSort(query *queryBuilder, findFilter *models.Fi return getSort(sort, direction, "tags") } -func (qb *tagQueryBuilder) queryTag(query string, args []interface{}) (*models.Tag, error) { - results, err := qb.queryTags(query, args) +func (qb *tagQueryBuilder) queryTag(ctx context.Context, query string, args []interface{}) (*models.Tag, error) { + results, err := qb.queryTags(ctx, query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *tagQueryBuilder) queryTags(query string, args []interface{}) ([]*models.Tag, error) { +func (qb *tagQueryBuilder) queryTags(ctx context.Context, query string, args []interface{}) ([]*models.Tag, error) { var ret models.Tags - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } @@ -620,20 +618,20 @@ func (qb *tagQueryBuilder) imageRepository() *imageRepository { } } -func (qb *tagQueryBuilder) GetImage(tagID int) ([]byte, error) { - return qb.imageRepository().get(tagID) +func (qb *tagQueryBuilder) GetImage(ctx context.Context, tagID int) ([]byte, error) { + return qb.imageRepository().get(ctx, tagID) } -func (qb *tagQueryBuilder) HasImage(tagID int) (bool, error) { - return qb.imageRepository().exists(tagID) +func (qb *tagQueryBuilder) HasImage(ctx context.Context, tagID int) (bool, error) { + return qb.imageRepository().exists(ctx, tagID) } -func (qb *tagQueryBuilder) UpdateImage(tagID int, image []byte) error { - return qb.imageRepository().replace(tagID, image) +func (qb *tagQueryBuilder) UpdateImage(ctx context.Context, tagID int, image []byte) error { + return qb.imageRepository().replace(ctx, tagID, image) } -func (qb *tagQueryBuilder) DestroyImage(tagID int) error { - return qb.imageRepository().destroy([]int{tagID}) +func (qb *tagQueryBuilder) DestroyImage(ctx context.Context, tagID int) error { + return qb.imageRepository().destroy(ctx, []int{tagID}) } func (qb *tagQueryBuilder) aliasRepository() *stringRepository { @@ -647,15 +645,15 @@ func (qb *tagQueryBuilder) aliasRepository() *stringRepository { } } -func (qb *tagQueryBuilder) GetAliases(tagID int) ([]string, error) { - return qb.aliasRepository().get(tagID) +func (qb *tagQueryBuilder) GetAliases(ctx context.Context, tagID int) ([]string, error) { + return qb.aliasRepository().get(ctx, tagID) } -func (qb *tagQueryBuilder) UpdateAliases(tagID int, aliases []string) error { - return qb.aliasRepository().replace(tagID, aliases) +func (qb *tagQueryBuilder) UpdateAliases(ctx context.Context, tagID int, aliases []string) error { + return qb.aliasRepository().replace(ctx, tagID, aliases) } -func (qb *tagQueryBuilder) Merge(source []int, destination int) error { +func (qb *tagQueryBuilder) Merge(ctx context.Context, source []int, destination int) error { if len(source) == 0 { return nil } @@ -680,7 +678,7 @@ func (qb *tagQueryBuilder) Merge(source []int, destination int) error { args = append(args, destination) for table, idColumn := range tagTables { - _, err := qb.tx.Exec(`UPDATE `+table+` + _, err := qb.tx.Exec(ctx, `UPDATE `+table+` SET tag_id = ? WHERE tag_id IN `+inBinding+` AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idColumn+` AND o.tag_id = ?)`, @@ -691,23 +689,23 @@ AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idCo } } - _, err := qb.tx.Exec("UPDATE "+sceneMarkerTable+" SET primary_tag_id = ? WHERE primary_tag_id IN "+inBinding, args...) + _, err := qb.tx.Exec(ctx, "UPDATE "+sceneMarkerTable+" SET primary_tag_id = ? WHERE primary_tag_id IN "+inBinding, args...) if err != nil { return err } - _, err = qb.tx.Exec("INSERT INTO "+tagAliasesTable+" (tag_id, alias) SELECT ?, name FROM "+tagTable+" WHERE id IN "+inBinding, args...) + _, err = qb.tx.Exec(ctx, "INSERT INTO "+tagAliasesTable+" (tag_id, alias) SELECT ?, name FROM "+tagTable+" WHERE id IN "+inBinding, args...) if err != nil { return err } - _, err = qb.tx.Exec("UPDATE "+tagAliasesTable+" SET tag_id = ? WHERE tag_id IN "+inBinding, args...) + _, err = qb.tx.Exec(ctx, "UPDATE "+tagAliasesTable+" SET tag_id = ? WHERE tag_id IN "+inBinding, args...) if err != nil { return err } for _, id := range source { - err = qb.Destroy(id) + err = qb.Destroy(ctx, id) if err != nil { return err } @@ -716,9 +714,9 @@ AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idCo return nil } -func (qb *tagQueryBuilder) UpdateParentTags(tagID int, parentIDs []int) error { +func (qb *tagQueryBuilder) UpdateParentTags(ctx context.Context, tagID int, parentIDs []int) error { tx := qb.tx - if _, err := tx.Exec("DELETE FROM tags_relations WHERE child_id = ?", tagID); err != nil { + if _, err := tx.Exec(ctx, "DELETE FROM tags_relations WHERE child_id = ?", tagID); err != nil { return err } @@ -731,7 +729,7 @@ func (qb *tagQueryBuilder) UpdateParentTags(tagID int, parentIDs []int) error { } query := "INSERT INTO tags_relations (parent_id, child_id) VALUES " + strings.Join(values, ", ") - if _, err := tx.Exec(query, args...); err != nil { + if _, err := tx.Exec(ctx, query, args...); err != nil { return err } } @@ -739,9 +737,9 @@ func (qb *tagQueryBuilder) UpdateParentTags(tagID int, parentIDs []int) error { return nil } -func (qb *tagQueryBuilder) UpdateChildTags(tagID int, childIDs []int) error { +func (qb *tagQueryBuilder) UpdateChildTags(ctx context.Context, tagID int, childIDs []int) error { tx := qb.tx - if _, err := tx.Exec("DELETE FROM tags_relations WHERE parent_id = ?", tagID); err != nil { + if _, err := tx.Exec(ctx, "DELETE FROM tags_relations WHERE parent_id = ?", tagID); err != nil { return err } @@ -754,7 +752,7 @@ func (qb *tagQueryBuilder) UpdateChildTags(tagID int, childIDs []int) error { } query := "INSERT INTO tags_relations (parent_id, child_id) VALUES " + strings.Join(values, ", ") - if _, err := tx.Exec(query, args...); err != nil { + if _, err := tx.Exec(ctx, query, args...); err != nil { return err } } @@ -764,7 +762,7 @@ func (qb *tagQueryBuilder) UpdateChildTags(tagID int, childIDs []int) error { // FindAllAncestors returns a slice of TagPath objects, representing all // ancestors of the tag with the provided id. -func (qb *tagQueryBuilder) FindAllAncestors(tagID int, excludeIDs []int) ([]*models.TagPath, error) { +func (qb *tagQueryBuilder) FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) { inBinding := getInBinding(len(excludeIDs) + 1) query := `WITH RECURSIVE @@ -783,7 +781,7 @@ SELECT t.*, p.path FROM tags t INNER JOIN parents p ON t.id = p.parent_id } args := []interface{}{tagID} args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...) - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } @@ -792,7 +790,7 @@ SELECT t.*, p.path FROM tags t INNER JOIN parents p ON t.id = p.parent_id // FindAllDescendants returns a slice of TagPath objects, representing all // descendants of the tag with the provided id. -func (qb *tagQueryBuilder) FindAllDescendants(tagID int, excludeIDs []int) ([]*models.TagPath, error) { +func (qb *tagQueryBuilder) FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) { inBinding := getInBinding(len(excludeIDs) + 1) query := `WITH RECURSIVE @@ -811,7 +809,7 @@ SELECT t.*, c.path FROM tags t INNER JOIN children c ON t.id = c.child_id } args := []interface{}{tagID} args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...) - if err := qb.query(query, args, &ret); err != nil { + if err := qb.query(ctx, query, args, &ret); err != nil { return nil, err } diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go index a5ed8d966..0bc0fb5c0 100644 --- a/pkg/sqlite/tag_test.go +++ b/pkg/sqlite/tag_test.go @@ -4,6 +4,7 @@ package sqlite_test import ( + "context" "database/sql" "fmt" "math" @@ -12,16 +13,17 @@ import ( "testing" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" "github.com/stretchr/testify/assert" ) func TestMarkerFindBySceneMarkerID(t *testing.T) { - withTxn(func(r models.Repository) error { - tqb := r.Tag() + withTxn(func(ctx context.Context) error { + tqb := sqlite.TagReaderWriter markerID := markerIDs[markerIdxWithTag] - tags, err := tqb.FindBySceneMarkerID(markerID) + tags, err := tqb.FindBySceneMarkerID(ctx, markerID) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) @@ -30,7 +32,7 @@ func TestMarkerFindBySceneMarkerID(t *testing.T) { assert.Len(t, tags, 1) assert.Equal(t, tagIDs[tagIdxWithMarkers], tags[0].ID) - tags, err = tqb.FindBySceneMarkerID(0) + tags, err = tqb.FindBySceneMarkerID(ctx, 0) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) @@ -43,12 +45,12 @@ func TestMarkerFindBySceneMarkerID(t *testing.T) { } func TestTagFindByName(t *testing.T) { - withTxn(func(r models.Repository) error { - tqb := r.Tag() + withTxn(func(ctx context.Context) error { + tqb := sqlite.TagReaderWriter name := tagNames[tagIdxWithScene] // find a tag by name - tag, err := tqb.FindByName(name, false) + tag, err := tqb.FindByName(ctx, name, false) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) @@ -58,7 +60,7 @@ func TestTagFindByName(t *testing.T) { name = tagNames[tagIdxWithDupName] // find a tag by name nocase - tag, err = tqb.FindByName(name, true) + tag, err = tqb.FindByName(ctx, name, true) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) @@ -74,15 +76,15 @@ func TestTagFindByName(t *testing.T) { } func TestTagQueryIgnoreAutoTag(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { ignoreAutoTag := true tagFilter := models.TagFilterType{ IgnoreAutoTag: &ignoreAutoTag, } - sqb := r.Tag() + sqb := sqlite.TagReaderWriter - tags := queryTags(t, sqb, &tagFilter, nil) + tags := queryTags(ctx, t, sqb, &tagFilter, nil) assert.Len(t, tags, int(math.Ceil(float64(totalTags)/5))) for _, s := range tags { @@ -94,12 +96,12 @@ func TestTagQueryIgnoreAutoTag(t *testing.T) { } func TestTagQueryForAutoTag(t *testing.T) { - withTxn(func(r models.Repository) error { - tqb := r.Tag() + withTxn(func(ctx context.Context) error { + tqb := sqlite.TagReaderWriter name := tagNames[tagIdx1WithScene] // find a tag by name - tags, err := tqb.QueryForAutoTag([]string{name}) + tags, err := tqb.QueryForAutoTag(ctx, []string{name}) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) @@ -112,7 +114,7 @@ func TestTagQueryForAutoTag(t *testing.T) { // find by alias name = getTagStringValue(tagIdx1WithScene, "Alias") - tags, err = tqb.QueryForAutoTag([]string{name}) + tags, err = tqb.QueryForAutoTag(ctx, []string{name}) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) @@ -128,19 +130,19 @@ func TestTagQueryForAutoTag(t *testing.T) { func TestTagFindByNames(t *testing.T) { var names []string - withTxn(func(r models.Repository) error { - tqb := r.Tag() + withTxn(func(ctx context.Context) error { + tqb := sqlite.TagReaderWriter names = append(names, tagNames[tagIdxWithScene]) // find tags by names - tags, err := tqb.FindByNames(names, false) + tags, err := tqb.FindByNames(ctx, names, false) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) } assert.Len(t, tags, 1) assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name) - tags, err = tqb.FindByNames(names, true) // find tags by names nocase + tags, err = tqb.FindByNames(ctx, names, true) // find tags by names nocase if err != nil { t.Errorf("Error finding tags: %s", err.Error()) } @@ -150,7 +152,7 @@ func TestTagFindByNames(t *testing.T) { names = append(names, tagNames[tagIdx1WithScene]) // find tags by names ( 2 names ) - tags, err = tqb.FindByNames(names, false) + tags, err = tqb.FindByNames(ctx, names, false) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) } @@ -158,7 +160,7 @@ func TestTagFindByNames(t *testing.T) { assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name) assert.Equal(t, tagNames[tagIdx1WithScene], tags[1].Name) - tags, err = tqb.FindByNames(names, true) // find tags by names ( 2 names nocase) + tags, err = tqb.FindByNames(ctx, names, true) // find tags by names ( 2 names nocase) if err != nil { t.Errorf("Error finding tags: %s", err.Error()) } @@ -173,8 +175,8 @@ func TestTagFindByNames(t *testing.T) { } func TestTagQuerySort(t *testing.T) { - withTxn(func(r models.Repository) error { - sqb := r.Tag() + withTxn(func(ctx context.Context) error { + sqb := sqlite.TagReaderWriter sortBy := "scenes_count" dir := models.SortDirectionEnumDesc @@ -183,24 +185,24 @@ func TestTagQuerySort(t *testing.T) { Direction: &dir, } - tags := queryTags(t, sqb, nil, findFilter) + tags := queryTags(ctx, t, sqb, nil, findFilter) assert := assert.New(t) assert.Equal(tagIDs[tagIdxWithScene], tags[0].ID) sortBy = "scene_markers_count" - tags = queryTags(t, sqb, nil, findFilter) + tags = queryTags(ctx, t, sqb, nil, findFilter) assert.Equal(tagIDs[tagIdxWithMarkers], tags[0].ID) sortBy = "images_count" - tags = queryTags(t, sqb, nil, findFilter) + tags = queryTags(ctx, t, sqb, nil, findFilter) assert.Equal(tagIDs[tagIdxWithImage], tags[0].ID) sortBy = "galleries_count" - tags = queryTags(t, sqb, nil, findFilter) + tags = queryTags(ctx, t, sqb, nil, findFilter) assert.Equal(tagIDs[tagIdxWithGallery], tags[0].ID) sortBy = "performers_count" - tags = queryTags(t, sqb, nil, findFilter) + tags = queryTags(ctx, t, sqb, nil, findFilter) assert.Equal(tagIDs[tagIdxWithPerformer], tags[0].ID) return nil @@ -220,7 +222,7 @@ func TestTagQueryName(t *testing.T) { Name: nameCriterion, } - verifyFn := func(tag *models.Tag, r models.Repository) { + verifyFn := func(ctx context.Context, tag *models.Tag) { verifyString(t, tag.Name, *nameCriterion) } @@ -250,8 +252,8 @@ func TestTagQueryAlias(t *testing.T) { Aliases: aliasCriterion, } - verifyFn := func(tag *models.Tag, r models.Repository) { - aliases, err := r.Tag().GetAliases(tag.ID) + verifyFn := func(ctx context.Context, tag *models.Tag) { + aliases, err := sqlite.TagReaderWriter.GetAliases(ctx, tag.ID) if err != nil { t.Errorf("Error querying tags: %s", err.Error()) } @@ -277,23 +279,23 @@ func TestTagQueryAlias(t *testing.T) { verifyTagQuery(t, tagFilter, nil, verifyFn) } -func verifyTagQuery(t *testing.T, tagFilter *models.TagFilterType, findFilter *models.FindFilterType, verifyFn func(t *models.Tag, r models.Repository)) { - withTxn(func(r models.Repository) error { - sqb := r.Tag() +func verifyTagQuery(t *testing.T, tagFilter *models.TagFilterType, findFilter *models.FindFilterType, verifyFn func(ctx context.Context, t *models.Tag)) { + withTxn(func(ctx context.Context) error { + sqb := sqlite.TagReaderWriter - tags := queryTags(t, sqb, tagFilter, findFilter) + tags := queryTags(ctx, t, sqb, tagFilter, findFilter) for _, tag := range tags { - verifyFn(tag, r) + verifyFn(ctx, tag) } return nil }) } -func queryTags(t *testing.T, qb models.TagReader, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) []*models.Tag { +func queryTags(ctx context.Context, t *testing.T, qb models.TagReader, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) []*models.Tag { t.Helper() - tags, _, err := qb.Query(tagFilter, findFilter) + tags, _, err := qb.Query(ctx, tagFilter, findFilter) if err != nil { t.Errorf("Error querying tags: %s", err.Error()) } @@ -302,8 +304,8 @@ func queryTags(t *testing.T, qb models.TagReader, tagFilter *models.TagFilterTyp } func TestTagQueryIsMissingImage(t *testing.T) { - withTxn(func(r models.Repository) error { - qb := r.Tag() + withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter isMissing := "image" tagFilter := models.TagFilterType{ IsMissing: &isMissing, @@ -314,7 +316,7 @@ func TestTagQueryIsMissingImage(t *testing.T) { Q: &q, } - tags, _, err := qb.Query(&tagFilter, &findFilter) + tags, _, err := qb.Query(ctx, &tagFilter, &findFilter) if err != nil { t.Errorf("Error querying tag: %s", err.Error()) } @@ -322,7 +324,7 @@ func TestTagQueryIsMissingImage(t *testing.T) { assert.Len(t, tags, 0) findFilter.Q = nil - tags, _, err = qb.Query(&tagFilter, &findFilter) + tags, _, err = qb.Query(ctx, &tagFilter, &findFilter) if err != nil { t.Errorf("Error querying tag: %s", err.Error()) } @@ -356,13 +358,13 @@ func TestTagQuerySceneCount(t *testing.T) { } func verifyTagSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Tag() + withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter tagFilter := models.TagFilterType{ SceneCount: &sceneCountCriterion, } - tags, _, err := qb.Query(&tagFilter, nil) + tags, _, err := qb.Query(ctx, &tagFilter, nil) if err != nil { t.Errorf("Error querying tag: %s", err.Error()) } @@ -398,13 +400,13 @@ func TestTagQueryMarkerCount(t *testing.T) { } func verifyTagMarkerCount(t *testing.T, markerCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Tag() + withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter tagFilter := models.TagFilterType{ MarkerCount: &markerCountCriterion, } - tags, _, err := qb.Query(&tagFilter, nil) + tags, _, err := qb.Query(ctx, &tagFilter, nil) if err != nil { t.Errorf("Error querying tag: %s", err.Error()) } @@ -440,13 +442,13 @@ func TestTagQueryImageCount(t *testing.T) { } func verifyTagImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Tag() + withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter tagFilter := models.TagFilterType{ ImageCount: &imageCountCriterion, } - tags, _, err := qb.Query(&tagFilter, nil) + tags, _, err := qb.Query(ctx, &tagFilter, nil) if err != nil { t.Errorf("Error querying tag: %s", err.Error()) } @@ -482,13 +484,13 @@ func TestTagQueryGalleryCount(t *testing.T) { } func verifyTagGalleryCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Tag() + withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter tagFilter := models.TagFilterType{ GalleryCount: &imageCountCriterion, } - tags, _, err := qb.Query(&tagFilter, nil) + tags, _, err := qb.Query(ctx, &tagFilter, nil) if err != nil { t.Errorf("Error querying tag: %s", err.Error()) } @@ -524,13 +526,13 @@ func TestTagQueryPerformerCount(t *testing.T) { } func verifyTagPerformerCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Tag() + withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter tagFilter := models.TagFilterType{ PerformerCount: &imageCountCriterion, } - tags, _, err := qb.Query(&tagFilter, nil) + tags, _, err := qb.Query(ctx, &tagFilter, nil) if err != nil { t.Errorf("Error querying tag: %s", err.Error()) } @@ -566,13 +568,13 @@ func TestTagQueryParentCount(t *testing.T) { } func verifyTagParentCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Tag() + withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter tagFilter := models.TagFilterType{ ParentCount: &sceneCountCriterion, } - tags := queryTags(t, qb, &tagFilter, nil) + tags := queryTags(ctx, t, qb, &tagFilter, nil) if len(tags) == 0 { t.Error("Expected at least one tag") @@ -609,13 +611,13 @@ func TestTagQueryChildCount(t *testing.T) { } func verifyTagChildCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { - withTxn(func(r models.Repository) error { - qb := r.Tag() + withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter tagFilter := models.TagFilterType{ ChildCount: &sceneCountCriterion, } - tags := queryTags(t, qb, &tagFilter, nil) + tags := queryTags(ctx, t, qb, &tagFilter, nil) if len(tags) == 0 { t.Error("Expected at least one tag") @@ -633,9 +635,9 @@ func verifyTagChildCount(t *testing.T, sceneCountCriterion models.IntCriterionIn } func TestTagQueryParent(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { const nameField = "Name" - sqb := r.Tag() + sqb := sqlite.TagReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithChildTag]), @@ -647,7 +649,7 @@ func TestTagQueryParent(t *testing.T) { Parents: &tagCriterion, } - tags := queryTags(t, sqb, &tagFilter, nil) + tags := queryTags(ctx, t, sqb, &tagFilter, nil) assert.Len(t, tags, 1) @@ -661,7 +663,7 @@ func TestTagQueryParent(t *testing.T) { Q: &q, } - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 0) depth := -1 @@ -674,12 +676,12 @@ func TestTagQueryParent(t *testing.T) { Depth: &depth, } - tags = queryTags(t, sqb, &tagFilter, nil) + tags = queryTags(ctx, t, sqb, &tagFilter, nil) assert.Len(t, tags, 2) depth = 1 - tags = queryTags(t, sqb, &tagFilter, nil) + tags = queryTags(ctx, t, sqb, &tagFilter, nil) assert.Len(t, tags, 2) tagCriterion = models.HierarchicalMultiCriterionInput{ @@ -687,22 +689,22 @@ func TestTagQueryParent(t *testing.T) { } q = getTagStringValue(tagIdxWithGallery, nameField) - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 1) assert.Equal(t, tagIDs[tagIdxWithGallery], tags[0].ID) q = getTagStringValue(tagIdxWithParentTag, nameField) - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 0) tagCriterion.Modifier = models.CriterionModifierNotNull - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 1) assert.Equal(t, tagIDs[tagIdxWithParentTag], tags[0].ID) q = getTagStringValue(tagIdxWithGallery, nameField) - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 0) return nil @@ -710,10 +712,10 @@ func TestTagQueryParent(t *testing.T) { } func TestTagQueryChild(t *testing.T) { - withTxn(func(r models.Repository) error { + withTxn(func(ctx context.Context) error { const nameField = "Name" - sqb := r.Tag() + sqb := sqlite.TagReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithParentTag]), @@ -725,7 +727,7 @@ func TestTagQueryChild(t *testing.T) { Children: &tagCriterion, } - tags := queryTags(t, sqb, &tagFilter, nil) + tags := queryTags(ctx, t, sqb, &tagFilter, nil) assert.Len(t, tags, 1) @@ -739,7 +741,7 @@ func TestTagQueryChild(t *testing.T) { Q: &q, } - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 0) depth := -1 @@ -752,12 +754,12 @@ func TestTagQueryChild(t *testing.T) { Depth: &depth, } - tags = queryTags(t, sqb, &tagFilter, nil) + tags = queryTags(ctx, t, sqb, &tagFilter, nil) assert.Len(t, tags, 2) depth = 1 - tags = queryTags(t, sqb, &tagFilter, nil) + tags = queryTags(ctx, t, sqb, &tagFilter, nil) assert.Len(t, tags, 2) tagCriterion = models.HierarchicalMultiCriterionInput{ @@ -765,22 +767,22 @@ func TestTagQueryChild(t *testing.T) { } q = getTagStringValue(tagIdxWithGallery, nameField) - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 1) assert.Equal(t, tagIDs[tagIdxWithGallery], tags[0].ID) q = getTagStringValue(tagIdxWithChildTag, nameField) - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 0) tagCriterion.Modifier = models.CriterionModifierNotNull - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 1) assert.Equal(t, tagIDs[tagIdxWithChildTag], tags[0].ID) q = getTagStringValue(tagIdxWithGallery, nameField) - tags = queryTags(t, sqb, &tagFilter, &findFilter) + tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter) assert.Len(t, tags, 0) return nil @@ -788,34 +790,34 @@ func TestTagQueryChild(t *testing.T) { } func TestTagUpdateTagImage(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Tag() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter // create tag to test against const name = "TestTagUpdateTagImage" tag := models.Tag{ Name: name, } - created, err := qb.Create(tag) + created, err := qb.Create(ctx, tag) if err != nil { return fmt.Errorf("Error creating tag: %s", err.Error()) } image := []byte("image") - err = qb.UpdateImage(created.ID, image) + err = qb.UpdateImage(ctx, created.ID, image) if err != nil { return fmt.Errorf("Error updating studio image: %s", err.Error()) } // ensure image set - storedImage, err := qb.GetImage(created.ID) + storedImage, err := qb.GetImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } assert.Equal(t, storedImage, image) // set nil image - err = qb.UpdateImage(created.ID, nil) + err = qb.UpdateImage(ctx, created.ID, nil) if err == nil { return fmt.Errorf("Expected error setting nil image") } @@ -827,32 +829,32 @@ func TestTagUpdateTagImage(t *testing.T) { } func TestTagDestroyTagImage(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Tag() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter // create performer to test against const name = "TestTagDestroyTagImage" tag := models.Tag{ Name: name, } - created, err := qb.Create(tag) + created, err := qb.Create(ctx, tag) if err != nil { return fmt.Errorf("Error creating tag: %s", err.Error()) } image := []byte("image") - err = qb.UpdateImage(created.ID, image) + err = qb.UpdateImage(ctx, created.ID, image) if err != nil { return fmt.Errorf("Error updating studio image: %s", err.Error()) } - err = qb.DestroyImage(created.ID) + err = qb.DestroyImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error destroying studio image: %s", err.Error()) } // image should be nil - storedImage, err := qb.GetImage(created.ID) + storedImage, err := qb.GetImage(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } @@ -865,27 +867,27 @@ func TestTagDestroyTagImage(t *testing.T) { } func TestTagUpdateAlias(t *testing.T) { - if err := withTxn(func(r models.Repository) error { - qb := r.Tag() + if err := withTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter // create tag to test against const name = "TestTagUpdateAlias" tag := models.Tag{ Name: name, } - created, err := qb.Create(tag) + created, err := qb.Create(ctx, tag) if err != nil { return fmt.Errorf("Error creating tag: %s", err.Error()) } aliases := []string{"alias1", "alias2"} - err = qb.UpdateAliases(created.ID, aliases) + err = qb.UpdateAliases(ctx, created.ID, aliases) if err != nil { return fmt.Errorf("Error updating tag aliases: %s", err.Error()) } // ensure aliases set - storedAliases, err := qb.GetAliases(created.ID) + storedAliases, err := qb.GetAliases(ctx, created.ID) if err != nil { return fmt.Errorf("Error getting aliases: %s", err.Error()) } @@ -901,11 +903,11 @@ func TestTagMerge(t *testing.T) { assert := assert.New(t) // merge tests - perform these in a transaction that we'll rollback - if err := withRollbackTxn(func(r models.Repository) error { - qb := r.Tag() + if err := withRollbackTxn(func(ctx context.Context) error { + qb := sqlite.TagReaderWriter // try merging into same tag - err := qb.Merge([]int{tagIDs[tagIdx1WithScene]}, tagIDs[tagIdx1WithScene]) + err := qb.Merge(ctx, []int{tagIDs[tagIdx1WithScene]}, tagIDs[tagIdx1WithScene]) assert.NotNil(err) // merge everything into tagIdxWithScene @@ -931,13 +933,13 @@ func TestTagMerge(t *testing.T) { } destID := tagIDs[tagIdxWithScene] - if err = qb.Merge(srcIDs, destID); err != nil { + if err = qb.Merge(ctx, srcIDs, destID); err != nil { return err } // ensure other tags are deleted for _, tagId := range srcIDs { - t, err := qb.Find(tagId) + t, err := qb.Find(ctx, tagId) if err != nil { return err } @@ -946,7 +948,7 @@ func TestTagMerge(t *testing.T) { } // ensure aliases are set on the destination - destAliases, err := qb.GetAliases(destID) + destAliases, err := qb.GetAliases(ctx, destID) if err != nil { return err } @@ -955,7 +957,7 @@ func TestTagMerge(t *testing.T) { } // ensure scene points to new tag - sceneTagIDs, err := r.Scene().GetTagIDs(sceneIDs[sceneIdxWithTwoTags]) + sceneTagIDs, err := sqlite.SceneReaderWriter.GetTagIDs(ctx, sceneIDs[sceneIdxWithTwoTags]) if err != nil { return err } @@ -963,14 +965,14 @@ func TestTagMerge(t *testing.T) { assert.Contains(sceneTagIDs, destID) // ensure marker points to new tag - marker, err := r.SceneMarker().Find(markerIDs[markerIdxWithTag]) + marker, err := sqlite.SceneMarkerReaderWriter.Find(ctx, markerIDs[markerIdxWithTag]) if err != nil { return err } assert.Equal(destID, marker.PrimaryTagID) - markerTagIDs, err := r.SceneMarker().GetTagIDs(marker.ID) + markerTagIDs, err := sqlite.SceneMarkerReaderWriter.GetTagIDs(ctx, marker.ID) if err != nil { return err } @@ -978,7 +980,7 @@ func TestTagMerge(t *testing.T) { assert.Contains(markerTagIDs, destID) // ensure image points to new tag - imageTagIDs, err := r.Image().GetTagIDs(imageIDs[imageIdxWithTwoTags]) + imageTagIDs, err := sqlite.ImageReaderWriter.GetTagIDs(ctx, imageIDs[imageIdxWithTwoTags]) if err != nil { return err } @@ -986,7 +988,7 @@ func TestTagMerge(t *testing.T) { assert.Contains(imageTagIDs, destID) // ensure gallery points to new tag - galleryTagIDs, err := r.Gallery().GetTagIDs(galleryIDs[galleryIdxWithTwoTags]) + galleryTagIDs, err := sqlite.GalleryReaderWriter.GetTagIDs(ctx, galleryIDs[galleryIdxWithTwoTags]) if err != nil { return err } @@ -994,7 +996,7 @@ func TestTagMerge(t *testing.T) { assert.Contains(galleryTagIDs, destID) // ensure performer points to new tag - performerTagIDs, err := r.Gallery().GetTagIDs(performerIDs[performerIdxWithTwoTags]) + performerTagIDs, err := sqlite.GalleryReaderWriter.GetTagIDs(ctx, performerIDs[performerIdxWithTwoTags]) if err != nil { return err } diff --git a/pkg/sqlite/transaction.go b/pkg/sqlite/transaction.go index 50486d01e..23f9b27b3 100644 --- a/pkg/sqlite/transaction.go +++ b/pkg/sqlite/transaction.go @@ -2,209 +2,67 @@ package sqlite import ( "context" - "database/sql" - "errors" "fmt" "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/models" ) -type dbi interface { - Get(dest interface{}, query string, args ...interface{}) error - Select(dest interface{}, query string, args ...interface{}) error - Queryx(query string, args ...interface{}) (*sqlx.Rows, error) - NamedExec(query string, arg interface{}) (sql.Result, error) - Exec(query string, args ...interface{}) (sql.Result, error) -} +type key int -type transaction struct { - Ctx context.Context - tx *sqlx.Tx -} +const ( + txnKey key = iota + 1 +) -func (t *transaction) Begin() error { - if t.tx != nil { - return errors.New("transaction already begun") +func (db *Database) Begin(ctx context.Context) (context.Context, error) { + if tx, _ := getTx(ctx); tx != nil { + return nil, fmt.Errorf("already in transaction") } - if err := database.Ready(); err != nil { + tx, err := db.db.BeginTxx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("beginning transaction: %w", err) + } + + return context.WithValue(ctx, txnKey, tx), nil +} + +func (db *Database) Commit(ctx context.Context) error { + tx, err := getTx(ctx) + if err != nil { return err } + return tx.Commit() +} - var err error - t.tx, err = database.DB.BeginTxx(t.Ctx, nil) +func (db *Database) Rollback(ctx context.Context) error { + tx, err := getTx(ctx) if err != nil { - return fmt.Errorf("error starting transaction: %v", err) - } - - return nil -} - -func (t *transaction) Rollback() error { - if t.tx == nil { - return errors.New("not in transaction") - } - - err := t.tx.Rollback() - if err != nil { - return fmt.Errorf("error rolling back transaction: %v", err) - } - t.tx = nil - - return nil -} - -func (t *transaction) Commit() error { - if t.tx == nil { - return errors.New("not in transaction") - } - - err := t.tx.Commit() - if err != nil { - return fmt.Errorf("error committing transaction: %v", err) - } - t.tx = nil - - return nil -} - -func (t *transaction) Repository() models.Repository { - return t -} - -func (t *transaction) ensureTx() { - if t.tx == nil { - panic("tx is nil") - } -} - -func (t *transaction) Gallery() models.GalleryReaderWriter { - t.ensureTx() - return NewGalleryReaderWriter(t.tx) -} - -func (t *transaction) Image() models.ImageReaderWriter { - t.ensureTx() - return NewImageReaderWriter(t.tx) -} - -func (t *transaction) Movie() models.MovieReaderWriter { - t.ensureTx() - return NewMovieReaderWriter(t.tx) -} - -func (t *transaction) Performer() models.PerformerReaderWriter { - t.ensureTx() - return NewPerformerReaderWriter(t.tx) -} - -func (t *transaction) SceneMarker() models.SceneMarkerReaderWriter { - t.ensureTx() - return NewSceneMarkerReaderWriter(t.tx) -} - -func (t *transaction) Scene() models.SceneReaderWriter { - t.ensureTx() - return NewSceneReaderWriter(t.tx) -} - -func (t *transaction) ScrapedItem() models.ScrapedItemReaderWriter { - t.ensureTx() - return NewScrapedItemReaderWriter(t.tx) -} - -func (t *transaction) Studio() models.StudioReaderWriter { - t.ensureTx() - return NewStudioReaderWriter(t.tx) -} - -func (t *transaction) Tag() models.TagReaderWriter { - t.ensureTx() - return NewTagReaderWriter(t.tx) -} - -func (t *transaction) SavedFilter() models.SavedFilterReaderWriter { - t.ensureTx() - return NewSavedFilterReaderWriter(t.tx) -} - -type ReadTransaction struct{} - -func (t *ReadTransaction) Begin() error { - if err := database.Ready(); err != nil { return err } - - return nil + return tx.Rollback() } -func (t *ReadTransaction) Rollback() error { - return nil +func getTx(ctx context.Context) (*sqlx.Tx, error) { + tx, ok := ctx.Value(txnKey).(*sqlx.Tx) + if !ok || tx == nil { + return nil, fmt.Errorf("not in transaction") + } + return tx, nil } -func (t *ReadTransaction) Commit() error { - return nil -} - -func (t *ReadTransaction) Repository() models.ReaderRepository { - return t -} - -func (t *ReadTransaction) Gallery() models.GalleryReader { - return NewGalleryReaderWriter(database.DB) -} - -func (t *ReadTransaction) Image() models.ImageReader { - return NewImageReaderWriter(database.DB) -} - -func (t *ReadTransaction) Movie() models.MovieReader { - return NewMovieReaderWriter(database.DB) -} - -func (t *ReadTransaction) Performer() models.PerformerReader { - return NewPerformerReaderWriter(database.DB) -} - -func (t *ReadTransaction) SceneMarker() models.SceneMarkerReader { - return NewSceneMarkerReaderWriter(database.DB) -} - -func (t *ReadTransaction) Scene() models.SceneReader { - return NewSceneReaderWriter(database.DB) -} - -func (t *ReadTransaction) ScrapedItem() models.ScrapedItemReader { - return NewScrapedItemReaderWriter(database.DB) -} - -func (t *ReadTransaction) Studio() models.StudioReader { - return NewStudioReaderWriter(database.DB) -} - -func (t *ReadTransaction) Tag() models.TagReader { - return NewTagReaderWriter(database.DB) -} - -func (t *ReadTransaction) SavedFilter() models.SavedFilterReader { - return NewSavedFilterReaderWriter(database.DB) -} - -type TransactionManager struct { -} - -func NewTransactionManager() *TransactionManager { - return &TransactionManager{} -} - -func (t *TransactionManager) WithTxn(ctx context.Context, fn func(r models.Repository) error) error { - database.WriteMu.Lock() - defer database.WriteMu.Unlock() - return models.WithTxn(&transaction{Ctx: ctx}, fn) -} - -func (t *TransactionManager) WithReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error { - return models.WithROTxn(&ReadTransaction{}, fn) +func (db *Database) TxnRepository() models.Repository { + return models.Repository{ + TxnManager: db, + Gallery: GalleryReaderWriter, + Image: ImageReaderWriter, + Movie: MovieReaderWriter, + Performer: PerformerReaderWriter, + Scene: SceneReaderWriter, + SceneMarker: SceneMarkerReaderWriter, + ScrapedItem: ScrapedItemReaderWriter, + Studio: StudioReaderWriter, + Tag: TagReaderWriter, + SavedFilter: SavedFilterReaderWriter, + } } diff --git a/pkg/sqlite/tx.go b/pkg/sqlite/tx.go new file mode 100644 index 000000000..f12b1dd4a --- /dev/null +++ b/pkg/sqlite/tx.go @@ -0,0 +1,55 @@ +package sqlite + +import ( + "context" + "database/sql" + + "github.com/jmoiron/sqlx" +) + +type dbi struct{} + +func (*dbi) Get(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + tx, err := getTx(ctx) + if err != nil { + return err + } + + return tx.Get(dest, query, args...) +} + +func (*dbi) Select(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + tx, err := getTx(ctx) + if err != nil { + return err + } + + return tx.Select(dest, query, args...) +} + +func (*dbi) Queryx(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + tx, err := getTx(ctx) + if err != nil { + return nil, err + } + + return tx.Queryx(query, args...) +} + +func (*dbi) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + tx, err := getTx(ctx) + if err != nil { + return nil, err + } + + return tx.NamedExec(query, arg) +} + +func (*dbi) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + tx, err := getTx(ctx) + if err != nil { + return nil, err + } + + return tx.Exec(query, args...) +} diff --git a/pkg/studio/export.go b/pkg/studio/export.go index 951a60417..21272ecc4 100644 --- a/pkg/studio/export.go +++ b/pkg/studio/export.go @@ -1,6 +1,7 @@ package studio import ( + "context" "fmt" "github.com/stashapp/stash/pkg/models" @@ -9,8 +10,15 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +type FinderImageStashIDGetter interface { + Finder + GetAliases(ctx context.Context, studioID int) ([]string, error) + GetImage(ctx context.Context, studioID int) ([]byte, error) + GetStashIDs(ctx context.Context, studioID int) ([]*models.StashID, error) +} + // ToJSON converts a Studio object into its JSON equivalent. -func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Studio, error) { +func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models.Studio) (*jsonschema.Studio, error) { newStudioJSON := jsonschema.Studio{ IgnoreAutoTag: studio.IgnoreAutoTag, CreatedAt: json.JSONTime{Time: studio.CreatedAt.Timestamp}, @@ -30,7 +38,7 @@ func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Stud } if studio.ParentID.Valid { - parent, err := reader.Find(int(studio.ParentID.Int64)) + parent, err := reader.Find(ctx, int(studio.ParentID.Int64)) if err != nil { return nil, fmt.Errorf("error getting parent studio: %v", err) } @@ -44,14 +52,14 @@ func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Stud newStudioJSON.Rating = int(studio.Rating.Int64) } - aliases, err := reader.GetAliases(studio.ID) + aliases, err := reader.GetAliases(ctx, studio.ID) if err != nil { return nil, fmt.Errorf("error getting studio aliases: %v", err) } newStudioJSON.Aliases = aliases - image, err := reader.GetImage(studio.ID) + image, err := reader.GetImage(ctx, studio.ID) if err != nil { return nil, fmt.Errorf("error getting studio image: %v", err) } @@ -60,7 +68,7 @@ func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Stud newStudioJSON.Image = utils.GetBase64StringFromData(image) } - stashIDs, _ := reader.GetStashIDs(studio.ID) + stashIDs, _ := reader.GetStashIDs(ctx, studio.ID) var ret []models.StashID for _, stashID := range stashIDs { newJoin := models.StashID{ diff --git a/pkg/studio/export_test.go b/pkg/studio/export_test.go index a1f261254..b15fbc018 100644 --- a/pkg/studio/export_test.go +++ b/pkg/studio/export_test.go @@ -1,6 +1,7 @@ package studio import ( + "context" "errors" "github.com/stashapp/stash/pkg/models" @@ -169,39 +170,40 @@ func initTestTable() { func TestToJSON(t *testing.T) { initTestTable() + ctx := context.Background() mockStudioReader := &mocks.StudioReaderWriter{} imageErr := errors.New("error getting image") - mockStudioReader.On("GetImage", studioID).Return(imageBytes, nil).Once() - mockStudioReader.On("GetImage", noImageID).Return(nil, nil).Once() - mockStudioReader.On("GetImage", errImageID).Return(nil, imageErr).Once() - mockStudioReader.On("GetImage", missingParentStudioID).Return(imageBytes, nil).Maybe() - mockStudioReader.On("GetImage", errStudioID).Return(imageBytes, nil).Maybe() - mockStudioReader.On("GetImage", errAliasID).Return(imageBytes, nil).Maybe() + mockStudioReader.On("GetImage", ctx, studioID).Return(imageBytes, nil).Once() + mockStudioReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once() + mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once() + mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe() + mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe() + mockStudioReader.On("GetImage", ctx, errAliasID).Return(imageBytes, nil).Maybe() parentStudioErr := errors.New("error getting parent studio") - mockStudioReader.On("Find", parentStudioID).Return(&parentStudio, nil) - mockStudioReader.On("Find", missingStudioID).Return(nil, nil) - mockStudioReader.On("Find", errParentStudioID).Return(nil, parentStudioErr) + mockStudioReader.On("Find", ctx, parentStudioID).Return(&parentStudio, nil) + mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil) + mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr) aliasErr := errors.New("error getting aliases") - mockStudioReader.On("GetAliases", studioID).Return([]string{"alias"}, nil).Once() - mockStudioReader.On("GetAliases", noImageID).Return(nil, nil).Once() - mockStudioReader.On("GetAliases", errImageID).Return(nil, nil).Once() - mockStudioReader.On("GetAliases", missingParentStudioID).Return(nil, nil).Once() - mockStudioReader.On("GetAliases", errAliasID).Return(nil, aliasErr).Once() + mockStudioReader.On("GetAliases", ctx, studioID).Return([]string{"alias"}, nil).Once() + mockStudioReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once() + mockStudioReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once() + mockStudioReader.On("GetAliases", ctx, missingParentStudioID).Return(nil, nil).Once() + mockStudioReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once() - mockStudioReader.On("GetStashIDs", studioID).Return(stashIDs, nil).Once() - mockStudioReader.On("GetStashIDs", noImageID).Return(nil, nil).Once() - mockStudioReader.On("GetStashIDs", missingParentStudioID).Return(stashIDs, nil).Once() + mockStudioReader.On("GetStashIDs", ctx, studioID).Return(stashIDs, nil).Once() + mockStudioReader.On("GetStashIDs", ctx, noImageID).Return(nil, nil).Once() + mockStudioReader.On("GetStashIDs", ctx, missingParentStudioID).Return(stashIDs, nil).Once() for i, s := range scenarios { studio := s.input - json, err := ToJSON(mockStudioReader, &studio) + json, err := ToJSON(ctx, mockStudioReader, &studio) switch { case !s.err && err != nil: diff --git a/pkg/studio/import.go b/pkg/studio/import.go index a44481982..627d81272 100644 --- a/pkg/studio/import.go +++ b/pkg/studio/import.go @@ -1,6 +1,7 @@ package studio import ( + "context" "database/sql" "errors" "fmt" @@ -11,10 +12,19 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +type NameFinderCreatorUpdater interface { + FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error) + Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error) + UpdateFull(ctx context.Context, updatedStudio models.Studio) (*models.Studio, error) + UpdateImage(ctx context.Context, studioID int, image []byte) error + UpdateAliases(ctx context.Context, studioID int, aliases []string) error + UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error +} + var ErrParentStudioNotExist = errors.New("parent studio does not exist") type Importer struct { - ReaderWriter models.StudioReaderWriter + ReaderWriter NameFinderCreatorUpdater Input jsonschema.Studio MissingRefBehaviour models.ImportMissingRefEnum @@ -22,7 +32,7 @@ type Importer struct { imageData []byte } -func (i *Importer) PreImport() error { +func (i *Importer) PreImport(ctx context.Context) error { checksum := md5.FromString(i.Input.Name) i.studio = models.Studio{ @@ -36,7 +46,7 @@ func (i *Importer) PreImport() error { Rating: sql.NullInt64{Int64: int64(i.Input.Rating), Valid: true}, } - if err := i.populateParentStudio(); err != nil { + if err := i.populateParentStudio(ctx); err != nil { return err } @@ -51,9 +61,9 @@ func (i *Importer) PreImport() error { return nil } -func (i *Importer) populateParentStudio() error { +func (i *Importer) populateParentStudio(ctx context.Context) error { if i.Input.ParentStudio != "" { - studio, err := i.ReaderWriter.FindByName(i.Input.ParentStudio, false) + studio, err := i.ReaderWriter.FindByName(ctx, i.Input.ParentStudio, false) if err != nil { return fmt.Errorf("error finding studio by name: %v", err) } @@ -68,7 +78,7 @@ func (i *Importer) populateParentStudio() error { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - parentID, err := i.createParentStudio(i.Input.ParentStudio) + parentID, err := i.createParentStudio(ctx, i.Input.ParentStudio) if err != nil { return err } @@ -85,10 +95,10 @@ func (i *Importer) populateParentStudio() error { return nil } -func (i *Importer) createParentStudio(name string) (int, error) { +func (i *Importer) createParentStudio(ctx context.Context, name string) (int, error) { newStudio := *models.NewStudio(name) - created, err := i.ReaderWriter.Create(newStudio) + created, err := i.ReaderWriter.Create(ctx, newStudio) if err != nil { return 0, err } @@ -96,20 +106,20 @@ func (i *Importer) createParentStudio(name string) (int, error) { return created.ID, nil } -func (i *Importer) PostImport(id int) error { +func (i *Importer) PostImport(ctx context.Context, id int) error { if len(i.imageData) > 0 { - if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil { + if err := i.ReaderWriter.UpdateImage(ctx, id, i.imageData); err != nil { return fmt.Errorf("error setting studio image: %v", err) } } if len(i.Input.StashIDs) > 0 { - if err := i.ReaderWriter.UpdateStashIDs(id, i.Input.StashIDs); err != nil { + if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil { return fmt.Errorf("error setting stash id: %v", err) } } - if err := i.ReaderWriter.UpdateAliases(id, i.Input.Aliases); err != nil { + if err := i.ReaderWriter.UpdateAliases(ctx, id, i.Input.Aliases); err != nil { return fmt.Errorf("error setting tag aliases: %v", err) } @@ -120,9 +130,9 @@ func (i *Importer) Name() string { return i.Input.Name } -func (i *Importer) FindExistingID() (*int, error) { +func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { const nocase = false - existing, err := i.ReaderWriter.FindByName(i.Name(), nocase) + existing, err := i.ReaderWriter.FindByName(ctx, i.Name(), nocase) if err != nil { return nil, err } @@ -135,8 +145,8 @@ func (i *Importer) FindExistingID() (*int, error) { return nil, nil } -func (i *Importer) Create() (*int, error) { - created, err := i.ReaderWriter.Create(i.studio) +func (i *Importer) Create(ctx context.Context) (*int, error) { + created, err := i.ReaderWriter.Create(ctx, i.studio) if err != nil { return nil, fmt.Errorf("error creating studio: %v", err) } @@ -145,10 +155,10 @@ func (i *Importer) Create() (*int, error) { return &id, nil } -func (i *Importer) Update(id int) error { +func (i *Importer) Update(ctx context.Context, id int) error { studio := i.studio studio.ID = id - _, err := i.ReaderWriter.UpdateFull(studio) + _, err := i.ReaderWriter.UpdateFull(ctx, studio) if err != nil { return fmt.Errorf("error updating existing studio: %v", err) } diff --git a/pkg/studio/import_test.go b/pkg/studio/import_test.go index 87b22519b..fc2ae402b 100644 --- a/pkg/studio/import_test.go +++ b/pkg/studio/import_test.go @@ -1,6 +1,7 @@ package studio import ( + "context" "errors" "testing" @@ -43,21 +44,22 @@ func TestImporterPreImport(t *testing.T) { IgnoreAutoTag: autoTagIgnored, }, } + ctx := context.Background() - err := i.PreImport() + err := i.PreImport(ctx) assert.NotNil(t, err) i.Input.Image = image - err = i.PreImport() + err = i.PreImport(ctx) assert.Nil(t, err) i.Input = *createFullJSONStudio(studioName, image, []string{"alias"}) i.Input.ParentStudio = "" - err = i.PreImport() + err = i.PreImport(ctx) assert.Nil(t, err) expectedStudio := createFullStudio(0, 0) @@ -68,6 +70,7 @@ func TestImporterPreImport(t *testing.T) { func TestImporterPreImportWithParent(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} + ctx := context.Background() i := Importer{ ReaderWriter: readerWriter, @@ -78,17 +81,17 @@ func TestImporterPreImportWithParent(t *testing.T) { }, } - readerWriter.On("FindByName", existingParentStudioName, false).Return(&models.Studio{ + readerWriter.On("FindByName", ctx, existingParentStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - readerWriter.On("FindByName", existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once() + readerWriter.On("FindByName", ctx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once() - err := i.PreImport() + err := i.PreImport(ctx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.studio.ParentID.Int64) i.Input.ParentStudio = existingParentStudioErr - err = i.PreImport() + err = i.PreImport(ctx) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -96,6 +99,7 @@ func TestImporterPreImportWithParent(t *testing.T) { func TestImporterPreImportWithMissingParent(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} + ctx := context.Background() i := Importer{ ReaderWriter: readerWriter, @@ -107,20 +111,20 @@ func TestImporterPreImportWithMissingParent(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumFail, } - readerWriter.On("FindByName", missingParentStudioName, false).Return(nil, nil).Times(3) - readerWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{ + readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Times(3) + readerWriter.On("Create", ctx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{ ID: existingStudioID, }, nil) - err := i.PreImport() + err := i.PreImport(ctx) assert.NotNil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore - err = i.PreImport() + err = i.PreImport(ctx) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumCreate - err = i.PreImport() + err = i.PreImport(ctx) assert.Nil(t, err) assert.Equal(t, int64(existingStudioID), i.studio.ParentID.Int64) @@ -129,6 +133,7 @@ func TestImporterPreImportWithMissingParent(t *testing.T) { func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} + ctx := context.Background() i := Importer{ ReaderWriter: readerWriter, @@ -140,15 +145,16 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { MissingRefBehaviour: models.ImportMissingRefEnumCreate, } - readerWriter.On("FindByName", missingParentStudioName, false).Return(nil, nil).Once() - readerWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) + readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Once() + readerWriter.On("Create", ctx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) - err := i.PreImport() + err := i.PreImport(ctx) assert.NotNil(t, err) } func TestImporterPostImport(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} + ctx := context.Background() i := Importer{ ReaderWriter: readerWriter, @@ -161,21 +167,21 @@ func TestImporterPostImport(t *testing.T) { updateStudioImageErr := errors.New("UpdateImage error") updateTagAliasErr := errors.New("UpdateAlias error") - readerWriter.On("UpdateImage", studioID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updateStudioImageErr).Once() - readerWriter.On("UpdateImage", errAliasID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once() + readerWriter.On("UpdateImage", ctx, errAliasID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateAliases", studioID, i.Input.Aliases).Return(nil).Once() - readerWriter.On("UpdateAliases", errImageID, i.Input.Aliases).Return(nil).Maybe() - readerWriter.On("UpdateAliases", errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() + readerWriter.On("UpdateAliases", ctx, studioID, i.Input.Aliases).Return(nil).Once() + readerWriter.On("UpdateAliases", ctx, errImageID, i.Input.Aliases).Return(nil).Maybe() + readerWriter.On("UpdateAliases", ctx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() - err := i.PostImport(studioID) + err := i.PostImport(ctx, studioID) assert.Nil(t, err) - err = i.PostImport(errImageID) + err = i.PostImport(ctx, errImageID) assert.NotNil(t, err) - err = i.PostImport(errAliasID) + err = i.PostImport(ctx, errAliasID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -183,6 +189,7 @@ func TestImporterPostImport(t *testing.T) { func TestImporterFindExistingID(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} + ctx := context.Background() i := Importer{ ReaderWriter: readerWriter, @@ -192,23 +199,23 @@ func TestImporterFindExistingID(t *testing.T) { } errFindByName := errors.New("FindByName error") - readerWriter.On("FindByName", studioName, false).Return(nil, nil).Once() - readerWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{ + readerWriter.On("FindByName", ctx, studioName, false).Return(nil, nil).Once() + readerWriter.On("FindByName", ctx, existingStudioName, false).Return(&models.Studio{ ID: existingStudioID, }, nil).Once() - readerWriter.On("FindByName", studioNameErr, false).Return(nil, errFindByName).Once() + readerWriter.On("FindByName", ctx, studioNameErr, false).Return(nil, errFindByName).Once() - id, err := i.FindExistingID() + id, err := i.FindExistingID(ctx) assert.Nil(t, id) assert.Nil(t, err) i.Input.Name = existingStudioName - id, err = i.FindExistingID() + id, err = i.FindExistingID(ctx) assert.Equal(t, existingStudioID, *id) assert.Nil(t, err) i.Input.Name = studioNameErr - id, err = i.FindExistingID() + id, err = i.FindExistingID(ctx) assert.Nil(t, id) assert.NotNil(t, err) @@ -217,6 +224,7 @@ func TestImporterFindExistingID(t *testing.T) { func TestCreate(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} + ctx := context.Background() studio := models.Studio{ Name: models.NullString(studioName), @@ -232,17 +240,17 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", studio).Return(&models.Studio{ + readerWriter.On("Create", ctx, studio).Return(&models.Studio{ ID: studioID, }, nil).Once() - readerWriter.On("Create", studioErr).Return(nil, errCreate).Once() + readerWriter.On("Create", ctx, studioErr).Return(nil, errCreate).Once() - id, err := i.Create() + id, err := i.Create(ctx) assert.Equal(t, studioID, *id) assert.Nil(t, err) i.studio = studioErr - id, err = i.Create() + id, err = i.Create(ctx) assert.Nil(t, id) assert.NotNil(t, err) @@ -251,6 +259,7 @@ func TestCreate(t *testing.T) { func TestUpdate(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} + ctx := context.Background() studio := models.Studio{ Name: models.NullString(studioName), @@ -269,18 +278,18 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input studio.ID = studioID - readerWriter.On("UpdateFull", studio).Return(nil, nil).Once() + readerWriter.On("UpdateFull", ctx, studio).Return(nil, nil).Once() - err := i.Update(studioID) + err := i.Update(ctx, studioID) assert.Nil(t, err) i.studio = studioErr // need to set id separately studioErr.ID = errImageID - readerWriter.On("UpdateFull", studioErr).Return(nil, errUpdate).Once() + readerWriter.On("UpdateFull", ctx, studioErr).Return(nil, errUpdate).Once() - err = i.Update(errImageID) + err = i.Update(ctx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) diff --git a/pkg/studio/query.go b/pkg/studio/query.go index 5b2f68896..dee499a1b 100644 --- a/pkg/studio/query.go +++ b/pkg/studio/query.go @@ -1,8 +1,20 @@ package studio -import "github.com/stashapp/stash/pkg/models" +import ( + "context" -func ByName(qb models.StudioReader, name string) (*models.Studio, error) { + "github.com/stashapp/stash/pkg/models" +) + +type Finder interface { + Find(ctx context.Context, id int) (*models.Studio, error) +} + +type Queryer interface { + Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) +} + +func ByName(ctx context.Context, qb Queryer, name string) (*models.Studio, error) { f := &models.StudioFilterType{ Name: &models.StringCriterionInput{ Value: name, @@ -11,7 +23,7 @@ func ByName(qb models.StudioReader, name string) (*models.Studio, error) { } pp := 1 - ret, count, err := qb.Query(f, &models.FindFilterType{ + ret, count, err := qb.Query(ctx, f, &models.FindFilterType{ PerPage: &pp, }) @@ -26,7 +38,7 @@ func ByName(qb models.StudioReader, name string) (*models.Studio, error) { return nil, nil } -func ByAlias(qb models.StudioReader, alias string) (*models.Studio, error) { +func ByAlias(ctx context.Context, qb Queryer, alias string) (*models.Studio, error) { f := &models.StudioFilterType{ Aliases: &models.StringCriterionInput{ Value: alias, @@ -35,7 +47,7 @@ func ByAlias(qb models.StudioReader, alias string) (*models.Studio, error) { } pp := 1 - ret, count, err := qb.Query(f, &models.FindFilterType{ + ret, count, err := qb.Query(ctx, f, &models.FindFilterType{ PerPage: &pp, }) diff --git a/pkg/studio/update.go b/pkg/studio/update.go index 35a655a73..addae5c94 100644 --- a/pkg/studio/update.go +++ b/pkg/studio/update.go @@ -1,11 +1,17 @@ package studio import ( + "context" "fmt" "github.com/stashapp/stash/pkg/models" ) +type NameFinderCreator interface { + FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error) + Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error) +} + type NameExistsError struct { Name string } @@ -25,9 +31,9 @@ func (e *NameUsedByAliasError) Error() string { // EnsureStudioNameUnique returns an error if the studio name provided // is used as a name or alias of another existing tag. -func EnsureStudioNameUnique(id int, name string, qb models.StudioReader) error { +func EnsureStudioNameUnique(ctx context.Context, id int, name string, qb Queryer) error { // ensure name is unique - sameNameStudio, err := ByName(qb, name) + sameNameStudio, err := ByName(ctx, qb, name) if err != nil { return err } @@ -39,7 +45,7 @@ func EnsureStudioNameUnique(id int, name string, qb models.StudioReader) error { } // query by alias - sameNameStudio, err = ByAlias(qb, name) + sameNameStudio, err = ByAlias(ctx, qb, name) if err != nil { return err } @@ -54,9 +60,9 @@ func EnsureStudioNameUnique(id int, name string, qb models.StudioReader) error { return nil } -func EnsureAliasesUnique(id int, aliases []string, qb models.StudioReader) error { +func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb Queryer) error { for _, a := range aliases { - if err := EnsureStudioNameUnique(id, a, qb); err != nil { + if err := EnsureStudioNameUnique(ctx, id, a, qb); err != nil { return err } } diff --git a/pkg/tag/export.go b/pkg/tag/export.go index e70392379..20c1b4adc 100644 --- a/pkg/tag/export.go +++ b/pkg/tag/export.go @@ -1,6 +1,7 @@ package tag import ( + "context" "fmt" "github.com/stashapp/stash/pkg/models" @@ -9,8 +10,14 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +type FinderAliasImageGetter interface { + GetAliases(ctx context.Context, studioID int) ([]string, error) + GetImage(ctx context.Context, tagID int) ([]byte, error) + FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error) +} + // ToJSON converts a Tag object into its JSON equivalent. -func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) { +func ToJSON(ctx context.Context, reader FinderAliasImageGetter, tag *models.Tag) (*jsonschema.Tag, error) { newTagJSON := jsonschema.Tag{ Name: tag.Name, IgnoreAutoTag: tag.IgnoreAutoTag, @@ -18,14 +25,14 @@ func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) { UpdatedAt: json.JSONTime{Time: tag.UpdatedAt.Timestamp}, } - aliases, err := reader.GetAliases(tag.ID) + aliases, err := reader.GetAliases(ctx, tag.ID) if err != nil { return nil, fmt.Errorf("error getting tag aliases: %v", err) } newTagJSON.Aliases = aliases - image, err := reader.GetImage(tag.ID) + image, err := reader.GetImage(ctx, tag.ID) if err != nil { return nil, fmt.Errorf("error getting tag image: %v", err) } @@ -34,7 +41,7 @@ func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) { newTagJSON.Image = utils.GetBase64StringFromData(image) } - parents, err := reader.FindByChildTagID(tag.ID) + parents, err := reader.FindByChildTagID(ctx, tag.ID) if err != nil { return nil, fmt.Errorf("error getting parents: %v", err) } diff --git a/pkg/tag/export_test.go b/pkg/tag/export_test.go index 930c0fdb1..255c940dd 100644 --- a/pkg/tag/export_test.go +++ b/pkg/tag/export_test.go @@ -1,6 +1,7 @@ package tag import ( + "context" "errors" "github.com/stashapp/stash/pkg/models" @@ -106,33 +107,34 @@ func initTestTable() { func TestToJSON(t *testing.T) { initTestTable() + ctx := context.Background() mockTagReader := &mocks.TagReaderWriter{} imageErr := errors.New("error getting image") aliasErr := errors.New("error getting aliases") parentsErr := errors.New("error getting parents") - mockTagReader.On("GetAliases", tagID).Return([]string{"alias"}, nil).Once() - mockTagReader.On("GetAliases", noImageID).Return(nil, nil).Once() - mockTagReader.On("GetAliases", errImageID).Return(nil, nil).Once() - mockTagReader.On("GetAliases", errAliasID).Return(nil, aliasErr).Once() - mockTagReader.On("GetAliases", withParentsID).Return(nil, nil).Once() - mockTagReader.On("GetAliases", errParentsID).Return(nil, nil).Once() + mockTagReader.On("GetAliases", ctx, tagID).Return([]string{"alias"}, nil).Once() + mockTagReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once() + mockTagReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once() + mockTagReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once() + mockTagReader.On("GetAliases", ctx, withParentsID).Return(nil, nil).Once() + mockTagReader.On("GetAliases", ctx, errParentsID).Return(nil, nil).Once() - mockTagReader.On("GetImage", tagID).Return(imageBytes, nil).Once() - mockTagReader.On("GetImage", noImageID).Return(nil, nil).Once() - mockTagReader.On("GetImage", errImageID).Return(nil, imageErr).Once() - mockTagReader.On("GetImage", withParentsID).Return(imageBytes, nil).Once() - mockTagReader.On("GetImage", errParentsID).Return(nil, nil).Once() + mockTagReader.On("GetImage", ctx, tagID).Return(imageBytes, nil).Once() + mockTagReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once() + mockTagReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once() + mockTagReader.On("GetImage", ctx, withParentsID).Return(imageBytes, nil).Once() + mockTagReader.On("GetImage", ctx, errParentsID).Return(nil, nil).Once() - mockTagReader.On("FindByChildTagID", tagID).Return(nil, nil).Once() - mockTagReader.On("FindByChildTagID", noImageID).Return(nil, nil).Once() - mockTagReader.On("FindByChildTagID", withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once() - mockTagReader.On("FindByChildTagID", errParentsID).Return(nil, parentsErr).Once() + mockTagReader.On("FindByChildTagID", ctx, tagID).Return(nil, nil).Once() + mockTagReader.On("FindByChildTagID", ctx, noImageID).Return(nil, nil).Once() + mockTagReader.On("FindByChildTagID", ctx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once() + mockTagReader.On("FindByChildTagID", ctx, errParentsID).Return(nil, parentsErr).Once() for i, s := range scenarios { tag := s.tag - json, err := ToJSON(mockTagReader, &tag) + json, err := ToJSON(ctx, mockTagReader, &tag) switch { case !s.err && err != nil: diff --git a/pkg/tag/import.go b/pkg/tag/import.go index 66028946c..937ea2359 100644 --- a/pkg/tag/import.go +++ b/pkg/tag/import.go @@ -1,6 +1,7 @@ package tag import ( + "context" "fmt" "github.com/stashapp/stash/pkg/models" @@ -8,6 +9,15 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +type NameFinderCreatorUpdater interface { + FindByName(ctx context.Context, name string, nocase bool) (*models.Tag, error) + Create(ctx context.Context, newTag models.Tag) (*models.Tag, error) + UpdateFull(ctx context.Context, updatedTag models.Tag) (*models.Tag, error) + UpdateImage(ctx context.Context, tagID int, image []byte) error + UpdateAliases(ctx context.Context, tagID int, aliases []string) error + UpdateParentTags(ctx context.Context, tagID int, parentIDs []int) error +} + type ParentTagNotExistError struct { missingParent string } @@ -21,7 +31,7 @@ func (e ParentTagNotExistError) MissingParent() string { } type Importer struct { - ReaderWriter models.TagReaderWriter + ReaderWriter NameFinderCreatorUpdater Input jsonschema.Tag MissingRefBehaviour models.ImportMissingRefEnum @@ -29,7 +39,7 @@ type Importer struct { imageData []byte } -func (i *Importer) PreImport() error { +func (i *Importer) PreImport(ctx context.Context) error { i.tag = models.Tag{ Name: i.Input.Name, IgnoreAutoTag: i.Input.IgnoreAutoTag, @@ -48,23 +58,23 @@ func (i *Importer) PreImport() error { return nil } -func (i *Importer) PostImport(id int) error { +func (i *Importer) PostImport(ctx context.Context, id int) error { if len(i.imageData) > 0 { - if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil { + if err := i.ReaderWriter.UpdateImage(ctx, id, i.imageData); err != nil { return fmt.Errorf("error setting tag image: %v", err) } } - if err := i.ReaderWriter.UpdateAliases(id, i.Input.Aliases); err != nil { + if err := i.ReaderWriter.UpdateAliases(ctx, id, i.Input.Aliases); err != nil { return fmt.Errorf("error setting tag aliases: %v", err) } - parents, err := i.getParents() + parents, err := i.getParents(ctx) if err != nil { return err } - if err := i.ReaderWriter.UpdateParentTags(id, parents); err != nil { + if err := i.ReaderWriter.UpdateParentTags(ctx, id, parents); err != nil { return fmt.Errorf("error setting parents: %v", err) } @@ -75,9 +85,9 @@ func (i *Importer) Name() string { return i.Input.Name } -func (i *Importer) FindExistingID() (*int, error) { +func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { const nocase = false - existing, err := i.ReaderWriter.FindByName(i.Name(), nocase) + existing, err := i.ReaderWriter.FindByName(ctx, i.Name(), nocase) if err != nil { return nil, err } @@ -90,8 +100,8 @@ func (i *Importer) FindExistingID() (*int, error) { return nil, nil } -func (i *Importer) Create() (*int, error) { - created, err := i.ReaderWriter.Create(i.tag) +func (i *Importer) Create(ctx context.Context) (*int, error) { + created, err := i.ReaderWriter.Create(ctx, i.tag) if err != nil { return nil, fmt.Errorf("error creating tag: %v", err) } @@ -100,10 +110,10 @@ func (i *Importer) Create() (*int, error) { return &id, nil } -func (i *Importer) Update(id int) error { +func (i *Importer) Update(ctx context.Context, id int) error { tag := i.tag tag.ID = id - _, err := i.ReaderWriter.UpdateFull(tag) + _, err := i.ReaderWriter.UpdateFull(ctx, tag) if err != nil { return fmt.Errorf("error updating existing tag: %v", err) } @@ -111,10 +121,10 @@ func (i *Importer) Update(id int) error { return nil } -func (i *Importer) getParents() ([]int, error) { +func (i *Importer) getParents(ctx context.Context) ([]int, error) { var parents []int for _, parent := range i.Input.Parents { - tag, err := i.ReaderWriter.FindByName(parent, false) + tag, err := i.ReaderWriter.FindByName(ctx, parent, false) if err != nil { return nil, fmt.Errorf("error finding parent by name: %v", err) } @@ -129,7 +139,7 @@ func (i *Importer) getParents() ([]int, error) { } if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { - parentID, err := i.createParent(parent) + parentID, err := i.createParent(ctx, parent) if err != nil { return nil, err } @@ -143,10 +153,10 @@ func (i *Importer) getParents() ([]int, error) { return parents, nil } -func (i *Importer) createParent(name string) (int, error) { +func (i *Importer) createParent(ctx context.Context, name string) (int, error) { newTag := *models.NewTag(name) - created, err := i.ReaderWriter.Create(newTag) + created, err := i.ReaderWriter.Create(ctx, newTag) if err != nil { return 0, err } diff --git a/pkg/tag/import_test.go b/pkg/tag/import_test.go index fb6f3c58f..e4fb3ce8d 100644 --- a/pkg/tag/import_test.go +++ b/pkg/tag/import_test.go @@ -1,6 +1,7 @@ package tag import ( + "context" "errors" "testing" @@ -23,6 +24,8 @@ const ( existingTagID = 100 ) +var testCtx = context.Background() + func TestImporterName(t *testing.T) { i := Importer{ Input: jsonschema.Tag{ @@ -42,13 +45,13 @@ func TestImporterPreImport(t *testing.T) { }, } - err := i.PreImport() + err := i.PreImport(testCtx) assert.NotNil(t, err) i.Input.Image = image - err = i.PreImport() + err = i.PreImport(testCtx) assert.Nil(t, err) } @@ -68,38 +71,38 @@ func TestImporterPostImport(t *testing.T) { updateTagAliasErr := errors.New("UpdateAlias error") updateTagParentsErr := errors.New("UpdateParentTags error") - readerWriter.On("UpdateAliases", tagID, i.Input.Aliases).Return(nil).Once() - readerWriter.On("UpdateAliases", errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() - readerWriter.On("UpdateAliases", withParentsID, i.Input.Aliases).Return(nil).Once() - readerWriter.On("UpdateAliases", errParentsID, i.Input.Aliases).Return(nil).Once() + readerWriter.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once() + readerWriter.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() + readerWriter.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once() + readerWriter.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once() - readerWriter.On("UpdateImage", tagID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", errAliasID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updateTagImageErr).Once() - readerWriter.On("UpdateImage", withParentsID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateImage", errParentsID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once() + readerWriter.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once() var parentTags []int - readerWriter.On("UpdateParentTags", tagID, parentTags).Return(nil).Once() - readerWriter.On("UpdateParentTags", withParentsID, []int{100}).Return(nil).Once() - readerWriter.On("UpdateParentTags", errParentsID, []int{100}).Return(updateTagParentsErr).Once() + readerWriter.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once() + readerWriter.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once() + readerWriter.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once() - readerWriter.On("FindByName", "Parent", false).Return(&models.Tag{ID: 100}, nil) + readerWriter.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil) - err := i.PostImport(tagID) + err := i.PostImport(testCtx, tagID) assert.Nil(t, err) - err = i.PostImport(errImageID) + err = i.PostImport(testCtx, errImageID) assert.NotNil(t, err) - err = i.PostImport(errAliasID) + err = i.PostImport(testCtx, errAliasID) assert.NotNil(t, err) i.Input.Parents = []string{"Parent"} - err = i.PostImport(withParentsID) + err = i.PostImport(testCtx, withParentsID) assert.Nil(t, err) - err = i.PostImport(errParentsID) + err = i.PostImport(testCtx, errParentsID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) @@ -129,70 +132,70 @@ func TestImporterPostImportParentMissing(t *testing.T) { var emptyParents []int - readerWriter.On("UpdateImage", mock.Anything, mock.Anything).Return(nil) - readerWriter.On("UpdateAliases", mock.Anything, mock.Anything).Return(nil) + readerWriter.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil) + readerWriter.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil) - readerWriter.On("FindByName", "Create", false).Return(nil, nil).Once() - readerWriter.On("FindByName", "CreateError", false).Return(nil, nil).Once() - readerWriter.On("FindByName", "CreateFindError", false).Return(nil, findError).Once() - readerWriter.On("FindByName", "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once() - readerWriter.On("FindByName", "Fail", false).Return(nil, nil).Once() - readerWriter.On("FindByName", "FailFindError", false).Return(nil, findError) - readerWriter.On("FindByName", "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once() - readerWriter.On("FindByName", "Ignore", false).Return(nil, nil).Once() - readerWriter.On("FindByName", "IgnoreFindError", false).Return(nil, findError) - readerWriter.On("FindByName", "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once() + readerWriter.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once() + readerWriter.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once() + readerWriter.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once() + readerWriter.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once() + readerWriter.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once() + readerWriter.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError) + readerWriter.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once() + readerWriter.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once() + readerWriter.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError) + readerWriter.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once() - readerWriter.On("UpdateParentTags", createID, []int{100}).Return(nil).Once() - readerWriter.On("UpdateParentTags", createFoundID, []int{101}).Return(nil).Once() - readerWriter.On("UpdateParentTags", failFoundID, []int{102}).Return(nil).Once() - readerWriter.On("UpdateParentTags", ignoreID, emptyParents).Return(nil).Once() - readerWriter.On("UpdateParentTags", ignoreFoundID, []int{103}).Return(nil).Once() + readerWriter.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once() + readerWriter.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once() + readerWriter.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once() + readerWriter.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once() + readerWriter.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once() - readerWriter.On("Create", mock.MatchedBy(func(t models.Tag) bool { return t.Name == "Create" })).Return(&models.Tag{ID: 100}, nil).Once() - readerWriter.On("Create", mock.MatchedBy(func(t models.Tag) bool { return t.Name == "CreateError" })).Return(nil, errors.New("failed creating parent")).Once() + readerWriter.On("Create", testCtx, mock.MatchedBy(func(t models.Tag) bool { return t.Name == "Create" })).Return(&models.Tag{ID: 100}, nil).Once() + readerWriter.On("Create", testCtx, mock.MatchedBy(func(t models.Tag) bool { return t.Name == "CreateError" })).Return(nil, errors.New("failed creating parent")).Once() i.MissingRefBehaviour = models.ImportMissingRefEnumCreate i.Input.Parents = []string{"Create"} - err := i.PostImport(createID) + err := i.PostImport(testCtx, createID) assert.Nil(t, err) i.Input.Parents = []string{"CreateError"} - err = i.PostImport(createErrorID) + err = i.PostImport(testCtx, createErrorID) assert.NotNil(t, err) i.Input.Parents = []string{"CreateFindError"} - err = i.PostImport(createFindErrorID) + err = i.PostImport(testCtx, createFindErrorID) assert.NotNil(t, err) i.Input.Parents = []string{"CreateFound"} - err = i.PostImport(createFoundID) + err = i.PostImport(testCtx, createFoundID) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumFail i.Input.Parents = []string{"Fail"} - err = i.PostImport(failID) + err = i.PostImport(testCtx, failID) assert.NotNil(t, err) i.Input.Parents = []string{"FailFindError"} - err = i.PostImport(failFindErrorID) + err = i.PostImport(testCtx, failFindErrorID) assert.NotNil(t, err) i.Input.Parents = []string{"FailFound"} - err = i.PostImport(failFoundID) + err = i.PostImport(testCtx, failFoundID) assert.Nil(t, err) i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore i.Input.Parents = []string{"Ignore"} - err = i.PostImport(ignoreID) + err = i.PostImport(testCtx, ignoreID) assert.Nil(t, err) i.Input.Parents = []string{"IgnoreFindError"} - err = i.PostImport(ignoreFindErrorID) + err = i.PostImport(testCtx, ignoreFindErrorID) assert.NotNil(t, err) i.Input.Parents = []string{"IgnoreFound"} - err = i.PostImport(ignoreFoundID) + err = i.PostImport(testCtx, ignoreFoundID) assert.Nil(t, err) readerWriter.AssertExpectations(t) @@ -209,23 +212,23 @@ func TestImporterFindExistingID(t *testing.T) { } errFindByName := errors.New("FindByName error") - readerWriter.On("FindByName", tagName, false).Return(nil, nil).Once() - readerWriter.On("FindByName", existingTagName, false).Return(&models.Tag{ + readerWriter.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once() + readerWriter.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{ ID: existingTagID, }, nil).Once() - readerWriter.On("FindByName", tagNameErr, false).Return(nil, errFindByName).Once() + readerWriter.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once() - id, err := i.FindExistingID() + id, err := i.FindExistingID(testCtx) assert.Nil(t, id) assert.Nil(t, err) i.Input.Name = existingTagName - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Equal(t, existingTagID, *id) assert.Nil(t, err) i.Input.Name = tagNameErr - id, err = i.FindExistingID() + id, err = i.FindExistingID(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -249,17 +252,17 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", tag).Return(&models.Tag{ + readerWriter.On("Create", testCtx, tag).Return(&models.Tag{ ID: tagID, }, nil).Once() - readerWriter.On("Create", tagErr).Return(nil, errCreate).Once() + readerWriter.On("Create", testCtx, tagErr).Return(nil, errCreate).Once() - id, err := i.Create() + id, err := i.Create(testCtx) assert.Equal(t, tagID, *id) assert.Nil(t, err) i.tag = tagErr - id, err = i.Create() + id, err = i.Create(testCtx) assert.Nil(t, id) assert.NotNil(t, err) @@ -286,18 +289,18 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input tag.ID = tagID - readerWriter.On("UpdateFull", tag).Return(nil, nil).Once() + readerWriter.On("UpdateFull", testCtx, tag).Return(nil, nil).Once() - err := i.Update(tagID) + err := i.Update(testCtx, tagID) assert.Nil(t, err) i.tag = tagErr // need to set id separately tagErr.ID = errImageID - readerWriter.On("UpdateFull", tagErr).Return(nil, errUpdate).Once() + readerWriter.On("UpdateFull", testCtx, tagErr).Return(nil, errUpdate).Once() - err = i.Update(errImageID) + err = i.Update(testCtx, errImageID) assert.NotNil(t, err) readerWriter.AssertExpectations(t) diff --git a/pkg/tag/query.go b/pkg/tag/query.go index ce7406403..a048054d7 100644 --- a/pkg/tag/query.go +++ b/pkg/tag/query.go @@ -1,8 +1,20 @@ package tag -import "github.com/stashapp/stash/pkg/models" +import ( + "context" -func ByName(qb models.TagReader, name string) (*models.Tag, error) { + "github.com/stashapp/stash/pkg/models" +) + +type Finder interface { + Find(ctx context.Context, id int) (*models.Tag, error) +} + +type Queryer interface { + Query(ctx context.Context, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) +} + +func ByName(ctx context.Context, qb Queryer, name string) (*models.Tag, error) { f := &models.TagFilterType{ Name: &models.StringCriterionInput{ Value: name, @@ -11,7 +23,7 @@ func ByName(qb models.TagReader, name string) (*models.Tag, error) { } pp := 1 - ret, count, err := qb.Query(f, &models.FindFilterType{ + ret, count, err := qb.Query(ctx, f, &models.FindFilterType{ PerPage: &pp, }) @@ -26,7 +38,7 @@ func ByName(qb models.TagReader, name string) (*models.Tag, error) { return nil, nil } -func ByAlias(qb models.TagReader, alias string) (*models.Tag, error) { +func ByAlias(ctx context.Context, qb Queryer, alias string) (*models.Tag, error) { f := &models.TagFilterType{ Aliases: &models.StringCriterionInput{ Value: alias, @@ -35,7 +47,7 @@ func ByAlias(qb models.TagReader, alias string) (*models.Tag, error) { } pp := 1 - ret, count, err := qb.Query(f, &models.FindFilterType{ + ret, count, err := qb.Query(ctx, f, &models.FindFilterType{ PerPage: &pp, }) diff --git a/pkg/tag/update.go b/pkg/tag/update.go index dfee55154..0c219b26c 100644 --- a/pkg/tag/update.go +++ b/pkg/tag/update.go @@ -1,11 +1,17 @@ package tag import ( + "context" "fmt" "github.com/stashapp/stash/pkg/models" ) +type NameFinderCreator interface { + FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Tag, error) + Create(ctx context.Context, newTag models.Tag) (*models.Tag, error) +} + type NameExistsError struct { Name string } @@ -37,9 +43,9 @@ func (e *InvalidTagHierarchyError) Error() string { // EnsureTagNameUnique returns an error if the tag name provided // is used as a name or alias of another existing tag. -func EnsureTagNameUnique(id int, name string, qb models.TagReader) error { +func EnsureTagNameUnique(ctx context.Context, id int, name string, qb Queryer) error { // ensure name is unique - sameNameTag, err := ByName(qb, name) + sameNameTag, err := ByName(ctx, qb, name) if err != nil { return err } @@ -51,7 +57,7 @@ func EnsureTagNameUnique(id int, name string, qb models.TagReader) error { } // query by alias - sameNameTag, err = ByAlias(qb, name) + sameNameTag, err = ByAlias(ctx, qb, name) if err != nil { return err } @@ -66,9 +72,9 @@ func EnsureTagNameUnique(id int, name string, qb models.TagReader) error { return nil } -func EnsureAliasesUnique(id int, aliases []string, qb models.TagReader) error { +func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb Queryer) error { for _, a := range aliases { - if err := EnsureTagNameUnique(id, a, qb); err != nil { + if err := EnsureTagNameUnique(ctx, id, a, qb); err != nil { return err } } @@ -76,12 +82,19 @@ func EnsureAliasesUnique(id int, aliases []string, qb models.TagReader) error { return nil } -func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.TagReader) error { +type RelationshipGetter interface { + FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) + FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) + FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error) + FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error) +} + +func ValidateHierarchy(ctx context.Context, tag *models.Tag, parentIDs, childIDs []int, qb RelationshipGetter) error { id := tag.ID allAncestors := make(map[int]*models.TagPath) allDescendants := make(map[int]*models.TagPath) - parentsAncestors, err := qb.FindAllAncestors(id, nil) + parentsAncestors, err := qb.FindAllAncestors(ctx, id, nil) if err != nil { return err } @@ -90,7 +103,7 @@ func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.Tag allAncestors[ancestorTag.ID] = ancestorTag } - childsDescendants, err := qb.FindAllDescendants(id, nil) + childsDescendants, err := qb.FindAllDescendants(ctx, id, nil) if err != nil { return err } @@ -128,7 +141,7 @@ func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.Tag } if parentIDs == nil { - parentTags, err := qb.FindByChildTagID(id) + parentTags, err := qb.FindByChildTagID(ctx, id) if err != nil { return err } @@ -139,7 +152,7 @@ func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.Tag } if childIDs == nil { - childTags, err := qb.FindByParentTagID(id) + childTags, err := qb.FindByParentTagID(ctx, id) if err != nil { return err } @@ -164,7 +177,7 @@ func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.Tag return nil } -func MergeHierarchy(destination int, sources []int, qb models.TagReader) ([]int, []int, error) { +func MergeHierarchy(ctx context.Context, destination int, sources []int, qb RelationshipGetter) ([]int, []int, error) { var mergedParents, mergedChildren []int allIds := append([]int{destination}, sources...) @@ -192,14 +205,14 @@ func MergeHierarchy(destination int, sources []int, qb models.TagReader) ([]int, } for _, id := range allIds { - parents, err := qb.FindByChildTagID(id) + parents, err := qb.FindByChildTagID(ctx, id) if err != nil { return nil, nil, err } mergedParents = addTo(mergedParents, parents) - children, err := qb.FindByParentTagID(id) + children, err := qb.FindByParentTagID(ctx, id) if err != nil { return nil, nil, err } diff --git a/pkg/tag/update_test.go b/pkg/tag/update_test.go index f7338da23..4cc14e961 100644 --- a/pkg/tag/update_test.go +++ b/pkg/tag/update_test.go @@ -1,6 +1,7 @@ package tag import ( + "context" "fmt" "testing" @@ -219,6 +220,7 @@ func TestEnsureHierarchy(t *testing.T) { func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) { mockTagReader := &mocks.TagReaderWriter{} + ctx := context.Background() var parentIDs, childIDs []int find := make(map[int]*models.Tag) @@ -245,33 +247,33 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, if queryParents { parentIDs = nil - mockTagReader.On("FindByChildTagID", tc.id).Return(tc.parents, nil).Once() + mockTagReader.On("FindByChildTagID", ctx, tc.id).Return(tc.parents, nil).Once() } if queryChildren { childIDs = nil - mockTagReader.On("FindByParentTagID", tc.id).Return(tc.children, nil).Once() + mockTagReader.On("FindByParentTagID", ctx, tc.id).Return(tc.children, nil).Once() } - mockTagReader.On("FindAllAncestors", mock.AnythingOfType("int"), []int(nil)).Return(func(tagID int, excludeIDs []int) []*models.TagPath { + mockTagReader.On("FindAllAncestors", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { return tc.onFindAllAncestors - }, func(tagID int, excludeIDs []int) error { + }, func(ctx context.Context, tagID int, excludeIDs []int) error { if tc.onFindAllAncestors != nil { return nil } return fmt.Errorf("undefined ancestors for: %d", tagID) }).Maybe() - mockTagReader.On("FindAllDescendants", mock.AnythingOfType("int"), []int(nil)).Return(func(tagID int, excludeIDs []int) []*models.TagPath { + mockTagReader.On("FindAllDescendants", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { return tc.onFindAllDescendants - }, func(tagID int, excludeIDs []int) error { + }, func(ctx context.Context, tagID int, excludeIDs []int) error { if tc.onFindAllDescendants != nil { return nil } return fmt.Errorf("undefined descendants for: %d", tagID) }).Maybe() - res := ValidateHierarchy(testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader) + res := ValidateHierarchy(ctx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader) assert := assert.New(t) diff --git a/pkg/txn/transaction.go b/pkg/txn/transaction.go new file mode 100644 index 000000000..6939828b4 --- /dev/null +++ b/pkg/txn/transaction.go @@ -0,0 +1,38 @@ +package txn + +import "context" + +type Manager interface { + Begin(ctx context.Context) (context.Context, error) + Commit(ctx context.Context) error + Rollback(ctx context.Context) error +} + +type TxnFunc func(ctx context.Context) error + +func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error { + var err error + ctx, err = m.Begin(ctx) + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + // a panic occurred, rollback and repanic + _ = m.Rollback(ctx) + panic(p) + } + + if err != nil { + // something went wrong, rollback + _ = m.Rollback(ctx) + } else { + // all good, commit + err = m.Commit(ctx) + } + }() + + err = fn(ctx) + return err +}