Restructure data layer (#2532)

* Add new txn manager interface
* Add txn management to sqlite
* Rename get to getByID
* Add contexts to repository methods
* Update query builders
* Add context to reader writer interfaces
* Use repository in resolver
* Tighten interfaces
* Tighten interfaces in dlna
* Tighten interfaces in match package
* Tighten interfaces in scraper package
* Tighten interfaces in scan code
* Tighten interfaces on autotag package
* Remove ReaderWriter usage
* Merge database package into sqlite
This commit is contained in:
WithoutPants 2022-05-19 17:49:32 +10:00
parent 7b5bd80515
commit 964b559309
244 changed files with 7377 additions and 6699 deletions

View File

@ -11,6 +11,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/txn"
)
var (
@ -30,7 +31,9 @@ type hookExecutor interface {
}
type Resolver struct {
txnManager models.TransactionManager
txnManager txn.Manager
repository models.Repository
hookExecutor hookExecutor
}
@ -85,17 +88,13 @@ type studioResolver struct{ *Resolver }
type movieResolver struct{ *Resolver }
type tagResolver struct{ *Resolver }
func (r *Resolver) withTxn(ctx context.Context, fn func(r models.Repository) error) error {
return r.txnManager.WithTxn(ctx, fn)
}
func (r *Resolver) withReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error {
return r.txnManager.WithReadTxn(ctx, fn)
func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithTxn(ctx, r.txnManager, fn)
}
func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SceneMarker().Wall(q)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.SceneMarker.Wall(ctx, q)
return err
}); err != nil {
return nil, err
@ -104,8 +103,8 @@ func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*model
}
func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().Wall(q)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Scene.Wall(ctx, q)
return err
}); err != nil {
return nil, err
@ -115,8 +114,8 @@ func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models
}
func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) (ret []*models.MarkerStringsResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SceneMarker().GetMarkerStrings(q, sort)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.SceneMarker.GetMarkerStrings(ctx, q, sort)
return err
}); err != nil {
return nil, err
@ -127,24 +126,25 @@ func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *stri
func (r *queryResolver) Stats(ctx context.Context) (*StatsResultType, error) {
var ret StatsResultType
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
scenesQB := repo.Scene()
imageQB := repo.Image()
galleryQB := repo.Gallery()
studiosQB := repo.Studio()
performersQB := repo.Performer()
moviesQB := repo.Movie()
tagsQB := repo.Tag()
scenesCount, _ := scenesQB.Count()
scenesSize, _ := scenesQB.Size()
scenesDuration, _ := scenesQB.Duration()
imageCount, _ := imageQB.Count()
imageSize, _ := imageQB.Size()
galleryCount, _ := galleryQB.Count()
performersCount, _ := performersQB.Count()
studiosCount, _ := studiosQB.Count()
moviesCount, _ := moviesQB.Count()
tagsCount, _ := tagsQB.Count()
if err := r.withTxn(ctx, func(ctx context.Context) error {
repo := r.repository
scenesQB := repo.Scene
imageQB := repo.Image
galleryQB := repo.Gallery
studiosQB := repo.Studio
performersQB := repo.Performer
moviesQB := repo.Movie
tagsQB := repo.Tag
scenesCount, _ := scenesQB.Count(ctx)
scenesSize, _ := scenesQB.Size(ctx)
scenesDuration, _ := scenesQB.Duration(ctx)
imageCount, _ := imageQB.Count(ctx)
imageSize, _ := imageQB.Size(ctx)
galleryCount, _ := galleryQB.Count(ctx)
performersCount, _ := performersQB.Count(ctx)
studiosCount, _ := studiosQB.Count(ctx)
moviesCount, _ := moviesQB.Count(ctx)
tagsCount, _ := tagsQB.Count(ctx)
ret = StatsResultType{
SceneCount: scenesCount,
@ -202,15 +202,15 @@ func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([
var keys []int
tags := make(map[int]*SceneMarkerTag)
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
sceneMarkers, err := repo.SceneMarker().FindBySceneID(sceneID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneMarkers, err := r.repository.SceneMarker.FindBySceneID(ctx, sceneID)
if err != nil {
return err
}
tqb := repo.Tag()
tqb := r.repository.Tag
for _, sceneMarker := range sceneMarkers {
markerPrimaryTag, err := tqb.Find(sceneMarker.PrimaryTagID)
markerPrimaryTag, err := tqb.Find(ctx, sceneMarker.PrimaryTagID)
if err != nil {
return err
}

View File

@ -24,12 +24,12 @@ func (r *galleryResolver) Title(ctx context.Context, obj *models.Gallery) (*stri
}
func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret []*models.Image, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
// #2376 - sort images by path
// doing this via Query is really slow, so stick with FindByGalleryID
ret, err = repo.Image().FindByGalleryID(obj.ID)
ret, err = r.repository.Image.FindByGalleryID(ctx, obj.ID)
if err != nil {
return err
}
@ -43,9 +43,9 @@ func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret
}
func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (ret *models.Image, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
// doing this via Query is really slow, so stick with FindByGalleryID
imgs, err := repo.Image().FindByGalleryID(obj.ID)
imgs, err := r.repository.Image.FindByGalleryID(ctx, obj.ID)
if err != nil {
return err
}
@ -100,9 +100,9 @@ func (r *galleryResolver) Rating(ctx context.Context, obj *models.Gallery) (*int
}
func (r *galleryResolver) Scenes(ctx context.Context, obj *models.Gallery) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = repo.Scene().FindByGalleryID(obj.ID)
ret, err = r.repository.Scene.FindByGalleryID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -116,9 +116,9 @@ func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret
return nil, nil
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
@ -128,9 +128,9 @@ func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret
}
func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = repo.Tag().FindByGalleryID(obj.ID)
ret, err = r.repository.Tag.FindByGalleryID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -140,9 +140,9 @@ func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []
}
func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (ret []*models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = repo.Performer().FindByGalleryID(obj.ID)
ret, err = r.repository.Performer.FindByGalleryID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -152,9 +152,9 @@ func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (
}
func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (ret int, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = repo.Image().CountByGalleryID(obj.ID)
ret, err = r.repository.Image.CountByGalleryID(ctx, obj.ID)
return err
}); err != nil {
return 0, err

View File

@ -45,9 +45,9 @@ func (r *imageResolver) Paths(ctx context.Context, obj *models.Image) (*ImagePat
}
func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) (ret []*models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = repo.Gallery().FindByImageID(obj.ID)
ret, err = r.repository.Gallery.FindByImageID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -61,8 +61,8 @@ func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *mod
return nil, nil
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
@ -72,8 +72,8 @@ func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *mod
}
func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().FindByImageID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.FindByImageID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -83,8 +83,8 @@ func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*mod
}
func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) (ret []*models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().FindByImageID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Performer.FindByImageID(ctx, obj.ID)
return err
}); err != nil {
return nil, err

View File

@ -56,8 +56,8 @@ func (r *movieResolver) Rating(ctx context.Context, obj *models.Movie) (*int, er
func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (ret *models.Studio, err error) {
if obj.StudioID.Valid {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
@ -92,9 +92,9 @@ func (r *movieResolver) FrontImagePath(ctx context.Context, obj *models.Movie) (
func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*string, error) {
// don't return any thing if there is no back image
var img []byte
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
img, err = repo.Movie().GetBackImage(obj.ID)
img, err = r.repository.Movie.GetBackImage(ctx, obj.ID)
if err != nil {
return err
}
@ -115,8 +115,8 @@ func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*
func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = repo.Scene().CountByMovieID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = r.repository.Scene.CountByMovieID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -126,9 +126,9 @@ func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret
}
func (r *movieResolver) Scenes(ctx context.Context, obj *models.Movie) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = repo.Scene().FindByMovieID(obj.ID)
ret, err = r.repository.Scene.FindByMovieID(ctx, obj.ID)
return err
}); err != nil {
return nil, err

View File

@ -142,8 +142,8 @@ func (r *performerResolver) ImagePath(ctx context.Context, obj *models.Performer
}
func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().FindByPerformerID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.FindByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -154,8 +154,8 @@ func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (re
func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = repo.Scene().CountByPerformerID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = r.repository.Scene.CountByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -166,8 +166,8 @@ func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performe
func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = image.CountByPerformerID(repo.Image(), obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = image.CountByPerformerID(ctx, r.repository.Image, obj.ID)
return err
}); err != nil {
return nil, err
@ -178,8 +178,8 @@ func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performe
func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = gallery.CountByPerformerID(repo.Gallery(), obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = gallery.CountByPerformerID(ctx, r.repository.Gallery, obj.ID)
return err
}); err != nil {
return nil, err
@ -189,8 +189,8 @@ func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Perfor
}
func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().FindByPerformerID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Scene.FindByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -200,8 +200,8 @@ func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (
}
func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) (ret []*models.StashID, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().GetStashIDs(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Performer.GetStashIDs(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -256,8 +256,8 @@ func (r *performerResolver) UpdatedAt(ctx context.Context, obj *models.Performer
}
func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (ret []*models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Movie().FindByPerformerID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Movie.FindByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -268,8 +268,8 @@ func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (
func (r *performerResolver) MovieCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = repo.Movie().CountByPerformerID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = r.repository.Movie.CountByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err

View File

@ -116,8 +116,8 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*ScenePat
}
func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (ret []*models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SceneMarker().FindBySceneID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.SceneMarker.FindBySceneID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -127,8 +127,8 @@ func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (re
}
func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret []*models.SceneCaption, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().GetCaptions(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Scene.GetCaptions(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -138,8 +138,8 @@ func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret []
}
func (r *sceneResolver) Galleries(ctx context.Context, obj *models.Scene) (ret []*models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Gallery().FindBySceneID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Gallery.FindBySceneID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -153,8 +153,8 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod
return nil, nil
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
@ -164,17 +164,17 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod
}
func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*SceneMovie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Scene()
mqb := repo.Movie()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
mqb := r.repository.Movie
sceneMovies, err := qb.GetMovies(obj.ID)
sceneMovies, err := qb.GetMovies(ctx, obj.ID)
if err != nil {
return err
}
for _, sm := range sceneMovies {
movie, err := mqb.Find(sm.MovieID)
movie, err := mqb.Find(ctx, sm.MovieID)
if err != nil {
return err
}
@ -200,8 +200,8 @@ func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*S
}
func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().FindBySceneID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.FindBySceneID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -211,8 +211,8 @@ func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*mod
}
func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret []*models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().FindBySceneID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Performer.FindBySceneID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -222,8 +222,8 @@ func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret
}
func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) (ret []*models.StashID, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().GetStashIDs(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Scene.GetStashIDs(ctx, obj.ID)
return err
}); err != nil {
return nil, err

View File

@ -13,9 +13,9 @@ func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker
panic("Invalid scene id")
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneID := int(obj.SceneID.Int64)
ret, err = repo.Scene().Find(sceneID)
ret, err = r.repository.Scene.Find(ctx, sceneID)
return err
}); err != nil {
return nil, err
@ -25,8 +25,8 @@ func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker
}
func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (ret *models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().Find(obj.PrimaryTagID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.Find(ctx, obj.PrimaryTagID)
return err
}); err != nil {
return nil, err
@ -36,8 +36,8 @@ func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneM
}
func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().FindBySceneMarkerID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.FindBySceneMarkerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err

View File

@ -29,9 +29,9 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
imagePath := urlbuilders.NewStudioURLBuilder(baseURL, obj).GetStudioImageURL()
var hasImage bool
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
hasImage, err = repo.Studio().HasImage(obj.ID)
hasImage, err = r.repository.Studio.HasImage(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -46,8 +46,8 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
}
func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret []string, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().GetAliases(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.GetAliases(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -58,8 +58,8 @@ func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret [
func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = repo.Scene().CountByStudioID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = r.repository.Scene.CountByStudioID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -70,8 +70,8 @@ func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (re
func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = image.CountByStudioID(repo.Image(), obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = image.CountByStudioID(ctx, r.repository.Image, obj.ID)
return err
}); err != nil {
return nil, err
@ -82,8 +82,8 @@ func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (re
func (r *studioResolver) GalleryCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = gallery.CountByStudioID(repo.Gallery(), obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = gallery.CountByStudioID(ctx, r.repository.Gallery, obj.ID)
return err
}); err != nil {
return nil, err
@ -97,8 +97,8 @@ func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (
return nil, nil
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(int(obj.ParentID.Int64))
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.Find(ctx, int(obj.ParentID.Int64))
return err
}); err != nil {
return nil, err
@ -108,8 +108,8 @@ func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (
}
func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (ret []*models.Studio, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().FindChildren(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.FindChildren(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -119,8 +119,8 @@ func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (
}
func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) (ret []*models.StashID, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().GetStashIDs(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.GetStashIDs(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -153,8 +153,8 @@ func (r *studioResolver) UpdatedAt(ctx context.Context, obj *models.Studio) (*ti
}
func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []*models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Movie().FindByStudioID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Movie.FindByStudioID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -165,8 +165,8 @@ func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []
func (r *studioResolver) MovieCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = repo.Movie().CountByStudioID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = r.repository.Movie.CountByStudioID(ctx, obj.ID)
return err
}); err != nil {
return nil, err

View File

@ -11,8 +11,8 @@ import (
)
func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().FindByChildTagID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.FindByChildTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -22,8 +22,8 @@ func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*mode
}
func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().FindByParentTagID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.FindByParentTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -33,8 +33,8 @@ func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*mod
}
func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []string, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().GetAliases(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.GetAliases(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -45,8 +45,8 @@ func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []strin
func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
count, err = repo.Scene().CountByTagID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
count, err = r.repository.Scene.CountByTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -57,8 +57,8 @@ func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int
func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
count, err = repo.SceneMarker().CountByTagID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
count, err = r.repository.SceneMarker.CountByTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@ -69,8 +69,8 @@ func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (re
func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = image.CountByTagID(repo.Image(), obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = image.CountByTagID(ctx, r.repository.Image, obj.ID)
return err
}); err != nil {
return nil, err
@ -81,8 +81,8 @@ func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int
func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = gallery.CountByTagID(repo.Gallery(), obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = gallery.CountByTagID(ctx, r.repository.Gallery, obj.ID)
return err
}); err != nil {
return nil, err
@ -93,8 +93,8 @@ func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *i
func (r *tagResolver) PerformerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
count, err = repo.Performer().CountByTagID(obj.ID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
count, err = r.repository.Performer.CountByTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err

View File

@ -132,7 +132,9 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
}
// validate changing VideoFileNamingAlgorithm
if err := manager.ValidateVideoFileNamingAlgorithm(r.txnManager, *input.VideoFileNamingAlgorithm); err != nil {
if err := r.withTxn(context.TODO(), func(ctx context.Context) error {
return manager.ValidateVideoFileNamingAlgorithm(ctx, r.repository.Scene, *input.VideoFileNamingAlgorithm)
}); err != nil {
return makeConfigGeneralResult(), err
}

View File

@ -11,6 +11,7 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/models"
@ -21,8 +22,8 @@ import (
)
func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Gallery().Find(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Gallery.Find(ctx, id)
return err
}); err != nil {
return nil, err
@ -80,26 +81,26 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat
// Start the transaction and save the gallery
var gallery *models.Gallery
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Gallery()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Gallery
var err error
gallery, err = qb.Create(newGallery)
gallery, err = qb.Create(ctx, newGallery)
if err != nil {
return err
}
// Save the performers
if err := r.updateGalleryPerformers(qb, gallery.ID, input.PerformerIds); err != nil {
if err := r.updateGalleryPerformers(ctx, qb, gallery.ID, input.PerformerIds); err != nil {
return err
}
// Save the tags
if err := r.updateGalleryTags(qb, gallery.ID, input.TagIds); err != nil {
if err := r.updateGalleryTags(ctx, qb, gallery.ID, input.TagIds); err != nil {
return err
}
// Save the scenes
if err := r.updateGalleryScenes(qb, gallery.ID, input.SceneIds); err != nil {
if err := r.updateGalleryScenes(ctx, qb, gallery.ID, input.SceneIds); err != nil {
return err
}
@ -112,28 +113,32 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat
return r.getGallery(ctx, gallery.ID)
}
func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error {
func (r *mutationResolver) updateGalleryPerformers(ctx context.Context, qb gallery.PerformerUpdater, galleryID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil {
return err
}
return qb.UpdatePerformers(galleryID, ids)
return qb.UpdatePerformers(ctx, galleryID, ids)
}
func (r *mutationResolver) updateGalleryTags(qb models.GalleryReaderWriter, galleryID int, tagIDs []string) error {
func (r *mutationResolver) updateGalleryTags(ctx context.Context, qb gallery.TagUpdater, galleryID int, tagIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagIDs)
if err != nil {
return err
}
return qb.UpdateTags(galleryID, ids)
return qb.UpdateTags(ctx, galleryID, ids)
}
func (r *mutationResolver) updateGalleryScenes(qb models.GalleryReaderWriter, galleryID int, sceneIDs []string) error {
type GallerySceneUpdater interface {
UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error
}
func (r *mutationResolver) updateGalleryScenes(ctx context.Context, qb GallerySceneUpdater, galleryID int, sceneIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(sceneIDs)
if err != nil {
return err
}
return qb.UpdateScenes(galleryID, ids)
return qb.UpdateScenes(ctx, galleryID, ids)
}
func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (ret *models.Gallery, err error) {
@ -142,8 +147,8 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle
}
// Start the transaction and save the gallery
if err := r.withTxn(ctx, func(repo models.Repository) error {
ret, err = r.galleryUpdate(input, translator, repo)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.galleryUpdate(ctx, input, translator)
return err
}); err != nil {
return nil, err
@ -158,13 +163,13 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the gallery
if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, gallery := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
thisGallery, err := r.galleryUpdate(*gallery, translator, repo)
thisGallery, err := r.galleryUpdate(ctx, *gallery, translator)
if err != nil {
return err
}
@ -196,8 +201,8 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
return newRet, nil
}
func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) {
qb := repo.Gallery()
func (r *mutationResolver) galleryUpdate(ctx context.Context, input models.GalleryUpdateInput, translator changesetTranslator) (*models.Gallery, error) {
qb := r.repository.Gallery
// Populate gallery from the input
galleryID, err := strconv.Atoi(input.ID)
@ -205,7 +210,7 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl
return nil, err
}
originalGallery, err := qb.Find(galleryID)
originalGallery, err := qb.Find(ctx, galleryID)
if err != nil {
return nil, err
}
@ -244,28 +249,28 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl
// gallery scene is set from the scene only
gallery, err := qb.UpdatePartial(updatedGallery)
gallery, err := qb.UpdatePartial(ctx, updatedGallery)
if err != nil {
return nil, err
}
// Save the performers
if translator.hasField("performer_ids") {
if err := r.updateGalleryPerformers(qb, galleryID, input.PerformerIds); err != nil {
if err := r.updateGalleryPerformers(ctx, qb, galleryID, input.PerformerIds); err != nil {
return nil, err
}
}
// Save the tags
if translator.hasField("tag_ids") {
if err := r.updateGalleryTags(qb, galleryID, input.TagIds); err != nil {
if err := r.updateGalleryTags(ctx, qb, galleryID, input.TagIds); err != nil {
return nil, err
}
}
// Save the scenes
if translator.hasField("scene_ids") {
if err := r.updateGalleryScenes(qb, galleryID, input.SceneIds); err != nil {
if err := r.updateGalleryScenes(ctx, qb, galleryID, input.SceneIds); err != nil {
return nil, err
}
}
@ -295,14 +300,14 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
ret := []*models.Gallery{}
// Start the transaction and save the galleries
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Gallery()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Gallery
for _, galleryIDStr := range input.Ids {
galleryID, _ := strconv.Atoi(galleryIDStr)
updatedGallery.ID = galleryID
gallery, err := qb.UpdatePartial(updatedGallery)
gallery, err := qb.UpdatePartial(ctx, updatedGallery)
if err != nil {
return err
}
@ -311,36 +316,36 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
// Save the performers
if translator.hasField("performer_ids") {
performerIDs, err := adjustGalleryPerformerIDs(qb, galleryID, *input.PerformerIds)
performerIDs, err := adjustGalleryPerformerIDs(ctx, qb, galleryID, *input.PerformerIds)
if err != nil {
return err
}
if err := qb.UpdatePerformers(galleryID, performerIDs); err != nil {
if err := qb.UpdatePerformers(ctx, galleryID, performerIDs); err != nil {
return err
}
}
// Save the tags
if translator.hasField("tag_ids") {
tagIDs, err := adjustGalleryTagIDs(qb, galleryID, *input.TagIds)
tagIDs, err := adjustGalleryTagIDs(ctx, qb, galleryID, *input.TagIds)
if err != nil {
return err
}
if err := qb.UpdateTags(galleryID, tagIDs); err != nil {
if err := qb.UpdateTags(ctx, galleryID, tagIDs); err != nil {
return err
}
}
// Save the scenes
if translator.hasField("scene_ids") {
sceneIDs, err := adjustGallerySceneIDs(qb, galleryID, *input.SceneIds)
sceneIDs, err := adjustGallerySceneIDs(ctx, qb, galleryID, *input.SceneIds)
if err != nil {
return err
}
if err := qb.UpdateScenes(galleryID, sceneIDs); err != nil {
if err := qb.UpdateScenes(ctx, galleryID, sceneIDs); err != nil {
return err
}
}
@ -367,8 +372,20 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
return newRet, nil
}
func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(galleryID)
type GalleryPerformerGetter interface {
GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error)
}
type GalleryTagGetter interface {
GetTagIDs(ctx context.Context, galleryID int) ([]int, error)
}
type GallerySceneGetter interface {
GetSceneIDs(ctx context.Context, galleryID int) ([]int, error)
}
func adjustGalleryPerformerIDs(ctx context.Context, qb GalleryPerformerGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(ctx, galleryID)
if err != nil {
return nil, err
}
@ -376,8 +393,8 @@ func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkU
return adjustIDs(ret, ids), nil
}
func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(galleryID)
func adjustGalleryTagIDs(ctx context.Context, qb GalleryTagGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(ctx, galleryID)
if err != nil {
return nil, err
}
@ -385,8 +402,8 @@ func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateI
return adjustIDs(ret, ids), nil
}
func adjustGallerySceneIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetSceneIDs(galleryID)
func adjustGallerySceneIDs(ctx context.Context, qb GallerySceneGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetSceneIDs(ctx, galleryID)
if err != nil {
return nil, err
}
@ -410,12 +427,12 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Gallery()
iqb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Gallery
iqb := r.repository.Image
for _, id := range galleryIDs {
gallery, err := qb.Find(id)
gallery, err := qb.Find(ctx, id)
if err != nil {
return err
}
@ -428,13 +445,13 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
// if this is a zip-based gallery, delete the images as well first
if gallery.Zip {
imgs, err := iqb.FindByGalleryID(id)
imgs, err := iqb.FindByGalleryID(ctx, id)
if err != nil {
return err
}
for _, img := range imgs {
if err := image.Destroy(img, iqb, fileDeleter, deleteGenerated, false); err != nil {
if err := image.Destroy(ctx, img, iqb, fileDeleter, deleteGenerated, false); err != nil {
return err
}
@ -448,19 +465,19 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
}
} else if deleteFile {
// Delete image if it is only attached to this gallery
imgs, err := iqb.FindByGalleryID(id)
imgs, err := iqb.FindByGalleryID(ctx, id)
if err != nil {
return err
}
for _, img := range imgs {
imgGalleries, err := qb.FindByImageID(img.ID)
imgGalleries, err := qb.FindByImageID(ctx, img.ID)
if err != nil {
return err
}
if len(imgGalleries) == 1 {
if err := image.Destroy(img, iqb, fileDeleter, deleteGenerated, deleteFile); err != nil {
if err := image.Destroy(ctx, img, iqb, fileDeleter, deleteGenerated, deleteFile); err != nil {
return err
}
@ -472,7 +489,7 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
// don't do this with the file deleter
}
if err := qb.Destroy(id); err != nil {
if err := qb.Destroy(ctx, id); err != nil {
return err
}
}
@ -537,9 +554,9 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAd
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Gallery()
gallery, err := qb.Find(galleryID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Gallery
gallery, err := qb.Find(ctx, galleryID)
if err != nil {
return err
}
@ -552,13 +569,13 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAd
return errors.New("cannot modify zip gallery images")
}
newIDs, err := qb.GetImageIDs(galleryID)
newIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil {
return err
}
newIDs = intslice.IntAppendUniques(newIDs, imageIDs)
return qb.UpdateImages(galleryID, newIDs)
return qb.UpdateImages(ctx, galleryID, newIDs)
}); err != nil {
return false, err
}
@ -577,9 +594,9 @@ func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input Galler
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Gallery()
gallery, err := qb.Find(galleryID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Gallery
gallery, err := qb.Find(ctx, galleryID)
if err != nil {
return err
}
@ -592,13 +609,13 @@ func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input Galler
return errors.New("cannot modify zip gallery images")
}
newIDs, err := qb.GetImageIDs(galleryID)
newIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil {
return err
}
newIDs = intslice.IntExclude(newIDs, imageIDs)
return qb.UpdateImages(galleryID, newIDs)
return qb.UpdateImages(ctx, galleryID, newIDs)
}); err != nil {
return false, err
}

View File

@ -16,8 +16,8 @@ import (
)
func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Image().Find(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Image.Find(ctx, id)
return err
}); err != nil {
return nil, err
@ -32,8 +32,8 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input ImageUpdateInp
}
// Start the transaction and save the image
if err := r.withTxn(ctx, func(repo models.Repository) error {
ret, err = r.imageUpdate(input, translator, repo)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.imageUpdate(ctx, input, translator)
return err
}); err != nil {
return nil, err
@ -48,13 +48,13 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat
inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the image
if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, image := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
thisImage, err := r.imageUpdate(*image, translator, repo)
thisImage, err := r.imageUpdate(ctx, *image, translator)
if err != nil {
return err
}
@ -86,7 +86,7 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat
return newRet, nil
}
func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) {
func (r *mutationResolver) imageUpdate(ctx context.Context, input ImageUpdateInput, translator changesetTranslator) (*models.Image, error) {
// Populate image from the input
imageID, err := strconv.Atoi(input.ID)
if err != nil {
@ -104,28 +104,28 @@ func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator change
updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id")
updatedImage.Organized = input.Organized
qb := repo.Image()
image, err := qb.Update(updatedImage)
qb := r.repository.Image
image, err := qb.Update(ctx, updatedImage)
if err != nil {
return nil, err
}
if translator.hasField("gallery_ids") {
if err := r.updateImageGalleries(qb, imageID, input.GalleryIds); err != nil {
if err := r.updateImageGalleries(ctx, imageID, input.GalleryIds); err != nil {
return nil, err
}
}
// Save the performers
if translator.hasField("performer_ids") {
if err := r.updateImagePerformers(qb, imageID, input.PerformerIds); err != nil {
if err := r.updateImagePerformers(ctx, imageID, input.PerformerIds); err != nil {
return nil, err
}
}
// Save the tags
if translator.hasField("tag_ids") {
if err := r.updateImageTags(qb, imageID, input.TagIds); err != nil {
if err := r.updateImageTags(ctx, imageID, input.TagIds); err != nil {
return nil, err
}
}
@ -133,28 +133,28 @@ func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator change
return image, nil
}
func (r *mutationResolver) updateImageGalleries(qb models.ImageReaderWriter, imageID int, galleryIDs []string) error {
func (r *mutationResolver) updateImageGalleries(ctx context.Context, imageID int, galleryIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(galleryIDs)
if err != nil {
return err
}
return qb.UpdateGalleries(imageID, ids)
return r.repository.Image.UpdateGalleries(ctx, imageID, ids)
}
func (r *mutationResolver) updateImagePerformers(qb models.ImageReaderWriter, imageID int, performerIDs []string) error {
func (r *mutationResolver) updateImagePerformers(ctx context.Context, imageID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil {
return err
}
return qb.UpdatePerformers(imageID, ids)
return r.repository.Image.UpdatePerformers(ctx, imageID, ids)
}
func (r *mutationResolver) updateImageTags(qb models.ImageReaderWriter, imageID int, tagsIDs []string) error {
func (r *mutationResolver) updateImageTags(ctx context.Context, imageID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil {
return err
}
return qb.UpdateTags(imageID, ids)
return r.repository.Image.UpdateTags(ctx, imageID, ids)
}
func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageUpdateInput) (ret []*models.Image, err error) {
@ -180,13 +180,13 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
updatedImage.Organized = input.Organized
// Start the transaction and save the image marker
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Image
for _, imageID := range imageIDs {
updatedImage.ID = imageID
image, err := qb.Update(updatedImage)
image, err := qb.Update(ctx, updatedImage)
if err != nil {
return err
}
@ -195,36 +195,36 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
// Save the galleries
if translator.hasField("gallery_ids") {
galleryIDs, err := adjustImageGalleryIDs(qb, imageID, *input.GalleryIds)
galleryIDs, err := r.adjustImageGalleryIDs(ctx, imageID, *input.GalleryIds)
if err != nil {
return err
}
if err := qb.UpdateGalleries(imageID, galleryIDs); err != nil {
if err := qb.UpdateGalleries(ctx, imageID, galleryIDs); err != nil {
return err
}
}
// Save the performers
if translator.hasField("performer_ids") {
performerIDs, err := adjustImagePerformerIDs(qb, imageID, *input.PerformerIds)
performerIDs, err := r.adjustImagePerformerIDs(ctx, imageID, *input.PerformerIds)
if err != nil {
return err
}
if err := qb.UpdatePerformers(imageID, performerIDs); err != nil {
if err := qb.UpdatePerformers(ctx, imageID, performerIDs); err != nil {
return err
}
}
// Save the tags
if translator.hasField("tag_ids") {
tagIDs, err := adjustImageTagIDs(qb, imageID, *input.TagIds)
tagIDs, err := r.adjustImageTagIDs(ctx, imageID, *input.TagIds)
if err != nil {
return err
}
if err := qb.UpdateTags(imageID, tagIDs); err != nil {
if err := qb.UpdateTags(ctx, imageID, tagIDs); err != nil {
return err
}
}
@ -251,8 +251,8 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
return newRet, nil
}
func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetGalleryIDs(imageID)
func (r *mutationResolver) adjustImageGalleryIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = r.repository.Image.GetGalleryIDs(ctx, imageID)
if err != nil {
return nil, err
}
@ -260,8 +260,8 @@ func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds
return adjustIDs(ret, ids), nil
}
func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(imageID)
func (r *mutationResolver) adjustImagePerformerIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = r.repository.Image.GetPerformerIDs(ctx, imageID)
if err != nil {
return nil, err
}
@ -269,8 +269,8 @@ func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateI
return adjustIDs(ret, ids), nil
}
func adjustImageTagIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(imageID)
func (r *mutationResolver) adjustImageTagIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = r.repository.Image.GetTagIDs(ctx, imageID)
if err != nil {
return nil, err
}
@ -289,10 +289,10 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
Deleter: *file.NewDeleter(),
Paths: manager.GetInstance().Paths,
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Image
i, err = qb.Find(imageID)
i, err = r.repository.Image.Find(ctx, imageID)
if err != nil {
return err
}
@ -301,7 +301,7 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
return fmt.Errorf("image with id %d not found", imageID)
}
return image.Destroy(i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile))
return image.Destroy(ctx, i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile))
}); err != nil {
fileDeleter.Rollback()
return false, err
@ -331,12 +331,12 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
Deleter: *file.NewDeleter(),
Paths: manager.GetInstance().Paths,
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Image
for _, imageID := range imageIDs {
i, err := qb.Find(imageID)
i, err := qb.Find(ctx, imageID)
if err != nil {
return err
}
@ -347,7 +347,7 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
images = append(images, i)
if err := image.Destroy(i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil {
if err := image.Destroy(ctx, i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil {
return err
}
}
@ -379,10 +379,10 @@ func (r *mutationResolver) ImageIncrementO(ctx context.Context, id string) (ret
return 0, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Image
ret, err = qb.IncrementOCounter(imageID)
ret, err = qb.IncrementOCounter(ctx, imageID)
return err
}); err != nil {
return 0, err
@ -397,10 +397,10 @@ func (r *mutationResolver) ImageDecrementO(ctx context.Context, id string) (ret
return 0, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Image
ret, err = qb.DecrementOCounter(imageID)
ret, err = qb.DecrementOCounter(ctx, imageID)
return err
}); err != nil {
return 0, err
@ -415,10 +415,10 @@ func (r *mutationResolver) ImageResetO(ctx context.Context, id string) (ret int,
return 0, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Image
ret, err = qb.ResetOCounter(imageID)
ret, err = qb.ResetOCounter(ctx, imageID)
return err
}); err != nil {
return 0, err

View File

@ -12,7 +12,6 @@ import (
"github.com/stashapp/stash/internal/identify"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger"
)
@ -111,6 +110,7 @@ func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatab
// if download is true, then backup to temporary file and return a link
download := input.Download != nil && *input.Download
mgr := manager.GetInstance()
database := mgr.Database
var backupPath string
if download {
if err := fsutil.EnsureDir(mgr.Paths.Generated.Downloads); err != nil {
@ -127,7 +127,7 @@ func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatab
backupPath = database.DatabaseBackupPath()
}
err := database.Backup(database.DB, backupPath)
err := database.Backup(backupPath)
if err != nil {
return nil, err
}

View File

@ -15,8 +15,8 @@ import (
)
func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Movie().Find(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Movie.Find(ctx, id)
return err
}); err != nil {
return nil, err
@ -100,16 +100,16 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
// Start the transaction and save the movie
var movie *models.Movie
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Movie()
movie, err = qb.Create(newMovie)
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Movie
movie, err = qb.Create(ctx, newMovie)
if err != nil {
return err
}
// update image table
if len(frontimageData) > 0 {
if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil {
if err := qb.UpdateImages(ctx, movie.ID, frontimageData, backimageData); err != nil {
return err
}
}
@ -174,9 +174,9 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
// Start the transaction and save the movie
var movie *models.Movie
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Movie()
movie, err = qb.Update(updatedMovie)
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Movie
movie, err = qb.Update(ctx, updatedMovie)
if err != nil {
return err
}
@ -184,13 +184,13 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
// update image table
if frontImageIncluded || backImageIncluded {
if !frontImageIncluded {
frontimageData, err = qb.GetFrontImage(updatedMovie.ID)
frontimageData, err = qb.GetFrontImage(ctx, updatedMovie.ID)
if err != nil {
return err
}
}
if !backImageIncluded {
backimageData, err = qb.GetBackImage(updatedMovie.ID)
backimageData, err = qb.GetBackImage(ctx, updatedMovie.ID)
if err != nil {
return err
}
@ -198,7 +198,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
if len(frontimageData) == 0 && len(backimageData) == 0 {
// both images are being nulled. Destroy them.
if err := qb.DestroyImages(movie.ID); err != nil {
if err := qb.DestroyImages(ctx, movie.ID); err != nil {
return err
}
} else {
@ -208,7 +208,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
frontimageData, _ = utils.ProcessImageInput(ctx, models.DefaultMovieImage)
}
if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil {
if err := qb.UpdateImages(ctx, movie.ID, frontimageData, backimageData); err != nil {
return err
}
}
@ -245,13 +245,13 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU
ret := []*models.Movie{}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Movie()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Movie
for _, movieID := range movieIDs {
updatedMovie.ID = movieID
existing, err := qb.Find(movieID)
existing, err := qb.Find(ctx, movieID)
if err != nil {
return err
}
@ -260,7 +260,7 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU
return fmt.Errorf("movie with id %d not found", movieID)
}
movie, err := qb.Update(updatedMovie)
movie, err := qb.Update(ctx, updatedMovie)
if err != nil {
return err
}
@ -294,8 +294,8 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input MovieDestroyI
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.Movie().Destroy(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
return r.repository.Movie.Destroy(ctx, id)
}); err != nil {
return false, err
}
@ -311,10 +311,10 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string)
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Movie()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Movie
for _, id := range ids {
if err := qb.Destroy(id); err != nil {
if err := qb.Destroy(ctx, id); err != nil {
return err
}
}

View File

@ -16,8 +16,8 @@ import (
)
func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().Find(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Performer.Find(ctx, id)
return err
}); err != nil {
return nil, err
@ -129,23 +129,23 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC
// Start the transaction and save the performer
var performer *models.Performer
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Performer()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Performer
performer, err = qb.Create(newPerformer)
performer, err = qb.Create(ctx, newPerformer)
if err != nil {
return err
}
if len(input.TagIds) > 0 {
if err := r.updatePerformerTags(qb, performer.ID, input.TagIds); err != nil {
if err := r.updatePerformerTags(ctx, performer.ID, input.TagIds); err != nil {
return err
}
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(performer.ID, imageData); err != nil {
if err := qb.UpdateImage(ctx, performer.ID, imageData); err != nil {
return err
}
}
@ -153,7 +153,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC
// Save the stash_ids
if input.StashIds != nil {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(performer.ID, stashIDJoins); err != nil {
if err := qb.UpdateStashIDs(ctx, performer.ID, stashIDJoins); err != nil {
return err
}
}
@ -230,11 +230,11 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
// Start the transaction and save the p
var p *models.Performer
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Performer()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Performer
// need to get existing performer
existing, err := qb.Find(updatedPerformer.ID)
existing, err := qb.Find(ctx, updatedPerformer.ID)
if err != nil {
return err
}
@ -249,26 +249,26 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
}
}
p, err = qb.Update(updatedPerformer)
p, err = qb.Update(ctx, updatedPerformer)
if err != nil {
return err
}
// Save the tags
if translator.hasField("tag_ids") {
if err := r.updatePerformerTags(qb, p.ID, input.TagIds); err != nil {
if err := r.updatePerformerTags(ctx, p.ID, input.TagIds); err != nil {
return err
}
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(p.ID, imageData); err != nil {
if err := qb.UpdateImage(ctx, p.ID, imageData); err != nil {
return err
}
} else if imageIncluded {
// must be unsetting
if err := qb.DestroyImage(p.ID); err != nil {
if err := qb.DestroyImage(ctx, p.ID); err != nil {
return err
}
}
@ -276,7 +276,7 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(performerID, stashIDJoins); err != nil {
if err := qb.UpdateStashIDs(ctx, performerID, stashIDJoins); err != nil {
return err
}
}
@ -290,12 +290,12 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
return r.getPerformer(ctx, p.ID)
}
func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error {
func (r *mutationResolver) updatePerformerTags(ctx context.Context, performerID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil {
return err
}
return qb.UpdateTags(performerID, ids)
return r.repository.Performer.UpdateTags(ctx, performerID, ids)
}
func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPerformerUpdateInput) ([]*models.Performer, error) {
@ -348,14 +348,14 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
ret := []*models.Performer{}
// Start the transaction and save the scene marker
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Performer()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Performer
for _, performerID := range performerIDs {
updatedPerformer.ID = performerID
// need to get existing performer
existing, err := qb.Find(performerID)
existing, err := qb.Find(ctx, performerID)
if err != nil {
return err
}
@ -368,7 +368,7 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
return err
}
performer, err := qb.Update(updatedPerformer)
performer, err := qb.Update(ctx, updatedPerformer)
if err != nil {
return err
}
@ -377,12 +377,12 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
// Save the tags
if translator.hasField("tag_ids") {
tagIDs, err := adjustTagIDs(qb, performerID, *input.TagIds)
tagIDs, err := adjustTagIDs(ctx, qb, performerID, *input.TagIds)
if err != nil {
return err
}
if err := qb.UpdateTags(performerID, tagIDs); err != nil {
if err := qb.UpdateTags(ctx, performerID, tagIDs); err != nil {
return err
}
}
@ -415,8 +415,8 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input Performer
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.Performer().Destroy(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
return r.repository.Performer.Destroy(ctx, id)
}); err != nil {
return false, err
}
@ -432,10 +432,10 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Performer()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Performer
for _, id := range ids {
if err := qb.Destroy(id); err != nil {
if err := qb.Destroy(ctx, id); err != nil {
return err
}
}

View File

@ -23,17 +23,17 @@ func (r *mutationResolver) SaveFilter(ctx context.Context, input SaveFilterInput
id = &idv
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
f := models.SavedFilter{
Mode: input.Mode,
Name: input.Name,
Filter: input.Filter,
}
if id == nil {
ret, err = repo.SavedFilter().Create(f)
ret, err = r.repository.SavedFilter.Create(ctx, f)
} else {
f.ID = *id
ret, err = repo.SavedFilter().Update(f)
ret, err = r.repository.SavedFilter.Update(ctx, f)
}
return err
}); err != nil {
@ -48,8 +48,8 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input Destroy
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.SavedFilter().Destroy(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
return r.repository.SavedFilter.Destroy(ctx, id)
}); err != nil {
return false, err
}
@ -58,24 +58,24 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input Destroy
}
func (r *mutationResolver) SetDefaultFilter(ctx context.Context, input SetDefaultFilterInput) (bool, error) {
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.SavedFilter()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.SavedFilter
if input.Filter == nil {
// clearing
def, err := qb.FindDefault(input.Mode)
def, err := qb.FindDefault(ctx, input.Mode)
if err != nil {
return err
}
if def != nil {
return qb.Destroy(def.ID)
return qb.Destroy(ctx, def.ID)
}
return nil
}
_, err := qb.SetDefault(models.SavedFilter{
_, err := qb.SetDefault(ctx, models.SavedFilter{
Mode: input.Mode,
Filter: *input.Filter,
})

View File

@ -19,8 +19,8 @@ import (
)
func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().Find(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Scene.Find(ctx, id)
return err
}); err != nil {
return nil, err
@ -35,8 +35,8 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp
}
// Start the transaction and save the scene
if err := r.withTxn(ctx, func(repo models.Repository) error {
ret, err = r.sceneUpdate(ctx, input, translator, repo)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.sceneUpdate(ctx, input, translator)
return err
}); err != nil {
return nil, err
@ -50,13 +50,13 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the scene
if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, scene := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
thisScene, err := r.sceneUpdate(ctx, *scene, translator, repo)
thisScene, err := r.sceneUpdate(ctx, *scene, translator)
ret = append(ret, thisScene)
if err != nil {
@ -89,7 +89,7 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
return newRet, nil
}
func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) {
func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator) (*models.Scene, error) {
// Populate scene from the input
sceneID, err := strconv.Atoi(input.ID)
if err != nil {
@ -122,43 +122,43 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
// update the cover after updating the scene
}
qb := repo.Scene()
s, err := qb.Update(updatedScene)
qb := r.repository.Scene
s, err := qb.Update(ctx, updatedScene)
if err != nil {
return nil, err
}
// update cover table
if len(coverImageData) > 0 {
if err := qb.UpdateCover(sceneID, coverImageData); err != nil {
if err := qb.UpdateCover(ctx, sceneID, coverImageData); err != nil {
return nil, err
}
}
// Save the performers
if translator.hasField("performer_ids") {
if err := r.updateScenePerformers(qb, sceneID, input.PerformerIds); err != nil {
if err := r.updateScenePerformers(ctx, sceneID, input.PerformerIds); err != nil {
return nil, err
}
}
// Save the movies
if translator.hasField("movies") {
if err := r.updateSceneMovies(qb, sceneID, input.Movies); err != nil {
if err := r.updateSceneMovies(ctx, sceneID, input.Movies); err != nil {
return nil, err
}
}
// Save the tags
if translator.hasField("tag_ids") {
if err := r.updateSceneTags(qb, sceneID, input.TagIds); err != nil {
if err := r.updateSceneTags(ctx, sceneID, input.TagIds); err != nil {
return nil, err
}
}
// Save the galleries
if translator.hasField("gallery_ids") {
if err := r.updateSceneGalleries(qb, sceneID, input.GalleryIds); err != nil {
if err := r.updateSceneGalleries(ctx, sceneID, input.GalleryIds); err != nil {
return nil, err
}
}
@ -166,7 +166,7 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(sceneID, stashIDJoins); err != nil {
if err := qb.UpdateStashIDs(ctx, sceneID, stashIDJoins); err != nil {
return nil, err
}
}
@ -182,15 +182,15 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
return s, nil
}
func (r *mutationResolver) updateScenePerformers(qb models.SceneReaderWriter, sceneID int, performerIDs []string) error {
func (r *mutationResolver) updateScenePerformers(ctx context.Context, sceneID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil {
return err
}
return qb.UpdatePerformers(sceneID, ids)
return r.repository.Scene.UpdatePerformers(ctx, sceneID, ids)
}
func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneID int, movies []*models.SceneMovieInput) error {
func (r *mutationResolver) updateSceneMovies(ctx context.Context, sceneID int, movies []*models.SceneMovieInput) error {
var movieJoins []models.MoviesScenes
for _, movie := range movies {
@ -213,23 +213,23 @@ func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneI
movieJoins = append(movieJoins, movieJoin)
}
return qb.UpdateMovies(sceneID, movieJoins)
return r.repository.Scene.UpdateMovies(ctx, sceneID, movieJoins)
}
func (r *mutationResolver) updateSceneTags(qb models.SceneReaderWriter, sceneID int, tagsIDs []string) error {
func (r *mutationResolver) updateSceneTags(ctx context.Context, sceneID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil {
return err
}
return qb.UpdateTags(sceneID, ids)
return r.repository.Scene.UpdateTags(ctx, sceneID, ids)
}
func (r *mutationResolver) updateSceneGalleries(qb models.SceneReaderWriter, sceneID int, galleryIDs []string) error {
func (r *mutationResolver) updateSceneGalleries(ctx context.Context, sceneID int, galleryIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(galleryIDs)
if err != nil {
return err
}
return qb.UpdateGalleries(sceneID, ids)
return r.repository.Scene.UpdateGalleries(ctx, sceneID, ids)
}
func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneUpdateInput) ([]*models.Scene, error) {
@ -260,13 +260,13 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU
ret := []*models.Scene{}
// Start the transaction and save the scene marker
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
for _, sceneID := range sceneIDs {
updatedScene.ID = sceneID
scene, err := qb.Update(updatedScene)
scene, err := qb.Update(ctx, updatedScene)
if err != nil {
return err
}
@ -275,48 +275,48 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU
// Save the performers
if translator.hasField("performer_ids") {
performerIDs, err := adjustScenePerformerIDs(qb, sceneID, *input.PerformerIds)
performerIDs, err := r.adjustScenePerformerIDs(ctx, sceneID, *input.PerformerIds)
if err != nil {
return err
}
if err := qb.UpdatePerformers(sceneID, performerIDs); err != nil {
if err := qb.UpdatePerformers(ctx, sceneID, performerIDs); err != nil {
return err
}
}
// Save the tags
if translator.hasField("tag_ids") {
tagIDs, err := adjustTagIDs(qb, sceneID, *input.TagIds)
tagIDs, err := adjustTagIDs(ctx, qb, sceneID, *input.TagIds)
if err != nil {
return err
}
if err := qb.UpdateTags(sceneID, tagIDs); err != nil {
if err := qb.UpdateTags(ctx, sceneID, tagIDs); err != nil {
return err
}
}
// Save the galleries
if translator.hasField("gallery_ids") {
galleryIDs, err := adjustSceneGalleryIDs(qb, sceneID, *input.GalleryIds)
galleryIDs, err := r.adjustSceneGalleryIDs(ctx, sceneID, *input.GalleryIds)
if err != nil {
return err
}
if err := qb.UpdateGalleries(sceneID, galleryIDs); err != nil {
if err := qb.UpdateGalleries(ctx, sceneID, galleryIDs); err != nil {
return err
}
}
// Save the movies
if translator.hasField("movie_ids") {
movies, err := adjustSceneMovieIDs(qb, sceneID, *input.MovieIds)
movies, err := r.adjustSceneMovieIDs(ctx, sceneID, *input.MovieIds)
if err != nil {
return err
}
if err := qb.UpdateMovies(sceneID, movies); err != nil {
if err := qb.UpdateMovies(ctx, sceneID, movies); err != nil {
return err
}
}
@ -380,8 +380,8 @@ func adjustIDs(existingIDs []int, updateIDs BulkUpdateIds) []int {
return existingIDs
}
func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(sceneID)
func (r *mutationResolver) adjustScenePerformerIDs(ctx context.Context, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = r.repository.Scene.GetPerformerIDs(ctx, sceneID)
if err != nil {
return nil, err
}
@ -390,11 +390,11 @@ func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateI
}
type tagIDsGetter interface {
GetTagIDs(id int) ([]int, error)
GetTagIDs(ctx context.Context, id int) ([]int, error)
}
func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(sceneID)
func adjustTagIDs(ctx context.Context, qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(ctx, sceneID)
if err != nil {
return nil, err
}
@ -402,8 +402,8 @@ func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, e
return adjustIDs(ret, ids), nil
}
func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetGalleryIDs(sceneID)
func (r *mutationResolver) adjustSceneGalleryIDs(ctx context.Context, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = r.repository.Scene.GetGalleryIDs(ctx, sceneID)
if err != nil {
return nil, err
}
@ -411,8 +411,8 @@ func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds
return adjustIDs(ret, ids), nil
}
func adjustSceneMovieIDs(qb models.SceneReader, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) {
existingMovies, err := qb.GetMovies(sceneID)
func (r *mutationResolver) adjustSceneMovieIDs(ctx context.Context, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) {
existingMovies, err := r.repository.Scene.GetMovies(ctx, sceneID)
if err != nil {
return nil, err
}
@ -471,10 +471,10 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
var err error
s, err = qb.Find(sceneID)
s, err = qb.Find(ctx, sceneID)
if err != nil {
return err
}
@ -486,7 +486,7 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
// kill any running encoders
manager.KillRunningStreams(s, fileNamingAlgo)
return scene.Destroy(s, repo, fileDeleter, deleteGenerated, deleteFile)
return scene.Destroy(ctx, s, r.repository.Scene, r.repository.SceneMarker, fileDeleter, deleteGenerated, deleteFile)
}); err != nil {
fileDeleter.Rollback()
return false, err
@ -519,13 +519,13 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
for _, id := range input.Ids {
sceneID, _ := strconv.Atoi(id)
s, err := qb.Find(sceneID)
s, err := qb.Find(ctx, sceneID)
if err != nil {
return err
}
@ -536,7 +536,7 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
// kill any running encoders
manager.KillRunningStreams(s, fileNamingAlgo)
if err := scene.Destroy(s, repo, fileDeleter, deleteGenerated, deleteFile); err != nil {
if err := scene.Destroy(ctx, s, r.repository.Scene, r.repository.SceneMarker, fileDeleter, deleteGenerated, deleteFile); err != nil {
return err
}
}
@ -564,8 +564,8 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
}
func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SceneMarker().Find(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.SceneMarker.Find(ctx, id)
return err
}); err != nil {
return nil, err
@ -666,11 +666,11 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
Paths: manager.GetInstance().Paths,
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.SceneMarker()
sqb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.SceneMarker
sqb := r.repository.Scene
marker, err := qb.Find(markerID)
marker, err := qb.Find(ctx, markerID)
if err != nil {
return err
@ -680,12 +680,12 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
return fmt.Errorf("scene marker with id %d not found", markerID)
}
s, err := sqb.Find(int(marker.SceneID.Int64))
s, err := sqb.Find(ctx, int(marker.SceneID.Int64))
if err != nil {
return err
}
return scene.DestroyMarker(s, marker, qb, fileDeleter)
return scene.DestroyMarker(ctx, s, marker, qb, fileDeleter)
}); err != nil {
fileDeleter.Rollback()
return false, err
@ -713,26 +713,26 @@ func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, cha
}
// Start the transaction and save the scene marker
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.SceneMarker()
sqb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.SceneMarker
sqb := r.repository.Scene
var err error
switch changeType {
case create:
sceneMarker, err = qb.Create(changedMarker)
sceneMarker, err = qb.Create(ctx, changedMarker)
case update:
// check to see if timestamp was changed
existingMarker, err = qb.Find(changedMarker.ID)
existingMarker, err = qb.Find(ctx, changedMarker.ID)
if err != nil {
return err
}
sceneMarker, err = qb.Update(changedMarker)
sceneMarker, err = qb.Update(ctx, changedMarker)
if err != nil {
return err
}
s, err = sqb.Find(int(existingMarker.SceneID.Int64))
s, err = sqb.Find(ctx, int(existingMarker.SceneID.Int64))
}
if err != nil {
return err
@ -749,7 +749,7 @@ func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, cha
// Save the marker tags
// If this tag is the primary tag, then let's not add it.
tagIDs = intslice.IntExclude(tagIDs, []int{changedMarker.PrimaryTagID})
return qb.UpdateTags(sceneMarker.ID, tagIDs)
return qb.UpdateTags(ctx, sceneMarker.ID, tagIDs)
}); err != nil {
fileDeleter.Rollback()
return nil, err
@ -766,10 +766,10 @@ func (r *mutationResolver) SceneIncrementO(ctx context.Context, id string) (ret
return 0, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
ret, err = qb.IncrementOCounter(sceneID)
ret, err = qb.IncrementOCounter(ctx, sceneID)
return err
}); err != nil {
return 0, err
@ -784,10 +784,10 @@ func (r *mutationResolver) SceneDecrementO(ctx context.Context, id string) (ret
return 0, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
ret, err = qb.DecrementOCounter(sceneID)
ret, err = qb.DecrementOCounter(ctx, sceneID)
return err
}); err != nil {
return 0, err
@ -802,10 +802,10 @@ func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int,
return 0, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
ret, err = qb.ResetOCounter(sceneID)
ret, err = qb.ResetOCounter(ctx, sceneID)
return err
}); err != nil {
return 0, err

View File

@ -7,10 +7,18 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox"
)
func (r *Resolver) stashboxRepository() stashbox.Repository {
return stashbox.Repository{
Scene: r.repository.Scene,
Performer: r.repository.Performer,
Tag: r.repository.Tag,
Studio: r.repository.Studio,
}
}
func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) {
boxes := config.GetInstance().GetStashBoxes()
@ -18,7 +26,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager)
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
}
@ -35,7 +43,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager)
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
id, err := strconv.Atoi(input.ID)
if err != nil {
@ -43,9 +51,9 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
}
var res *string
err = r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Scene()
scene, err := qb.Find(id)
err = r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
scene, err := qb.Find(ctx, id)
if err != nil {
return err
}
@ -65,7 +73,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager)
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
id, err := strconv.Atoi(input.ID)
if err != nil {
@ -73,9 +81,9 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
}
var res *string
err = r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Performer()
performer, err := qb.Find(id)
err = r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Performer
performer, err := qb.Find(ctx, id)
if err != nil {
return err
}

View File

@ -17,8 +17,8 @@ import (
)
func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.Find(ctx, id)
return err
}); err != nil {
return nil, err
@ -72,18 +72,18 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI
// Start the transaction and save the studio
var s *models.Studio
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Studio()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio
var err error
s, err = qb.Create(newStudio)
s, err = qb.Create(ctx, newStudio)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(s.ID, imageData); err != nil {
if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err
}
}
@ -91,17 +91,17 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI
// Save the stash_ids
if input.StashIds != nil {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(s.ID, stashIDJoins); err != nil {
if err := qb.UpdateStashIDs(ctx, s.ID, stashIDJoins); err != nil {
return err
}
}
if len(input.Aliases) > 0 {
if err := studio.EnsureAliasesUnique(s.ID, input.Aliases, qb); err != nil {
if err := studio.EnsureAliasesUnique(ctx, s.ID, input.Aliases, qb); err != nil {
return err
}
if err := qb.UpdateAliases(s.ID, input.Aliases); err != nil {
if err := qb.UpdateAliases(ctx, s.ID, input.Aliases); err != nil {
return err
}
}
@ -155,27 +155,27 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateI
// Start the transaction and save the studio
var s *models.Studio
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Studio()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio
if err := manager.ValidateModifyStudio(updatedStudio, qb); err != nil {
if err := manager.ValidateModifyStudio(ctx, updatedStudio, qb); err != nil {
return err
}
var err error
s, err = qb.Update(updatedStudio)
s, err = qb.Update(ctx, updatedStudio)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(s.ID, imageData); err != nil {
if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err
}
} else if imageIncluded {
// must be unsetting
if err := qb.DestroyImage(s.ID); err != nil {
if err := qb.DestroyImage(ctx, s.ID); err != nil {
return err
}
}
@ -183,17 +183,17 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateI
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(studioID, stashIDJoins); err != nil {
if err := qb.UpdateStashIDs(ctx, studioID, stashIDJoins); err != nil {
return err
}
}
if translator.hasField("aliases") {
if err := studio.EnsureAliasesUnique(studioID, input.Aliases, qb); err != nil {
if err := studio.EnsureAliasesUnique(ctx, studioID, input.Aliases, qb); err != nil {
return err
}
if err := qb.UpdateAliases(studioID, input.Aliases); err != nil {
if err := qb.UpdateAliases(ctx, studioID, input.Aliases); err != nil {
return err
}
}
@ -213,8 +213,8 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestro
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.Studio().Destroy(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
return r.repository.Studio.Destroy(ctx, id)
}); err != nil {
return false, err
}
@ -230,10 +230,10 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Studio()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio
for _, id := range ids {
if err := qb.Destroy(id); err != nil {
if err := qb.Destroy(ctx, id); err != nil {
return err
}
}

View File

@ -15,8 +15,8 @@ import (
)
func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().Find(id)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.Find(ctx, id)
return err
}); err != nil {
return nil, err
@ -68,44 +68,44 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
// Start the transaction and save the tag
var t *models.Tag
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Tag()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Tag
// ensure name is unique
if err := tag.EnsureTagNameUnique(0, newTag.Name, qb); err != nil {
if err := tag.EnsureTagNameUnique(ctx, 0, newTag.Name, qb); err != nil {
return err
}
t, err = qb.Create(newTag)
t, err = qb.Create(ctx, newTag)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(t.ID, imageData); err != nil {
if err := qb.UpdateImage(ctx, t.ID, imageData); err != nil {
return err
}
}
if len(input.Aliases) > 0 {
if err := tag.EnsureAliasesUnique(t.ID, input.Aliases, qb); err != nil {
if err := tag.EnsureAliasesUnique(ctx, t.ID, input.Aliases, qb); err != nil {
return err
}
if err := qb.UpdateAliases(t.ID, input.Aliases); err != nil {
if err := qb.UpdateAliases(ctx, t.ID, input.Aliases); err != nil {
return err
}
}
if len(parentIDs) > 0 {
if err := qb.UpdateParentTags(t.ID, parentIDs); err != nil {
if err := qb.UpdateParentTags(ctx, t.ID, parentIDs); err != nil {
return err
}
}
if len(childIDs) > 0 {
if err := qb.UpdateChildTags(t.ID, childIDs); err != nil {
if err := qb.UpdateChildTags(ctx, t.ID, childIDs); err != nil {
return err
}
}
@ -113,7 +113,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
// FIXME: This should be called before any changes are made, but
// requires a rewrite of ValidateHierarchy.
if len(parentIDs) > 0 || len(childIDs) > 0 {
if err := tag.ValidateHierarchy(t, parentIDs, childIDs, qb); err != nil {
if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil {
return err
}
}
@ -168,11 +168,11 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
// Start the transaction and save the tag
var t *models.Tag
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Tag()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Tag
// ensure name is unique
t, err = qb.Find(tagID)
t, err = qb.Find(ctx, tagID)
if err != nil {
return err
}
@ -188,48 +188,48 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
}
if input.Name != nil && t.Name != *input.Name {
if err := tag.EnsureTagNameUnique(tagID, *input.Name, qb); err != nil {
if err := tag.EnsureTagNameUnique(ctx, tagID, *input.Name, qb); err != nil {
return err
}
updatedTag.Name = input.Name
}
t, err = qb.Update(updatedTag)
t, err = qb.Update(ctx, updatedTag)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(tagID, imageData); err != nil {
if err := qb.UpdateImage(ctx, tagID, imageData); err != nil {
return err
}
} else if imageIncluded {
// must be unsetting
if err := qb.DestroyImage(tagID); err != nil {
if err := qb.DestroyImage(ctx, tagID); err != nil {
return err
}
}
if translator.hasField("aliases") {
if err := tag.EnsureAliasesUnique(tagID, input.Aliases, qb); err != nil {
if err := tag.EnsureAliasesUnique(ctx, tagID, input.Aliases, qb); err != nil {
return err
}
if err := qb.UpdateAliases(tagID, input.Aliases); err != nil {
if err := qb.UpdateAliases(ctx, tagID, input.Aliases); err != nil {
return err
}
}
if parentIDs != nil {
if err := qb.UpdateParentTags(tagID, parentIDs); err != nil {
if err := qb.UpdateParentTags(ctx, tagID, parentIDs); err != nil {
return err
}
}
if childIDs != nil {
if err := qb.UpdateChildTags(tagID, childIDs); err != nil {
if err := qb.UpdateChildTags(ctx, tagID, childIDs); err != nil {
return err
}
}
@ -237,7 +237,7 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
// FIXME: This should be called before any changes are made, but
// requires a rewrite of ValidateHierarchy.
if parentIDs != nil || childIDs != nil {
if err := tag.ValidateHierarchy(t, parentIDs, childIDs, qb); err != nil {
if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil {
logger.Errorf("Error saving tag: %s", err)
return err
}
@ -258,8 +258,8 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input TagDestroyInput
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.Tag().Destroy(tagID)
if err := r.withTxn(ctx, func(ctx context.Context) error {
return r.repository.Tag.Destroy(ctx, tagID)
}); err != nil {
return false, err
}
@ -275,10 +275,10 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo
return false, err
}
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Tag()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Tag
for _, id := range ids {
if err := qb.Destroy(id); err != nil {
if err := qb.Destroy(ctx, id); err != nil {
return err
}
}
@ -311,11 +311,11 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput)
}
var t *models.Tag
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Tag()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Tag
var err error
t, err = qb.Find(destination)
t, err = qb.Find(ctx, destination)
if err != nil {
return err
}
@ -324,25 +324,25 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput)
return fmt.Errorf("Tag with ID %d not found", destination)
}
parents, children, err := tag.MergeHierarchy(destination, source, qb)
parents, children, err := tag.MergeHierarchy(ctx, destination, source, qb)
if err != nil {
return err
}
if err = qb.Merge(source, destination); err != nil {
if err = qb.Merge(ctx, source, destination); err != nil {
return err
}
err = qb.UpdateParentTags(destination, parents)
err = qb.UpdateParentTags(ctx, destination, parents)
if err != nil {
return err
}
err = qb.UpdateChildTags(destination, children)
err = qb.UpdateChildTags(ctx, destination, children)
if err != nil {
return err
}
err = tag.ValidateHierarchy(t, parents, children, qb)
err = tag.ValidateHierarchy(ctx, t, parents, children, qb)
if err != nil {
logger.Errorf("Error merging tag: %s", err)
return err

View File

@ -16,17 +16,23 @@ import (
// TODO - move this into a common area
func newResolver() *Resolver {
return &Resolver{
txnManager: mocks.NewTransactionManager(),
txnManager: &mocks.TxnManager{},
repository: mocks.NewTxnRepository(),
hookExecutor: &mockHookExecutor{},
}
}
const tagName = "tagName"
const errTagName = "errTagName"
const (
tagName = "tagName"
errTagName = "errTagName"
const existingTagID = 1
const existingTagName = "existingTagName"
const newTagID = 2
existingTagID = 1
existingTagName = "existingTagName"
newTagID = 2
)
var testCtx = context.Background()
type mockHookExecutor struct{}
@ -36,7 +42,7 @@ func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType
func TestTagCreate(t *testing.T) {
r := newResolver()
tagRW := r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter)
tagRW := r.repository.Tag.(*mocks.TagReaderWriter)
pp := 1
findFilter := &models.FindFilterType{
@ -61,25 +67,25 @@ func TestTagCreate(t *testing.T) {
}
}
tagRW.On("Query", tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
tagRW.On("Query", testCtx, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, 1, nil).Once()
tagRW.On("Query", tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", testCtx, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", testCtx, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
expectedErr := errors.New("TagCreate error")
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
_, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{
_, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: existingTagName,
})
assert.NotNil(t, err)
_, err = r.Mutation().TagCreate(context.TODO(), TagCreateInput{
_, err = r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: errTagName,
})
@ -87,18 +93,18 @@ func TestTagCreate(t *testing.T) {
tagRW.AssertExpectations(t)
r = newResolver()
tagRW = r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter)
tagRW = r.repository.Tag.(*mocks.TagReaderWriter)
tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", testCtx, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", testCtx, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
newTag := &models.Tag{
ID: newTagID,
Name: tagName,
}
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil)
tagRW.On("Find", newTagID).Return(newTag, nil)
tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(newTag, nil)
tagRW.On("Find", testCtx, newTagID).Return(newTag, nil)
tag, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{
tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: tagName,
})

View File

@ -222,7 +222,7 @@ func makeConfigUIResult() map[string]interface{} {
}
func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) {
client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager)
client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager, r.stashboxRepository())
user, err := client.GetUser(ctx)
valid := user != nil && user.Me != nil

View File

@ -13,8 +13,8 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models
return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Gallery().Find(idInt)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Gallery.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@ -24,8 +24,8 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models
}
func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *FindGalleriesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
galleries, total, err := repo.Gallery().Query(galleryFilter, filter)
if err := r.withTxn(ctx, func(ctx context.Context) error {
galleries, total, err := r.repository.Gallery.Query(ctx, galleryFilter, filter)
if err != nil {
return err
}

View File

@ -12,8 +12,8 @@ import (
func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *string) (*models.Image, error) {
var image *models.Image
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Image
var err error
if id != nil {
@ -22,12 +22,12 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str
return err
}
image, err = qb.Find(idInt)
image, err = qb.Find(ctx, idInt)
if err != nil {
return err
}
} else if checksum != nil {
image, err = qb.FindByChecksum(*checksum)
image, err = qb.FindByChecksum(ctx, *checksum)
}
return err
@ -39,12 +39,12 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str
}
func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *FindImagesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Image()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Image
fields := graphql.CollectAllFields(ctx)
result, err := qb.Query(models.ImageQueryOptions{
result, err := qb.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: filter,
Count: stringslice.StrInclude(fields, "count"),
@ -57,7 +57,7 @@ func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.Imag
return err
}
images, err := result.Resolve()
images, err := result.Resolve(ctx)
if err != nil {
return err
}

View File

@ -13,8 +13,8 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M
return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Movie().Find(idInt)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Movie.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@ -24,8 +24,8 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M
}
func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *FindMoviesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
movies, total, err := repo.Movie().Query(movieFilter, filter)
if err := r.withTxn(ctx, func(ctx context.Context) error {
movies, total, err := r.repository.Movie.Query(ctx, movieFilter, filter)
if err != nil {
return err
}
@ -44,8 +44,8 @@ func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.Movi
}
func (r *queryResolver) AllMovies(ctx context.Context) (ret []*models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Movie().All()
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Movie.All(ctx)
return err
}); err != nil {
return nil, err

View File

@ -13,8 +13,8 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode
return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().Find(idInt)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Performer.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@ -24,8 +24,8 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode
}
func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *FindPerformersResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
performers, total, err := repo.Performer().Query(performerFilter, filter)
if err := r.withTxn(ctx, func(ctx context.Context) error {
performers, total, err := r.repository.Performer.Query(ctx, performerFilter, filter)
if err != nil {
return err
}
@ -43,8 +43,8 @@ func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *mod
}
func (r *queryResolver) AllPerformers(ctx context.Context) (ret []*models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().All()
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Performer.All(ctx)
return err
}); err != nil {
return nil, err

View File

@ -13,8 +13,8 @@ func (r *queryResolver) FindSavedFilter(ctx context.Context, id string) (ret *mo
return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SavedFilter().Find(idInt)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.SavedFilter.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@ -23,11 +23,11 @@ func (r *queryResolver) FindSavedFilter(ctx context.Context, id string) (ret *mo
}
func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.FilterMode) (ret []*models.SavedFilter, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
if mode != nil {
ret, err = repo.SavedFilter().FindByMode(*mode)
ret, err = r.repository.SavedFilter.FindByMode(ctx, *mode)
} else {
ret, err = repo.SavedFilter().All()
ret, err = r.repository.SavedFilter.All(ctx)
}
return err
}); err != nil {
@ -37,8 +37,8 @@ func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.Filte
}
func (r *queryResolver) FindDefaultFilter(ctx context.Context, mode models.FilterMode) (ret *models.SavedFilter, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SavedFilter().FindDefault(mode)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.SavedFilter.FindDefault(ctx, mode)
return err
}); err != nil {
return nil, err

View File

@ -12,20 +12,20 @@ import (
func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) {
var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
var err error
if id != nil {
idInt, err := strconv.Atoi(*id)
if err != nil {
return err
}
scene, err = qb.Find(idInt)
scene, err = qb.Find(ctx, idInt)
if err != nil {
return err
}
} else if checksum != nil {
scene, err = qb.FindByChecksum(*checksum)
scene, err = qb.FindByChecksum(ctx, *checksum)
}
return err
@ -39,18 +39,18 @@ func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *str
func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInput) (*models.Scene, error) {
var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Scene()
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
var err error
if input.Checksum != nil {
scene, err = qb.FindByChecksum(*input.Checksum)
scene, err = qb.FindByChecksum(ctx, *input.Checksum)
if err != nil {
return err
}
}
if scene == nil && input.Oshash != nil {
scene, err = qb.FindByOSHash(*input.Oshash)
scene, err = qb.FindByOSHash(ctx, *input.Oshash)
if err != nil {
return err
}
@ -65,7 +65,7 @@ func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInpu
}
func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIDs []int, filter *models.FindFilterType) (ret *FindScenesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var scenes []*models.Scene
var err error
@ -73,7 +73,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
result := &models.SceneQueryResult{}
if len(sceneIDs) > 0 {
scenes, err = repo.Scene().FindMany(sceneIDs)
scenes, err = r.repository.Scene.FindMany(ctx, sceneIDs)
if err == nil {
result.Count = len(scenes)
for _, s := range scenes {
@ -83,7 +83,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
}
}
} else {
result, err = repo.Scene().Query(models.SceneQueryOptions{
result, err = r.repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: filter,
Count: stringslice.StrInclude(fields, "count"),
@ -93,7 +93,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
TotalSize: stringslice.StrInclude(fields, "filesize"),
})
if err == nil {
scenes, err = result.Resolve()
scenes, err = result.Resolve(ctx)
}
}
@ -117,7 +117,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
}
func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *FindScenesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneFilter := &models.SceneFilterType{}
@ -138,7 +138,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
fields := graphql.CollectAllFields(ctx)
result, err := repo.Scene().Query(models.SceneQueryOptions{
result, err := r.repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: queryFilter,
Count: stringslice.StrInclude(fields, "count"),
@ -151,7 +151,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
return err
}
scenes, err := result.Resolve()
scenes, err := result.Resolve(ctx)
if err != nil {
return err
}
@ -174,8 +174,14 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config manager.SceneParserInput) (ret *SceneParserResultType, err error) {
parser := manager.NewSceneFilenameParser(filter, config)
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
result, count, err := parser.Parse(repo)
if err := r.withTxn(ctx, func(ctx context.Context) error {
result, count, err := parser.Parse(ctx, manager.SceneFilenameParserRepository{
Scene: r.repository.Scene,
Performer: r.repository.Performer,
Studio: r.repository.Studio,
Movie: r.repository.Movie,
Tag: r.repository.Tag,
})
if err != nil {
return err
@ -199,8 +205,8 @@ func (r *queryResolver) FindDuplicateScenes(ctx context.Context, distance *int)
if distance != nil {
dist = *distance
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().FindDuplicates(dist)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Scene.FindDuplicates(ctx, dist)
return err
}); err != nil {
return nil, err

View File

@ -7,8 +7,8 @@ import (
)
func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *FindSceneMarkersResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
sceneMarkers, total, err := repo.SceneMarker().Query(sceneMarkerFilter, filter)
if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneMarkers, total, err := r.repository.SceneMarker.Query(ctx, sceneMarkerFilter, filter)
if err != nil {
return err
}

View File

@ -13,9 +13,9 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.
return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
ret, err = repo.Studio().Find(idInt)
ret, err = r.repository.Studio.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@ -25,8 +25,8 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.
}
func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *FindStudiosResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
studios, total, err := repo.Studio().Query(studioFilter, filter)
if err := r.withTxn(ctx, func(ctx context.Context) error {
studios, total, err := r.repository.Studio.Query(ctx, studioFilter, filter)
if err != nil {
return err
}
@ -45,8 +45,8 @@ func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.St
}
func (r *queryResolver) AllStudios(ctx context.Context) (ret []*models.Studio, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().All()
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Studio.All(ctx)
return err
}); err != nil {
return nil, err

View File

@ -13,8 +13,8 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag
return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().Find(idInt)
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@ -24,8 +24,8 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag
}
func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *FindTagsResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
tags, total, err := repo.Tag().Query(tagFilter, filter)
if err := r.withTxn(ctx, func(ctx context.Context) error {
tags, total, err := r.repository.Tag.Query(ctx, tagFilter, filter)
if err != nil {
return err
}
@ -44,8 +44,8 @@ func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilte
}
func (r *queryResolver) AllTags(ctx context.Context) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().All()
if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.Tag.All(ctx)
return err
}); err != nil {
return nil, err

View File

@ -14,10 +14,10 @@ import (
func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manager.SceneStreamEndpoint, error) {
// find the scene
var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if err := r.withTxn(ctx, func(ctx context.Context) error {
idInt, _ := strconv.Atoi(*id)
var err error
scene, err = repo.Scene().Find(idInt)
scene, err = r.repository.Scene.Find(ctx, idInt)
return err
}); err != nil {
return nil, err

View File

@ -234,7 +234,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) {
return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index)
}
return stashbox.NewClient(*boxes[index], r.txnManager), nil
return stashbox.NewClient(*boxes[index], r.txnManager, r.stashboxRepository()), nil
}
func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) {

View File

@ -13,17 +13,24 @@ import (
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type ImageFinder interface {
Find(ctx context.Context, id int) (*models.Image, error)
FindByChecksum(ctx context.Context, checksum string) (*models.Image, error)
}
type imageRoutes struct {
txnManager models.TransactionManager
txnManager txn.Manager
imageFinder ImageFinder
}
func (rs imageRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{imageId}", func(r chi.Router) {
r.Use(ImageCtx)
r.Use(rs.ImageCtx)
r.Get("/image", rs.Image)
r.Get("/thumbnail", rs.Thumbnail)
@ -85,18 +92,18 @@ func (rs imageRoutes) Image(w http.ResponseWriter, r *http.Request) {
// endregion
func ImageCtx(next http.Handler) http.Handler {
func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
imageIdentifierQueryParam := chi.URLParam(r, "imageId")
imageID, _ := strconv.Atoi(imageIdentifierQueryParam)
var image *models.Image
readTxnErr := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
qb := repo.Image()
readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
qb := rs.imageFinder
if imageID == 0 {
image, _ = qb.FindByChecksum(imageIdentifierQueryParam)
image, _ = qb.FindByChecksum(ctx, imageIdentifierQueryParam)
} else {
image, _ = qb.Find(imageID)
image, _ = qb.Find(ctx, imageID)
}
return nil

View File

@ -6,21 +6,28 @@ import (
"strconv"
"github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
type MovieFinder interface {
GetFrontImage(ctx context.Context, movieID int) ([]byte, error)
GetBackImage(ctx context.Context, movieID int) ([]byte, error)
Find(ctx context.Context, id int) (*models.Movie, error)
}
type movieRoutes struct {
txnManager models.TransactionManager
txnManager txn.Manager
movieFinder MovieFinder
}
func (rs movieRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{movieId}", func(r chi.Router) {
r.Use(MovieCtx)
r.Use(rs.MovieCtx)
r.Get("/frontimage", rs.FrontImage)
r.Get("/backimage", rs.BackImage)
})
@ -33,8 +40,8 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default")
var image []byte
if defaultParam != "true" {
err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Movie().GetFrontImage(movie.ID)
err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = rs.movieFinder.GetFrontImage(ctx, movie.ID)
return nil
})
if err != nil {
@ -56,8 +63,8 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default")
var image []byte
if defaultParam != "true" {
err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Movie().GetBackImage(movie.ID)
err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = rs.movieFinder.GetBackImage(ctx, movie.ID)
return nil
})
if err != nil {
@ -74,7 +81,7 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
}
}
func MovieCtx(next http.Handler) http.Handler {
func (rs movieRoutes) MovieCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
movieID, err := strconv.Atoi(chi.URLParam(r, "movieId"))
if err != nil {
@ -83,9 +90,9 @@ func MovieCtx(next http.Handler) http.Handler {
}
var movie *models.Movie
if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
movie, err = repo.Movie().Find(movieID)
movie, err = rs.movieFinder.Find(ctx, movieID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404)

View File

@ -6,22 +6,28 @@ import (
"strconv"
"github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
type PerformerFinder interface {
Find(ctx context.Context, id int) (*models.Performer, error)
GetImage(ctx context.Context, performerID int) ([]byte, error)
}
type performerRoutes struct {
txnManager models.TransactionManager
txnManager txn.Manager
performerFinder PerformerFinder
}
func (rs performerRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{performerId}", func(r chi.Router) {
r.Use(PerformerCtx)
r.Use(rs.PerformerCtx)
r.Get("/image", rs.Image)
})
@ -34,8 +40,8 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
readTxnErr := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Performer().GetImage(performer.ID)
readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = rs.performerFinder.GetImage(ctx, performer.ID)
return nil
})
if readTxnErr != nil {
@ -52,7 +58,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
}
func PerformerCtx(next http.Handler) http.Handler {
func (rs performerRoutes) PerformerCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
performerID, err := strconv.Atoi(chi.URLParam(r, "performerId"))
if err != nil {
@ -61,9 +67,9 @@ func PerformerCtx(next http.Handler) http.Handler {
}
var performer *models.Performer
if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
performer, err = repo.Performer().Find(performerID)
performer, err = rs.performerFinder.Find(ctx, performerID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404)

View File

@ -15,18 +15,36 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
type SceneFinder interface {
manager.SceneCoverGetter
scene.IDFinder
FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error)
FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error)
GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error)
}
type SceneMarkerFinder interface {
Find(ctx context.Context, id int) (*models.SceneMarker, error)
FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error)
}
type sceneRoutes struct {
txnManager models.TransactionManager
txnManager txn.Manager
sceneFinder SceneFinder
sceneMarkerFinder SceneMarkerFinder
tagFinder scene.MarkerTagFinder
}
func (rs sceneRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{sceneId}", func(r chi.Router) {
r.Use(SceneCtx)
r.Use(rs.SceneCtx)
// streaming endpoints
r.Get("/stream", rs.StreamDirect)
@ -48,8 +66,8 @@ func (rs sceneRoutes) Routes() chi.Router {
r.Get("/scene_marker/{sceneMarkerId}/preview", rs.SceneMarkerPreview)
r.Get("/scene_marker/{sceneMarkerId}/screenshot", rs.SceneMarkerScreenshot)
})
r.With(SceneCtx).Get("/{sceneId}_thumbs.vtt", rs.VttThumbs)
r.With(SceneCtx).Get("/{sceneId}_sprite.jpg", rs.VttSprite)
r.With(rs.SceneCtx).Get("/{sceneId}_thumbs.vtt", rs.VttThumbs)
r.With(rs.SceneCtx).Get("/{sceneId}_sprite.jpg", rs.VttSprite)
return r
}
@ -60,7 +78,8 @@ func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene)
ss := manager.SceneServer{
TXNManager: rs.txnManager,
TxnManager: rs.txnManager,
SceneCoverGetter: rs.sceneFinder,
}
ss.StreamSceneDirect(scene, w, r)
}
@ -190,7 +209,8 @@ func (rs sceneRoutes) Screenshot(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene)
ss := manager.SceneServer{
TXNManager: rs.txnManager,
TxnManager: rs.txnManager,
SceneCoverGetter: rs.sceneFinder,
}
ss.ServeScreenshot(scene, w, r)
}
@ -221,16 +241,16 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
}
var ret string
if err := rs.txnManager.WithReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Tag()
primaryTag, err := qb.Find(marker.PrimaryTagID)
if err := txn.WithTxn(ctx, rs.txnManager, func(ctx context.Context) error {
qb := rs.tagFinder
primaryTag, err := qb.Find(ctx, marker.PrimaryTagID)
if err != nil {
return err
}
ret = primaryTag.Name
tags, err := qb.FindBySceneMarkerID(marker.ID)
tags, err := qb.FindBySceneMarkerID(ctx, marker.ID)
if err != nil {
return err
}
@ -250,9 +270,9 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene)
var sceneMarkers []*models.SceneMarker
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
sceneMarkers, err = repo.SceneMarker().FindBySceneID(scene.ID)
sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID)
return err
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -289,9 +309,9 @@ func (rs sceneRoutes) InteractiveHeatmap(w http.ResponseWriter, r *http.Request)
func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang string, ext string) {
s := r.Context().Value(sceneKey).(*models.Scene)
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
captions, err := repo.Scene().GetCaptions(s.ID)
captions, err := rs.sceneFinder.GetCaptions(ctx, s.ID)
for _, caption := range captions {
if lang == caption.LanguageCode && ext == caption.CaptionType {
sub, err := scene.ReadSubs(caption.Path(s.Path))
@ -344,9 +364,9 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request)
scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID)
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
}); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@ -367,9 +387,9 @@ func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request)
scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID)
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
}); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@ -400,9 +420,9 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID)
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
}); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@ -431,23 +451,23 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
// endregion
func SceneCtx(next http.Handler) http.Handler {
func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sceneIdentifierQueryParam := chi.URLParam(r, "sceneId")
sceneID, _ := strconv.Atoi(sceneIdentifierQueryParam)
var scene *models.Scene
readTxnErr := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
qb := repo.Scene()
readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
qb := rs.sceneFinder
if sceneID == 0 {
// determine checksum/os by the length of the query param
if len(sceneIdentifierQueryParam) == 32 {
scene, _ = qb.FindByChecksum(sceneIdentifierQueryParam)
scene, _ = qb.FindByChecksum(ctx, sceneIdentifierQueryParam)
} else {
scene, _ = qb.FindByOSHash(sceneIdentifierQueryParam)
scene, _ = qb.FindByOSHash(ctx, sceneIdentifierQueryParam)
}
} else {
scene, _ = qb.Find(sceneID)
scene, _ = qb.Find(ctx, sceneID)
}
return nil

View File

@ -8,21 +8,28 @@ import (
"syscall"
"github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
type StudioFinder interface {
studio.Finder
GetImage(ctx context.Context, studioID int) ([]byte, error)
}
type studioRoutes struct {
txnManager models.TransactionManager
txnManager txn.Manager
studioFinder StudioFinder
}
func (rs studioRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{studioId}", func(r chi.Router) {
r.Use(StudioCtx)
r.Use(rs.StudioCtx)
r.Get("/image", rs.Image)
})
@ -35,8 +42,8 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Studio().GetImage(studio.ID)
err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = rs.studioFinder.GetImage(ctx, studio.ID)
return nil
})
if err != nil {
@ -58,7 +65,7 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
}
func StudioCtx(next http.Handler) http.Handler {
func (rs studioRoutes) StudioCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
studioID, err := strconv.Atoi(chi.URLParam(r, "studioId"))
if err != nil {
@ -67,9 +74,9 @@ func StudioCtx(next http.Handler) http.Handler {
}
var studio *models.Studio
if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
studio, err = repo.Studio().Find(studioID)
studio, err = rs.studioFinder.Find(ctx, studioID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404)

View File

@ -6,21 +6,28 @@ import (
"strconv"
"github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/tag"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
type TagFinder interface {
tag.Finder
GetImage(ctx context.Context, tagID int) ([]byte, error)
}
type tagRoutes struct {
txnManager models.TransactionManager
txnManager txn.Manager
tagFinder TagFinder
}
func (rs tagRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{tagId}", func(r chi.Router) {
r.Use(TagCtx)
r.Use(rs.TagCtx)
r.Get("/image", rs.Image)
})
@ -33,8 +40,8 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Tag().GetImage(tag.ID)
err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = rs.tagFinder.GetImage(ctx, tag.ID)
return nil
})
if err != nil {
@ -51,7 +58,7 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
}
func TagCtx(next http.Handler) http.Handler {
func (rs tagRoutes) TagCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tagID, err := strconv.Atoi(chi.URLParam(r, "tagId"))
if err != nil {
@ -60,9 +67,9 @@ func TagCtx(next http.Handler) http.Handler {
}
var tag *models.Tag
if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
tag, err = repo.Tag().Find(tagID)
tag, err = rs.tagFinder.Find(ctx, tagID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404)

View File

@ -73,10 +73,11 @@ func Start() error {
return errors.New(message)
}
txnManager := manager.GetInstance().TxnManager
txnManager := manager.GetInstance().Repository
pluginCache := manager.GetInstance().PluginCache
resolver := &Resolver{
txnManager: txnManager,
repository: txnManager,
hookExecutor: pluginCache,
}
@ -118,22 +119,30 @@ func Start() error {
r.Get(loginEndPoint, getLoginHandler(loginUIBox))
r.Mount("/performer", performerRoutes{
txnManager: txnManager,
txnManager: txnManager,
performerFinder: txnManager.Performer,
}.Routes())
r.Mount("/scene", sceneRoutes{
txnManager: txnManager,
txnManager: txnManager,
sceneFinder: txnManager.Scene,
sceneMarkerFinder: txnManager.SceneMarker,
tagFinder: txnManager.Tag,
}.Routes())
r.Mount("/image", imageRoutes{
txnManager: txnManager,
txnManager: txnManager,
imageFinder: txnManager.Image,
}.Routes())
r.Mount("/studio", studioRoutes{
txnManager: txnManager,
txnManager: txnManager,
studioFinder: txnManager.Studio,
}.Routes())
r.Mount("/movie", movieRoutes{
txnManager: txnManager,
txnManager: txnManager,
movieFinder: txnManager.Movie,
}.Routes())
r.Mount("/tag", tagRoutes{
txnManager: txnManager,
tagFinder: txnManager.Tag,
}.Routes())
r.Mount("/downloads", downloadsRoutes{}.Routes())

View File

@ -1,6 +1,8 @@
package autotag
import (
"context"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
@ -21,18 +23,18 @@ func getGalleryFileTagger(s *models.Gallery, cache *match.Cache) tagger {
}
// GalleryPerformers tags the provided gallery with performers whose name matches the gallery's path.
func GalleryPerformers(s *models.Gallery, rw models.GalleryReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error {
func GalleryPerformers(ctx context.Context, s *models.Gallery, rw gallery.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache)
return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) {
return gallery.AddPerformer(rw, subjectID, otherID)
return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
return gallery.AddPerformer(ctx, rw, subjectID, otherID)
})
}
// GalleryStudios tags the provided gallery with the first studio whose name matches the gallery's path.
//
// Gallerys will not be tagged if studio is already set.
func GalleryStudios(s *models.Gallery, rw models.GalleryReaderWriter, studioReader models.StudioReader, cache *match.Cache) error {
func GalleryStudios(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid {
// don't modify
return nil
@ -40,16 +42,16 @@ func GalleryStudios(s *models.Gallery, rw models.GalleryReaderWriter, studioRead
t := getGalleryFileTagger(s, cache)
return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) {
return addGalleryStudio(rw, subjectID, otherID)
return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addGalleryStudio(ctx, rw, subjectID, otherID)
})
}
// GalleryTags tags the provided gallery with tags whose name matches the gallery's path.
func GalleryTags(s *models.Gallery, rw models.GalleryReaderWriter, tagReader models.TagReader, cache *match.Cache) error {
func GalleryTags(ctx context.Context, s *models.Gallery, rw gallery.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache)
return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) {
return gallery.AddTag(rw, subjectID, otherID)
return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
return gallery.AddTag(ctx, rw, subjectID, otherID)
})
}

View File

@ -1,6 +1,7 @@
package autotag
import (
"context"
"testing"
"github.com/stashapp/stash/pkg/models"
@ -11,6 +12,8 @@ import (
const galleryExt = "zip"
var testCtx = context.Background()
func TestGalleryPerformers(t *testing.T) {
t.Parallel()
@ -37,19 +40,19 @@ func TestGalleryPerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
mockGalleryReader.On("GetPerformerIDs", galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdatePerformers", galleryID, []int{performerID}).Return(nil).Once()
mockGalleryReader.On("GetPerformerIDs", testCtx, galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdatePerformers", testCtx, galleryID, []int{performerID}).Return(nil).Once()
}
gallery := models.Gallery{
ID: galleryID,
Path: models.NullString(test.Path),
}
err := GalleryPerformers(&gallery, mockGalleryReader, mockPerformerReader, nil)
err := GalleryPerformers(testCtx, &gallery, mockGalleryReader, mockPerformerReader, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
@ -81,9 +84,9 @@ func TestGalleryStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
if test.Matches {
mockGalleryReader.On("Find", galleryID).Return(&models.Gallery{}, nil).Once()
mockGalleryReader.On("Find", testCtx, galleryID).Return(&models.Gallery{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
mockGalleryReader.On("UpdatePartial", models.GalleryPartial{
mockGalleryReader.On("UpdatePartial", testCtx, models.GalleryPartial{
ID: galleryID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
@ -93,7 +96,7 @@ func TestGalleryStudios(t *testing.T) {
ID: galleryID,
Path: models.NullString(test.Path),
}
err := GalleryStudios(&gallery, mockGalleryReader, mockStudioReader, nil)
err := GalleryStudios(testCtx, &gallery, mockGalleryReader, mockStudioReader, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
@ -104,9 +107,9 @@ func TestGalleryStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockGalleryReader, test)
}
@ -119,12 +122,12 @@ func TestGalleryStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", studioID).Return([]string{
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockGalleryReader, test)
}
@ -154,15 +157,15 @@ func TestGalleryTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
if test.Matches {
mockGalleryReader.On("GetTagIDs", galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdateTags", galleryID, []int{tagID}).Return(nil).Once()
mockGalleryReader.On("GetTagIDs", testCtx, galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdateTags", testCtx, galleryID, []int{tagID}).Return(nil).Once()
}
gallery := models.Gallery{
ID: galleryID,
Path: models.NullString(test.Path),
}
err := GalleryTags(&gallery, mockGalleryReader, mockTagReader, nil)
err := GalleryTags(testCtx, &gallery, mockGalleryReader, mockTagReader, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
@ -173,9 +176,9 @@ func TestGalleryTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockGalleryReader, test)
}
@ -187,12 +190,12 @@ func TestGalleryTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", tagID).Return([]string{
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockGalleryReader, test)
}

View File

@ -1,6 +1,8 @@
package autotag
import (
"context"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
@ -17,18 +19,18 @@ func getImageFileTagger(s *models.Image, cache *match.Cache) tagger {
}
// ImagePerformers tags the provided image with performers whose name matches the image's path.
func ImagePerformers(s *models.Image, rw models.ImageReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error {
func ImagePerformers(ctx context.Context, s *models.Image, rw image.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache)
return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) {
return image.AddPerformer(rw, subjectID, otherID)
return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
return image.AddPerformer(ctx, rw, subjectID, otherID)
})
}
// ImageStudios tags the provided image with the first studio whose name matches the image's path.
//
// Images will not be tagged if studio is already set.
func ImageStudios(s *models.Image, rw models.ImageReaderWriter, studioReader models.StudioReader, cache *match.Cache) error {
func ImageStudios(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid {
// don't modify
return nil
@ -36,16 +38,16 @@ func ImageStudios(s *models.Image, rw models.ImageReaderWriter, studioReader mod
t := getImageFileTagger(s, cache)
return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) {
return addImageStudio(rw, subjectID, otherID)
return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addImageStudio(ctx, rw, subjectID, otherID)
})
}
// ImageTags tags the provided image with tags whose name matches the image's path.
func ImageTags(s *models.Image, rw models.ImageReaderWriter, tagReader models.TagReader, cache *match.Cache) error {
func ImageTags(ctx context.Context, s *models.Image, rw image.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache)
return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) {
return image.AddTag(rw, subjectID, otherID)
return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
return image.AddTag(ctx, rw, subjectID, otherID)
})
}

View File

@ -37,19 +37,19 @@ func TestImagePerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
mockImageReader.On("GetPerformerIDs", imageID).Return(nil, nil).Once()
mockImageReader.On("UpdatePerformers", imageID, []int{performerID}).Return(nil).Once()
mockImageReader.On("GetPerformerIDs", testCtx, imageID).Return(nil, nil).Once()
mockImageReader.On("UpdatePerformers", testCtx, imageID, []int{performerID}).Return(nil).Once()
}
image := models.Image{
ID: imageID,
Path: test.Path,
}
err := ImagePerformers(&image, mockImageReader, mockPerformerReader, nil)
err := ImagePerformers(testCtx, &image, mockImageReader, mockPerformerReader, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
@ -81,9 +81,9 @@ func TestImageStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
if test.Matches {
mockImageReader.On("Find", imageID).Return(&models.Image{}, nil).Once()
mockImageReader.On("Find", testCtx, imageID).Return(&models.Image{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
mockImageReader.On("Update", models.ImagePartial{
mockImageReader.On("Update", testCtx, models.ImagePartial{
ID: imageID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
@ -93,7 +93,7 @@ func TestImageStudios(t *testing.T) {
ID: imageID,
Path: test.Path,
}
err := ImageStudios(&image, mockImageReader, mockStudioReader, nil)
err := ImageStudios(testCtx, &image, mockImageReader, mockStudioReader, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
@ -104,9 +104,9 @@ func TestImageStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockImageReader, test)
}
@ -119,12 +119,12 @@ func TestImageStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", studioID).Return([]string{
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockImageReader, test)
}
@ -154,15 +154,15 @@ func TestImageTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
if test.Matches {
mockImageReader.On("GetTagIDs", imageID).Return(nil, nil).Once()
mockImageReader.On("UpdateTags", imageID, []int{tagID}).Return(nil).Once()
mockImageReader.On("GetTagIDs", testCtx, imageID).Return(nil, nil).Once()
mockImageReader.On("UpdateTags", testCtx, imageID, []int{tagID}).Return(nil).Once()
}
image := models.Image{
ID: imageID,
Path: test.Path,
}
err := ImageTags(&image, mockImageReader, mockTagReader, nil)
err := ImageTags(testCtx, &image, mockImageReader, mockTagReader, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
@ -173,9 +173,9 @@ func TestImageTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockImageReader, test)
}
@ -188,12 +188,12 @@ func TestImageTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", tagID).Return([]string{
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockImageReader, test)
}

View File

@ -10,10 +10,10 @@ import (
"os"
"testing"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sqlite"
"github.com/stashapp/stash/pkg/txn"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/golang-migrate/migrate/v4/source/file"
@ -28,8 +28,11 @@ const existingStudioGalleryName = testName + ".dontChangeStudio.mp4"
var existingStudioID int
var db *sqlite.Database
var r models.Repository
func testTeardown(databaseFile string) {
err := database.DB.Close()
err := db.Close()
if err != nil {
panic(err)
@ -50,10 +53,13 @@ func runTests(m *testing.M) int {
f.Close()
databaseFile := f.Name()
if err := database.Initialize(databaseFile); err != nil {
db = &sqlite.Database{}
if err := db.Open(databaseFile); err != nil {
panic(fmt.Sprintf("Could not initialize database: %s", err.Error()))
}
r = db.TxnRepository()
// defer close and delete the database
defer testTeardown(databaseFile)
@ -71,7 +77,7 @@ func TestMain(m *testing.M) {
os.Exit(ret)
}
func createPerformer(pqb models.PerformerWriter) error {
func createPerformer(ctx context.Context, pqb models.PerformerWriter) error {
// create the performer
performer := models.Performer{
Checksum: testName,
@ -79,7 +85,7 @@ func createPerformer(pqb models.PerformerWriter) error {
Favorite: sql.NullBool{Valid: true, Bool: false},
}
_, err := pqb.Create(performer)
_, err := pqb.Create(ctx, performer)
if err != nil {
return err
}
@ -87,23 +93,23 @@ func createPerformer(pqb models.PerformerWriter) error {
return nil
}
func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) {
func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) {
// create the studio
studio := models.Studio{
Checksum: name,
Name: sql.NullString{Valid: true, String: name},
}
return qb.Create(studio)
return qb.Create(ctx, studio)
}
func createTag(qb models.TagWriter) error {
func createTag(ctx context.Context, qb models.TagWriter) error {
// create the studio
tag := models.Tag{
Name: testName,
}
_, err := qb.Create(tag)
_, err := qb.Create(ctx, tag)
if err != nil {
return err
}
@ -111,18 +117,18 @@ func createTag(qb models.TagWriter) error {
return nil
}
func createScenes(sqb models.SceneReaderWriter) error {
func createScenes(ctx context.Context, sqb models.SceneReaderWriter) error {
// create the scenes
scenePatterns, falseScenePatterns := generateTestPaths(testName, sceneExt)
for _, fn := range scenePatterns {
err := createScene(sqb, makeScene(fn, true))
err := createScene(ctx, sqb, makeScene(fn, true))
if err != nil {
return err
}
}
for _, fn := range falseScenePatterns {
err := createScene(sqb, makeScene(fn, false))
err := createScene(ctx, sqb, makeScene(fn, false))
if err != nil {
return err
}
@ -132,7 +138,7 @@ func createScenes(sqb models.SceneReaderWriter) error {
for _, fn := range scenePatterns {
s := makeScene("organized"+fn, false)
s.Organized = true
err := createScene(sqb, s)
err := createScene(ctx, sqb, s)
if err != nil {
return err
}
@ -141,7 +147,7 @@ func createScenes(sqb models.SceneReaderWriter) error {
// create scene with existing studio io
studioScene := makeScene(existingStudioSceneName, true)
studioScene.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
err := createScene(sqb, studioScene)
err := createScene(ctx, sqb, studioScene)
if err != nil {
return err
}
@ -163,8 +169,8 @@ func makeScene(name string, expectedResult bool) *models.Scene {
return scene
}
func createScene(sqb models.SceneWriter, scene *models.Scene) error {
_, err := sqb.Create(*scene)
func createScene(ctx context.Context, sqb models.SceneWriter, scene *models.Scene) error {
_, err := sqb.Create(ctx, *scene)
if err != nil {
return fmt.Errorf("Failed to create scene with name '%s': %s", scene.Path, err.Error())
@ -173,18 +179,18 @@ func createScene(sqb models.SceneWriter, scene *models.Scene) error {
return nil
}
func createImages(sqb models.ImageReaderWriter) error {
func createImages(ctx context.Context, sqb models.ImageReaderWriter) error {
// create the images
imagePatterns, falseImagePatterns := generateTestPaths(testName, imageExt)
for _, fn := range imagePatterns {
err := createImage(sqb, makeImage(fn, true))
err := createImage(ctx, sqb, makeImage(fn, true))
if err != nil {
return err
}
}
for _, fn := range falseImagePatterns {
err := createImage(sqb, makeImage(fn, false))
err := createImage(ctx, sqb, makeImage(fn, false))
if err != nil {
return err
}
@ -194,7 +200,7 @@ func createImages(sqb models.ImageReaderWriter) error {
for _, fn := range imagePatterns {
s := makeImage("organized"+fn, false)
s.Organized = true
err := createImage(sqb, s)
err := createImage(ctx, sqb, s)
if err != nil {
return err
}
@ -203,7 +209,7 @@ func createImages(sqb models.ImageReaderWriter) error {
// create image with existing studio io
studioImage := makeImage(existingStudioImageName, true)
studioImage.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
err := createImage(sqb, studioImage)
err := createImage(ctx, sqb, studioImage)
if err != nil {
return err
}
@ -225,8 +231,8 @@ func makeImage(name string, expectedResult bool) *models.Image {
return image
}
func createImage(sqb models.ImageWriter, image *models.Image) error {
_, err := sqb.Create(*image)
func createImage(ctx context.Context, sqb models.ImageWriter, image *models.Image) error {
_, err := sqb.Create(ctx, *image)
if err != nil {
return fmt.Errorf("Failed to create image with name '%s': %s", image.Path, err.Error())
@ -235,18 +241,18 @@ func createImage(sqb models.ImageWriter, image *models.Image) error {
return nil
}
func createGalleries(sqb models.GalleryReaderWriter) error {
func createGalleries(ctx context.Context, sqb models.GalleryReaderWriter) error {
// create the galleries
galleryPatterns, falseGalleryPatterns := generateTestPaths(testName, galleryExt)
for _, fn := range galleryPatterns {
err := createGallery(sqb, makeGallery(fn, true))
err := createGallery(ctx, sqb, makeGallery(fn, true))
if err != nil {
return err
}
}
for _, fn := range falseGalleryPatterns {
err := createGallery(sqb, makeGallery(fn, false))
err := createGallery(ctx, sqb, makeGallery(fn, false))
if err != nil {
return err
}
@ -256,7 +262,7 @@ func createGalleries(sqb models.GalleryReaderWriter) error {
for _, fn := range galleryPatterns {
s := makeGallery("organized"+fn, false)
s.Organized = true
err := createGallery(sqb, s)
err := createGallery(ctx, sqb, s)
if err != nil {
return err
}
@ -265,7 +271,7 @@ func createGalleries(sqb models.GalleryReaderWriter) error {
// create gallery with existing studio io
studioGallery := makeGallery(existingStudioGalleryName, true)
studioGallery.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
err := createGallery(sqb, studioGallery)
err := createGallery(ctx, sqb, studioGallery)
if err != nil {
return err
}
@ -287,8 +293,8 @@ func makeGallery(name string, expectedResult bool) *models.Gallery {
return gallery
}
func createGallery(sqb models.GalleryWriter, gallery *models.Gallery) error {
_, err := sqb.Create(*gallery)
func createGallery(ctx context.Context, sqb models.GalleryWriter, gallery *models.Gallery) error {
_, err := sqb.Create(ctx, *gallery)
if err != nil {
return fmt.Errorf("Failed to create gallery with name '%s': %s", gallery.Path.String, err.Error())
@ -297,47 +303,46 @@ func createGallery(sqb models.GalleryWriter, gallery *models.Gallery) error {
return nil
}
func withTxn(f func(r models.Repository) error) error {
t := sqlite.NewTransactionManager()
return t.WithTxn(context.TODO(), f)
func withTxn(f func(ctx context.Context) error) error {
return txn.WithTxn(context.TODO(), db, f)
}
func populateDB() error {
if err := withTxn(func(r models.Repository) error {
err := createPerformer(r.Performer())
if err := withTxn(func(ctx context.Context) error {
err := createPerformer(ctx, r.Performer)
if err != nil {
return err
}
_, err = createStudio(r.Studio(), testName)
_, err = createStudio(ctx, r.Studio, testName)
if err != nil {
return err
}
// create existing studio
existingStudio, err := createStudio(r.Studio(), existingStudioName)
existingStudio, err := createStudio(ctx, r.Studio, existingStudioName)
if err != nil {
return err
}
existingStudioID = existingStudio.ID
err = createTag(r.Tag())
err = createTag(ctx, r.Tag)
if err != nil {
return err
}
err = createScenes(r.Scene())
err = createScenes(ctx, r.Scene)
if err != nil {
return err
}
err = createImages(r.Image())
err = createImages(ctx, r.Image)
if err != nil {
return err
}
err = createGalleries(r.Gallery())
err = createGalleries(ctx, r.Gallery)
if err != nil {
return err
}
@ -352,9 +357,9 @@ func populateDB() error {
func TestParsePerformerScenes(t *testing.T) {
var performers []*models.Performer
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
performers, err = r.Performer().All()
performers, err = r.Performer.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@ -362,24 +367,24 @@ func TestParsePerformerScenes(t *testing.T) {
}
for _, p := range performers {
if err := withTxn(func(r models.Repository) error {
return PerformerScenes(p, nil, r.Scene(), nil)
if err := withTxn(func(ctx context.Context) error {
return PerformerScenes(ctx, p, nil, r.Scene, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that scenes were tagged correctly
withTxn(func(r models.Repository) error {
pqb := r.Performer()
withTxn(func(ctx context.Context) error {
pqb := r.Performer
scenes, err := r.Scene().All()
scenes, err := r.Scene.All(ctx)
if err != nil {
t.Error(err.Error())
}
for _, scene := range scenes {
performers, err := pqb.FindBySceneID(scene.ID)
performers, err := pqb.FindBySceneID(ctx, scene.ID)
if err != nil {
t.Errorf("Error getting scene performers: %s", err.Error())
@ -399,9 +404,9 @@ func TestParsePerformerScenes(t *testing.T) {
func TestParseStudioScenes(t *testing.T) {
var studios []*models.Studio
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
studios, err = r.Studio().All()
studios, err = r.Studio.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting studio: %s", err)
@ -409,21 +414,21 @@ func TestParseStudioScenes(t *testing.T) {
}
for _, s := range studios {
if err := withTxn(func(r models.Repository) error {
aliases, err := r.Studio().GetAliases(s.ID)
if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil {
return err
}
return StudioScenes(s, nil, aliases, r.Scene(), nil)
return StudioScenes(ctx, s, nil, aliases, r.Scene, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that scenes were tagged correctly
withTxn(func(r models.Repository) error {
scenes, err := r.Scene().All()
withTxn(func(ctx context.Context) error {
scenes, err := r.Scene.All(ctx)
if err != nil {
t.Error(err.Error())
}
@ -455,9 +460,9 @@ func TestParseStudioScenes(t *testing.T) {
func TestParseTagScenes(t *testing.T) {
var tags []*models.Tag
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
tags, err = r.Tag().All()
tags, err = r.Tag.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@ -465,29 +470,29 @@ func TestParseTagScenes(t *testing.T) {
}
for _, s := range tags {
if err := withTxn(func(r models.Repository) error {
aliases, err := r.Tag().GetAliases(s.ID)
if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil {
return err
}
return TagScenes(s, nil, aliases, r.Scene(), nil)
return TagScenes(ctx, s, nil, aliases, r.Scene, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that scenes were tagged correctly
withTxn(func(r models.Repository) error {
scenes, err := r.Scene().All()
withTxn(func(ctx context.Context) error {
scenes, err := r.Scene.All(ctx)
if err != nil {
t.Error(err.Error())
}
tqb := r.Tag()
tqb := r.Tag
for _, scene := range scenes {
tags, err := tqb.FindBySceneID(scene.ID)
tags, err := tqb.FindBySceneID(ctx, scene.ID)
if err != nil {
t.Errorf("Error getting scene tags: %s", err.Error())
@ -507,9 +512,9 @@ func TestParseTagScenes(t *testing.T) {
func TestParsePerformerImages(t *testing.T) {
var performers []*models.Performer
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
performers, err = r.Performer().All()
performers, err = r.Performer.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@ -517,24 +522,24 @@ func TestParsePerformerImages(t *testing.T) {
}
for _, p := range performers {
if err := withTxn(func(r models.Repository) error {
return PerformerImages(p, nil, r.Image(), nil)
if err := withTxn(func(ctx context.Context) error {
return PerformerImages(ctx, p, nil, r.Image, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that images were tagged correctly
withTxn(func(r models.Repository) error {
pqb := r.Performer()
withTxn(func(ctx context.Context) error {
pqb := r.Performer
images, err := r.Image().All()
images, err := r.Image.All(ctx)
if err != nil {
t.Error(err.Error())
}
for _, image := range images {
performers, err := pqb.FindByImageID(image.ID)
performers, err := pqb.FindByImageID(ctx, image.ID)
if err != nil {
t.Errorf("Error getting image performers: %s", err.Error())
@ -554,9 +559,9 @@ func TestParsePerformerImages(t *testing.T) {
func TestParseStudioImages(t *testing.T) {
var studios []*models.Studio
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
studios, err = r.Studio().All()
studios, err = r.Studio.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting studio: %s", err)
@ -564,21 +569,21 @@ func TestParseStudioImages(t *testing.T) {
}
for _, s := range studios {
if err := withTxn(func(r models.Repository) error {
aliases, err := r.Studio().GetAliases(s.ID)
if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil {
return err
}
return StudioImages(s, nil, aliases, r.Image(), nil)
return StudioImages(ctx, s, nil, aliases, r.Image, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that images were tagged correctly
withTxn(func(r models.Repository) error {
images, err := r.Image().All()
withTxn(func(ctx context.Context) error {
images, err := r.Image.All(ctx)
if err != nil {
t.Error(err.Error())
}
@ -610,9 +615,9 @@ func TestParseStudioImages(t *testing.T) {
func TestParseTagImages(t *testing.T) {
var tags []*models.Tag
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
tags, err = r.Tag().All()
tags, err = r.Tag.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@ -620,29 +625,29 @@ func TestParseTagImages(t *testing.T) {
}
for _, s := range tags {
if err := withTxn(func(r models.Repository) error {
aliases, err := r.Tag().GetAliases(s.ID)
if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil {
return err
}
return TagImages(s, nil, aliases, r.Image(), nil)
return TagImages(ctx, s, nil, aliases, r.Image, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that images were tagged correctly
withTxn(func(r models.Repository) error {
images, err := r.Image().All()
withTxn(func(ctx context.Context) error {
images, err := r.Image.All(ctx)
if err != nil {
t.Error(err.Error())
}
tqb := r.Tag()
tqb := r.Tag
for _, image := range images {
tags, err := tqb.FindByImageID(image.ID)
tags, err := tqb.FindByImageID(ctx, image.ID)
if err != nil {
t.Errorf("Error getting image tags: %s", err.Error())
@ -662,9 +667,9 @@ func TestParseTagImages(t *testing.T) {
func TestParsePerformerGalleries(t *testing.T) {
var performers []*models.Performer
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
performers, err = r.Performer().All()
performers, err = r.Performer.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@ -672,24 +677,24 @@ func TestParsePerformerGalleries(t *testing.T) {
}
for _, p := range performers {
if err := withTxn(func(r models.Repository) error {
return PerformerGalleries(p, nil, r.Gallery(), nil)
if err := withTxn(func(ctx context.Context) error {
return PerformerGalleries(ctx, p, nil, r.Gallery, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that galleries were tagged correctly
withTxn(func(r models.Repository) error {
pqb := r.Performer()
withTxn(func(ctx context.Context) error {
pqb := r.Performer
galleries, err := r.Gallery().All()
galleries, err := r.Gallery.All(ctx)
if err != nil {
t.Error(err.Error())
}
for _, gallery := range galleries {
performers, err := pqb.FindByGalleryID(gallery.ID)
performers, err := pqb.FindByGalleryID(ctx, gallery.ID)
if err != nil {
t.Errorf("Error getting gallery performers: %s", err.Error())
@ -709,9 +714,9 @@ func TestParsePerformerGalleries(t *testing.T) {
func TestParseStudioGalleries(t *testing.T) {
var studios []*models.Studio
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
studios, err = r.Studio().All()
studios, err = r.Studio.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting studio: %s", err)
@ -719,21 +724,21 @@ func TestParseStudioGalleries(t *testing.T) {
}
for _, s := range studios {
if err := withTxn(func(r models.Repository) error {
aliases, err := r.Studio().GetAliases(s.ID)
if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil {
return err
}
return StudioGalleries(s, nil, aliases, r.Gallery(), nil)
return StudioGalleries(ctx, s, nil, aliases, r.Gallery, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that galleries were tagged correctly
withTxn(func(r models.Repository) error {
galleries, err := r.Gallery().All()
withTxn(func(ctx context.Context) error {
galleries, err := r.Gallery.All(ctx)
if err != nil {
t.Error(err.Error())
}
@ -765,9 +770,9 @@ func TestParseStudioGalleries(t *testing.T) {
func TestParseTagGalleries(t *testing.T) {
var tags []*models.Tag
if err := withTxn(func(r models.Repository) error {
if err := withTxn(func(ctx context.Context) error {
var err error
tags, err = r.Tag().All()
tags, err = r.Tag.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@ -775,29 +780,29 @@ func TestParseTagGalleries(t *testing.T) {
}
for _, s := range tags {
if err := withTxn(func(r models.Repository) error {
aliases, err := r.Tag().GetAliases(s.ID)
if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil {
return err
}
return TagGalleries(s, nil, aliases, r.Gallery(), nil)
return TagGalleries(ctx, s, nil, aliases, r.Gallery, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that galleries were tagged correctly
withTxn(func(r models.Repository) error {
galleries, err := r.Gallery().All()
withTxn(func(ctx context.Context) error {
galleries, err := r.Gallery.All(ctx)
if err != nil {
t.Error(err.Error())
}
tqb := r.Tag()
tqb := r.Tag
for _, gallery := range galleries {
tags, err := tqb.FindByGalleryID(gallery.ID)
tags, err := tqb.FindByGalleryID(ctx, gallery.ID)
if err != nil {
t.Errorf("Error getting gallery tags: %s", err.Error())

View File

@ -1,6 +1,8 @@
package autotag
import (
"context"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
@ -8,6 +10,21 @@ import (
"github.com/stashapp/stash/pkg/scene"
)
type SceneQueryPerformerUpdater interface {
scene.Queryer
scene.PerformerUpdater
}
type ImageQueryPerformerUpdater interface {
image.Queryer
image.PerformerUpdater
}
type GalleryQueryPerformerUpdater interface {
gallery.Queryer
gallery.PerformerUpdater
}
func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger {
return tagger{
ID: p.ID,
@ -18,28 +35,28 @@ func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger {
}
// PerformerScenes searches for scenes whose path matches the provided performer name and tags the scene with the performer.
func PerformerScenes(p *models.Performer, paths []string, rw models.SceneReaderWriter, cache *match.Cache) error {
func PerformerScenes(ctx context.Context, p *models.Performer, paths []string, rw SceneQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache)
return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
return scene.AddPerformer(rw, otherID, subjectID)
return t.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return scene.AddPerformer(ctx, rw, otherID, subjectID)
})
}
// PerformerImages searches for images whose path matches the provided performer name and tags the image with the performer.
func PerformerImages(p *models.Performer, paths []string, rw models.ImageReaderWriter, cache *match.Cache) error {
func PerformerImages(ctx context.Context, p *models.Performer, paths []string, rw ImageQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache)
return t.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) {
return image.AddPerformer(rw, otherID, subjectID)
return t.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return image.AddPerformer(ctx, rw, otherID, subjectID)
})
}
// PerformerGalleries searches for galleries whose path matches the provided performer name and tags the gallery with the performer.
func PerformerGalleries(p *models.Performer, paths []string, rw models.GalleryReaderWriter, cache *match.Cache) error {
func PerformerGalleries(ctx context.Context, p *models.Performer, paths []string, rw GalleryQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache)
return t.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) {
return gallery.AddPerformer(rw, otherID, subjectID)
return t.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return gallery.AddPerformer(ctx, rw, otherID, subjectID)
})
}

View File

@ -72,16 +72,16 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage,
}
mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
for i := range matchingPaths {
sceneID := i + 1
mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once()
mockSceneReader.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdatePerformers", testCtx, sceneID, []int{performerID}).Return(nil).Once()
}
err := PerformerScenes(&performer, nil, mockSceneReader, nil)
err := PerformerScenes(testCtx, &performer, nil, mockSceneReader, nil)
assert := assert.New(t)
@ -147,16 +147,16 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage,
}
mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
for i := range matchingPaths {
imageID := i + 1
mockImageReader.On("GetPerformerIDs", imageID).Return(nil, nil).Once()
mockImageReader.On("UpdatePerformers", imageID, []int{performerID}).Return(nil).Once()
mockImageReader.On("GetPerformerIDs", testCtx, imageID).Return(nil, nil).Once()
mockImageReader.On("UpdatePerformers", testCtx, imageID, []int{performerID}).Return(nil).Once()
}
err := PerformerImages(&performer, nil, mockImageReader, nil)
err := PerformerImages(testCtx, &performer, nil, mockImageReader, nil)
assert := assert.New(t)
@ -222,15 +222,15 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage,
}
mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
for i := range matchingPaths {
galleryID := i + 1
mockGalleryReader.On("GetPerformerIDs", galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdatePerformers", galleryID, []int{performerID}).Return(nil).Once()
mockGalleryReader.On("GetPerformerIDs", testCtx, galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdatePerformers", testCtx, galleryID, []int{performerID}).Return(nil).Once()
}
err := PerformerGalleries(&performer, nil, mockGalleryReader, nil)
err := PerformerGalleries(testCtx, &performer, nil, mockGalleryReader, nil)
assert := assert.New(t)

View File

@ -1,6 +1,8 @@
package autotag
import (
"context"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
@ -17,18 +19,18 @@ func getSceneFileTagger(s *models.Scene, cache *match.Cache) tagger {
}
// ScenePerformers tags the provided scene with performers whose name matches the scene's path.
func ScenePerformers(s *models.Scene, rw models.SceneReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error {
func ScenePerformers(ctx context.Context, s *models.Scene, rw scene.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache)
return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) {
return scene.AddPerformer(rw, subjectID, otherID)
return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
return scene.AddPerformer(ctx, rw, subjectID, otherID)
})
}
// SceneStudios tags the provided scene with the first studio whose name matches the scene's path.
//
// Scenes will not be tagged if studio is already set.
func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader models.StudioReader, cache *match.Cache) error {
func SceneStudios(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid {
// don't modify
return nil
@ -36,16 +38,16 @@ func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader mod
t := getSceneFileTagger(s, cache)
return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(rw, subjectID, otherID)
return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(ctx, rw, subjectID, otherID)
})
}
// SceneTags tags the provided scene with tags whose name matches the scene's path.
func SceneTags(s *models.Scene, rw models.SceneReaderWriter, tagReader models.TagReader, cache *match.Cache) error {
func SceneTags(ctx context.Context, s *models.Scene, rw scene.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache)
return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) {
return scene.AddTag(rw, subjectID, otherID)
return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
return scene.AddTag(ctx, rw, subjectID, otherID)
})
}

View File

@ -173,19 +173,19 @@ func TestScenePerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once()
mockSceneReader.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdatePerformers", testCtx, sceneID, []int{performerID}).Return(nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.Path,
}
err := ScenePerformers(&scene, mockSceneReader, mockPerformerReader, nil)
err := ScenePerformers(testCtx, &scene, mockSceneReader, mockPerformerReader, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
@ -217,9 +217,9 @@ func TestSceneStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
if test.Matches {
mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once()
mockSceneReader.On("Find", testCtx, sceneID).Return(&models.Scene{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
mockSceneReader.On("Update", models.ScenePartial{
mockSceneReader.On("Update", testCtx, models.ScenePartial{
ID: sceneID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
@ -229,7 +229,7 @@ func TestSceneStudios(t *testing.T) {
ID: sceneID,
Path: test.Path,
}
err := SceneStudios(&scene, mockSceneReader, mockStudioReader, nil)
err := SceneStudios(testCtx, &scene, mockSceneReader, mockStudioReader, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
@ -240,9 +240,9 @@ func TestSceneStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockSceneReader, test)
}
@ -255,12 +255,12 @@ func TestSceneStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", studioID).Return([]string{
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockSceneReader, test)
}
@ -290,15 +290,15 @@ func TestSceneTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
if test.Matches {
mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once()
mockSceneReader.On("GetTagIDs", testCtx, sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdateTags", testCtx, sceneID, []int{tagID}).Return(nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.Path,
}
err := SceneTags(&scene, mockSceneReader, mockTagReader, nil)
err := SceneTags(testCtx, &scene, mockSceneReader, mockTagReader, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
@ -309,9 +309,9 @@ func TestSceneTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockSceneReader, test)
}
@ -324,12 +324,12 @@ func TestSceneTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", tagID).Return([]string{
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockSceneReader, test)
}

View File

@ -1,15 +1,19 @@
package autotag
import (
"context"
"database/sql"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
)
func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int) (bool, error) {
func addSceneStudio(ctx context.Context, sceneWriter SceneFinderUpdater, sceneID, studioID int) (bool, error) {
// don't set if already set
scene, err := sceneWriter.Find(sceneID)
scene, err := sceneWriter.Find(ctx, sceneID)
if err != nil {
return false, err
}
@ -25,15 +29,15 @@ func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int)
StudioID: &s,
}
if _, err := sceneWriter.Update(scenePartial); err != nil {
if _, err := sceneWriter.Update(ctx, scenePartial); err != nil {
return false, err
}
return true, nil
}
func addImageStudio(imageWriter models.ImageReaderWriter, imageID, studioID int) (bool, error) {
func addImageStudio(ctx context.Context, imageWriter ImageFinderUpdater, imageID, studioID int) (bool, error) {
// don't set if already set
image, err := imageWriter.Find(imageID)
image, err := imageWriter.Find(ctx, imageID)
if err != nil {
return false, err
}
@ -49,15 +53,15 @@ func addImageStudio(imageWriter models.ImageReaderWriter, imageID, studioID int)
StudioID: &s,
}
if _, err := imageWriter.Update(imagePartial); err != nil {
if _, err := imageWriter.Update(ctx, imagePartial); err != nil {
return false, err
}
return true, nil
}
func addGalleryStudio(galleryWriter models.GalleryReaderWriter, galleryID, studioID int) (bool, error) {
func addGalleryStudio(ctx context.Context, galleryWriter GalleryFinderUpdater, galleryID, studioID int) (bool, error) {
// don't set if already set
gallery, err := galleryWriter.Find(galleryID)
gallery, err := galleryWriter.Find(ctx, galleryID)
if err != nil {
return false, err
}
@ -73,7 +77,7 @@ func addGalleryStudio(galleryWriter models.GalleryReaderWriter, galleryID, studi
StudioID: &s,
}
if _, err := galleryWriter.UpdatePartial(galleryPartial); err != nil {
if _, err := galleryWriter.UpdatePartial(ctx, galleryPartial); err != nil {
return false, err
}
return true, nil
@ -98,13 +102,19 @@ func getStudioTagger(p *models.Studio, aliases []string, cache *match.Cache) []t
return ret
}
type SceneFinderUpdater interface {
scene.Queryer
Find(ctx context.Context, id int) (*models.Scene, error)
Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error)
}
// StudioScenes searches for scenes whose path matches the provided studio name and tags the scene with the studio, if studio is not already set on the scene.
func StudioScenes(p *models.Studio, paths []string, aliases []string, rw models.SceneReaderWriter, cache *match.Cache) error {
func StudioScenes(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw SceneFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache)
for _, tt := range t {
if err := tt.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(rw, otherID, subjectID)
if err := tt.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
@ -113,13 +123,19 @@ func StudioScenes(p *models.Studio, paths []string, aliases []string, rw models.
return nil
}
type ImageFinderUpdater interface {
image.Queryer
Find(ctx context.Context, id int) (*models.Image, error)
Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error)
}
// StudioImages searches for images whose path matches the provided studio name and tags the image with the studio, if studio is not already set on the image.
func StudioImages(p *models.Studio, paths []string, aliases []string, rw models.ImageReaderWriter, cache *match.Cache) error {
func StudioImages(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw ImageFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache)
for _, tt := range t {
if err := tt.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) {
return addImageStudio(rw, otherID, subjectID)
if err := tt.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return addImageStudio(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
@ -128,13 +144,19 @@ func StudioImages(p *models.Studio, paths []string, aliases []string, rw models.
return nil
}
type GalleryFinderUpdater interface {
gallery.Queryer
Find(ctx context.Context, id int) (*models.Gallery, error)
UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error)
}
// StudioGalleries searches for galleries whose path matches the provided studio name and tags the gallery with the studio, if studio is not already set on the gallery.
func StudioGalleries(p *models.Studio, paths []string, aliases []string, rw models.GalleryReaderWriter, cache *match.Cache) error {
func StudioGalleries(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw GalleryFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache)
for _, tt := range t {
if err := tt.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) {
return addGalleryStudio(rw, otherID, subjectID)
if err := tt.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return addGalleryStudio(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}

View File

@ -113,7 +113,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
@ -128,21 +128,21 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
},
}
mockSceneReader.On("Query", scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
}
for i := range matchingPaths {
sceneID := i + 1
mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once()
mockSceneReader.On("Find", testCtx, sceneID).Return(&models.Scene{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
mockSceneReader.On("Update", models.ScenePartial{
mockSceneReader.On("Update", testCtx, models.ScenePartial{
ID: sceneID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
}
err := StudioScenes(&studio, nil, aliases, mockSceneReader, nil)
err := StudioScenes(testCtx, &studio, nil, aliases, mockSceneReader, nil)
assert := assert.New(t)
@ -206,7 +206,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else {
@ -220,21 +220,21 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
},
}
mockImageReader.On("Query", image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
mockImageReader.On("Query", testCtx, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
}
for i := range matchingPaths {
imageID := i + 1
mockImageReader.On("Find", imageID).Return(&models.Image{}, nil).Once()
mockImageReader.On("Find", testCtx, imageID).Return(&models.Image{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
mockImageReader.On("Update", models.ImagePartial{
mockImageReader.On("Update", testCtx, models.ImagePartial{
ID: imageID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
}
err := StudioImages(&studio, nil, aliases, mockImageReader, nil)
err := StudioImages(testCtx, &studio, nil, aliases, mockImageReader, nil)
assert := assert.New(t)
@ -297,7 +297,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter)
onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once()
} else {
@ -311,20 +311,20 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
},
}
mockGalleryReader.On("Query", expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
mockGalleryReader.On("Query", testCtx, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
}
for i := range matchingPaths {
galleryID := i + 1
mockGalleryReader.On("Find", galleryID).Return(&models.Gallery{}, nil).Once()
mockGalleryReader.On("Find", testCtx, galleryID).Return(&models.Gallery{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
mockGalleryReader.On("UpdatePartial", models.GalleryPartial{
mockGalleryReader.On("UpdatePartial", testCtx, models.GalleryPartial{
ID: galleryID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
}
err := StudioGalleries(&studio, nil, aliases, mockGalleryReader, nil)
err := StudioGalleries(testCtx, &studio, nil, aliases, mockGalleryReader, nil)
assert := assert.New(t)

View File

@ -1,6 +1,8 @@
package autotag
import (
"context"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
@ -8,6 +10,21 @@ import (
"github.com/stashapp/stash/pkg/scene"
)
type SceneQueryTagUpdater interface {
scene.Queryer
scene.TagUpdater
}
type ImageQueryTagUpdater interface {
image.Queryer
image.TagUpdater
}
type GalleryQueryTagUpdater interface {
gallery.Queryer
gallery.TagUpdater
}
func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger {
ret := []tagger{{
ID: p.ID,
@ -29,12 +46,12 @@ func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger
}
// TagScenes searches for scenes whose path matches the provided tag name and tags the scene with the tag.
func TagScenes(p *models.Tag, paths []string, aliases []string, rw models.SceneReaderWriter, cache *match.Cache) error {
func TagScenes(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw SceneQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache)
for _, tt := range t {
if err := tt.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
return scene.AddTag(rw, otherID, subjectID)
if err := tt.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return scene.AddTag(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
@ -43,12 +60,12 @@ func TagScenes(p *models.Tag, paths []string, aliases []string, rw models.SceneR
}
// TagImages searches for images whose path matches the provided tag name and tags the image with the tag.
func TagImages(p *models.Tag, paths []string, aliases []string, rw models.ImageReaderWriter, cache *match.Cache) error {
func TagImages(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw ImageQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache)
for _, tt := range t {
if err := tt.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) {
return image.AddTag(rw, otherID, subjectID)
if err := tt.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return image.AddTag(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
@ -57,12 +74,12 @@ func TagImages(p *models.Tag, paths []string, aliases []string, rw models.ImageR
}
// TagGalleries searches for galleries whose path matches the provided tag name and tags the gallery with the tag.
func TagGalleries(p *models.Tag, paths []string, aliases []string, rw models.GalleryReaderWriter, cache *match.Cache) error {
func TagGalleries(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw GalleryQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache)
for _, tt := range t {
if err := tt.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) {
return gallery.AddTag(rw, otherID, subjectID)
if err := tt.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return gallery.AddTag(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}

View File

@ -113,7 +113,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} else {
@ -127,17 +127,17 @@ func testTagScenes(t *testing.T, tc testTagCase) {
},
}
mockSceneReader.On("Query", scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
}
for i := range matchingPaths {
sceneID := i + 1
mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once()
mockSceneReader.On("GetTagIDs", testCtx, sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdateTags", testCtx, sceneID, []int{tagID}).Return(nil).Once()
}
err := TagScenes(&tag, nil, aliases, mockSceneReader, nil)
err := TagScenes(testCtx, &tag, nil, aliases, mockSceneReader, nil)
assert := assert.New(t)
@ -201,7 +201,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else {
@ -215,17 +215,17 @@ func testTagImages(t *testing.T, tc testTagCase) {
},
}
mockImageReader.On("Query", image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
mockImageReader.On("Query", testCtx, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
}
for i := range matchingPaths {
imageID := i + 1
mockImageReader.On("GetTagIDs", imageID).Return(nil, nil).Once()
mockImageReader.On("UpdateTags", imageID, []int{tagID}).Return(nil).Once()
mockImageReader.On("GetTagIDs", testCtx, imageID).Return(nil, nil).Once()
mockImageReader.On("UpdateTags", testCtx, imageID, []int{tagID}).Return(nil).Once()
}
err := TagImages(&tag, nil, aliases, mockImageReader, nil)
err := TagImages(testCtx, &tag, nil, aliases, mockImageReader, nil)
assert := assert.New(t)
@ -289,7 +289,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
onNameQuery := mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter)
onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once()
} else {
@ -303,16 +303,16 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
},
}
mockGalleryReader.On("Query", expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
mockGalleryReader.On("Query", testCtx, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
}
for i := range matchingPaths {
galleryID := i + 1
mockGalleryReader.On("GetTagIDs", galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdateTags", galleryID, []int{tagID}).Return(nil).Once()
mockGalleryReader.On("GetTagIDs", testCtx, galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdateTags", testCtx, galleryID, []int{tagID}).Return(nil).Once()
}
err := TagGalleries(&tag, nil, aliases, mockGalleryReader, nil)
err := TagGalleries(testCtx, &tag, nil, aliases, mockGalleryReader, nil)
assert := assert.New(t)

View File

@ -14,11 +14,14 @@
package autotag
import (
"context"
"fmt"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
)
type tagger struct {
@ -41,8 +44,8 @@ func (t *tagger) addLog(otherType, otherName string) {
logger.Infof("Added %s '%s' to %s '%s'", otherType, otherName, t.Type, t.Name)
}
func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc addLinkFunc) error {
others, err := match.PathToPerformers(t.Path, performerReader, t.cache, t.trimExt)
func (t *tagger) tagPerformers(ctx context.Context, performerReader match.PerformerAutoTagQueryer, addFunc addLinkFunc) error {
others, err := match.PathToPerformers(ctx, t.Path, performerReader, t.cache, t.trimExt)
if err != nil {
return err
}
@ -62,8 +65,8 @@ func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc a
return nil
}
func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFunc) error {
studio, err := match.PathToStudio(t.Path, studioReader, t.cache, t.trimExt)
func (t *tagger) tagStudios(ctx context.Context, studioReader match.StudioAutoTagQueryer, addFunc addLinkFunc) error {
studio, err := match.PathToStudio(ctx, t.Path, studioReader, t.cache, t.trimExt)
if err != nil {
return err
}
@ -83,8 +86,8 @@ func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFun
return nil
}
func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error {
others, err := match.PathToTags(t.Path, tagReader, t.cache, t.trimExt)
func (t *tagger) tagTags(ctx context.Context, tagReader match.TagAutoTagQueryer, addFunc addLinkFunc) error {
others, err := match.PathToTags(ctx, t.Path, tagReader, t.cache, t.trimExt)
if err != nil {
return err
}
@ -104,8 +107,8 @@ func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error
return nil
}
func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFunc addLinkFunc) error {
others, err := match.PathToScenes(t.Name, paths, sceneReader)
func (t *tagger) tagScenes(ctx context.Context, paths []string, sceneReader scene.Queryer, addFunc addLinkFunc) error {
others, err := match.PathToScenes(ctx, t.Name, paths, sceneReader)
if err != nil {
return err
}
@ -125,8 +128,8 @@ func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFu
return nil
}
func (t *tagger) tagImages(paths []string, imageReader models.ImageReader, addFunc addLinkFunc) error {
others, err := match.PathToImages(t.Name, paths, imageReader)
func (t *tagger) tagImages(ctx context.Context, paths []string, imageReader image.Queryer, addFunc addLinkFunc) error {
others, err := match.PathToImages(ctx, t.Name, paths, imageReader)
if err != nil {
return err
}
@ -146,8 +149,8 @@ func (t *tagger) tagImages(paths []string, imageReader models.ImageReader, addFu
return nil
}
func (t *tagger) tagGalleries(paths []string, galleryReader models.GalleryReader, addFunc addLinkFunc) error {
others, err := match.PathToGalleries(t.Name, paths, galleryReader)
func (t *tagger) tagGalleries(ctx context.Context, paths []string, galleryReader gallery.Queryer, addFunc addLinkFunc) error {
others, err := match.PathToGalleries(ctx, t.Name, paths, galleryReader)
if err != nil {
return err
}

View File

@ -41,6 +41,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/txn"
)
var pageSize = 100
@ -56,7 +57,6 @@ type browse struct {
type contentDirectoryService struct {
*Server
upnp.Eventing
txnManager models.TransactionManager
}
func formatDurationSexagesimal(d time.Duration) string {
@ -352,8 +352,8 @@ func (me *contentDirectoryService) handleBrowseMetadata(obj object, host string)
} else {
var scene *models.Scene
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
scene, err = r.Scene().Find(sceneID)
if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
scene, err = me.repository.SceneFinder.Find(ctx, sceneID)
if err != nil {
return err
}
@ -431,14 +431,14 @@ func getRootObjects() []interface{} {
func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} {
var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
sort := "title"
findFilter := &models.FindFilterType{
PerPage: &pageSize,
Sort: &sort,
}
scenes, total, err := scene.QueryWithCount(r.Scene(), sceneFilter, findFilter)
scenes, total, err := scene.QueryWithCount(ctx, me.repository.SceneFinder, sceneFilter, findFilter)
if err != nil {
return err
}
@ -449,7 +449,7 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
parentID: parentID,
}
objs, err = pager.getPages(r, total)
objs, err = pager.getPages(ctx, me.repository.SceneFinder, total)
if err != nil {
return err
}
@ -470,14 +470,14 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} {
var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
pager := scenePager{
sceneFilter: sceneFilter,
parentID: parentID,
}
var err error
objs, err = pager.getPageVideos(r, page, host)
objs, err = pager.getPageVideos(ctx, me.repository.SceneFinder, page, host)
if err != nil {
return err
}
@ -511,8 +511,8 @@ func (me *contentDirectoryService) getAllScenes(host string) []interface{} {
func (me *contentDirectoryService) getStudios() []interface{} {
var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
studios, err := r.Studio().All()
if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
studios, err := me.repository.StudioFinder.All(ctx)
if err != nil {
return err
}
@ -550,8 +550,8 @@ func (me *contentDirectoryService) getStudioScenes(paths []string, host string)
func (me *contentDirectoryService) getTags() []interface{} {
var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
tags, err := r.Tag().All()
if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
tags, err := me.repository.TagFinder.All(ctx)
if err != nil {
return err
}
@ -589,8 +589,8 @@ func (me *contentDirectoryService) getTagScenes(paths []string, host string) []i
func (me *contentDirectoryService) getPerformers() []interface{} {
var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
performers, err := r.Performer().All()
if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
performers, err := me.repository.PerformerFinder.All(ctx)
if err != nil {
return err
}
@ -628,8 +628,8 @@ func (me *contentDirectoryService) getPerformerScenes(paths []string, host strin
func (me *contentDirectoryService) getMovies() []interface{} {
var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
movies, err := r.Movie().All()
if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
movies, err := me.repository.MovieFinder.All(ctx)
if err != nil {
return err
}

View File

@ -31,7 +31,6 @@ import (
"strings"
"testing"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert"
)
@ -59,8 +58,7 @@ func TestRootParentObjectID(t *testing.T) {
func testHandleBrowse(argsXML string) (map[string]string, error) {
cds := contentDirectoryService{
Server: &Server{},
txnManager: mocks.NewTransactionManager(),
Server: &Server{},
}
r := &http.Request{}

View File

@ -48,8 +48,31 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/txn"
)
type SceneFinder interface {
scene.Queryer
scene.IDFinder
}
type StudioFinder interface {
All(ctx context.Context) ([]*models.Studio, error)
}
type TagFinder interface {
All(ctx context.Context) ([]*models.Tag, error)
}
type PerformerFinder interface {
All(ctx context.Context) ([]*models.Performer, error)
}
type MovieFinder interface {
All(ctx context.Context) ([]*models.Movie, error)
}
const (
serverField = "Linux/3.4 DLNADOC/1.50 UPnP/1.0 DMS/1.0"
rootDeviceType = "urn:schemas-upnp-org:device:MediaServer:1"
@ -249,7 +272,8 @@ type Server struct {
// Time interval between SSPD announces
NotifyInterval time.Duration
txnManager models.TransactionManager
txnManager txn.Manager
repository Repository
sceneServer sceneServer
ipWhitelistManager *ipWhitelistManager
}
@ -415,12 +439,12 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) {
}
var scene *models.Scene
err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error {
err := txn.WithTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
idInt, err := strconv.Atoi(sceneId)
if err != nil {
return nil
}
scene, _ = r.Scene().Find(idInt)
scene, _ = me.repository.SceneFinder.Find(ctx, idInt)
return nil
})
if err != nil {
@ -555,12 +579,12 @@ func (me *Server) initMux(mux *http.ServeMux) {
mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) {
sceneId := r.URL.Query().Get("scene")
var scene *models.Scene
err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error {
err := txn.WithTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
sceneIdInt, err := strconv.Atoi(sceneId)
if err != nil {
return nil
}
scene, _ = r.Scene().Find(sceneIdInt)
scene, _ = me.repository.SceneFinder.Find(ctx, sceneIdInt)
return nil
})
if err != nil {
@ -595,8 +619,7 @@ func (me *Server) initMux(mux *http.ServeMux) {
func (me *Server) initServices() {
me.services = map[string]UPnPService{
"ContentDirectory": &contentDirectoryService{
Server: me,
txnManager: me.txnManager,
Server: me,
},
"ConnectionManager": &connectionManagerService{
Server: me,

View File

@ -1,6 +1,7 @@
package dlna
import (
"context"
"fmt"
"math"
"strconv"
@ -18,7 +19,7 @@ func (p *scenePager) getPageID(page int) string {
return p.parentID + "/page/" + strconv.Itoa(page)
}
func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface{}, error) {
func (p *scenePager) getPages(ctx context.Context, r scene.Queryer, total int) ([]interface{}, error) {
var objs []interface{}
// get the first scene of each page to set an appropriate title
@ -37,7 +38,7 @@ func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface
if pages <= 10 || (page-1)%(pages/10) == 0 {
thisPage := ((page - 1) * pageSize) + 1
findFilter.Page = &thisPage
scenes, err := scene.Query(r.Scene(), p.sceneFilter, findFilter)
scenes, err := scene.Query(ctx, r, p.sceneFilter, findFilter)
if err != nil {
return nil, err
}
@ -58,7 +59,7 @@ func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface
return objs, nil
}
func (p *scenePager) getPageVideos(r models.ReaderRepository, page int, host string) ([]interface{}, error) {
func (p *scenePager) getPageVideos(ctx context.Context, r SceneFinder, page int, host string) ([]interface{}, error) {
var objs []interface{}
sort := "title"
@ -68,7 +69,7 @@ func (p *scenePager) getPageVideos(r models.ReaderRepository, page int, host str
Sort: &sort,
}
scenes, err := scene.Query(r.Scene(), p.sceneFilter, findFilter)
scenes, err := scene.Query(ctx, r, p.sceneFilter, findFilter)
if err != nil {
return nil, err
}

View File

@ -10,8 +10,17 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type Repository struct {
SceneFinder SceneFinder
StudioFinder StudioFinder
TagFinder TagFinder
PerformerFinder PerformerFinder
MovieFinder MovieFinder
}
type Status struct {
Running bool `json:"running"`
// If not currently running, time until it will be started. If running, time until it will be stopped
@ -48,7 +57,8 @@ type Config interface {
}
type Service struct {
txnManager models.TransactionManager
txnManager txn.Manager
repository Repository
config Config
sceneServer sceneServer
ipWhitelistMgr *ipWhitelistManager
@ -121,6 +131,7 @@ func (s *Service) init() error {
s.server = &Server{
txnManager: s.txnManager,
sceneServer: s.sceneServer,
repository: s.repository,
ipWhitelistManager: s.ipWhitelistMgr,
Interfaces: interfaces,
HTTPConn: func() net.Listener {
@ -181,9 +192,10 @@ func (s *Service) init() error {
// }
// NewService initialises and returns a new DLNA service.
func NewService(txnManager models.TransactionManager, cfg Config, sceneServer sceneServer) *Service {
func NewService(txnManager txn.Manager, repo Repository, cfg Config, sceneServer sceneServer) *Service {
ret := &Service{
txnManager: txnManager,
repository: repo,
sceneServer: sceneServer,
config: cfg,
ipWhitelistMgr: &ipWhitelistManager{

View File

@ -9,6 +9,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@ -28,13 +29,18 @@ type ScraperSource struct {
}
type SceneIdentifier struct {
SceneReaderUpdater SceneReaderUpdater
StudioCreator StudioCreator
PerformerCreator PerformerCreator
TagCreator TagCreator
DefaultOptions *MetadataOptions
Sources []ScraperSource
ScreenshotSetter scene.ScreenshotSetter
SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor
}
func (t *SceneIdentifier) Identify(ctx context.Context, txnManager models.TransactionManager, scene *models.Scene) error {
func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, scene *models.Scene) error {
result, err := t.scrapeScene(ctx, scene)
if err != nil {
return err
@ -80,7 +86,7 @@ func (t *SceneIdentifier) scrapeScene(ctx context.Context, scene *models.Scene)
return nil, nil
}
func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, result *scrapeResult, repo models.Repository) (*scene.UpdateSet, error) {
func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, result *scrapeResult) (*scene.UpdateSet, error) {
ret := &scene.UpdateSet{
ID: s.ID,
}
@ -106,15 +112,18 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
scraped := result.result
rel := sceneRelationships{
repo: repo,
scene: s,
result: result,
fieldOptions: fieldOptions,
sceneReader: t.SceneReaderUpdater,
studioCreator: t.StudioCreator,
performerCreator: t.PerformerCreator,
tagCreator: t.TagCreator,
scene: s,
result: result,
fieldOptions: fieldOptions,
}
ret.Partial = getScenePartial(s, scraped, fieldOptions, setOrganized)
studioID, err := rel.studio()
studioID, err := rel.studio(ctx)
if err != nil {
return nil, fmt.Errorf("error getting studio: %w", err)
}
@ -134,17 +143,17 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
}
}
ret.PerformerIDs, err = rel.performers(ignoreMale)
ret.PerformerIDs, err = rel.performers(ctx, ignoreMale)
if err != nil {
return nil, err
}
ret.TagIDs, err = rel.tags()
ret.TagIDs, err = rel.tags(ctx)
if err != nil {
return nil, err
}
ret.StashIDs, err = rel.stashIDs()
ret.StashIDs, err = rel.stashIDs(ctx)
if err != nil {
return nil, err
}
@ -167,11 +176,11 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
return ret, nil
}
func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.TransactionManager, s *models.Scene, result *scrapeResult) error {
func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, result *scrapeResult) error {
var updater *scene.UpdateSet
if err := txnManager.WithTxn(ctx, func(repo models.Repository) error {
if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
var err error
updater, err = t.getSceneUpdater(ctx, s, result, repo)
updater, err = t.getSceneUpdater(ctx, s, result)
if err != nil {
return err
}
@ -182,7 +191,7 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.Tra
return nil
}
_, err = updater.Update(repo.Scene(), t.ScreenshotSetter)
_, err = updater.Update(ctx, t.SceneReaderUpdater, t.ScreenshotSetter)
if err != nil {
return fmt.Errorf("error updating scene: %w", err)
}

View File

@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/mock"
)
var testCtx = context.Background()
type mockSceneScraper struct {
errIDs []int
results map[int]*scraper.ScrapedScene
@ -70,11 +72,12 @@ func TestSceneIdentifier_Identify(t *testing.T) {
},
}
repo := mocks.NewTransactionManager()
repo.Scene().(*mocks.SceneReaderWriter).On("Update", mock.MatchedBy(func(partial models.ScenePartial) bool {
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockSceneReaderWriter.On("Update", testCtx, mock.MatchedBy(func(partial models.ScenePartial) bool {
return partial.ID != errUpdateID
})).Return(nil, nil)
repo.Scene().(*mocks.SceneReaderWriter).On("Update", mock.MatchedBy(func(partial models.ScenePartial) bool {
mockSceneReaderWriter.On("Update", testCtx, mock.MatchedBy(func(partial models.ScenePartial) bool {
return partial.ID == errUpdateID
})).Return(nil, errors.New("update error"))
@ -116,6 +119,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
}
identifier := SceneIdentifier{
SceneReaderUpdater: mockSceneReaderWriter,
DefaultOptions: defaultOptions,
Sources: sources,
SceneUpdatePostHookExecutor: mockHookExecutor{},
@ -126,7 +130,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
scene := &models.Scene{
ID: tt.sceneID,
}
if err := identifier.Identify(context.TODO(), repo, scene); (err != nil) != tt.wantErr {
if err := identifier.Identify(testCtx, &mocks.TxnManager{}, scene); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr)
}
})
@ -134,7 +138,9 @@ func TestSceneIdentifier_Identify(t *testing.T) {
}
func TestSceneIdentifier_modifyScene(t *testing.T) {
repo := mocks.NewTransactionManager()
repo := models.Repository{
TxnManager: &mocks.TxnManager{},
}
tr := &SceneIdentifier{}
type args struct {
@ -159,7 +165,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tr.modifyScene(context.TODO(), repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
if err := tr.modifyScene(testCtx, repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr)
}
})

View File

@ -1,6 +1,7 @@
package identify
import (
"context"
"database/sql"
"fmt"
"strconv"
@ -10,7 +11,12 @@ import (
"github.com/stashapp/stash/pkg/models"
)
func getPerformerID(endpoint string, r models.Repository, p *models.ScrapedPerformer, createMissing bool) (*int, error) {
type PerformerCreator interface {
Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error)
UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error
}
func getPerformerID(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer, createMissing bool) (*int, error) {
if p.StoredID != nil {
// existing performer, just add it
performerID, err := strconv.Atoi(*p.StoredID)
@ -20,20 +26,20 @@ func getPerformerID(endpoint string, r models.Repository, p *models.ScrapedPerfo
return &performerID, nil
} else if createMissing && p.Name != nil { // name is mandatory
return createMissingPerformer(endpoint, r, p)
return createMissingPerformer(ctx, endpoint, w, p)
}
return nil, nil
}
func createMissingPerformer(endpoint string, r models.Repository, p *models.ScrapedPerformer) (*int, error) {
created, err := r.Performer().Create(scrapedToPerformerInput(p))
func createMissingPerformer(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer) (*int, error) {
created, err := w.Create(ctx, scrapedToPerformerInput(p))
if err != nil {
return nil, fmt.Errorf("error creating performer: %w", err)
}
if endpoint != "" && p.RemoteSiteID != nil {
if err := r.Performer().UpdateStashIDs(created.ID, []models.StashID{
if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{
{
Endpoint: endpoint,
StashID: *p.RemoteSiteID,

View File

@ -23,8 +23,8 @@ func Test_getPerformerID(t *testing.T) {
validStoredID := 1
name := "name"
repo := mocks.NewTransactionManager()
repo.PerformerMock().On("Create", mock.Anything).Return(&models.Performer{
mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Performer{
ID: validStoredID,
}, nil)
@ -110,7 +110,7 @@ func Test_getPerformerID(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getPerformerID(tt.args.endpoint, repo, tt.args.p, tt.args.createMissing)
got, err := getPerformerID(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p, tt.args.createMissing)
if (err != nil) != tt.wantErr {
t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr)
return
@ -131,23 +131,23 @@ func Test_createMissingPerformer(t *testing.T) {
invalidName := "invalidName"
performerID := 1
repo := mocks.NewTransactionManager()
repo.PerformerMock().On("Create", mock.MatchedBy(func(p models.Performer) bool {
mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool {
return p.Name.String == validName
})).Return(&models.Performer{
ID: performerID,
}, nil)
repo.PerformerMock().On("Create", mock.MatchedBy(func(p models.Performer) bool {
mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool {
return p.Name.String == invalidName
})).Return(nil, errors.New("error creating performer"))
repo.PerformerMock().On("UpdateStashIDs", performerID, []models.StashID{
mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{
{
Endpoint: invalidEndpoint,
StashID: remoteSiteID,
},
}).Return(errors.New("error updating stash ids"))
repo.PerformerMock().On("UpdateStashIDs", performerID, []models.StashID{
mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{
{
Endpoint: validEndpoint,
StashID: remoteSiteID,
@ -213,7 +213,7 @@ func Test_createMissingPerformer(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := createMissingPerformer(tt.args.endpoint, repo, tt.args.p)
got, err := createMissingPerformer(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p)
if (err != nil) != tt.wantErr {
t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@ -9,19 +9,35 @@ import (
"time"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/utils"
)
type sceneRelationships struct {
repo models.Repository
scene *models.Scene
result *scrapeResult
fieldOptions map[string]*FieldOptions
type SceneReaderUpdater interface {
GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error)
GetTagIDs(ctx context.Context, sceneID int) ([]int, error)
GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error)
GetCover(ctx context.Context, sceneID int) ([]byte, error)
scene.Updater
}
func (g sceneRelationships) studio() (*int64, error) {
type TagCreator interface {
Create(ctx context.Context, newTag models.Tag) (*models.Tag, error)
}
type sceneRelationships struct {
sceneReader SceneReaderUpdater
studioCreator StudioCreator
performerCreator PerformerCreator
tagCreator TagCreator
scene *models.Scene
result *scrapeResult
fieldOptions map[string]*FieldOptions
}
func (g sceneRelationships) studio(ctx context.Context) (*int64, error) {
existingID := g.scene.StudioID
fieldStrategy := g.fieldOptions["studio"]
createMissing := fieldStrategy != nil && utils.IsTrue(fieldStrategy.CreateMissing)
@ -45,13 +61,13 @@ func (g sceneRelationships) studio() (*int64, error) {
return &studioID, nil
}
} else if createMissing {
return createMissingStudio(endpoint, g.repo, scraped)
return createMissingStudio(ctx, endpoint, g.studioCreator, scraped)
}
return nil, nil
}
func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
func (g sceneRelationships) performers(ctx context.Context, ignoreMale bool) ([]int, error) {
fieldStrategy := g.fieldOptions["performers"]
scraped := g.result.result.Performers
@ -66,11 +82,10 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
strategy = fieldStrategy.Strategy
}
repo := g.repo
endpoint := g.result.source.RemoteSite
var performerIDs []int
originalPerformerIDs, err := repo.Scene().GetPerformerIDs(g.scene.ID)
originalPerformerIDs, err := g.sceneReader.GetPerformerIDs(ctx, g.scene.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene performers: %w", err)
}
@ -85,7 +100,7 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
continue
}
performerID, err := getPerformerID(endpoint, repo, p, createMissing)
performerID, err := getPerformerID(ctx, endpoint, g.performerCreator, p, createMissing)
if err != nil {
return nil, err
}
@ -103,11 +118,10 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
return performerIDs, nil
}
func (g sceneRelationships) tags() ([]int, error) {
func (g sceneRelationships) tags(ctx context.Context) ([]int, error) {
fieldStrategy := g.fieldOptions["tags"]
scraped := g.result.result.Tags
target := g.scene
r := g.repo
// just check if ignored
if len(scraped) == 0 || !shouldSetSingleValueField(fieldStrategy, false) {
@ -121,7 +135,7 @@ func (g sceneRelationships) tags() ([]int, error) {
}
var tagIDs []int
originalTagIDs, err := r.Scene().GetTagIDs(target.ID)
originalTagIDs, err := g.sceneReader.GetTagIDs(ctx, target.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene tags: %w", err)
}
@ -142,7 +156,7 @@ func (g sceneRelationships) tags() ([]int, error) {
tagIDs = intslice.IntAppendUnique(tagIDs, int(tagID))
} else if createMissing {
now := time.Now()
created, err := r.Tag().Create(models.Tag{
created, err := g.tagCreator.Create(ctx, models.Tag{
Name: t.Name,
CreatedAt: models.SQLiteTimestamp{Timestamp: now},
UpdatedAt: models.SQLiteTimestamp{Timestamp: now},
@ -163,11 +177,10 @@ func (g sceneRelationships) tags() ([]int, error) {
return tagIDs, nil
}
func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
func (g sceneRelationships) stashIDs(ctx context.Context) ([]models.StashID, error) {
remoteSiteID := g.result.result.RemoteSiteID
fieldStrategy := g.fieldOptions["stash_ids"]
target := g.scene
r := g.repo
endpoint := g.result.source.RemoteSite
@ -183,7 +196,7 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
var originalStashIDs []models.StashID
var stashIDs []models.StashID
stashIDPtrs, err := r.Scene().GetStashIDs(target.ID)
stashIDPtrs, err := g.sceneReader.GetStashIDs(ctx, target.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene tag: %w", err)
}
@ -227,14 +240,13 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
func (g sceneRelationships) cover(ctx context.Context) ([]byte, error) {
scraped := g.result.result.Image
r := g.repo
if scraped == nil {
return nil, nil
}
// always overwrite if present
existingCover, err := r.Scene().GetCover(g.scene.ID)
existingCover, err := g.sceneReader.GetCover(ctx, g.scene.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene cover: %w", err)
}

View File

@ -24,14 +24,14 @@ func Test_sceneRelationships_studio(t *testing.T) {
Strategy: FieldStrategyMerge,
}
repo := mocks.NewTransactionManager()
repo.StudioMock().On("Create", mock.Anything).Return(&models.Studio{
mockStudioReaderWriter := &mocks.StudioReaderWriter{}
mockStudioReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Studio{
ID: int(validStoredIDInt),
}, nil)
tr := sceneRelationships{
repo: repo,
fieldOptions: make(map[string]*FieldOptions),
studioCreator: mockStudioReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
tests := []struct {
@ -124,7 +124,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
},
}
got, err := tr.studio()
got, err := tr.studio(testCtx)
if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.studio() error = %v, wantErr %v", err, tt.wantErr)
return
@ -156,13 +156,13 @@ func Test_sceneRelationships_performers(t *testing.T) {
Strategy: FieldStrategyMerge,
}
repo := mocks.NewTransactionManager()
repo.SceneMock().On("GetPerformerIDs", sceneID).Return(nil, nil)
repo.SceneMock().On("GetPerformerIDs", sceneWithPerformerID).Return([]int{existingPerformerID}, nil)
repo.SceneMock().On("GetPerformerIDs", errSceneID).Return(nil, errors.New("error getting IDs"))
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockSceneReaderWriter.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil)
mockSceneReaderWriter.On("GetPerformerIDs", testCtx, sceneWithPerformerID).Return([]int{existingPerformerID}, nil)
mockSceneReaderWriter.On("GetPerformerIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
tr := sceneRelationships{
repo: repo,
sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
@ -316,7 +316,7 @@ func Test_sceneRelationships_performers(t *testing.T) {
},
}
got, err := tr.performers(tt.ignoreMale)
got, err := tr.performers(testCtx, tt.ignoreMale)
if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.performers() error = %v, wantErr %v", err, tt.wantErr)
return
@ -347,22 +347,24 @@ func Test_sceneRelationships_tags(t *testing.T) {
Strategy: FieldStrategyMerge,
}
repo := mocks.NewTransactionManager()
repo.SceneMock().On("GetTagIDs", sceneID).Return(nil, nil)
repo.SceneMock().On("GetTagIDs", sceneWithTagID).Return([]int{existingID}, nil)
repo.SceneMock().On("GetTagIDs", errSceneID).Return(nil, errors.New("error getting IDs"))
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockTagReaderWriter := &mocks.TagReaderWriter{}
mockSceneReaderWriter.On("GetTagIDs", testCtx, sceneID).Return(nil, nil)
mockSceneReaderWriter.On("GetTagIDs", testCtx, sceneWithTagID).Return([]int{existingID}, nil)
mockSceneReaderWriter.On("GetTagIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
repo.TagMock().On("Create", mock.MatchedBy(func(p models.Tag) bool {
mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Tag) bool {
return p.Name == validName
})).Return(&models.Tag{
ID: validStoredIDInt,
}, nil)
repo.TagMock().On("Create", mock.MatchedBy(func(p models.Tag) bool {
mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Tag) bool {
return p.Name == invalidName
})).Return(nil, errors.New("error creating tag"))
tr := sceneRelationships{
repo: repo,
sceneReader: mockSceneReaderWriter,
tagCreator: mockTagReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
@ -505,7 +507,7 @@ func Test_sceneRelationships_tags(t *testing.T) {
},
}
got, err := tr.tags()
got, err := tr.tags(testCtx)
if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.tags() error = %v, wantErr %v", err, tt.wantErr)
return
@ -534,18 +536,18 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
Strategy: FieldStrategyMerge,
}
repo := mocks.NewTransactionManager()
repo.SceneMock().On("GetStashIDs", sceneID).Return(nil, nil)
repo.SceneMock().On("GetStashIDs", sceneWithStashID).Return([]*models.StashID{
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockSceneReaderWriter.On("GetStashIDs", testCtx, sceneID).Return(nil, nil)
mockSceneReaderWriter.On("GetStashIDs", testCtx, sceneWithStashID).Return([]*models.StashID{
{
StashID: remoteSiteID,
Endpoint: existingEndpoint,
},
}, nil)
repo.SceneMock().On("GetStashIDs", errSceneID).Return(nil, errors.New("error getting IDs"))
mockSceneReaderWriter.On("GetStashIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
tr := sceneRelationships{
repo: repo,
sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
@ -680,7 +682,7 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
},
}
got, err := tr.stashIDs()
got, err := tr.stashIDs(testCtx)
if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.stashIDs() error = %v, wantErr %v", err, tt.wantErr)
return
@ -707,12 +709,12 @@ func Test_sceneRelationships_cover(t *testing.T) {
newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData)
invalidData := newDataEncoded + "!!!"
repo := mocks.NewTransactionManager()
repo.SceneMock().On("GetCover", sceneID).Return(existingData, nil)
repo.SceneMock().On("GetCover", errSceneID).Return(nil, errors.New("error getting cover"))
mockSceneReaderWriter := &mocks.SceneReaderWriter{}
mockSceneReaderWriter.On("GetCover", testCtx, sceneID).Return(existingData, nil)
mockSceneReaderWriter.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover"))
tr := sceneRelationships{
repo: repo,
sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}

View File

@ -1,6 +1,7 @@
package identify
import (
"context"
"database/sql"
"fmt"
"time"
@ -9,14 +10,19 @@ import (
"github.com/stashapp/stash/pkg/models"
)
func createMissingStudio(endpoint string, repo models.Repository, studio *models.ScrapedStudio) (*int64, error) {
created, err := repo.Studio().Create(scrapedToStudioInput(studio))
type StudioCreator interface {
Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error)
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
}
func createMissingStudio(ctx context.Context, endpoint string, w StudioCreator, studio *models.ScrapedStudio) (*int64, error) {
created, err := w.Create(ctx, scrapedToStudioInput(studio))
if err != nil {
return nil, fmt.Errorf("error creating studio: %w", err)
}
if endpoint != "" && studio.RemoteSiteID != nil {
if err := repo.Studio().UpdateStashIDs(created.ID, []models.StashID{
if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{
{
Endpoint: endpoint,
StashID: *studio.RemoteSiteID,

View File

@ -20,23 +20,24 @@ func Test_createMissingStudio(t *testing.T) {
createdID := 1
createdID64 := int64(createdID)
repo := mocks.NewTransactionManager()
repo.StudioMock().On("Create", mock.MatchedBy(func(p models.Studio) bool {
repo := mocks.NewTxnRepository()
mockStudioReaderWriter := repo.Studio.(*mocks.StudioReaderWriter)
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Studio) bool {
return p.Name.String == validName
})).Return(&models.Studio{
ID: createdID,
}, nil)
repo.StudioMock().On("Create", mock.MatchedBy(func(p models.Studio) bool {
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Studio) bool {
return p.Name.String == invalidName
})).Return(nil, errors.New("error creating performer"))
repo.StudioMock().On("UpdateStashIDs", createdID, []models.StashID{
mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{
{
Endpoint: invalidEndpoint,
StashID: remoteSiteID,
},
}).Return(errors.New("error updating stash ids"))
repo.StudioMock().On("UpdateStashIDs", createdID, []models.StashID{
mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{
{
Endpoint: validEndpoint,
StashID: remoteSiteID,
@ -102,7 +103,7 @@ func Test_createMissingStudio(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := createMissingStudio(tt.args.endpoint, repo, tt.args.studio)
got, err := createMissingStudio(testCtx, tt.args.endpoint, mockStudioReaderWriter, tt.args.studio)
if (err != nil) != tt.wantErr {
t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@ -7,16 +7,21 @@ import (
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManager) {
type SceneCounter interface {
Count(ctx context.Context) (int, error)
}
func setInitialMD5Config(ctx context.Context, txnManager txn.Manager, counter SceneCounter) {
// if there are no scene files in the database, then default the
// VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to
// false, otherwise set them to true for backwards compatibility purposes
var count int
if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
var err error
count, err = r.Scene().Count()
count, err = counter.Count(ctx)
return err
}); err != nil {
logger.Errorf("Error while counting scenes: %s", err.Error())
@ -36,6 +41,11 @@ func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManag
}
}
type SceneMissingHashCounter interface {
CountMissingChecksum(ctx context.Context) (int, error)
CountMissingOSHash(ctx context.Context) (int, error)
}
// ValidateVideoFileNamingAlgorithm validates changing the
// VideoFileNamingAlgorithm configuration flag.
//
@ -44,30 +54,27 @@ func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManag
//
// Likewise, if VideoFileNamingAlgorithm is set to oshash, then this function
// will ensure that all oshash values are set on all scenes.
func ValidateVideoFileNamingAlgorithm(txnManager models.TransactionManager, newValue models.HashAlgorithm) error {
func ValidateVideoFileNamingAlgorithm(ctx context.Context, qb SceneMissingHashCounter, newValue models.HashAlgorithm) error {
// if algorithm is being set to MD5, then all checksums must be present
return txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
qb := r.Scene()
if newValue == models.HashAlgorithmMd5 {
missingMD5, err := qb.CountMissingChecksum()
if err != nil {
return err
}
if missingMD5 > 0 {
return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true")
}
} else if newValue == models.HashAlgorithmOshash {
missingOSHash, err := qb.CountMissingOSHash()
if err != nil {
return err
}
if missingOSHash > 0 {
return errors.New("some oshash values are missing on scenes. Run Scan to populate")
}
if newValue == models.HashAlgorithmMd5 {
missingMD5, err := qb.CountMissingChecksum(ctx)
if err != nil {
return err
}
return nil
})
if missingMD5 > 0 {
return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true")
}
} else if newValue == models.HashAlgorithmOshash {
missingOSHash, err := qb.CountMissingOSHash(ctx)
if err != nil {
return err
}
if missingOSHash > 0 {
return errors.New("some oshash values are missing on scenes. Run Scan to populate")
}
}
return nil
}

View File

@ -1,6 +1,7 @@
package manager
import (
"context"
"database/sql"
"errors"
"path/filepath"
@ -470,7 +471,15 @@ func (p *SceneFilenameParser) initWhiteSpaceRegex() {
}
}
func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParserResult, int, error) {
type SceneFilenameParserRepository struct {
Scene scene.Queryer
Performer PerformerNamesFinder
Studio studio.Queryer
Movie MovieNameFinder
Tag tag.Queryer
}
func (p *SceneFilenameParser) Parse(ctx context.Context, repo SceneFilenameParserRepository) ([]*SceneParserResult, int, error) {
// perform the query to find the scenes
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
@ -492,17 +501,17 @@ func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParse
p.Filter.Q = nil
scenes, total, err := scene.QueryWithCount(repo.Scene(), sceneFilter, p.Filter)
scenes, total, err := scene.QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter)
if err != nil {
return nil, 0, err
}
ret := p.parseScenes(repo, scenes, mapper)
ret := p.parseScenes(ctx, repo, scenes, mapper)
return ret, total, nil
}
func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult {
func (p *SceneFilenameParser) parseScenes(ctx context.Context, repo SceneFilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult {
var ret []*SceneParserResult
for _, scene := range scenes {
sceneHolder := mapper.parse(scene)
@ -511,7 +520,7 @@ func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes [
r := &SceneParserResult{
Scene: scene,
}
p.setParserResult(repo, *sceneHolder, r)
p.setParserResult(ctx, repo, *sceneHolder, r)
ret = append(ret, r)
}
@ -530,7 +539,11 @@ func (p SceneFilenameParser) replaceWhitespaceCharacters(value string) string {
return value
}
func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performerName string) *models.Performer {
type PerformerNamesFinder interface {
FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error)
}
func (p *SceneFilenameParser) queryPerformer(ctx context.Context, qb PerformerNamesFinder, performerName string) *models.Performer {
// massage the performer name
performerName = delimiterRE.ReplaceAllString(performerName, " ")
@ -540,7 +553,7 @@ func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performe
}
// perform an exact match and grab the first
performers, _ := qb.FindByNames([]string{performerName}, true)
performers, _ := qb.FindByNames(ctx, []string{performerName}, true)
var ret *models.Performer
if len(performers) > 0 {
@ -553,7 +566,7 @@ func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performe
return ret
}
func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName string) *models.Studio {
func (p *SceneFilenameParser) queryStudio(ctx context.Context, qb studio.Queryer, studioName string) *models.Studio {
// massage the performer name
studioName = delimiterRE.ReplaceAllString(studioName, " ")
@ -562,11 +575,11 @@ func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName str
return ret
}
ret, _ := studio.ByName(qb, studioName)
ret, _ := studio.ByName(ctx, qb, studioName)
// try to match on alias
if ret == nil {
ret, _ = studio.ByAlias(qb, studioName)
ret, _ = studio.ByAlias(ctx, qb, studioName)
}
// add result to cache
@ -575,7 +588,11 @@ func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName str
return ret
}
func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string) *models.Movie {
type MovieNameFinder interface {
FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error)
}
func (p *SceneFilenameParser) queryMovie(ctx context.Context, qb MovieNameFinder, movieName string) *models.Movie {
// massage the movie name
movieName = delimiterRE.ReplaceAllString(movieName, " ")
@ -584,7 +601,7 @@ func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string
return ret
}
ret, _ := qb.FindByName(movieName, true)
ret, _ := qb.FindByName(ctx, movieName, true)
// add result to cache
p.movieCache[movieName] = ret
@ -592,7 +609,7 @@ func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string
return ret
}
func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *models.Tag {
func (p *SceneFilenameParser) queryTag(ctx context.Context, qb tag.Queryer, tagName string) *models.Tag {
// massage the tag name
tagName = delimiterRE.ReplaceAllString(tagName, " ")
@ -602,11 +619,11 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod
}
// match tag name exactly
ret, _ := tag.ByName(qb, tagName)
ret, _ := tag.ByName(ctx, qb, tagName)
// try to match on alias
if ret == nil {
ret, _ = tag.ByAlias(qb, tagName)
ret, _ = tag.ByAlias(ctx, qb, tagName)
}
// add result to cache
@ -615,12 +632,12 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod
return ret
}
func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHolder, result *SceneParserResult) {
func (p *SceneFilenameParser) setPerformers(ctx context.Context, qb PerformerNamesFinder, h sceneHolder, result *SceneParserResult) {
// query for each performer
performersSet := make(map[int]bool)
for _, performerName := range h.performers {
if performerName != "" {
performer := p.queryPerformer(qb, performerName)
performer := p.queryPerformer(ctx, qb, performerName)
if performer != nil {
if _, found := performersSet[performer.ID]; !found {
result.PerformerIds = append(result.PerformerIds, strconv.Itoa(performer.ID))
@ -631,12 +648,12 @@ func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHo
}
}
func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result *SceneParserResult) {
func (p *SceneFilenameParser) setTags(ctx context.Context, qb tag.Queryer, h sceneHolder, result *SceneParserResult) {
// query for each performer
tagsSet := make(map[int]bool)
for _, tagName := range h.tags {
if tagName != "" {
tag := p.queryTag(qb, tagName)
tag := p.queryTag(ctx, qb, tagName)
if tag != nil {
if _, found := tagsSet[tag.ID]; !found {
result.TagIds = append(result.TagIds, strconv.Itoa(tag.ID))
@ -647,10 +664,10 @@ func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result
}
}
func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, result *SceneParserResult) {
func (p *SceneFilenameParser) setStudio(ctx context.Context, qb studio.Queryer, h sceneHolder, result *SceneParserResult) {
// query for each performer
if h.studio != "" {
studio := p.queryStudio(qb, h.studio)
studio := p.queryStudio(ctx, qb, h.studio)
if studio != nil {
studioID := strconv.Itoa(studio.ID)
result.StudioID = &studioID
@ -658,12 +675,12 @@ func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, r
}
}
func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, result *SceneParserResult) {
func (p *SceneFilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sceneHolder, result *SceneParserResult) {
// query for each movie
moviesSet := make(map[int]bool)
for _, movieName := range h.movies {
if movieName != "" {
movie := p.queryMovie(qb, movieName)
movie := p.queryMovie(ctx, qb, movieName)
if movie != nil {
if _, found := moviesSet[movie.ID]; !found {
result.Movies = append(result.Movies, &SceneMovieID{
@ -676,7 +693,7 @@ func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, re
}
}
func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sceneHolder, result *SceneParserResult) {
func (p *SceneFilenameParser) setParserResult(ctx context.Context, repo SceneFilenameParserRepository, h sceneHolder, result *SceneParserResult) {
if h.result.Title.Valid {
title := h.result.Title.String
title = p.replaceWhitespaceCharacters(title)
@ -698,15 +715,15 @@ func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sc
}
if len(h.performers) > 0 {
p.setPerformers(repo.Performer(), h, result)
p.setPerformers(ctx, repo.Performer, h, result)
}
if len(h.tags) > 0 {
p.setTags(repo.Tag(), h, result)
p.setTags(ctx, repo.Tag, h, result)
}
p.setStudio(repo.Studio(), h, result)
p.setStudio(ctx, repo.Studio, h, result)
if len(h.movies) > 0 {
p.setMovies(repo.Movie(), h, result)
p.setMovies(ctx, repo.Movie, h, result)
}
}

View File

@ -1,6 +1,7 @@
package manager
import (
"context"
"fmt"
"io"
"strconv"
@ -52,22 +53,22 @@ func (e ImportDuplicateEnum) MarshalGQL(w io.Writer) {
}
type importer interface {
PreImport() error
PostImport(id int) error
PreImport(ctx context.Context) error
PostImport(ctx context.Context, id int) error
Name() string
FindExistingID() (*int, error)
Create() (*int, error)
Update(id int) error
FindExistingID(ctx context.Context) (*int, error)
Create(ctx context.Context) (*int, error)
Update(ctx context.Context, id int) error
}
func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
if err := i.PreImport(); err != nil {
func performImport(ctx context.Context, i importer, duplicateBehaviour ImportDuplicateEnum) error {
if err := i.PreImport(ctx); err != nil {
return err
}
// try to find an existing object with the same name
name := i.Name()
existing, err := i.FindExistingID()
existing, err := i.FindExistingID(ctx)
if err != nil {
return fmt.Errorf("error finding existing objects: %v", err)
}
@ -84,12 +85,12 @@ func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
// must be overwriting
id = *existing
if err := i.Update(id); err != nil {
if err := i.Update(ctx, id); err != nil {
return fmt.Errorf("error updating existing object: %v", err)
}
} else {
// creating
createdID, err := i.Create()
createdID, err := i.Create(ctx)
if err != nil {
return fmt.Errorf("error creating object: %v", err)
}
@ -97,7 +98,7 @@ func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
id = *createdID
}
if err := i.PostImport(id); err != nil {
if err := i.PostImport(ctx, id); err != nil {
return err
}

View File

@ -17,7 +17,6 @@ import (
"github.com/stashapp/stash/internal/dlna"
"github.com/stashapp/stash/internal/log"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/ffmpeg"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/job"
@ -115,7 +114,8 @@ type Manager struct {
DLNAService *dlna.Service
TxnManager models.TransactionManager
Database *sqlite.Database
Repository models.Repository
scanSubs *subscriptionManager
}
@ -150,6 +150,8 @@ func initialize() error {
l := initLog()
initProfiling(cfg.GetCPUProfilePath())
db := &sqlite.Database{}
instance = &Manager{
Config: cfg,
Logger: l,
@ -157,7 +159,20 @@ func initialize() error {
DownloadStore: NewDownloadStore(),
PluginCache: plugin.NewCache(cfg),
TxnManager: sqlite.NewTransactionManager(),
Database: db,
Repository: models.Repository{
TxnManager: db,
Gallery: sqlite.GalleryReaderWriter,
Image: sqlite.ImageReaderWriter,
Movie: sqlite.MovieReaderWriter,
Performer: sqlite.PerformerReaderWriter,
Scene: sqlite.SceneReaderWriter,
SceneMarker: sqlite.SceneMarkerReaderWriter,
ScrapedItem: sqlite.ScrapedItemReaderWriter,
Studio: sqlite.StudioReaderWriter,
Tag: sqlite.TagReaderWriter,
SavedFilter: sqlite.SavedFilterReaderWriter,
},
scanSubs: &subscriptionManager{},
}
@ -165,9 +180,17 @@ func initialize() error {
instance.JobManager = initJobManager()
sceneServer := SceneServer{
TXNManager: instance.TxnManager,
TxnManager: instance.Repository,
SceneCoverGetter: instance.Repository.Scene,
}
instance.DLNAService = dlna.NewService(instance.TxnManager, instance.Config, &sceneServer)
instance.DLNAService = dlna.NewService(instance.Repository, dlna.Repository{
SceneFinder: instance.Repository.Scene,
StudioFinder: instance.Repository.Studio,
TagFinder: instance.Repository.Tag,
PerformerFinder: instance.Repository.Performer,
MovieFinder: instance.Repository.Movie,
}, instance.Config, &sceneServer)
if !cfg.IsNewSystem() {
logger.Infof("using config file: %s", cfg.GetConfigFile())
@ -177,9 +200,14 @@ func initialize() error {
}
if err != nil {
return fmt.Errorf("error initializing configuration: %w", err)
panic(fmt.Sprintf("error initializing configuration: %s", err.Error()))
} else if err := instance.PostInit(ctx); err != nil {
return err
var migrationNeededErr *sqlite.MigrationNeededError
if errors.As(err, &migrationNeededErr) {
logger.Warn(err.Error())
} else {
panic(err)
}
}
initSecurity(cfg)
@ -352,7 +380,8 @@ func (s *Manager) PostInit(ctx context.Context) error {
})
}
if err := database.Initialize(s.Config.GetDatabasePath()); err != nil {
database := s.Database
if err := database.Open(s.Config.GetDatabasePath()); err != nil {
return err
}
@ -377,7 +406,14 @@ func writeStashIcon() {
// initScraperCache initializes a new scraper cache and returns it.
func (s *Manager) initScraperCache() *scraper.Cache {
ret, err := scraper.NewCache(config.GetInstance(), s.TxnManager)
ret, err := scraper.NewCache(config.GetInstance(), s.Repository, scraper.Repository{
SceneFinder: s.Repository.Scene,
GalleryFinder: s.Repository.Gallery,
TagFinder: s.Repository.Tag,
PerformerFinder: s.Repository.Performer,
MovieFinder: s.Repository.Movie,
StudioFinder: s.Repository.Studio,
})
if err != nil {
logger.Errorf("Error reading scraper configs: %s", err.Error())
@ -476,7 +512,12 @@ func (s *Manager) Setup(ctx context.Context, input SetupInput) error {
// initialise the database
if err := s.PostInit(ctx); err != nil {
return fmt.Errorf("error initializing the database: %v", err)
var migrationNeededErr *sqlite.MigrationNeededError
if errors.As(err, &migrationNeededErr) {
logger.Warn(err.Error())
} else {
return fmt.Errorf("error initializing the database: %v", err)
}
}
s.Config.FinalizeSetup()
@ -501,6 +542,8 @@ type MigrateInput struct {
}
func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
database := s.Database
// always backup so that we can roll back to the previous version if
// migration fails
backupPath := input.BackupPath
@ -509,7 +552,7 @@ func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
}
// perform database backup
if err := database.Backup(database.DB, backupPath); err != nil {
if err := database.Backup(backupPath); err != nil {
return fmt.Errorf("error backing up database: %s", err)
}
@ -541,6 +584,7 @@ func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
}
func (s *Manager) GetSystemStatus() *SystemStatus {
database := s.Database
status := SystemStatusEnumOk
dbSchema := int(database.Version())
dbPath := database.DatabasePath()
@ -569,7 +613,7 @@ func (s *Manager) Shutdown(code int) {
// TODO: Each part of the manager needs to gracefully stop at some point
// for now, we just close the database.
err := database.Close()
err := s.Database.Close()
if err != nil {
logger.Errorf("Error closing database: %s", err)
if code == 0 {

View File

@ -84,7 +84,7 @@ func (s *Manager) Scan(ctx context.Context, input ScanMetadataInput) (int, error
}
scanJob := ScanJob{
txnManager: s.TxnManager,
txnManager: s.Repository,
input: input,
subscriptions: s.scanSubs,
}
@ -101,7 +101,7 @@ func (s *Manager) Import(ctx context.Context) (int, error) {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
task := ImportTask{
txnManager: s.TxnManager,
txnManager: s.Repository,
BaseDir: metadataPath,
Reset: true,
DuplicateBehaviour: ImportDuplicateEnumFail,
@ -125,7 +125,7 @@ func (s *Manager) Export(ctx context.Context) (int, error) {
var wg sync.WaitGroup
wg.Add(1)
task := ExportTask{
txnManager: s.TxnManager,
txnManager: s.Repository,
full: true,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
}
@ -156,7 +156,7 @@ func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (in
}
j := &GenerateJob{
txnManager: s.TxnManager,
txnManager: s.Repository,
input: input,
}
@ -185,9 +185,9 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
}
var scene *models.Scene
if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
var err error
scene, err = r.Scene().Find(sceneIdInt)
scene, err = s.Repository.Scene.Find(ctx, sceneIdInt)
return err
}); err != nil || scene == nil {
logger.Errorf("failed to get scene for generate: %s", err.Error())
@ -195,7 +195,7 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
}
task := GenerateScreenshotTask{
txnManager: s.TxnManager,
txnManager: s.Repository,
Scene: *scene,
ScreenshotAt: at,
fileNamingAlgorithm: config.GetInstance().GetVideoFileNamingAlgorithm(),
@ -222,7 +222,7 @@ type AutoTagMetadataInput struct {
func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int {
j := autoTagJob{
txnManager: s.TxnManager,
txnManager: s.Repository,
input: input,
}
@ -237,7 +237,7 @@ type CleanMetadataInput struct {
func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int {
j := cleanJob{
txnManager: s.TxnManager,
txnManager: s.Repository,
input: input,
scanSubs: s.scanSubs,
}
@ -251,9 +251,9 @@ func (s *Manager) MigrateHash(ctx context.Context) int {
logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String())
var scenes []*models.Scene
if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
var err error
scenes, err = r.Scene().All()
scenes, err = s.Repository.Scene.All(ctx)
return err
}); err != nil {
logger.Errorf("failed to fetch list of scenes for migration: %s", err.Error())
@ -327,15 +327,14 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// This is why we mark this section nolint. In principle, we should look to
// rewrite the section at some point, to avoid the linter warning.
if len(input.PerformerIds) > 0 { //nolint:gocritic
if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
performerQuery := r.Performer()
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := s.Repository.Performer
for _, performerID := range input.PerformerIds {
if id, err := strconv.Atoi(performerID); err == nil {
performer, err := performerQuery.Find(id)
performer, err := performerQuery.Find(ctx, id)
if err == nil {
tasks = append(tasks, StashBoxPerformerTagTask{
txnManager: s.TxnManager,
performer: performer,
refresh: input.Refresh,
box: box,
@ -354,7 +353,6 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
for i := range input.PerformerNames {
if len(input.PerformerNames[i]) > 0 {
tasks = append(tasks, StashBoxPerformerTagTask{
txnManager: s.TxnManager,
name: &input.PerformerNames[i],
refresh: input.Refresh,
box: box,
@ -367,14 +365,14 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// However, this doesn't really help with readability of the current section. Mark it
// as nolint for now. In the future we'd like to rewrite this code by factoring some of
// this into separate functions.
if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
performerQuery := r.Performer()
if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := s.Repository.Performer
var performers []*models.Performer
var err error
if input.Refresh {
performers, err = performerQuery.FindByStashIDStatus(true, box.Endpoint)
performers, err = performerQuery.FindByStashIDStatus(ctx, true, box.Endpoint)
} else {
performers, err = performerQuery.FindByStashIDStatus(false, box.Endpoint)
performers, err = performerQuery.FindByStashIDStatus(ctx, false, box.Endpoint)
}
if err != nil {
return fmt.Errorf("error querying performers: %v", err)
@ -382,7 +380,6 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
for _, performer := range performers {
tasks = append(tasks, StashBoxPerformerTagTask{
txnManager: s.TxnManager,
performer: performer,
refresh: input.Refresh,
box: box,

View File

@ -4,5 +4,5 @@ import "context"
// PostMigrate is executed after migrations have been executed.
func (s *Manager) PostMigrate(ctx context.Context) {
setInitialMD5Config(ctx, s.TxnManager)
setInitialMD5Config(ctx, s.Repository, s.Repository.Scene)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@ -49,8 +50,13 @@ func KillRunningStreams(scene *models.Scene, fileNamingAlgo models.HashAlgorithm
instance.ReadLockManager.Cancel(transcodePath)
}
type SceneCoverGetter interface {
GetCover(ctx context.Context, sceneID int) ([]byte, error)
}
type SceneServer struct {
TXNManager models.TransactionManager
TxnManager txn.Manager
SceneCoverGetter SceneCoverGetter
}
func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWriter, r *http.Request) {
@ -75,8 +81,8 @@ func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter
http.ServeFile(w, r, filepath)
} else {
var cover []byte
err := s.TXNManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
cover, _ = repo.Scene().GetCover(scene.ID)
err := txn.WithTxn(r.Context(), s.TxnManager, func(ctx context.Context) error {
cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID)
return nil
})
if err != nil {

View File

@ -1,13 +1,15 @@
package manager
import (
"context"
"errors"
"fmt"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/studio"
)
func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) error {
func ValidateModifyStudio(ctx context.Context, studio models.StudioPartial, qb studio.Finder) error {
if studio.ParentID == nil || !studio.ParentID.Valid {
return nil
}
@ -22,7 +24,7 @@ func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) e
return errors.New("studio cannot be an ancestor of itself")
}
currentStudio, err := qb.Find(int(currentParentID.Int64))
currentStudio, err := qb.Find(ctx, int(currentParentID.Int64))
if err != nil {
return fmt.Errorf("error finding parent studio: %v", err)
}

View File

@ -19,7 +19,7 @@ import (
)
type autoTagJob struct {
txnManager models.TransactionManager
txnManager models.Repository
input AutoTagMetadataInput
cache match.Cache
@ -73,27 +73,28 @@ func (j *autoTagJob) autoTagSpecific(ctx context.Context, progress *job.Progress
studioCount := len(studioIds)
tagCount := len(tagIds)
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
performerQuery := r.Performer()
studioQuery := r.Studio()
tagQuery := r.Tag()
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := j.txnManager
performerQuery := r.Performer
studioQuery := r.Studio
tagQuery := r.Tag
const wildcard = "*"
var err error
if performerCount == 1 && performerIds[0] == wildcard {
performerCount, err = performerQuery.Count()
performerCount, err = performerQuery.Count(ctx)
if err != nil {
return fmt.Errorf("error getting performer count: %v", err)
}
}
if studioCount == 1 && studioIds[0] == wildcard {
studioCount, err = studioQuery.Count()
studioCount, err = studioQuery.Count(ctx)
if err != nil {
return fmt.Errorf("error getting studio count: %v", err)
}
}
if tagCount == 1 && tagIds[0] == wildcard {
tagCount, err = tagQuery.Count()
tagCount, err = tagQuery.Count(ctx)
if err != nil {
return fmt.Errorf("error getting tag count: %v", err)
}
@ -123,14 +124,14 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
for _, performerId := range performerIds {
var performers []*models.Performer
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
performerQuery := r.Performer()
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := j.txnManager.Performer
ignoreAutoTag := false
perPage := -1
if performerId == "*" {
var err error
performers, _, err = performerQuery.Query(&models.PerformerFilterType{
performers, _, err = performerQuery.Query(ctx, &models.PerformerFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{
PerPage: &perPage,
@ -144,7 +145,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error())
}
performer, err := performerQuery.Find(performerIdInt)
performer, err := performerQuery.Find(ctx, performerIdInt)
if err != nil {
return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error())
}
@ -161,14 +162,15 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return nil
}
if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error {
if err := autotag.PerformerScenes(performer, paths, r.Scene(), &j.cache); err != nil {
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := j.txnManager
if err := autotag.PerformerScenes(ctx, performer, paths, r.Scene, &j.cache); err != nil {
return err
}
if err := autotag.PerformerImages(performer, paths, r.Image(), &j.cache); err != nil {
if err := autotag.PerformerImages(ctx, performer, paths, r.Image, &j.cache); err != nil {
return err
}
if err := autotag.PerformerGalleries(performer, paths, r.Gallery(), &j.cache); err != nil {
if err := autotag.PerformerGalleries(ctx, performer, paths, r.Gallery, &j.cache); err != nil {
return err
}
@ -193,16 +195,18 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return
}
r := j.txnManager
for _, studioId := range studioIds {
var studios []*models.Studio
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
studioQuery := r.Studio()
if err := r.WithTxn(ctx, func(ctx context.Context) error {
studioQuery := r.Studio
ignoreAutoTag := false
perPage := -1
if studioId == "*" {
var err error
studios, _, err = studioQuery.Query(&models.StudioFilterType{
studios, _, err = studioQuery.Query(ctx, &models.StudioFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{
PerPage: &perPage,
@ -216,7 +220,7 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error())
}
studio, err := studioQuery.Find(studioIdInt)
studio, err := studioQuery.Find(ctx, studioIdInt)
if err != nil {
return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error())
}
@ -234,19 +238,19 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return nil
}
if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error {
aliases, err := r.Studio().GetAliases(studio.ID)
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
aliases, err := r.Studio.GetAliases(ctx, studio.ID)
if err != nil {
return err
}
if err := autotag.StudioScenes(studio, paths, aliases, r.Scene(), &j.cache); err != nil {
if err := autotag.StudioScenes(ctx, studio, paths, aliases, r.Scene, &j.cache); err != nil {
return err
}
if err := autotag.StudioImages(studio, paths, aliases, r.Image(), &j.cache); err != nil {
if err := autotag.StudioImages(ctx, studio, paths, aliases, r.Image, &j.cache); err != nil {
return err
}
if err := autotag.StudioGalleries(studio, paths, aliases, r.Gallery(), &j.cache); err != nil {
if err := autotag.StudioGalleries(ctx, studio, paths, aliases, r.Gallery, &j.cache); err != nil {
return err
}
@ -271,15 +275,17 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return
}
r := j.txnManager
for _, tagId := range tagIds {
var tags []*models.Tag
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
tagQuery := r.Tag()
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
tagQuery := r.Tag
ignoreAutoTag := false
perPage := -1
if tagId == "*" {
var err error
tags, _, err = tagQuery.Query(&models.TagFilterType{
tags, _, err = tagQuery.Query(ctx, &models.TagFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{
PerPage: &perPage,
@ -293,7 +299,7 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error())
}
tag, err := tagQuery.Find(tagIdInt)
tag, err := tagQuery.Find(ctx, tagIdInt)
if err != nil {
return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error())
}
@ -306,19 +312,19 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return nil
}
if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error {
aliases, err := r.Tag().GetAliases(tag.ID)
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
aliases, err := r.Tag.GetAliases(ctx, tag.ID)
if err != nil {
return err
}
if err := autotag.TagScenes(tag, paths, aliases, r.Scene(), &j.cache); err != nil {
if err := autotag.TagScenes(ctx, tag, paths, aliases, r.Scene, &j.cache); err != nil {
return err
}
if err := autotag.TagImages(tag, paths, aliases, r.Image(), &j.cache); err != nil {
if err := autotag.TagImages(ctx, tag, paths, aliases, r.Image, &j.cache); err != nil {
return err
}
if err := autotag.TagGalleries(tag, paths, aliases, r.Gallery(), &j.cache); err != nil {
if err := autotag.TagGalleries(ctx, tag, paths, aliases, r.Gallery, &j.cache); err != nil {
return err
}
@ -345,7 +351,7 @@ type autoTagFilesTask struct {
tags bool
progress *job.Progress
txnManager models.TransactionManager
txnManager models.Repository
cache *match.Cache
}
@ -425,13 +431,13 @@ func (t *autoTagFilesTask) makeGalleryFilter() *models.GalleryFilterType {
return ret
}
func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
func (t *autoTagFilesTask) getCount(ctx context.Context, r models.Repository) (int, error) {
pp := 0
findFilter := &models.FindFilterType{
PerPage: &pp,
}
sceneResults, err := r.Scene().Query(models.SceneQueryOptions{
sceneResults, err := r.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
Count: true,
@ -444,7 +450,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
sceneCount := sceneResults.Count
imageResults, err := r.Image().Query(models.ImageQueryOptions{
imageResults, err := r.Image.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
Count: true,
@ -457,7 +463,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
imageCount := imageResults.Count
_, galleryCount, err := r.Gallery().Query(t.makeGalleryFilter(), findFilter)
_, galleryCount, err := r.Gallery.Query(ctx, t.makeGalleryFilter(), findFilter)
if err != nil {
return 0, err
}
@ -465,7 +471,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
return sceneCount + imageCount + galleryCount, nil
}
func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRepository) error {
func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) {
return nil
}
@ -477,7 +483,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRep
more := true
for more {
scenes, err := scene.Query(r.Scene(), sceneFilter, findFilter)
scenes, err := scene.Query(ctx, r.Scene, sceneFilter, findFilter)
if err != nil {
return err
}
@ -518,7 +524,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRep
return nil
}
func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRepository) error {
func (t *autoTagFilesTask) processImages(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) {
return nil
}
@ -530,7 +536,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRep
more := true
for more {
images, err := image.Query(r.Image(), imageFilter, findFilter)
images, err := image.Query(ctx, r.Image, imageFilter, findFilter)
if err != nil {
return err
}
@ -571,7 +577,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRep
return nil
}
func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.ReaderRepository) error {
func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) {
return nil
}
@ -583,7 +589,7 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Reader
more := true
for more {
galleries, _, err := r.Gallery().Query(galleryFilter, findFilter)
galleries, _, err := r.Gallery.Query(ctx, galleryFilter, findFilter)
if err != nil {
return err
}
@ -625,8 +631,9 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Reader
}
func (t *autoTagFilesTask) process(ctx context.Context) {
if err := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
total, err := t.getCount(r)
r := t.txnManager
if err := r.WithTxn(ctx, func(ctx context.Context) error {
total, err := t.getCount(ctx, t.txnManager)
if err != nil {
return err
}
@ -661,7 +668,7 @@ func (t *autoTagFilesTask) process(ctx context.Context) {
}
type autoTagSceneTask struct {
txnManager models.TransactionManager
txnManager models.Repository
scene *models.Scene
performers bool
@ -673,19 +680,20 @@ type autoTagSceneTask struct {
func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers {
if err := autotag.ScenePerformers(t.scene, r.Scene(), r.Performer(), t.cache); err != nil {
if err := autotag.ScenePerformers(ctx, t.scene, r.Scene, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging scene performers for %s: %v", t.scene.Path, err)
}
}
if t.studios {
if err := autotag.SceneStudios(t.scene, r.Scene(), r.Studio(), t.cache); err != nil {
if err := autotag.SceneStudios(ctx, t.scene, r.Scene, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging scene studio for %s: %v", t.scene.Path, err)
}
}
if t.tags {
if err := autotag.SceneTags(t.scene, r.Scene(), r.Tag(), t.cache); err != nil {
if err := autotag.SceneTags(ctx, t.scene, r.Scene, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging scene tags for %s: %v", t.scene.Path, err)
}
}
@ -697,7 +705,7 @@ func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
}
type autoTagImageTask struct {
txnManager models.TransactionManager
txnManager models.Repository
image *models.Image
performers bool
@ -709,19 +717,20 @@ type autoTagImageTask struct {
func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers {
if err := autotag.ImagePerformers(t.image, r.Image(), r.Performer(), t.cache); err != nil {
if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging image performers for %s: %v", t.image.Path, err)
}
}
if t.studios {
if err := autotag.ImageStudios(t.image, r.Image(), r.Studio(), t.cache); err != nil {
if err := autotag.ImageStudios(ctx, t.image, r.Image, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging image studio for %s: %v", t.image.Path, err)
}
}
if t.tags {
if err := autotag.ImageTags(t.image, r.Image(), r.Tag(), t.cache); err != nil {
if err := autotag.ImageTags(ctx, t.image, r.Image, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging image tags for %s: %v", t.image.Path, err)
}
}
@ -733,7 +742,7 @@ func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
}
type autoTagGalleryTask struct {
txnManager models.TransactionManager
txnManager models.Repository
gallery *models.Gallery
performers bool
@ -745,19 +754,20 @@ type autoTagGalleryTask struct {
func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers {
if err := autotag.GalleryPerformers(t.gallery, r.Gallery(), r.Performer(), t.cache); err != nil {
if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging gallery performers for %s: %v", t.gallery.Path.String, err)
}
}
if t.studios {
if err := autotag.GalleryStudios(t.gallery, r.Gallery(), r.Studio(), t.cache); err != nil {
if err := autotag.GalleryStudios(ctx, t.gallery, r.Gallery, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging gallery studio for %s: %v", t.gallery.Path.String, err)
}
}
if t.tags {
if err := autotag.GalleryTags(t.gallery, r.Gallery(), r.Tag(), t.cache); err != nil {
if err := autotag.GalleryTags(ctx, t.gallery, r.Gallery, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging gallery tags for %s: %v", t.gallery.Path.String, err)
}
}

View File

@ -18,7 +18,7 @@ import (
)
type cleanJob struct {
txnManager models.TransactionManager
txnManager models.Repository
input CleanMetadataInput
scanSubs *subscriptionManager
}
@ -29,8 +29,10 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
logger.Infof("Running in Dry Mode")
}
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
total, err := j.getCount(r)
r := j.txnManager
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
total, err := j.getCount(ctx, r)
if err != nil {
return fmt.Errorf("error getting count: %w", err)
}
@ -41,13 +43,13 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
return nil
}
if err := j.processScenes(ctx, progress, r.Scene()); err != nil {
if err := j.processScenes(ctx, progress, r.Scene); err != nil {
return fmt.Errorf("error cleaning scenes: %w", err)
}
if err := j.processImages(ctx, progress, r.Image()); err != nil {
if err := j.processImages(ctx, progress, r.Image); err != nil {
return fmt.Errorf("error cleaning images: %w", err)
}
if err := j.processGalleries(ctx, progress, r.Gallery(), r.Image()); err != nil {
if err := j.processGalleries(ctx, progress, r.Gallery, r.Image); err != nil {
return fmt.Errorf("error cleaning galleries: %w", err)
}
@ -66,9 +68,9 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
logger.Info("Finished Cleaning")
}
func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) {
func (j *cleanJob) getCount(ctx context.Context, r models.Repository) (int, error) {
sceneFilter := scene.PathsFilter(j.input.Paths)
sceneResult, err := r.Scene().Query(models.SceneQueryOptions{
sceneResult, err := r.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
Count: true,
},
@ -78,12 +80,12 @@ func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) {
return 0, err
}
imageCount, err := r.Image().QueryCount(image.PathsFilter(j.input.Paths), nil)
imageCount, err := r.Image.QueryCount(ctx, image.PathsFilter(j.input.Paths), nil)
if err != nil {
return 0, err
}
galleryCount, err := r.Gallery().QueryCount(gallery.PathsFilter(j.input.Paths), nil)
galleryCount, err := r.Gallery.QueryCount(ctx, gallery.PathsFilter(j.input.Paths), nil)
if err != nil {
return 0, err
}
@ -91,7 +93,7 @@ func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) {
return sceneResult.Count + imageCount + galleryCount, nil
}
func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb models.SceneReader) error {
func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb scene.Queryer) error {
batchSize := 1000
findFilter := models.BatchFindFilter(batchSize)
@ -107,7 +109,7 @@ func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb
return nil
}
scenes, err := scene.Query(qb, sceneFilter, findFilter)
scenes, err := scene.Query(ctx, qb, sceneFilter, findFilter)
if err != nil {
return fmt.Errorf("error querying for scenes: %w", err)
}
@ -154,7 +156,7 @@ func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb
return nil
}
func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, qb models.GalleryReader, iqb models.ImageReader) error {
func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, qb gallery.Queryer, iqb models.ImageReader) error {
batchSize := 1000
findFilter := models.BatchFindFilter(batchSize)
@ -170,14 +172,14 @@ func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress,
return nil
}
galleries, _, err := qb.Query(galleryFilter, findFilter)
galleries, _, err := qb.Query(ctx, galleryFilter, findFilter)
if err != nil {
return fmt.Errorf("error querying for galleries: %w", err)
}
for _, gallery := range galleries {
progress.ExecuteTask(fmt.Sprintf("Assessing gallery %s for clean", gallery.GetTitle()), func() {
if j.shouldCleanGallery(gallery, iqb) {
if j.shouldCleanGallery(ctx, gallery, iqb) {
toDelete = append(toDelete, gallery.ID)
} else {
// increment progress, no further processing
@ -215,7 +217,7 @@ func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress,
return nil
}
func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb models.ImageReader) error {
func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb image.Queryer) error {
batchSize := 1000
findFilter := models.BatchFindFilter(batchSize)
@ -234,7 +236,7 @@ func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb
return nil
}
images, err := image.Query(qb, imageFilter, findFilter)
images, err := image.Query(ctx, qb, imageFilter, findFilter)
if err != nil {
return fmt.Errorf("error querying for images: %w", err)
}
@ -318,7 +320,7 @@ func (j *cleanJob) shouldCleanScene(s *models.Scene) bool {
return false
}
func (j *cleanJob) shouldCleanGallery(g *models.Gallery, qb models.ImageReader) bool {
func (j *cleanJob) shouldCleanGallery(ctx context.Context, g *models.Gallery, qb models.ImageReader) bool {
// never clean manually created galleries
if !g.Path.Valid {
return false
@ -348,7 +350,7 @@ func (j *cleanJob) shouldCleanGallery(g *models.Gallery, qb models.ImageReader)
}
} else {
// folder-based - delete if it has no images
count, err := qb.CountByGalleryID(g.ID)
count, err := qb.CountByGalleryID(ctx, g.ID)
if err != nil {
logger.Warnf("Error trying to count gallery images for %q: %v", path, err)
return false
@ -401,16 +403,17 @@ func (j *cleanJob) deleteScene(ctx context.Context, fileNamingAlgorithm models.H
Paths: GetInstance().Paths,
}
var s *models.Scene
if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
repo := j.txnManager
qb := repo.Scene
var err error
s, err = qb.Find(sceneID)
s, err = qb.Find(ctx, sceneID)
if err != nil {
return err
}
return scene.Destroy(s, repo, fileDeleter, true, false)
return scene.Destroy(ctx, s, repo.Scene, repo.SceneMarker, fileDeleter, true, false)
}); err != nil {
fileDeleter.Rollback()
@ -431,16 +434,16 @@ func (j *cleanJob) deleteScene(ctx context.Context, fileNamingAlgorithm models.H
func (j *cleanJob) deleteGallery(ctx context.Context, galleryID int) {
var g *models.Gallery
if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error {
qb := repo.Gallery()
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := j.txnManager.Gallery
var err error
g, err = qb.Find(galleryID)
g, err = qb.Find(ctx, galleryID)
if err != nil {
return err
}
return qb.Destroy(galleryID)
return qb.Destroy(ctx, galleryID)
}); err != nil {
logger.Errorf("Error deleting gallery from database: %s", err.Error())
return
@ -459,11 +462,11 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) {
}
var i *models.Image
if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := j.txnManager.Image
var err error
i, err = qb.Find(imageID)
i, err = qb.Find(ctx, imageID)
if err != nil {
return err
}
@ -472,7 +475,7 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) {
return fmt.Errorf("image not found: %d", imageID)
}
return image.Destroy(i, qb, fileDeleter, true, false)
return image.Destroy(ctx, i, qb, fileDeleter, true, false)
}); err != nil {
fileDeleter.Rollback()

View File

@ -32,7 +32,7 @@ import (
)
type ExportTask struct {
txnManager models.TransactionManager
txnManager models.Repository
full bool
baseDir string
@ -100,7 +100,7 @@ func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportT
}
return &ExportTask{
txnManager: GetInstance().TxnManager,
txnManager: GetInstance().Repository,
fileNamingAlgorithm: a,
scenes: newExportSpec(input.Scenes),
images: newExportSpec(input.Images),
@ -146,30 +146,32 @@ func (t *ExportTask) Start(ctx context.Context, wg *sync.WaitGroup) {
paths.EnsureJSONDirs(t.baseDir)
txnErr := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
txnErr := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
// include movie scenes and gallery images
if !t.full {
// only include movie scenes if includeDependencies is also set
if !t.scenes.all && t.includeDependencies {
t.populateMovieScenes(r)
t.populateMovieScenes(ctx, r)
}
// always export gallery images
if !t.images.all {
t.populateGalleryImages(r)
t.populateGalleryImages(ctx, r)
}
}
t.ExportScenes(workerCount, r)
t.ExportImages(workerCount, r)
t.ExportGalleries(workerCount, r)
t.ExportMovies(workerCount, r)
t.ExportPerformers(workerCount, r)
t.ExportStudios(workerCount, r)
t.ExportTags(workerCount, r)
t.ExportScenes(ctx, workerCount, r)
t.ExportImages(ctx, workerCount, r)
t.ExportGalleries(ctx, workerCount, r)
t.ExportMovies(ctx, workerCount, r)
t.ExportPerformers(ctx, workerCount, r)
t.ExportStudios(ctx, workerCount, r)
t.ExportTags(ctx, workerCount, r)
if t.full {
t.ExportScrapedItems(r)
t.ExportScrapedItems(ctx, r)
}
return nil
@ -284,17 +286,17 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error {
return nil
}
func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
reader := repo.Movie()
sceneReader := repo.Scene()
func (t *ExportTask) populateMovieScenes(ctx context.Context, repo models.Repository) {
reader := repo.Movie
sceneReader := repo.Scene
var movies []*models.Movie
var err error
all := t.full || (t.movies != nil && t.movies.all)
if all {
movies, err = reader.All()
movies, err = reader.All(ctx)
} else if t.movies != nil && len(t.movies.IDs) > 0 {
movies, err = reader.FindMany(t.movies.IDs)
movies, err = reader.FindMany(ctx, t.movies.IDs)
}
if err != nil {
@ -302,7 +304,7 @@ func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
}
for _, m := range movies {
scenes, err := sceneReader.FindByMovieID(m.ID)
scenes, err := sceneReader.FindByMovieID(ctx, m.ID)
if err != nil {
logger.Errorf("[movies] <%s> failed to fetch scenes for movie: %s", m.Checksum, err.Error())
continue
@ -314,17 +316,17 @@ func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
}
}
func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
reader := repo.Gallery()
imageReader := repo.Image()
func (t *ExportTask) populateGalleryImages(ctx context.Context, repo models.Repository) {
reader := repo.Gallery
imageReader := repo.Image
var galleries []*models.Gallery
var err error
all := t.full || (t.galleries != nil && t.galleries.all)
if all {
galleries, err = reader.All()
galleries, err = reader.All(ctx)
} else if t.galleries != nil && len(t.galleries.IDs) > 0 {
galleries, err = reader.FindMany(t.galleries.IDs)
galleries, err = reader.FindMany(ctx, t.galleries.IDs)
}
if err != nil {
@ -332,7 +334,7 @@ func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
}
for _, g := range galleries {
images, err := imageReader.FindByGalleryID(g.ID)
images, err := imageReader.FindByGalleryID(ctx, g.ID)
if err != nil {
logger.Errorf("[galleries] <%s> failed to fetch images for gallery: %s", g.Checksum, err.Error())
continue
@ -344,18 +346,18 @@ func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
}
}
func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo models.Repository) {
var scenesWg sync.WaitGroup
sceneReader := repo.Scene()
sceneReader := repo.Scene
var scenes []*models.Scene
var err error
all := t.full || (t.scenes != nil && t.scenes.all)
if all {
scenes, err = sceneReader.All()
scenes, err = sceneReader.All(ctx)
} else if t.scenes != nil && len(t.scenes.IDs) > 0 {
scenes, err = sceneReader.FindMany(t.scenes.IDs)
scenes, err = sceneReader.FindMany(ctx, t.scenes.IDs)
}
if err != nil {
@ -369,7 +371,7 @@ func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Scene workers
scenesWg.Add(1)
go exportScene(&scenesWg, jobCh, repo, t)
go exportScene(ctx, &scenesWg, jobCh, repo, t)
}
for i, scene := range scenes {
@ -388,32 +390,32 @@ func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.ReaderRepository, t *ExportTask) {
func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.Repository, t *ExportTask) {
defer wg.Done()
sceneReader := repo.Scene()
studioReader := repo.Studio()
movieReader := repo.Movie()
galleryReader := repo.Gallery()
performerReader := repo.Performer()
tagReader := repo.Tag()
sceneMarkerReader := repo.SceneMarker()
sceneReader := repo.Scene
studioReader := repo.Studio
movieReader := repo.Movie
galleryReader := repo.Gallery
performerReader := repo.Performer
tagReader := repo.Tag
sceneMarkerReader := repo.SceneMarker
for s := range jobChan {
sceneHash := s.GetHash(t.fileNamingAlgorithm)
newSceneJSON, err := scene.ToBasicJSON(sceneReader, s)
newSceneJSON, err := scene.ToBasicJSON(ctx, sceneReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene JSON: %s", sceneHash, err.Error())
continue
}
newSceneJSON.Studio, err = scene.GetStudioName(studioReader, s)
newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene studio name: %s", sceneHash, err.Error())
continue
}
galleries, err := galleryReader.FindBySceneID(s.ID)
galleries, err := galleryReader.FindBySceneID(ctx, s.ID)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene gallery checksums: %s", sceneHash, err.Error())
continue
@ -421,7 +423,7 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
newSceneJSON.Galleries = gallery.GetChecksums(galleries)
performers, err := performerReader.FindBySceneID(s.ID)
performers, err := performerReader.FindBySceneID(ctx, s.ID)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene performer names: %s", sceneHash, err.Error())
continue
@ -429,19 +431,19 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
newSceneJSON.Performers = performer.GetNames(performers)
newSceneJSON.Tags, err = scene.GetTagNames(tagReader, s)
newSceneJSON.Tags, err = scene.GetTagNames(ctx, tagReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene tag names: %s", sceneHash, err.Error())
continue
}
newSceneJSON.Markers, err = scene.GetSceneMarkersJSON(sceneMarkerReader, tagReader, s)
newSceneJSON.Markers, err = scene.GetSceneMarkersJSON(ctx, sceneMarkerReader, tagReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene markers JSON: %s", sceneHash, err.Error())
continue
}
newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(movieReader, sceneReader, s)
newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(ctx, movieReader, sceneReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene movies JSON: %s", sceneHash, err.Error())
continue
@ -454,14 +456,14 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
t.galleries.IDs = intslice.IntAppendUniques(t.galleries.IDs, gallery.GetIDs(galleries))
tagIDs, err := scene.GetDependentTagIDs(tagReader, sceneMarkerReader, s)
tagIDs, err := scene.GetDependentTagIDs(ctx, tagReader, sceneMarkerReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene tags: %s", sceneHash, err.Error())
continue
}
t.tags.IDs = intslice.IntAppendUniques(t.tags.IDs, tagIDs)
movieIDs, err := scene.GetDependentMovieIDs(sceneReader, s)
movieIDs, err := scene.GetDependentMovieIDs(ctx, sceneReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene movies: %s", sceneHash, err.Error())
continue
@ -482,18 +484,18 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
}
}
func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo models.Repository) {
var imagesWg sync.WaitGroup
imageReader := repo.Image()
imageReader := repo.Image
var images []*models.Image
var err error
all := t.full || (t.images != nil && t.images.all)
if all {
images, err = imageReader.All()
images, err = imageReader.All(ctx)
} else if t.images != nil && len(t.images.IDs) > 0 {
images, err = imageReader.FindMany(t.images.IDs)
images, err = imageReader.FindMany(ctx, t.images.IDs)
}
if err != nil {
@ -507,7 +509,7 @@ func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Image workers
imagesWg.Add(1)
go exportImage(&imagesWg, jobCh, repo, t)
go exportImage(ctx, &imagesWg, jobCh, repo, t)
}
for i, image := range images {
@ -526,12 +528,12 @@ func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.ReaderRepository, t *ExportTask) {
func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.Repository, t *ExportTask) {
defer wg.Done()
studioReader := repo.Studio()
galleryReader := repo.Gallery()
performerReader := repo.Performer()
tagReader := repo.Tag()
studioReader := repo.Studio
galleryReader := repo.Gallery
performerReader := repo.Performer
tagReader := repo.Tag
for s := range jobChan {
imageHash := s.Checksum
@ -539,13 +541,13 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON := image.ToBasicJSON(s)
var err error
newImageJSON.Studio, err = image.GetStudioName(studioReader, s)
newImageJSON.Studio, err = image.GetStudioName(ctx, studioReader, s)
if err != nil {
logger.Errorf("[images] <%s> error getting image studio name: %s", imageHash, err.Error())
continue
}
imageGalleries, err := galleryReader.FindByImageID(s.ID)
imageGalleries, err := galleryReader.FindByImageID(ctx, s.ID)
if err != nil {
logger.Errorf("[images] <%s> error getting image galleries: %s", imageHash, err.Error())
continue
@ -553,7 +555,7 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON.Galleries = t.getGalleryChecksums(imageGalleries)
performers, err := performerReader.FindByImageID(s.ID)
performers, err := performerReader.FindByImageID(ctx, s.ID)
if err != nil {
logger.Errorf("[images] <%s> error getting image performer names: %s", imageHash, err.Error())
continue
@ -561,7 +563,7 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON.Performers = performer.GetNames(performers)
tags, err := tagReader.FindByImageID(s.ID)
tags, err := tagReader.FindByImageID(ctx, s.ID)
if err != nil {
logger.Errorf("[images] <%s> error getting image tag names: %s", imageHash, err.Error())
continue
@ -597,18 +599,18 @@ func (t *ExportTask) getGalleryChecksums(galleries []*models.Gallery) (ret []str
return
}
func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository) {
func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo models.Repository) {
var galleriesWg sync.WaitGroup
reader := repo.Gallery()
reader := repo.Gallery
var galleries []*models.Gallery
var err error
all := t.full || (t.galleries != nil && t.galleries.all)
if all {
galleries, err = reader.All()
galleries, err = reader.All(ctx)
} else if t.galleries != nil && len(t.galleries.IDs) > 0 {
galleries, err = reader.FindMany(t.galleries.IDs)
galleries, err = reader.FindMany(ctx, t.galleries.IDs)
}
if err != nil {
@ -622,7 +624,7 @@ func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository)
for w := 0; w < workers; w++ { // create export Scene workers
galleriesWg.Add(1)
go exportGallery(&galleriesWg, jobCh, repo, t)
go exportGallery(ctx, &galleriesWg, jobCh, repo, t)
}
for i, gallery := range galleries {
@ -646,11 +648,11 @@ func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository)
logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.ReaderRepository, t *ExportTask) {
func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.Repository, t *ExportTask) {
defer wg.Done()
studioReader := repo.Studio()
performerReader := repo.Performer()
tagReader := repo.Tag()
studioReader := repo.Studio
performerReader := repo.Performer
tagReader := repo.Tag
for g := range jobChan {
galleryHash := g.Checksum
@ -661,13 +663,13 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
continue
}
newGalleryJSON.Studio, err = gallery.GetStudioName(studioReader, g)
newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g)
if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery studio name: %s", galleryHash, err.Error())
continue
}
performers, err := performerReader.FindByGalleryID(g.ID)
performers, err := performerReader.FindByGalleryID(ctx, g.ID)
if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery performer names: %s", galleryHash, err.Error())
continue
@ -675,7 +677,7 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
newGalleryJSON.Performers = performer.GetNames(performers)
tags, err := tagReader.FindByGalleryID(g.ID)
tags, err := tagReader.FindByGalleryID(ctx, g.ID)
if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery tag names: %s", galleryHash, err.Error())
continue
@ -703,17 +705,17 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
}
}
func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository) {
func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo models.Repository) {
var performersWg sync.WaitGroup
reader := repo.Performer()
reader := repo.Performer
var performers []*models.Performer
var err error
all := t.full || (t.performers != nil && t.performers.all)
if all {
performers, err = reader.All()
performers, err = reader.All(ctx)
} else if t.performers != nil && len(t.performers.IDs) > 0 {
performers, err = reader.FindMany(t.performers.IDs)
performers, err = reader.FindMany(ctx, t.performers.IDs)
}
if err != nil {
@ -726,7 +728,7 @@ func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository)
for w := 0; w < workers; w++ { // create export Performer workers
performersWg.Add(1)
go t.exportPerformer(&performersWg, jobCh, repo)
go t.exportPerformer(ctx, &performersWg, jobCh, repo)
}
for i, performer := range performers {
@ -743,20 +745,20 @@ func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository)
logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.ReaderRepository) {
func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.Repository) {
defer wg.Done()
performerReader := repo.Performer()
performerReader := repo.Performer
for p := range jobChan {
newPerformerJSON, err := performer.ToJSON(performerReader, p)
newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p)
if err != nil {
logger.Errorf("[performers] <%s> error getting performer JSON: %s", p.Checksum, err.Error())
continue
}
tags, err := repo.Tag().FindByPerformerID(p.ID)
tags, err := repo.Tag.FindByPerformerID(ctx, p.ID)
if err != nil {
logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Checksum, err.Error())
continue
@ -781,17 +783,17 @@ func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.
}
}
func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo models.Repository) {
var studiosWg sync.WaitGroup
reader := repo.Studio()
reader := repo.Studio
var studios []*models.Studio
var err error
all := t.full || (t.studios != nil && t.studios.all)
if all {
studios, err = reader.All()
studios, err = reader.All(ctx)
} else if t.studios != nil && len(t.studios.IDs) > 0 {
studios, err = reader.FindMany(t.studios.IDs)
studios, err = reader.FindMany(ctx, t.studios.IDs)
}
if err != nil {
@ -805,7 +807,7 @@ func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Studio workers
studiosWg.Add(1)
go t.exportStudio(&studiosWg, jobCh, repo)
go t.exportStudio(ctx, &studiosWg, jobCh, repo)
}
for i, studio := range studios {
@ -822,13 +824,13 @@ func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.ReaderRepository) {
func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.Repository) {
defer wg.Done()
studioReader := repo.Studio()
studioReader := repo.Studio
for s := range jobChan {
newStudioJSON, err := studio.ToJSON(studioReader, s)
newStudioJSON, err := studio.ToJSON(ctx, studioReader, s)
if err != nil {
logger.Errorf("[studios] <%s> error getting studio JSON: %s", s.Checksum, err.Error())
@ -846,17 +848,17 @@ func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Stu
}
}
func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo models.Repository) {
var tagsWg sync.WaitGroup
reader := repo.Tag()
reader := repo.Tag
var tags []*models.Tag
var err error
all := t.full || (t.tags != nil && t.tags.all)
if all {
tags, err = reader.All()
tags, err = reader.All(ctx)
} else if t.tags != nil && len(t.tags.IDs) > 0 {
tags, err = reader.FindMany(t.tags.IDs)
tags, err = reader.FindMany(ctx, t.tags.IDs)
}
if err != nil {
@ -870,7 +872,7 @@ func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Tag workers
tagsWg.Add(1)
go t.exportTag(&tagsWg, jobCh, repo)
go t.exportTag(ctx, &tagsWg, jobCh, repo)
}
for i, tag := range tags {
@ -890,13 +892,13 @@ func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.ReaderRepository) {
func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.Repository) {
defer wg.Done()
tagReader := repo.Tag()
tagReader := repo.Tag
for thisTag := range jobChan {
newTagJSON, err := tag.ToJSON(tagReader, thisTag)
newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag)
if err != nil {
logger.Errorf("[tags] <%s> error getting tag JSON: %s", thisTag.Name, err.Error())
@ -917,17 +919,17 @@ func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, r
}
}
func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo models.Repository) {
var moviesWg sync.WaitGroup
reader := repo.Movie()
reader := repo.Movie
var movies []*models.Movie
var err error
all := t.full || (t.movies != nil && t.movies.all)
if all {
movies, err = reader.All()
movies, err = reader.All(ctx)
} else if t.movies != nil && len(t.movies.IDs) > 0 {
movies, err = reader.FindMany(t.movies.IDs)
movies, err = reader.FindMany(ctx, t.movies.IDs)
}
if err != nil {
@ -941,7 +943,7 @@ func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Studio workers
moviesWg.Add(1)
go t.exportMovie(&moviesWg, jobCh, repo)
go t.exportMovie(ctx, &moviesWg, jobCh, repo)
}
for i, movie := range movies {
@ -958,14 +960,14 @@ func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.ReaderRepository) {
func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.Repository) {
defer wg.Done()
movieReader := repo.Movie()
studioReader := repo.Studio()
movieReader := repo.Movie
studioReader := repo.Studio
for m := range jobChan {
newMovieJSON, err := movie.ToJSON(movieReader, studioReader, m)
newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m)
if err != nil {
logger.Errorf("[movies] <%s> error getting tag JSON: %s", m.Checksum, err.Error())
@ -991,10 +993,10 @@ func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movi
}
}
func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) {
qb := repo.ScrapedItem()
sqb := repo.Studio()
scrapedItems, err := qb.All()
func (t *ExportTask) ExportScrapedItems(ctx context.Context, repo models.Repository) {
qb := repo.ScrapedItem
sqb := repo.Studio
scrapedItems, err := qb.All(ctx)
if err != nil {
logger.Errorf("[scraped sites] failed to fetch all items: %s", err.Error())
}
@ -1009,7 +1011,7 @@ func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) {
var studioName string
if scrapedItem.StudioID.Valid {
studio, _ := sqb.Find(int(scrapedItem.StudioID.Int64))
studio, _ := sqb.Find(ctx, int(scrapedItem.StudioID.Int64))
if studio != nil {
studioName = studio.Name.String
}

View File

@ -54,7 +54,7 @@ type GeneratePreviewOptionsInput struct {
const generateQueueSize = 200000
type GenerateJob struct {
txnManager models.TransactionManager
txnManager models.Repository
input GenerateMetadataInput
overwrite bool
@ -110,20 +110,20 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
Overwrite: j.overwrite,
}
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
qb := r.Scene()
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := j.txnManager.Scene
if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 {
totals = j.queueTasks(ctx, g, queue)
} else {
if len(j.input.SceneIDs) > 0 {
scenes, err = qb.FindMany(sceneIDs)
scenes, err = qb.FindMany(ctx, sceneIDs)
for _, s := range scenes {
j.queueSceneJobs(ctx, g, s, queue, &totals)
}
}
if len(j.input.MarkerIDs) > 0 {
markers, err = r.SceneMarker().FindMany(markerIDs)
markers, err = j.txnManager.SceneMarker.FindMany(ctx, markerIDs)
if err != nil {
return err
}
@ -192,13 +192,13 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
findFilter := models.BatchFindFilter(batchSize)
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
for more := true; more; {
if job.IsCancelled(ctx) {
return context.Canceled
}
scenes, err := scene.Query(r.Scene(), nil, findFilter)
scenes, err := scene.Query(ctx, j.txnManager.Scene, nil, findFilter)
if err != nil {
return err
}

View File

@ -15,7 +15,7 @@ type GenerateInteractiveHeatmapSpeedTask struct {
Scene models.Scene
Overwrite bool
fileNamingAlgorithm models.HashAlgorithm
TxnManager models.TransactionManager
TxnManager models.Repository
}
func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string {
@ -47,22 +47,22 @@ func (t *GenerateInteractiveHeatmapSpeedTask) Start(ctx context.Context) {
var s *models.Scene
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
s, err = r.Scene().FindByPath(t.Scene.Path)
s, err = t.TxnManager.Scene.FindByPath(ctx, t.Scene.Path)
return err
}); err != nil {
logger.Error(err.Error())
return
}
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
qb := r.Scene()
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := t.TxnManager.Scene
scenePartial := models.ScenePartial{
ID: s.ID,
InteractiveSpeed: &median,
}
_, err := qb.Update(scenePartial)
_, err := qb.Update(ctx, scenePartial)
return err
}); err != nil {
logger.Error(err.Error())

View File

@ -13,7 +13,7 @@ import (
)
type GenerateMarkersTask struct {
TxnManager models.TransactionManager
TxnManager models.Repository
Scene *models.Scene
Marker *models.SceneMarker
Overwrite bool
@ -42,9 +42,9 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
if t.Marker != nil {
var scene *models.Scene
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
scene, err = r.Scene().Find(int(t.Marker.SceneID.Int64))
scene, err = t.TxnManager.Scene.Find(ctx, int(t.Marker.SceneID.Int64))
return err
}); err != nil {
logger.Errorf("error finding scene for marker: %s", err.Error())
@ -69,9 +69,9 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) {
var sceneMarkers []*models.SceneMarker
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID)
sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
return err
}); err != nil {
logger.Errorf("error getting scene markers: %s", err.Error())
@ -134,9 +134,9 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *ffmpeg.VideoFile, scene
func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int {
markers := 0
var sceneMarkers []*models.SceneMarker
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID)
sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
return err
}); err != nil {
logger.Errorf("errror finding scene markers: %s", err.Error())

View File

@ -14,7 +14,7 @@ type GeneratePhashTask struct {
Scene models.Scene
Overwrite bool
fileNamingAlgorithm models.HashAlgorithm
txnManager models.TransactionManager
txnManager models.Repository
}
func (t *GeneratePhashTask) GetDescription() string {
@ -40,14 +40,14 @@ func (t *GeneratePhashTask) Start(ctx context.Context) {
return
}
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
qb := r.Scene()
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := t.txnManager.Scene
hashValue := sql.NullInt64{Int64: int64(*hash), Valid: true}
scenePartial := models.ScenePartial{
ID: t.Scene.ID,
Phash: &hashValue,
}
_, err := qb.Update(scenePartial)
_, err := qb.Update(ctx, scenePartial)
return err
}); err != nil {
logger.Error(err.Error())

View File

@ -17,7 +17,7 @@ type GenerateScreenshotTask struct {
Scene models.Scene
ScreenshotAt *float64
fileNamingAlgorithm models.HashAlgorithm
txnManager models.TransactionManager
txnManager models.Repository
}
func (t *GenerateScreenshotTask) Start(ctx context.Context) {
@ -74,8 +74,8 @@ func (t *GenerateScreenshotTask) Start(ctx context.Context) {
return
}
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
qb := r.Scene()
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := t.txnManager.Scene
updatedTime := time.Now()
updatedScene := models.ScenePartial{
ID: t.Scene.ID,
@ -87,12 +87,12 @@ func (t *GenerateScreenshotTask) Start(ctx context.Context) {
}
// update the scene cover table
if err := qb.UpdateCover(t.Scene.ID, coverImageData); err != nil {
if err := qb.UpdateCover(ctx, t.Scene.ID, coverImageData); err != nil {
return fmt.Errorf("error setting screenshot: %v", err)
}
// update the scene with the update date
_, err = qb.Update(updatedScene)
_, err = qb.Update(ctx, updatedScene)
if err != nil {
return fmt.Errorf("error updating scene: %v", err)
}

View File

@ -14,12 +14,12 @@ import (
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/txn"
)
var ErrInput = errors.New("invalid request input")
type IdentifyJob struct {
txnManager models.TransactionManager
postHookExecutor identify.SceneUpdatePostHookExecutor
input identify.Options
@ -29,7 +29,6 @@ type IdentifyJob struct {
func CreateIdentifyJob(input identify.Options) *IdentifyJob {
return &IdentifyJob{
txnManager: instance.TxnManager,
postHookExecutor: instance.PluginCache,
input: input,
stashBoxes: instance.Config.GetStashBoxes(),
@ -52,9 +51,9 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// if scene ids provided, use those
// otherwise, batch query for all scenes - ordering by path
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
if len(j.input.SceneIDs) == 0 {
return j.identifyAllScenes(ctx, r, sources)
return j.identifyAllScenes(ctx, sources)
}
sceneIDs, err := stringslice.StringSliceToIntSlice(j.input.SceneIDs)
@ -70,7 +69,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// find the scene
var err error
scene, err := r.Scene().Find(id)
scene, err := instance.Repository.Scene.Find(ctx, id)
if err != nil {
return fmt.Errorf("error finding scene with id %d: %w", id, err)
}
@ -88,7 +87,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
}
}
func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepository, sources []identify.ScraperSource) error {
func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error {
// exclude organised
organised := false
sceneFilter := scene.FilterFromPaths(j.input.Paths)
@ -102,7 +101,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepo
// get the count
pp := 0
findFilter.PerPage = &pp
countResult, err := r.Scene().Query(models.SceneQueryOptions{
countResult, err := instance.Repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
Count: true,
@ -115,7 +114,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepo
j.progress.SetTotal(countResult.Count)
return scene.BatchProcess(ctx, r.Scene(), sceneFilter, findFilter, func(scene *models.Scene) error {
return scene.BatchProcess(ctx, instance.Repository.Scene, sceneFilter, findFilter, func(scene *models.Scene) error {
if job.IsCancelled(ctx) {
return nil
}
@ -133,6 +132,11 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
var taskError error
j.progress.ExecuteTask("Identifying "+s.Path, func() {
task := identify.SceneIdentifier{
SceneReaderUpdater: instance.Repository.Scene,
StudioCreator: instance.Repository.Studio,
PerformerCreator: instance.Repository.Performer,
TagCreator: instance.Repository.Tag,
DefaultOptions: j.input.Options,
Sources: sources,
ScreenshotSetter: &scene.PathsScreenshotSetter{
@ -142,7 +146,7 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
SceneUpdatePostHookExecutor: j.postHookExecutor,
}
taskError = task.Identify(ctx, j.txnManager, s)
taskError = task.Identify(ctx, instance.Repository, s)
})
if taskError != nil {
@ -166,7 +170,12 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) {
src = identify.ScraperSource{
Name: "stash-box: " + stashBox.Endpoint,
Scraper: stashboxSource{
stashbox.NewClient(*stashBox, j.txnManager),
stashbox.NewClient(*stashBox, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
}),
stashBox.Endpoint,
},
RemoteSite: stashBox.Endpoint,

View File

@ -12,8 +12,6 @@ import (
"time"
"github.com/99designs/gqlgen/graphql"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
@ -30,7 +28,7 @@ import (
)
type ImportTask struct {
txnManager models.TransactionManager
txnManager models.Repository
json jsonUtils
BaseDir string
@ -73,7 +71,7 @@ func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*Import
}
return &ImportTask{
txnManager: GetInstance().TxnManager,
txnManager: GetInstance().Repository,
BaseDir: baseDir,
TmpZip: tmpZip,
Reset: false,
@ -126,7 +124,7 @@ func (t *ImportTask) Start(ctx context.Context) {
t.scraped = scraped
if t.Reset {
err := database.Reset(config.GetInstance().GetDatabasePath())
err := t.txnManager.Reset()
if err != nil {
logger.Errorf("Error resetting database: %s", err.Error())
@ -211,15 +209,16 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
logger.Progressf("[performers] %d of %d", index, len(t.mappings.Performers))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := r.Performer()
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Performer
importer := &performer.Importer{
ReaderWriter: readerWriter,
TagWriter: r.Tag(),
TagWriter: r.Tag,
Input: *performerJSON,
}
return performImport(importer, t.DuplicateBehaviour)
return performImport(ctx, importer, t.DuplicateBehaviour)
}); err != nil {
logger.Errorf("[performers] <%s> import failed: %s", mappingJSON.Checksum, err.Error())
}
@ -243,8 +242,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Progressf("[studios] %d of %d", index, len(t.mappings.Studios))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
return t.ImportStudio(studioJSON, pendingParent, r.Studio())
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportStudio(ctx, studioJSON, pendingParent, t.txnManager.Studio)
}); err != nil {
if errors.Is(err, studio.ErrParentStudioNotExist) {
// add to the pending parent list so that it is created after the parent
@ -265,8 +264,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
for _, s := range pendingParent {
for _, orphanStudioJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
return t.ImportStudio(orphanStudioJSON, nil, r.Studio())
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportStudio(ctx, orphanStudioJSON, nil, t.txnManager.Studio)
}); err != nil {
logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error())
continue
@ -278,7 +277,7 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Info("[studios] import complete")
}
func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter models.StudioReaderWriter) error {
func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter studio.NameFinderCreatorUpdater) error {
importer := &studio.Importer{
ReaderWriter: readerWriter,
Input: *studioJSON,
@ -290,7 +289,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m
importer.MissingRefBehaviour = models.ImportMissingRefEnumFail
}
if err := performImport(importer, t.DuplicateBehaviour); err != nil {
if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil {
return err
}
@ -298,7 +297,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m
s := pendingParent[studioJSON.Name]
for _, childStudioJSON := range s {
// map is nil since we're not checking parent studios at this point
if err := t.ImportStudio(childStudioJSON, nil, readerWriter); err != nil {
if err := t.ImportStudio(ctx, childStudioJSON, nil, readerWriter); err != nil {
return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error())
}
}
@ -322,9 +321,10 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
logger.Progressf("[movies] %d of %d", index, len(t.mappings.Movies))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := r.Movie()
studioReaderWriter := r.Studio()
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Movie
studioReaderWriter := r.Studio
movieImporter := &movie.Importer{
ReaderWriter: readerWriter,
@ -333,7 +333,7 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
MissingRefBehaviour: t.MissingRefBehaviour,
}
return performImport(movieImporter, t.DuplicateBehaviour)
return performImport(ctx, movieImporter, t.DuplicateBehaviour)
}); err != nil {
logger.Errorf("[movies] <%s> import failed: %s", mappingJSON.Checksum, err.Error())
continue
@ -356,11 +356,12 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
logger.Progressf("[galleries] %d of %d", index, len(t.mappings.Galleries))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := r.Gallery()
tagWriter := r.Tag()
performerWriter := r.Performer()
studioWriter := r.Studio()
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Gallery
tagWriter := r.Tag
performerWriter := r.Performer
studioWriter := r.Studio
galleryImporter := &gallery.Importer{
ReaderWriter: readerWriter,
@ -371,7 +372,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
MissingRefBehaviour: t.MissingRefBehaviour,
}
return performImport(galleryImporter, t.DuplicateBehaviour)
return performImport(ctx, galleryImporter, t.DuplicateBehaviour)
}); err != nil {
logger.Errorf("[galleries] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error())
continue
@ -395,8 +396,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
return t.ImportTag(tagJSON, pendingParent, false, r.Tag())
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportTag(ctx, tagJSON, pendingParent, false, t.txnManager.Tag)
}); err != nil {
var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) {
@ -411,8 +412,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
for _, s := range pendingParent {
for _, orphanTagJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
return t.ImportTag(orphanTagJSON, nil, true, r.Tag())
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportTag(ctx, orphanTagJSON, nil, true, t.txnManager.Tag)
}); err != nil {
logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error())
continue
@ -423,7 +424,7 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Info("[tags] import complete")
}
func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter models.TagReaderWriter) error {
func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter tag.NameFinderCreatorUpdater) error {
importer := &tag.Importer{
ReaderWriter: readerWriter,
Input: *tagJSON,
@ -435,12 +436,12 @@ func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string
importer.MissingRefBehaviour = models.ImportMissingRefEnumFail
}
if err := performImport(importer, t.DuplicateBehaviour); err != nil {
if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil {
return err
}
for _, childTagJSON := range pendingParent[tagJSON.Name] {
if err := t.ImportTag(childTagJSON, pendingParent, fail, readerWriter); err != nil {
if err := t.ImportTag(ctx, childTagJSON, pendingParent, fail, readerWriter); err != nil {
var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) {
pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON)
@ -457,10 +458,11 @@ func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string
}
func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
logger.Info("[scraped sites] importing")
qb := r.ScrapedItem()
sqb := r.Studio()
r := t.txnManager
qb := r.ScrapedItem
sqb := r.Studio
currentTime := time.Now()
for i, mappingJSON := range t.scraped {
@ -484,7 +486,7 @@ func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)},
}
studio, err := sqb.FindByName(mappingJSON.Studio, false)
studio, err := sqb.FindByName(ctx, mappingJSON.Studio, false)
if err != nil {
logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error())
}
@ -492,7 +494,7 @@ func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true}
}
_, err = qb.Create(newScrapedItem)
_, err = qb.Create(ctx, newScrapedItem)
if err != nil {
logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error())
}
@ -522,14 +524,15 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
sceneHash := mappingJSON.Checksum
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := r.Scene()
tagWriter := r.Tag()
galleryWriter := r.Gallery()
movieWriter := r.Movie()
performerWriter := r.Performer()
studioWriter := r.Studio()
markerWriter := r.SceneMarker()
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Scene
tagWriter := r.Tag
galleryWriter := r.Gallery
movieWriter := r.Movie
performerWriter := r.Performer
studioWriter := r.Studio
markerWriter := r.SceneMarker
sceneImporter := &scene.Importer{
ReaderWriter: readerWriter,
@ -546,7 +549,7 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
TagWriter: tagWriter,
}
if err := performImport(sceneImporter, t.DuplicateBehaviour); err != nil {
if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil {
return err
}
@ -560,7 +563,7 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
TagWriter: tagWriter,
}
if err := performImport(markerImporter, t.DuplicateBehaviour); err != nil {
if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil {
return err
}
}
@ -590,12 +593,13 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
imageHash := mappingJSON.Checksum
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := r.Image()
tagWriter := r.Tag()
galleryWriter := r.Gallery()
performerWriter := r.Performer()
studioWriter := r.Studio()
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Image
tagWriter := r.Tag
galleryWriter := r.Gallery
performerWriter := r.Performer
studioWriter := r.Studio
imageImporter := &image.Importer{
ReaderWriter: readerWriter,
@ -610,7 +614,7 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
TagWriter: tagWriter,
}
return performImport(imageImporter, t.DuplicateBehaviour)
return performImport(ctx, imageImporter, t.DuplicateBehaviour)
}); err != nil {
logger.Errorf("[images] <%s> import failed: %s", imageHash, err.Error())
}

View File

@ -24,7 +24,7 @@ import (
const scanQueueSize = 200000
type ScanJob struct {
txnManager models.TransactionManager
txnManager models.Repository
input ScanMetadataInput
subscriptions *subscriptionManager
}
@ -220,20 +220,21 @@ func (j *ScanJob) doesPathExist(ctx context.Context, path string) bool {
gExt := config.GetGalleryExtensions()
ret := false
txnErr := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
txnErr := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := j.txnManager
switch {
case fsutil.MatchExtension(path, gExt):
g, _ := r.Gallery().FindByPath(path)
g, _ := r.Gallery.FindByPath(ctx, path)
if g != nil {
ret = true
}
case fsutil.MatchExtension(path, vidExt):
s, _ := r.Scene().FindByPath(path)
s, _ := r.Scene.FindByPath(ctx, path)
if s != nil {
ret = true
}
case fsutil.MatchExtension(path, imgExt):
i, _ := r.Image().FindByPath(path)
i, _ := r.Image.FindByPath(ctx, path)
if i != nil {
ret = true
}
@ -249,7 +250,7 @@ func (j *ScanJob) doesPathExist(ctx context.Context, path string) bool {
}
type ScanTask struct {
TxnManager models.TransactionManager
TxnManager models.Repository
file file.SourceFile
UseFileMetadata bool
StripFileExtension bool

View File

@ -21,12 +21,12 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
images := 0
scanImages := false
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
g, err = r.Gallery().FindByPath(path)
g, err = t.TxnManager.Gallery.FindByPath(ctx, path)
if g != nil && err == nil {
images, err = r.Image().CountByGalleryID(g.ID)
images, err = t.TxnManager.Image.CountByGalleryID(ctx, g.ID)
if err != nil {
return fmt.Errorf("error getting images for zip gallery %s: %s", path, err.Error())
}
@ -43,7 +43,7 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
ImageExtensions: instance.Config.GetImageExtensions(),
StripFileExtension: t.StripFileExtension,
CaseSensitiveFs: t.CaseSensitiveFs,
TxnManager: t.TxnManager,
CreatorUpdater: t.TxnManager.Gallery,
Paths: instance.Paths,
PluginCache: instance.PluginCache,
MutexManager: t.mutexManager,
@ -79,10 +79,11 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
// associates a gallery to a scene with the same basename
func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.SizedWaitGroup) {
path := t.file.Path()
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
qb := r.Gallery()
sqb := r.Scene()
g, err := qb.FindByPath(path)
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.TxnManager
qb := r.Gallery
sqb := r.Scene
g, err := qb.FindByPath(ctx, path)
if err != nil {
return err
}
@ -106,10 +107,10 @@ func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.Size
}
}
for _, scenePath := range relatedFiles {
scene, _ := sqb.FindByPath(scenePath)
scene, _ := sqb.FindByPath(ctx, scenePath)
// found related Scene
if scene != nil {
sceneGalleries, _ := sqb.FindByGalleryID(g.ID) // check if gallery is already associated to the scene
sceneGalleries, _ := sqb.FindByGalleryID(ctx, g.ID) // check if gallery is already associated to the scene
isAssoc := false
for _, sg := range sceneGalleries {
if scene.ID == sg.ID {
@ -119,7 +120,7 @@ func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.Size
}
if !isAssoc {
logger.Infof("associate: Gallery %s is related to scene: %d", path, scene.ID)
if err := sqb.UpdateGalleries(scene.ID, []int{g.ID}); err != nil {
if err := sqb.UpdateGalleries(ctx, scene.ID, []int{g.ID}); err != nil {
return err
}
}
@ -152,11 +153,11 @@ func (t *ScanTask) scanZipImages(ctx context.Context, zipGallery *models.Gallery
func (t *ScanTask) regenerateZipImages(ctx context.Context, zipGallery *models.Gallery) {
var images []*models.Image
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
iqb := r.Image()
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
iqb := t.TxnManager.Image
var err error
images, err = iqb.FindByGalleryID(zipGallery.ID)
images, err = iqb.FindByGalleryID(ctx, zipGallery.ID)
return err
}); err != nil {
logger.Warnf("failed to find gallery images: %s", err.Error())

View File

@ -23,9 +23,9 @@ func (t *ScanTask) scanImage(ctx context.Context) {
var i *models.Image
path := t.file.Path()
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
i, err = r.Image().FindByPath(path)
i, err = t.TxnManager.Image.FindByPath(ctx, path)
return err
}); err != nil {
logger.Error(err.Error())
@ -36,6 +36,8 @@ func (t *ScanTask) scanImage(ctx context.Context) {
Scanner: image.FileScanner(&file.FSHasher{}),
StripFileExtension: t.StripFileExtension,
TxnManager: t.TxnManager,
CreatorUpdater: t.TxnManager.Image,
CaseSensitiveFs: t.CaseSensitiveFs,
Paths: GetInstance().Paths,
PluginCache: instance.PluginCache,
MutexManager: t.mutexManager,
@ -58,8 +60,8 @@ func (t *ScanTask) scanImage(ctx context.Context) {
if i != nil {
if t.zipGallery != nil {
// associate with gallery
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
return gallery.AddImage(r.Gallery(), t.zipGallery.ID, i.ID)
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
return gallery.AddImage(ctx, t.TxnManager.Gallery, t.zipGallery.ID, i.ID)
}); err != nil {
logger.Error(err.Error())
return
@ -69,9 +71,9 @@ func (t *ScanTask) scanImage(ctx context.Context) {
logger.Infof("Associating image %s with folder gallery", i.Path)
var galleryID int
var isNewGallery bool
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
galleryID, isNewGallery, err = t.associateImageWithFolderGallery(i.ID, r.Gallery())
galleryID, isNewGallery, err = t.associateImageWithFolderGallery(ctx, i.ID, t.TxnManager.Gallery)
return err
}); err != nil {
logger.Error(err.Error())
@ -90,11 +92,17 @@ func (t *ScanTask) scanImage(ctx context.Context) {
}
}
func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.GalleryReaderWriter) (galleryID int, isNew bool, err error) {
type GalleryImageAssociator interface {
FindByPath(ctx context.Context, path string) (*models.Gallery, error)
Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error)
gallery.ImageUpdater
}
func (t *ScanTask) associateImageWithFolderGallery(ctx context.Context, imageID int, qb GalleryImageAssociator) (galleryID int, isNew bool, err error) {
// find a gallery with the path specified
path := filepath.Dir(t.file.Path())
var g *models.Gallery
g, err = qb.FindByPath(path)
g, err = qb.FindByPath(ctx, path)
if err != nil {
return
}
@ -120,7 +128,7 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.Galler
}
logger.Infof("Creating gallery for folder %s", path)
g, err = qb.Create(newGallery)
g, err = qb.Create(ctx, newGallery)
if err != nil {
return 0, false, err
}
@ -129,7 +137,7 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.Galler
}
// associate image with gallery
err = gallery.AddImage(qb, g.ID, imageID)
err = gallery.AddImage(ctx, qb, g.ID, imageID)
galleryID = g.ID
return
}

View File

@ -34,9 +34,9 @@ func (t *ScanTask) scanScene(ctx context.Context) *models.Scene {
var retScene *models.Scene
var s *models.Scene
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
s, err = r.Scene().FindByPath(t.file.Path())
s, err = t.TxnManager.Scene.FindByPath(ctx, t.file.Path())
return err
}); err != nil {
logger.Error(err.Error())
@ -54,7 +54,9 @@ func (t *ScanTask) scanScene(ctx context.Context) *models.Scene {
StripFileExtension: t.StripFileExtension,
FileNamingAlgorithm: t.fileNamingAlgorithm,
TxnManager: t.TxnManager,
CreatorUpdater: t.TxnManager.Scene,
Paths: GetInstance().Paths,
CaseSensitiveFs: t.CaseSensitiveFs,
Screenshotter: &sceneScreenshotter{
g: g,
},
@ -88,12 +90,12 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
captionLang := scene.GetCaptionsLangFromPath(captionPath)
relatedFiles := scene.GenerateCaptionCandidates(captionPath, vExt)
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
sqb := r.Scene()
sqb := t.TxnManager.Scene
for _, scenePath := range relatedFiles {
s, er := sqb.FindByPath(scenePath)
s, er := sqb.FindByPath(ctx, scenePath)
if er != nil {
logger.Errorf("Error searching for scene %s: %v", scenePath, er)
@ -101,7 +103,7 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
}
if s != nil { // found related Scene
logger.Debugf("Matched captions to scene %s", s.Path)
captions, er := sqb.GetCaptions(s.ID)
captions, er := sqb.GetCaptions(ctx, s.ID)
if er == nil {
fileExt := filepath.Ext(captionPath)
ext := fileExt[1:]
@ -112,7 +114,7 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
CaptionType: ext,
}
captions = append(captions, newCaption)
er = sqb.UpdateCaptions(s.ID, captions)
er = sqb.UpdateCaptions(ctx, s.ID, captions)
if er == nil {
logger.Debugf("Updated captions for scene %s. Added %s", s.Path, captionLang)
}

View File

@ -10,11 +10,11 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
type StashBoxPerformerTagTask struct {
txnManager models.TransactionManager
box *models.StashBox
name *string
performer *models.Performer
@ -41,12 +41,17 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
var performer *models.ScrapedPerformer
var err error
client := stashbox.NewClient(*t.box, t.txnManager)
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
})
if t.refresh {
var performerID string
txnErr := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
stashids, _ := r.Performer().GetStashIDs(t.performer.ID)
txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
stashids, _ := instance.Repository.Performer.GetStashIDs(ctx, t.performer.ID)
for _, id := range stashids {
if id.Endpoint == t.box.Endpoint {
performerID = id.StashID
@ -156,11 +161,12 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
partial.URL = &value
}
txnErr := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
_, err := r.Performer().Update(partial)
txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
r := instance.Repository
_, err := r.Performer.Update(ctx, partial)
if !t.refresh {
err = r.Performer().UpdateStashIDs(t.performer.ID, []models.StashID{
err = r.Performer.UpdateStashIDs(ctx, t.performer.ID, []models.StashID{
{
Endpoint: t.box.Endpoint,
StashID: *performer.RemoteSiteID,
@ -176,7 +182,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
if err != nil {
return err
}
err = r.Performer().UpdateImage(t.performer.ID, image)
err = r.Performer.UpdateImage(ctx, t.performer.ID, image)
if err != nil {
return err
}
@ -218,13 +224,14 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
URL: getNullString(performer.URL),
UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
}
err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
createdPerformer, err := r.Performer().Create(newPerformer)
err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
r := instance.Repository
createdPerformer, err := r.Performer.Create(ctx, newPerformer)
if err != nil {
return err
}
err = r.Performer().UpdateStashIDs(createdPerformer.ID, []models.StashID{
err = r.Performer.UpdateStashIDs(ctx, createdPerformer.ID, []models.StashID{
{
Endpoint: t.box.Endpoint,
StashID: *performer.RemoteSiteID,
@ -239,7 +246,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
if imageErr != nil {
return imageErr
}
err = r.Performer().UpdateImage(createdPerformer.ID, image)
err = r.Performer.UpdateImage(ctx, createdPerformer.ID, image)
}
return err
})

View File

@ -1,40 +0,0 @@
package database
import (
"context"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/logger"
)
// WithTxn executes the provided function within a transaction. It rolls back
// the transaction if the function returns an error, otherwise the transaction
// is committed.
func WithTxn(fn func(tx *sqlx.Tx) error) error {
ctx := context.TODO()
tx := DB.MustBeginTx(ctx, nil)
var err error
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
if err := tx.Rollback(); err != nil {
logger.Warnf("failure when performing transaction rollback: %v", err)
}
panic(p)
}
if err != nil {
// something went wrong, rollback
if err := tx.Rollback(); err != nil {
logger.Warnf("failure when performing transaction rollback: %v", err)
}
} else {
// all good, commit
err = tx.Commit()
}
}()
err = fn(tx)
return err
}

View File

@ -1,9 +1,12 @@
package gallery
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/json"
"github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/utils"
)
@ -52,9 +55,9 @@ func ToBasicJSON(gallery *models.Gallery) (*jsonschema.Gallery, error) {
// GetStudioName returns the name of the provided gallery's studio. It returns an
// empty string if there is no studio assigned to the gallery.
func GetStudioName(reader models.StudioReader, gallery *models.Gallery) (string, error) {
func GetStudioName(ctx context.Context, reader studio.Finder, gallery *models.Gallery) (string, error) {
if gallery.StudioID.Valid {
studio, err := reader.Find(int(gallery.StudioID.Int64))
studio, err := reader.Find(ctx, int(gallery.StudioID.Int64))
if err != nil {
return "", err
}

View File

@ -154,15 +154,15 @@ func TestGetStudioName(t *testing.T) {
studioErr := errors.New("error getting image")
mockStudioReader.On("Find", studioID).Return(&models.Studio{
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
Name: models.NullString(studioName),
}, nil).Once()
mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
gallery := s.input
json, err := GetStudioName(mockStudioReader, &gallery)
json, err := GetStudioName(testCtx, mockStudioReader, &gallery)
switch {
case !s.err && err != nil:

View File

@ -1,20 +1,30 @@
package gallery
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/performer"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/tag"
)
type FullCreatorUpdater interface {
FinderCreatorUpdater
UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error
UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error
}
type Importer struct {
ReaderWriter models.GalleryReaderWriter
StudioWriter models.StudioReaderWriter
PerformerWriter models.PerformerReaderWriter
TagWriter models.TagReaderWriter
ReaderWriter FullCreatorUpdater
StudioWriter studio.NameFinderCreator
PerformerWriter performer.NameFinderCreator
TagWriter tag.NameFinderCreator
Input jsonschema.Gallery
MissingRefBehaviour models.ImportMissingRefEnum
@ -23,18 +33,18 @@ type Importer struct {
tags []*models.Tag
}
func (i *Importer) PreImport() error {
func (i *Importer) PreImport(ctx context.Context) error {
i.gallery = i.galleryJSONToGallery(i.Input)
if err := i.populateStudio(); err != nil {
if err := i.populateStudio(ctx); err != nil {
return err
}
if err := i.populatePerformers(); err != nil {
if err := i.populatePerformers(ctx); err != nil {
return err
}
if err := i.populateTags(); err != nil {
if err := i.populateTags(ctx); err != nil {
return err
}
@ -74,9 +84,9 @@ func (i *Importer) galleryJSONToGallery(galleryJSON jsonschema.Gallery) models.G
return newGallery
}
func (i *Importer) populateStudio() error {
func (i *Importer) populateStudio(ctx context.Context) error {
if i.Input.Studio != "" {
studio, err := i.StudioWriter.FindByName(i.Input.Studio, false)
studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false)
if err != nil {
return fmt.Errorf("error finding studio by name: %v", err)
}
@ -91,7 +101,7 @@ func (i *Importer) populateStudio() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
studioID, err := i.createStudio(i.Input.Studio)
studioID, err := i.createStudio(ctx, i.Input.Studio)
if err != nil {
return err
}
@ -108,10 +118,10 @@ func (i *Importer) populateStudio() error {
return nil
}
func (i *Importer) createStudio(name string) (int, error) {
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := *models.NewStudio(name)
created, err := i.StudioWriter.Create(newStudio)
created, err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {
return 0, err
}
@ -119,10 +129,10 @@ func (i *Importer) createStudio(name string) (int, error) {
return created.ID, nil
}
func (i *Importer) populatePerformers() error {
func (i *Importer) populatePerformers(ctx context.Context) error {
if len(i.Input.Performers) > 0 {
names := i.Input.Performers
performers, err := i.PerformerWriter.FindByNames(names, false)
performers, err := i.PerformerWriter.FindByNames(ctx, names, false)
if err != nil {
return err
}
@ -145,7 +155,7 @@ func (i *Importer) populatePerformers() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
createdPerformers, err := i.createPerformers(missingPerformers)
createdPerformers, err := i.createPerformers(ctx, missingPerformers)
if err != nil {
return fmt.Errorf("error creating gallery performers: %v", err)
}
@ -162,12 +172,12 @@ func (i *Importer) populatePerformers() error {
return nil
}
func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) {
func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*models.Performer, error) {
var ret []*models.Performer
for _, name := range names {
newPerformer := *models.NewPerformer(name)
created, err := i.PerformerWriter.Create(newPerformer)
created, err := i.PerformerWriter.Create(ctx, newPerformer)
if err != nil {
return nil, err
}
@ -178,10 +188,10 @@ func (i *Importer) createPerformers(names []string) ([]*models.Performer, error)
return ret, nil
}
func (i *Importer) populateTags() error {
func (i *Importer) populateTags(ctx context.Context) error {
if len(i.Input.Tags) > 0 {
names := i.Input.Tags
tags, err := i.TagWriter.FindByNames(names, false)
tags, err := i.TagWriter.FindByNames(ctx, names, false)
if err != nil {
return err
}
@ -201,7 +211,7 @@ func (i *Importer) populateTags() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
createdTags, err := i.createTags(missingTags)
createdTags, err := i.createTags(ctx, missingTags)
if err != nil {
return fmt.Errorf("error creating gallery tags: %v", err)
}
@ -218,12 +228,12 @@ func (i *Importer) populateTags() error {
return nil
}
func (i *Importer) createTags(names []string) ([]*models.Tag, error) {
func (i *Importer) createTags(ctx context.Context, names []string) ([]*models.Tag, error) {
var ret []*models.Tag
for _, name := range names {
newTag := *models.NewTag(name)
created, err := i.TagWriter.Create(newTag)
created, err := i.TagWriter.Create(ctx, newTag)
if err != nil {
return nil, err
}
@ -234,14 +244,14 @@ func (i *Importer) createTags(names []string) ([]*models.Tag, error) {
return ret, nil
}
func (i *Importer) PostImport(id int) error {
func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.performers) > 0 {
var performerIDs []int
for _, performer := range i.performers {
performerIDs = append(performerIDs, performer.ID)
}
if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil {
if err := i.ReaderWriter.UpdatePerformers(ctx, id, performerIDs); err != nil {
return fmt.Errorf("failed to associate performers: %v", err)
}
}
@ -251,7 +261,7 @@ func (i *Importer) PostImport(id int) error {
for _, t := range i.tags {
tagIDs = append(tagIDs, t.ID)
}
if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil {
if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %v", err)
}
}
@ -263,8 +273,8 @@ func (i *Importer) Name() string {
return i.Input.Path
}
func (i *Importer) FindExistingID() (*int, error) {
existing, err := i.ReaderWriter.FindByChecksum(i.Input.Checksum)
func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
existing, err := i.ReaderWriter.FindByChecksum(ctx, i.Input.Checksum)
if err != nil {
return nil, err
}
@ -277,8 +287,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil
}
func (i *Importer) Create() (*int, error) {
created, err := i.ReaderWriter.Create(i.gallery)
func (i *Importer) Create(ctx context.Context) (*int, error) {
created, err := i.ReaderWriter.Create(ctx, i.gallery)
if err != nil {
return nil, fmt.Errorf("error creating gallery: %v", err)
}
@ -287,10 +297,10 @@ func (i *Importer) Create() (*int, error) {
return &id, nil
}
func (i *Importer) Update(id int) error {
func (i *Importer) Update(ctx context.Context, id int) error {
gallery := i.gallery
gallery.ID = id
_, err := i.ReaderWriter.Update(gallery)
_, err := i.ReaderWriter.Update(ctx, gallery)
if err != nil {
return fmt.Errorf("error updating existing gallery: %v", err)
}

View File

@ -1,6 +1,7 @@
package gallery
import (
"context"
"errors"
"testing"
"time"
@ -40,6 +41,8 @@ const (
errChecksum = "errChecksum"
)
var testCtx = context.Background()
var (
createdAt = time.Date(2001, time.January, 2, 1, 2, 3, 4, time.Local)
updatedAt = time.Date(2002, time.January, 2, 1, 2, 3, 4, time.Local)
@ -75,7 +78,7 @@ func TestImporterPreImport(t *testing.T) {
},
}
err := i.PreImport()
err := i.PreImport(testCtx)
assert.Nil(t, err)
expectedGallery := models.Gallery{
@ -112,17 +115,17 @@ func TestImporterPreImportWithStudio(t *testing.T) {
},
}
studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport()
err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64)
i.Input.Studio = existingStudioErr
err = i.PreImport()
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
@ -140,20 +143,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{
ID: existingStudioID,
}, nil)
err := i.PreImport()
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport()
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport()
err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64)
@ -172,10 +175,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
err := i.PreImport()
err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@ -193,20 +196,20 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: models.NullString(existingPerformerName),
},
}, nil).Once()
performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport()
err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID)
i.Input.Performers = []string{existingPerformerErr}
err = i.PreImport()
err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
@ -226,20 +229,20 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(&models.Performer{
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{
ID: existingPerformerID,
}, nil)
err := i.PreImport()
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport()
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport()
err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID)
@ -260,10 +263,10 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
err := i.PreImport()
err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@ -281,20 +284,20 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport()
err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
i.Input.Tags = []string{existingTagErr}
err = i.PreImport()
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
@ -314,20 +317,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{
ID: existingTagID,
}, nil)
err := i.PreImport()
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport()
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport()
err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
@ -348,10 +351,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
err := i.PreImport()
err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@ -369,13 +372,13 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
updateErr := errors.New("UpdatePerformers error")
galleryReaderWriter.On("UpdatePerformers", galleryID, []int{existingPerformerID}).Return(nil).Once()
galleryReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
galleryReaderWriter.On("UpdatePerformers", testCtx, galleryID, []int{existingPerformerID}).Return(nil).Once()
galleryReaderWriter.On("UpdatePerformers", testCtx, errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
err := i.PostImport(galleryID)
err := i.PostImport(testCtx, galleryID)
assert.Nil(t, err)
err = i.PostImport(errPerformersID)
err = i.PostImport(testCtx, errPerformersID)
assert.NotNil(t, err)
galleryReaderWriter.AssertExpectations(t)
@ -395,13 +398,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
updateErr := errors.New("UpdateTags error")
galleryReaderWriter.On("UpdateTags", galleryID, []int{existingTagID}).Return(nil).Once()
galleryReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
galleryReaderWriter.On("UpdateTags", testCtx, galleryID, []int{existingTagID}).Return(nil).Once()
galleryReaderWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
err := i.PostImport(galleryID)
err := i.PostImport(testCtx, galleryID)
assert.Nil(t, err)
err = i.PostImport(errTagsID)
err = i.PostImport(testCtx, errTagsID)
assert.NotNil(t, err)
galleryReaderWriter.AssertExpectations(t)
@ -419,23 +422,23 @@ func TestImporterFindExistingID(t *testing.T) {
}
expectedErr := errors.New("FindBy* error")
readerWriter.On("FindByChecksum", missingChecksum).Return(nil, nil).Once()
readerWriter.On("FindByChecksum", checksum).Return(&models.Gallery{
readerWriter.On("FindByChecksum", testCtx, missingChecksum).Return(nil, nil).Once()
readerWriter.On("FindByChecksum", testCtx, checksum).Return(&models.Gallery{
ID: existingGalleryID,
}, nil).Once()
readerWriter.On("FindByChecksum", errChecksum).Return(nil, expectedErr).Once()
readerWriter.On("FindByChecksum", testCtx, errChecksum).Return(nil, expectedErr).Once()
id, err := i.FindExistingID()
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Checksum = checksum
id, err = i.FindExistingID()
id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingGalleryID, *id)
assert.Nil(t, err)
i.Input.Checksum = errChecksum
id, err = i.FindExistingID()
id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@ -459,17 +462,17 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
readerWriter.On("Create", gallery).Return(&models.Gallery{
readerWriter.On("Create", testCtx, gallery).Return(&models.Gallery{
ID: galleryID,
}, nil).Once()
readerWriter.On("Create", galleryErr).Return(nil, errCreate).Once()
readerWriter.On("Create", testCtx, galleryErr).Return(nil, errCreate).Once()
id, err := i.Create()
id, err := i.Create(testCtx)
assert.Equal(t, galleryID, *id)
assert.Nil(t, err)
i.gallery = galleryErr
id, err = i.Create()
id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@ -490,9 +493,9 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
gallery.ID = galleryID
readerWriter.On("Update", gallery).Return(nil, nil).Once()
readerWriter.On("Update", testCtx, gallery).Return(nil, nil).Once()
err := i.Update(galleryID)
err := i.Update(testCtx, galleryID)
assert.Nil(t, err)
readerWriter.AssertExpectations(t)

View File

@ -1,12 +1,25 @@
package gallery
import (
"context"
"strconv"
"github.com/stashapp/stash/pkg/models"
)
func CountByPerformerID(r models.GalleryReader, id int) (int, error) {
type Queryer interface {
Query(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error)
}
type CountQueryer interface {
QueryCount(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error)
}
type ChecksumsFinder interface {
FindByChecksums(ctx context.Context, checksums []string) ([]*models.Gallery, error)
}
func CountByPerformerID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{
Performers: &models.MultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@ -14,10 +27,10 @@ func CountByPerformerID(r models.GalleryReader, id int) (int, error) {
},
}
return r.QueryCount(filter, nil)
return r.QueryCount(ctx, filter, nil)
}
func CountByStudioID(r models.GalleryReader, id int) (int, error) {
func CountByStudioID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{
Studios: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@ -25,10 +38,10 @@ func CountByStudioID(r models.GalleryReader, id int) (int, error) {
},
}
return r.QueryCount(filter, nil)
return r.QueryCount(ctx, filter, nil)
}
func CountByTagID(r models.GalleryReader, id int) (int, error) {
func CountByTagID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{
Tags: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@ -36,5 +49,5 @@ func CountByTagID(r models.GalleryReader, id int) (int, error) {
},
}
return r.QueryCount(filter, nil)
return r.QueryCount(ctx, filter, nil)
}

View File

@ -14,18 +14,26 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
const mutexType = "gallery"
type FinderCreatorUpdater interface {
FindByChecksum(ctx context.Context, checksum string) (*models.Gallery, error)
Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error)
Update(ctx context.Context, updatedGallery models.Gallery) (*models.Gallery, error)
}
type Scanner struct {
file.Scanner
ImageExtensions []string
StripFileExtension bool
CaseSensitiveFs bool
TxnManager models.TransactionManager
TxnManager txn.Manager
CreatorUpdater FinderCreatorUpdater
Paths *paths.Paths
PluginCache *plugin.Cache
MutexManager *utils.MutexManager
@ -75,19 +83,19 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase
done := make(chan struct{})
scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done)
if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
// free the mutex once transaction is complete
defer close(done)
// ensure no clashes of hashes
if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum {
dupe, _ := r.Gallery().FindByChecksum(retGallery.Checksum)
dupe, _ := scanner.CreatorUpdater.FindByChecksum(ctx, retGallery.Checksum)
if dupe != nil {
return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path.String)
}
}
retGallery, err = r.Gallery().Update(*retGallery)
retGallery, err = scanner.CreatorUpdater.Update(ctx, *retGallery)
return err
}); err != nil {
return nil, false, err
@ -116,10 +124,10 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
scanner.MutexManager.Claim(mutexType, checksum, done)
defer close(done)
if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
qb := r.Gallery()
if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
qb := scanner.CreatorUpdater
g, _ = qb.FindByChecksum(checksum)
g, _ = qb.FindByChecksum(ctx, checksum)
if g != nil {
exists, _ := fsutil.FileExists(g.Path.String)
if !scanner.CaseSensitiveFs {
@ -138,7 +146,7 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
String: path,
Valid: true,
}
g, err = qb.Update(*g)
g, err = qb.Update(ctx, *g)
if err != nil {
return err
}
@ -167,7 +175,7 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
}
logger.Infof("%s doesn't exist. Creating new item...", path)
g, err = qb.Create(*g)
g, err = qb.Create(ctx, *g)
if err != nil {
return err
}

View File

@ -1,29 +1,50 @@
package gallery
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
)
func UpdateFileModTime(qb models.GalleryWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) {
return qb.UpdatePartial(models.GalleryPartial{
type PartialUpdater interface {
UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error)
}
type ImageUpdater interface {
GetImageIDs(ctx context.Context, galleryID int) ([]int, error)
UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error
}
type PerformerUpdater interface {
GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error)
UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error
}
type TagUpdater interface {
GetTagIDs(ctx context.Context, galleryID int) ([]int, error)
UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error
}
func UpdateFileModTime(ctx context.Context, qb PartialUpdater, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) {
return qb.UpdatePartial(ctx, models.GalleryPartial{
ID: id,
FileModTime: &modTime,
})
}
func AddImage(qb models.GalleryReaderWriter, galleryID int, imageID int) error {
imageIDs, err := qb.GetImageIDs(galleryID)
func AddImage(ctx context.Context, qb ImageUpdater, galleryID int, imageID int) error {
imageIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil {
return err
}
imageIDs = intslice.IntAppendUnique(imageIDs, imageID)
return qb.UpdateImages(galleryID, imageIDs)
return qb.UpdateImages(ctx, galleryID, imageIDs)
}
func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool, error) {
performerIDs, err := qb.GetPerformerIDs(id)
func AddPerformer(ctx context.Context, qb PerformerUpdater, id int, performerID int) (bool, error) {
performerIDs, err := qb.GetPerformerIDs(ctx, id)
if err != nil {
return false, err
}
@ -32,7 +53,7 @@ func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool,
performerIDs = intslice.IntAppendUnique(performerIDs, performerID)
if len(performerIDs) != oldLen {
if err := qb.UpdatePerformers(id, performerIDs); err != nil {
if err := qb.UpdatePerformers(ctx, id, performerIDs); err != nil {
return false, err
}
@ -42,8 +63,8 @@ func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool,
return false, nil
}
func AddTag(qb models.GalleryReaderWriter, id int, tagID int) (bool, error) {
tagIDs, err := qb.GetTagIDs(id)
func AddTag(ctx context.Context, qb TagUpdater, id int, tagID int) (bool, error) {
tagIDs, err := qb.GetTagIDs(ctx, id)
if err != nil {
return false, err
}
@ -52,7 +73,7 @@ func AddTag(qb models.GalleryReaderWriter, id int, tagID int) (bool, error) {
tagIDs = intslice.IntAppendUnique(tagIDs, tagID)
if len(tagIDs) != oldLen {
if err := qb.UpdateTags(id, tagIDs); err != nil {
if err := qb.UpdateTags(ctx, id, tagIDs); err != nil {
return false, err
}

View File

@ -1,6 +1,8 @@
package image
import (
"context"
"github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/models"
@ -8,7 +10,7 @@ import (
)
type Destroyer interface {
Destroy(id int) error
Destroy(ctx context.Context, id int) error
}
// FileDeleter is an extension of file.Deleter that handles deletion of image files.
@ -30,7 +32,7 @@ func (d *FileDeleter) MarkGeneratedFiles(image *models.Image) error {
}
// Destroy destroys an image, optionally marking the file and generated files for deletion.
func Destroy(i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
func Destroy(ctx context.Context, i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
// don't try to delete if the image is in a zip file
if deleteFile && !file.IsZipPath(i.Path) {
if err := fileDeleter.Files([]string{i.Path}); err != nil {
@ -44,5 +46,5 @@ func Destroy(i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, del
}
}
return destroyer.Destroy(i.ID)
return destroyer.Destroy(ctx, i.ID)
}

View File

@ -1,9 +1,12 @@
package image
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/json"
"github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/studio"
)
// ToBasicJSON converts a image object into its JSON object equivalent. It
@ -56,9 +59,9 @@ func getImageFileJSON(image *models.Image) *jsonschema.ImageFile {
// GetStudioName returns the name of the provided image's studio. It returns an
// empty string if there is no studio assigned to the image.
func GetStudioName(reader models.StudioReader, image *models.Image) (string, error) {
func GetStudioName(ctx context.Context, reader studio.Finder, image *models.Image) (string, error) {
if image.StudioID.Valid {
studio, err := reader.Find(int(image.StudioID.Int64))
studio, err := reader.Find(ctx, int(image.StudioID.Int64))
if err != nil {
return "", err
}

Some files were not shown because too many files have changed in this diff Show More