diff --git a/internal/api/resolver.go b/internal/api/resolver.go index 8da246e26..50adea9ad 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -11,7 +11,7 @@ import ( "github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper/stashbox" ) @@ -29,7 +29,7 @@ var ( ) type hookExecutor interface { - ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string) + ExecutePostHooks(ctx context.Context, id int, hookType hook.TriggerEnum, input interface{}, inputFields []string) } type Resolver struct { diff --git a/internal/api/resolver_mutation_gallery.go b/internal/api/resolver_mutation_gallery.go index a72531152..2df6f1b77 100644 --- a/internal/api/resolver_mutation_gallery.go +++ b/internal/api/resolver_mutation_gallery.go @@ -13,6 +13,7 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/utils" ) @@ -90,7 +91,7 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, newGallery.ID, hook.GalleryCreatePost, input, nil) return r.getGallery(ctx, newGallery.ID) } @@ -108,7 +109,7 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle } // execute post hooks outside txn - r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.GalleryUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, ret.ID, hook.GalleryUpdatePost, input, translator.getFields()) return r.getGallery(ctx, ret.ID) } @@ -142,7 +143,7 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models. inputMap: inputMaps[i], } - r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, hook.GalleryUpdatePost, input, translator.getFields()) gallery, err = r.getGallery(ctx, gallery.ID) if err != nil { @@ -313,7 +314,7 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall // execute post hooks outside of txn var newRet []*models.Gallery for _, gallery := range ret { - r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, hook.GalleryUpdatePost, input, translator.getFields()) gallery, err := r.getGallery(ctx, gallery.ID) if err != nil { @@ -386,9 +387,9 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall } } - // call post hook after performing the other actions + // call post hook after performing the other actionsa for _, gallery := range galleries { - r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{ + r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, hook.GalleryDestroyPost, plugin.GalleryDestroyInput{ GalleryDestroyInput: input, Checksum: gallery.PrimaryChecksum(), Path: gallery.Path, @@ -397,7 +398,7 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall // call image destroy post hook as well for _, img := range imgsDestroyed { - r.hookExecutor.ExecutePostHooks(ctx, img.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{ + r.hookExecutor.ExecutePostHooks(ctx, img.ID, hook.ImageDestroyPost, plugin.ImageDestroyInput{ Checksum: img.Checksum, Path: img.Path, }, nil) @@ -518,7 +519,7 @@ func (r *mutationResolver) GalleryChapterCreate(ctx context.Context, input Galle return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, newChapter.ID, plugin.GalleryChapterCreatePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, newChapter.ID, hook.GalleryChapterCreatePost, input, nil) return r.getGalleryChapter(ctx, newChapter.ID) } @@ -584,7 +585,7 @@ func (r *mutationResolver) GalleryChapterUpdate(ctx context.Context, input Galle return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, chapterID, plugin.GalleryChapterUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, chapterID, hook.GalleryChapterUpdatePost, input, translator.getFields()) return r.getGalleryChapter(ctx, chapterID) } @@ -612,7 +613,7 @@ func (r *mutationResolver) GalleryChapterDestroy(ctx context.Context, id string) return false, err } - r.hookExecutor.ExecutePostHooks(ctx, chapterID, plugin.GalleryChapterDestroyPost, id, nil) + r.hookExecutor.ExecutePostHooks(ctx, chapterID, hook.GalleryChapterDestroyPost, id, nil) return true, nil } diff --git a/internal/api/resolver_mutation_image.go b/internal/api/resolver_mutation_image.go index 0cd5d3487..fc337c464 100644 --- a/internal/api/resolver_mutation_image.go +++ b/internal/api/resolver_mutation_image.go @@ -10,6 +10,7 @@ import ( "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil" "github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/utils" @@ -41,7 +42,7 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input ImageUpdateInp } // execute post hooks outside txn - r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.ImageUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, ret.ID, hook.ImageUpdatePost, input, translator.getFields()) return r.getImage(ctx, ret.ID) } @@ -75,7 +76,7 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat inputMap: inputMaps[i], } - r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, image.ID, hook.ImageUpdatePost, input, translator.getFields()) image, err = r.getImage(ctx, image.ID) if err != nil { @@ -288,7 +289,7 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU // execute post hooks outside of txn var newRet []*models.Image for _, image := range ret { - r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, image.ID, hook.ImageUpdatePost, input, translator.getFields()) image, err = r.getImage(ctx, image.ID) if err != nil { @@ -332,7 +333,7 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD fileDeleter.Commit() // call post hook after performing the other actions - r.hookExecutor.ExecutePostHooks(ctx, i.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{ + r.hookExecutor.ExecutePostHooks(ctx, i.ID, hook.ImageDestroyPost, plugin.ImageDestroyInput{ ImageDestroyInput: input, Checksum: i.Checksum, Path: i.Path, @@ -383,7 +384,7 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image for _, image := range images { // call post hook after performing the other actions - r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, plugin.ImagesDestroyInput{ + r.hookExecutor.ExecutePostHooks(ctx, image.ID, hook.ImageDestroyPost, plugin.ImagesDestroyInput{ ImagesDestroyInput: input, Checksum: image.Checksum, Path: image.Path, diff --git a/internal/api/resolver_mutation_movie.go b/internal/api/resolver_mutation_movie.go index 227c3eaa7..cb4474654 100644 --- a/internal/api/resolver_mutation_movie.go +++ b/internal/api/resolver_mutation_movie.go @@ -7,7 +7,7 @@ import ( "github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/utils" ) @@ -102,7 +102,7 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, newMovie.ID, plugin.MovieCreatePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, newMovie.ID, hook.MovieCreatePost, input, nil) return r.getMovie(ctx, newMovie.ID) } @@ -181,7 +181,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, movie.ID, hook.MovieUpdatePost, input, translator.getFields()) return r.getMovie(ctx, movie.ID) } @@ -227,7 +227,7 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU var newRet []*models.Movie for _, movie := range ret { - r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, movie.ID, hook.MovieUpdatePost, input, translator.getFields()) movie, err = r.getMovie(ctx, movie.ID) if err != nil { @@ -252,7 +252,7 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input MovieDestroyI return false, err } - r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, id, hook.MovieDestroyPost, input, nil) return true, nil } @@ -277,7 +277,7 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string) } for _, id := range ids { - r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, movieIDs, nil) + r.hookExecutor.ExecutePostHooks(ctx, id, hook.MovieDestroyPost, movieIDs, nil) } return true, nil diff --git a/internal/api/resolver_mutation_performer.go b/internal/api/resolver_mutation_performer.go index 17c87dd15..202778e74 100644 --- a/internal/api/resolver_mutation_performer.go +++ b/internal/api/resolver_mutation_performer.go @@ -7,7 +7,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/performer" - "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/utils" ) @@ -108,7 +108,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, newPerformer.ID, plugin.PerformerCreatePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, newPerformer.ID, hook.PerformerCreatePost, input, nil) return r.getPerformer(ctx, newPerformer.ID) } @@ -207,7 +207,7 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, performerID, plugin.PerformerUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, performerID, hook.PerformerUpdatePost, input, translator.getFields()) return r.getPerformer(ctx, performerID) } @@ -297,7 +297,7 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe // execute post hooks outside of txn var newRet []*models.Performer for _, performer := range ret { - r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.PerformerUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, performer.ID, hook.PerformerUpdatePost, input, translator.getFields()) performer, err = r.getPerformer(ctx, performer.ID) if err != nil { @@ -322,7 +322,7 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input Performer return false, err } - r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, id, hook.PerformerDestroyPost, input, nil) return true, nil } @@ -347,7 +347,7 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [ } for _, id := range ids { - r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, performerIDs, nil) + r.hookExecutor.ExecutePostHooks(ctx, id, hook.PerformerDestroyPost, performerIDs, nil) } return true, nil diff --git a/internal/api/resolver_mutation_scene.go b/internal/api/resolver_mutation_scene.go index cc530815c..15bf45147 100644 --- a/internal/api/resolver_mutation_scene.go +++ b/internal/api/resolver_mutation_scene.go @@ -12,6 +12,7 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/sliceutil" "github.com/stashapp/stash/pkg/sliceutil/stringslice" @@ -116,7 +117,7 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, ret.ID, hook.SceneUpdatePost, input, translator.getFields()) return r.getScene(ctx, ret.ID) } @@ -150,7 +151,7 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce inputMap: inputMaps[i], } - r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, scene.ID, hook.SceneUpdatePost, input, translator.getFields()) scene, err = r.getScene(ctx, scene.ID) if err != nil { @@ -385,7 +386,7 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU // execute post hooks outside of txn var newRet []*models.Scene for _, scene := range ret { - r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, scene.ID, hook.SceneUpdatePost, input, translator.getFields()) scene, err = r.getScene(ctx, scene.ID) if err != nil { @@ -441,7 +442,7 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD fileDeleter.Commit() // call post hook after performing the other actions - r.hookExecutor.ExecutePostHooks(ctx, s.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{ + r.hookExecutor.ExecutePostHooks(ctx, s.ID, hook.SceneDestroyPost, plugin.SceneDestroyInput{ SceneDestroyInput: input, Checksum: s.Checksum, OSHash: s.OSHash, @@ -502,7 +503,7 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene for _, scene := range scenes { // call post hook after performing the other actions - r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.ScenesDestroyInput{ + r.hookExecutor.ExecutePostHooks(ctx, scene.ID, hook.SceneDestroyPost, plugin.ScenesDestroyInput{ ScenesDestroyInput: input, Checksum: scene.Checksum, OSHash: scene.OSHash, @@ -653,7 +654,7 @@ func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input SceneMar return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, newMarker.ID, plugin.SceneMarkerCreatePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, newMarker.ID, hook.SceneMarkerCreatePost, input, nil) return r.getSceneMarker(ctx, newMarker.ID) } @@ -751,7 +752,7 @@ func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input SceneMar // perform the post-commit actions fileDeleter.Commit() - r.hookExecutor.ExecutePostHooks(ctx, markerID, plugin.SceneMarkerUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, markerID, hook.SceneMarkerUpdatePost, input, translator.getFields()) return r.getSceneMarker(ctx, markerID) } @@ -801,7 +802,7 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b // perform the post-commit actions fileDeleter.Commit() - r.hookExecutor.ExecutePostHooks(ctx, markerID, plugin.SceneMarkerDestroyPost, id, nil) + r.hookExecutor.ExecutePostHooks(ctx, markerID, hook.SceneMarkerDestroyPost, id, nil) return true, nil } diff --git a/internal/api/resolver_mutation_studio.go b/internal/api/resolver_mutation_studio.go index c41efe9ff..74b6e6c20 100644 --- a/internal/api/resolver_mutation_studio.go +++ b/internal/api/resolver_mutation_studio.go @@ -6,7 +6,7 @@ import ( "strconv" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/studio" "github.com/stashapp/stash/pkg/utils" @@ -81,7 +81,7 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, newStudio.ID, plugin.StudioCreatePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, newStudio.ID, hook.StudioCreatePost, input, nil) return r.getStudio(ctx, newStudio.ID) } @@ -147,7 +147,7 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, studioID, plugin.StudioUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, studioID, hook.StudioUpdatePost, input, translator.getFields()) return r.getStudio(ctx, studioID) } @@ -163,7 +163,7 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestro return false, err } - r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, id, hook.StudioDestroyPost, input, nil) return true, nil } @@ -188,7 +188,7 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin } for _, id := range ids { - r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, studioIDs, nil) + r.hookExecutor.ExecutePostHooks(ctx, id, hook.StudioDestroyPost, studioIDs, nil) } return true, nil diff --git a/internal/api/resolver_mutation_tag.go b/internal/api/resolver_mutation_tag.go index cec4a7772..d3a522e55 100644 --- a/internal/api/resolver_mutation_tag.go +++ b/internal/api/resolver_mutation_tag.go @@ -7,7 +7,7 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/tag" "github.com/stashapp/stash/pkg/utils" @@ -119,7 +119,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, newTag.ID, plugin.TagCreatePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, newTag.ID, hook.TagCreatePost, input, nil) return r.getTag(ctx, newTag.ID) } @@ -235,7 +235,7 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagUpdatePost, input, translator.getFields()) + r.hookExecutor.ExecutePostHooks(ctx, t.ID, hook.TagUpdatePost, input, translator.getFields()) return r.getTag(ctx, t.ID) } @@ -251,7 +251,7 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input TagDestroyInput return false, err } - r.hookExecutor.ExecutePostHooks(ctx, tagID, plugin.TagDestroyPost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, tagID, hook.TagDestroyPost, input, nil) return true, nil } @@ -276,7 +276,7 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo } for _, id := range ids { - r.hookExecutor.ExecutePostHooks(ctx, id, plugin.TagDestroyPost, tagIDs, nil) + r.hookExecutor.ExecutePostHooks(ctx, id, hook.TagDestroyPost, tagIDs, nil) } return true, nil @@ -340,7 +340,7 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput) return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagMergePost, input, nil) + r.hookExecutor.ExecutePostHooks(ctx, t.ID, hook.TagMergePost, input, nil) return t, nil } diff --git a/internal/api/resolver_mutation_tag_test.go b/internal/api/resolver_mutation_tag_test.go index 5f94e1e91..836eb4648 100644 --- a/internal/api/resolver_mutation_tag_test.go +++ b/internal/api/resolver_mutation_tag_test.go @@ -7,7 +7,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -35,7 +35,7 @@ var testCtx = context.Background() type mockHookExecutor struct{} -func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string) { +func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType hook.TriggerEnum, input interface{}, inputFields []string) { } func TestTagCreate(t *testing.T) { diff --git a/internal/manager/task_clean.go b/internal/manager/task_clean.go index d33ac1609..3b9227549 100644 --- a/internal/manager/task_clean.go +++ b/internal/manager/task_clean.go @@ -15,6 +15,7 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/scene" ) @@ -129,7 +130,7 @@ func (j *cleanJob) deleteGallery(ctx context.Context, id int) { return err } - pluginCache.RegisterPostHooks(ctx, id, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{ + pluginCache.RegisterPostHooks(ctx, id, hook.GalleryDestroyPost, plugin.GalleryDestroyInput{ Checksum: g.PrimaryChecksum(), Path: g.Path, }, nil) @@ -302,7 +303,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil return err } - mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{ + mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, hook.SceneDestroyPost, plugin.SceneDestroyInput{ Checksum: scene.Checksum, OSHash: scene.OSHash, Path: scene.Path, @@ -349,7 +350,7 @@ func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models return err } - mgr.PluginCache.RegisterPostHooks(ctx, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{ + mgr.PluginCache.RegisterPostHooks(ctx, g.ID, hook.GalleryDestroyPost, plugin.GalleryDestroyInput{ Checksum: g.PrimaryChecksum(), Path: g.Path, }, nil) @@ -389,7 +390,7 @@ func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderI return err } - mgr.PluginCache.RegisterPostHooks(ctx, g.ID, plugin.GalleryDestroyPost, plugin.GalleryDestroyInput{ + mgr.PluginCache.RegisterPostHooks(ctx, g.ID, hook.GalleryDestroyPost, plugin.GalleryDestroyInput{ // No checksum for folders // Checksum: g.Checksum(), Path: g.Path, @@ -423,7 +424,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil return err } - mgr.PluginCache.RegisterPostHooks(ctx, i.ID, plugin.ImageDestroyPost, plugin.ImageDestroyInput{ + mgr.PluginCache.RegisterPostHooks(ctx, i.ID, hook.ImageDestroyPost, plugin.ImageDestroyInput{ Checksum: i.Checksum, Path: i.Path, }, nil) diff --git a/pkg/gallery/scan.go b/pkg/gallery/scan.go index f4a9adcc5..9d0313b17 100644 --- a/pkg/gallery/scan.go +++ b/pkg/gallery/scan.go @@ -9,6 +9,7 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" ) type ScanCreatorUpdater interface { @@ -83,7 +84,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f models.File, oldFile models. return fmt.Errorf("creating new gallery: %w", err) } - h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, hook.GalleryCreatePost, nil, nil) // associate all the images in the zip file with the gallery for _, i := range images { @@ -138,7 +139,7 @@ func (h *ScanHandler) associateExisting(ctx context.Context, existing []*models. } if !found || updateExisting { - h.PluginCache.RegisterPostHooks(ctx, i.ID, plugin.GalleryUpdatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, i.ID, hook.GalleryUpdatePost, nil, nil) } } diff --git a/pkg/image/scan.go b/pkg/image/scan.go index 2311ccc94..b388a8145 100644 --- a/pkg/image/scan.go +++ b/pkg/image/scan.go @@ -11,6 +11,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/sliceutil" "github.com/stashapp/stash/pkg/txn" ) @@ -137,7 +138,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f models.File, oldFile models. } } - h.PluginCache.RegisterPostHooks(ctx, newImage.ID, plugin.ImageCreatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, newImage.ID, hook.ImageCreatePost, nil, nil) existing = []*models.Image{&newImage} } @@ -228,7 +229,7 @@ func (h *ScanHandler) associateExisting(ctx context.Context, existing []*models. } if changed || updateExisting { - h.PluginCache.RegisterPostHooks(ctx, i.ID, plugin.ImageUpdatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, i.ID, hook.ImageUpdatePost, nil, nil) } } @@ -257,7 +258,7 @@ func (h *ScanHandler) getOrCreateFolderBasedGallery(ctx context.Context, f model return nil, fmt.Errorf("creating folder based gallery: %w", err) } - h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, hook.GalleryCreatePost, nil, nil) // it's possible that there are other images in the folder that // need to be added to the new gallery. Find and add them now. @@ -311,7 +312,7 @@ func (h *ScanHandler) getOrCreateZipBasedGallery(ctx context.Context, zipFile mo return nil, fmt.Errorf("creating zip-based gallery: %w", err) } - h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, plugin.GalleryCreatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, newGallery.ID, hook.GalleryCreatePost, nil, nil) return &newGallery, nil } diff --git a/pkg/plugin/config.go b/pkg/plugin/config.go index 88e7e7324..20221e029 100644 --- a/pkg/plugin/config.go +++ b/pkg/plugin/config.go @@ -9,6 +9,7 @@ import ( "sort" "strings" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/utils" "gopkg.in/yaml.v2" ) @@ -195,7 +196,7 @@ func (c Config) getPluginHooks(includePlugin bool) []*PluginHook { return ret } -func convertHooks(hooks []HookTriggerEnum) []string { +func convertHooks(hooks []hook.TriggerEnum) []string { var ret []string for _, h := range hooks { ret = append(ret, h.String()) @@ -275,7 +276,7 @@ func (c Config) getTask(name string) *OperationConfig { return nil } -func (c Config) getHooks(hookType HookTriggerEnum) []*HookConfig { +func (c Config) getHooks(hookType hook.TriggerEnum) []*HookConfig { var ret []*HookConfig for _, h := range c.Hooks { for _, t := range h.TriggeredBy { @@ -399,7 +400,7 @@ type HookConfig struct { OperationConfig `yaml:",inline"` // A list of stash operations that will be used to trigger this hook operation. - TriggeredBy []HookTriggerEnum `yaml:"triggeredBy"` + TriggeredBy []hook.TriggerEnum `yaml:"triggeredBy"` } func loadPluginFromYAML(reader io.Reader) (*Config, error) { diff --git a/pkg/plugin/hook/hooks.go b/pkg/plugin/hook/hooks.go new file mode 100644 index 000000000..1b7d93be4 --- /dev/null +++ b/pkg/plugin/hook/hooks.go @@ -0,0 +1,131 @@ +package hook + +type TriggerEnum string + +// Scan-related hooks are current disabled until post-hook execution is +// integrated. + +const ( + SceneMarkerCreatePost TriggerEnum = "SceneMarker.Create.Post" + SceneMarkerUpdatePost TriggerEnum = "SceneMarker.Update.Post" + SceneMarkerDestroyPost TriggerEnum = "SceneMarker.Destroy.Post" + + SceneCreatePost TriggerEnum = "Scene.Create.Post" + SceneUpdatePost TriggerEnum = "Scene.Update.Post" + SceneDestroyPost TriggerEnum = "Scene.Destroy.Post" + + ImageCreatePost TriggerEnum = "Image.Create.Post" + ImageUpdatePost TriggerEnum = "Image.Update.Post" + ImageDestroyPost TriggerEnum = "Image.Destroy.Post" + + GalleryCreatePost TriggerEnum = "Gallery.Create.Post" + GalleryUpdatePost TriggerEnum = "Gallery.Update.Post" + GalleryDestroyPost TriggerEnum = "Gallery.Destroy.Post" + + GalleryChapterCreatePost TriggerEnum = "GalleryChapter.Create.Post" + GalleryChapterUpdatePost TriggerEnum = "GalleryChapter.Update.Post" + GalleryChapterDestroyPost TriggerEnum = "GalleryChapter.Destroy.Post" + + MovieCreatePost TriggerEnum = "Movie.Create.Post" + MovieUpdatePost TriggerEnum = "Movie.Update.Post" + MovieDestroyPost TriggerEnum = "Movie.Destroy.Post" + + PerformerCreatePost TriggerEnum = "Performer.Create.Post" + PerformerUpdatePost TriggerEnum = "Performer.Update.Post" + PerformerDestroyPost TriggerEnum = "Performer.Destroy.Post" + + StudioCreatePost TriggerEnum = "Studio.Create.Post" + StudioUpdatePost TriggerEnum = "Studio.Update.Post" + StudioDestroyPost TriggerEnum = "Studio.Destroy.Post" + + TagCreatePost TriggerEnum = "Tag.Create.Post" + TagUpdatePost TriggerEnum = "Tag.Update.Post" + TagMergePost TriggerEnum = "Tag.Merge.Post" + TagDestroyPost TriggerEnum = "Tag.Destroy.Post" +) + +var AllHookTriggerEnum = []TriggerEnum{ + SceneMarkerCreatePost, + SceneMarkerUpdatePost, + SceneMarkerDestroyPost, + + SceneCreatePost, + SceneUpdatePost, + SceneDestroyPost, + + ImageCreatePost, + ImageUpdatePost, + ImageDestroyPost, + + GalleryCreatePost, + GalleryUpdatePost, + GalleryDestroyPost, + + GalleryChapterCreatePost, + GalleryChapterUpdatePost, + GalleryChapterDestroyPost, + + MovieCreatePost, + MovieUpdatePost, + MovieDestroyPost, + + PerformerCreatePost, + PerformerUpdatePost, + PerformerDestroyPost, + + StudioCreatePost, + StudioUpdatePost, + StudioDestroyPost, + + TagCreatePost, + TagUpdatePost, + TagMergePost, + TagDestroyPost, +} + +func (e TriggerEnum) IsValid() bool { + + switch e { + case SceneMarkerCreatePost, + SceneMarkerUpdatePost, + SceneMarkerDestroyPost, + + SceneCreatePost, + SceneUpdatePost, + SceneDestroyPost, + + ImageCreatePost, + ImageUpdatePost, + ImageDestroyPost, + + GalleryCreatePost, + GalleryUpdatePost, + GalleryDestroyPost, + + GalleryChapterCreatePost, + GalleryChapterUpdatePost, + GalleryChapterDestroyPost, + + MovieCreatePost, + MovieUpdatePost, + MovieDestroyPost, + + PerformerCreatePost, + PerformerUpdatePost, + PerformerDestroyPost, + + StudioCreatePost, + StudioUpdatePost, + StudioDestroyPost, + + TagCreatePost, + TagUpdatePost, + TagDestroyPost: + return true + } + return false +} + +func (e TriggerEnum) String() string { + return string(e) +} diff --git a/pkg/plugin/hooks.go b/pkg/plugin/hooks.go index fc91765b8..1a40c52f1 100644 --- a/pkg/plugin/hooks.go +++ b/pkg/plugin/hooks.go @@ -12,136 +12,6 @@ type PluginHook struct { Plugin *Plugin `json:"plugin"` } -type HookTriggerEnum string - -// Scan-related hooks are current disabled until post-hook execution is -// integrated. - -const ( - SceneMarkerCreatePost HookTriggerEnum = "SceneMarker.Create.Post" - SceneMarkerUpdatePost HookTriggerEnum = "SceneMarker.Update.Post" - SceneMarkerDestroyPost HookTriggerEnum = "SceneMarker.Destroy.Post" - - SceneCreatePost HookTriggerEnum = "Scene.Create.Post" - SceneUpdatePost HookTriggerEnum = "Scene.Update.Post" - SceneDestroyPost HookTriggerEnum = "Scene.Destroy.Post" - - ImageCreatePost HookTriggerEnum = "Image.Create.Post" - ImageUpdatePost HookTriggerEnum = "Image.Update.Post" - ImageDestroyPost HookTriggerEnum = "Image.Destroy.Post" - - GalleryCreatePost HookTriggerEnum = "Gallery.Create.Post" - GalleryUpdatePost HookTriggerEnum = "Gallery.Update.Post" - GalleryDestroyPost HookTriggerEnum = "Gallery.Destroy.Post" - - GalleryChapterCreatePost HookTriggerEnum = "GalleryChapter.Create.Post" - GalleryChapterUpdatePost HookTriggerEnum = "GalleryChapter.Update.Post" - GalleryChapterDestroyPost HookTriggerEnum = "GalleryChapter.Destroy.Post" - - MovieCreatePost HookTriggerEnum = "Movie.Create.Post" - MovieUpdatePost HookTriggerEnum = "Movie.Update.Post" - MovieDestroyPost HookTriggerEnum = "Movie.Destroy.Post" - - PerformerCreatePost HookTriggerEnum = "Performer.Create.Post" - PerformerUpdatePost HookTriggerEnum = "Performer.Update.Post" - PerformerDestroyPost HookTriggerEnum = "Performer.Destroy.Post" - - StudioCreatePost HookTriggerEnum = "Studio.Create.Post" - StudioUpdatePost HookTriggerEnum = "Studio.Update.Post" - StudioDestroyPost HookTriggerEnum = "Studio.Destroy.Post" - - TagCreatePost HookTriggerEnum = "Tag.Create.Post" - TagUpdatePost HookTriggerEnum = "Tag.Update.Post" - TagMergePost HookTriggerEnum = "Tag.Merge.Post" - TagDestroyPost HookTriggerEnum = "Tag.Destroy.Post" -) - -var AllHookTriggerEnum = []HookTriggerEnum{ - SceneMarkerCreatePost, - SceneMarkerUpdatePost, - SceneMarkerDestroyPost, - - SceneCreatePost, - SceneUpdatePost, - SceneDestroyPost, - - ImageCreatePost, - ImageUpdatePost, - ImageDestroyPost, - - GalleryCreatePost, - GalleryUpdatePost, - GalleryDestroyPost, - - GalleryChapterCreatePost, - GalleryChapterUpdatePost, - GalleryChapterDestroyPost, - - MovieCreatePost, - MovieUpdatePost, - MovieDestroyPost, - - PerformerCreatePost, - PerformerUpdatePost, - PerformerDestroyPost, - - StudioCreatePost, - StudioUpdatePost, - StudioDestroyPost, - - TagCreatePost, - TagUpdatePost, - TagMergePost, - TagDestroyPost, -} - -func (e HookTriggerEnum) IsValid() bool { - - switch e { - case SceneMarkerCreatePost, - SceneMarkerUpdatePost, - SceneMarkerDestroyPost, - - SceneCreatePost, - SceneUpdatePost, - SceneDestroyPost, - - ImageCreatePost, - ImageUpdatePost, - ImageDestroyPost, - - GalleryCreatePost, - GalleryUpdatePost, - GalleryDestroyPost, - - GalleryChapterCreatePost, - GalleryChapterUpdatePost, - GalleryChapterDestroyPost, - - MovieCreatePost, - MovieUpdatePost, - MovieDestroyPost, - - PerformerCreatePost, - PerformerUpdatePost, - PerformerDestroyPost, - - StudioCreatePost, - StudioUpdatePost, - StudioDestroyPost, - - TagCreatePost, - TagUpdatePost, - TagDestroyPost: - return true - } - return false -} - -func (e HookTriggerEnum) String() string { - return string(e) -} - func addHookContext(argsMap common.ArgsMap, hookContext common.HookContext) { argsMap[common.HookContextKey] = hookContext } diff --git a/pkg/plugin/plugins.go b/pkg/plugin/plugins.go index dfdc05af2..0a44ca000 100644 --- a/pkg/plugin/plugins.go +++ b/pkg/plugin/plugins.go @@ -21,6 +21,7 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin/common" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/session" "github.com/stashapp/stash/pkg/sliceutil" "github.com/stashapp/stash/pkg/txn" @@ -356,7 +357,7 @@ func waitForTask(ctx context.Context, task Task) error { return nil } -func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) { +func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType hook.TriggerEnum, input interface{}, inputFields []string) { if err := c.executePostHooks(ctx, hookType, common.HookContext{ ID: id, Type: hookType.String(), @@ -367,7 +368,7 @@ func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTrigge } } -func (c Cache) RegisterPostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) { +func (c Cache) RegisterPostHooks(ctx context.Context, id int, hookType hook.TriggerEnum, input interface{}, inputFields []string) { txn.AddPostCommitHook(ctx, func(ctx context.Context) { c.ExecutePostHooks(ctx, id, hookType, input, inputFields) }) @@ -379,23 +380,28 @@ func (c Cache) ExecuteSceneUpdatePostHooks(ctx context.Context, input models.Sce logger.Errorf("error converting id in SceneUpdatePostHooks: %v", err) return } - c.ExecutePostHooks(ctx, id, SceneUpdatePost, input, inputFields) + c.ExecutePostHooks(ctx, id, hook.SceneUpdatePost, input, inputFields) } -func (c Cache) executePostHooks(ctx context.Context, hookType HookTriggerEnum, hookContext common.HookContext) error { - visitedPlugins := session.GetVisitedPlugins(ctx) +// maxCyclicLoopDepth is the maximum number of identical plugin hook calls that +// can be made before a cyclic loop is detected. It is set to an arbitrary value +// that should not be hit under normal circumstances. +const maxCyclicLoopDepth = 10 + +func (c Cache) executePostHooks(ctx context.Context, hookType hook.TriggerEnum, hookContext common.HookContext) error { + visitedPluginHookCounts := getVisitedPluginHookCounts(ctx) for _, p := range c.enabledPlugins() { hooks := p.getHooks(hookType) // don't revisit a plugin we've already visited // only log if there's hooks that we're skipping - if len(hooks) > 0 && sliceutil.Contains(visitedPlugins, p.id) { - logger.Debugf("plugin ID '%s' already triggered, not re-triggering", p.id) + if len(hooks) > 0 && visitedPluginHookCounts.For(p.id, hookType) >= maxCyclicLoopDepth { + logger.Debugf("cyclic loop detected: plugin ID '%s' hook %s, not re-triggering", p.id, hookType) continue } for _, h := range hooks { - newCtx := session.AddVisitedPlugin(ctx, p.id) + newCtx := session.AddVisitedPluginHook(ctx, p.id, hookType) serverConnection := c.makeServerConnection(newCtx) pluginInput := buildPluginInput(&p, &h.OperationConfig, serverConnection, nil) @@ -434,6 +440,46 @@ func (c Cache) executePostHooks(ctx context.Context, hookType HookTriggerEnum, h return nil } +type visitedPluginHookCount struct { + session.VisitedPluginHook + Count int +} + +type visitedPluginHookCounts []visitedPluginHookCount + +func (v visitedPluginHookCounts) For(pluginID string, hookType hook.TriggerEnum) int { + for _, c := range v { + if c.VisitedPluginHook.PluginID == pluginID && c.VisitedPluginHook.HookType == hookType { + return c.Count + } + } + return 0 +} + +func getVisitedPluginHookCounts(ctx context.Context) visitedPluginHookCounts { + visitedPluginHooks := session.GetVisitedPluginHooks(ctx) + + visitedPluginHookCounts := make([]visitedPluginHookCount, 0) + for _, p := range visitedPluginHooks { + found := false + for i, v := range visitedPluginHookCounts { + if v.VisitedPluginHook == p { + visitedPluginHookCounts[i].Count++ + found = true + break + } + } + if !found { + visitedPluginHookCounts = append(visitedPluginHookCounts, visitedPluginHookCount{ + VisitedPluginHook: p, + Count: 1, + }) + } + } + + return visitedPluginHookCounts +} + func (c Cache) getPlugin(pluginID string) *Config { for _, s := range c.plugins { if s.id == pluginID { diff --git a/pkg/scene/create.go b/pkg/scene/create.go index 428c636a7..cd9234b5d 100644 --- a/pkg/scene/create.go +++ b/pkg/scene/create.go @@ -7,7 +7,7 @@ import ( "time" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" ) func (s *Service) Create(ctx context.Context, input *models.Scene, fileIDs []models.FileID, coverImage []byte) (*models.Scene, error) { @@ -54,7 +54,7 @@ func (s *Service) Create(ctx context.Context, input *models.Scene, fileIDs []mod } } - s.PluginCache.RegisterPostHooks(ctx, ret.ID, plugin.SceneCreatePost, nil, nil) + s.PluginCache.RegisterPostHooks(ctx, ret.ID, hook.SceneCreatePost, nil, nil) // re-find the scene so that it correctly returns file-related fields return ret, nil diff --git a/pkg/scene/scan.go b/pkg/scene/scan.go index 821485eb9..5676f6d4f 100644 --- a/pkg/scene/scan.go +++ b/pkg/scene/scan.go @@ -10,6 +10,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/plugin" + "github.com/stashapp/stash/pkg/plugin/hook" "github.com/stashapp/stash/pkg/txn" ) @@ -107,7 +108,7 @@ func (h *ScanHandler) Handle(ctx context.Context, f models.File, oldFile models. return fmt.Errorf("creating new scene: %w", err) } - h.PluginCache.RegisterPostHooks(ctx, newScene.ID, plugin.SceneCreatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, newScene.ID, hook.SceneCreatePost, nil, nil) existing = []*models.Scene{&newScene} } @@ -164,7 +165,7 @@ func (h *ScanHandler) associateExisting(ctx context.Context, existing []*models. } if !found || updateExisting { - h.PluginCache.RegisterPostHooks(ctx, s.ID, plugin.SceneUpdatePost, nil, nil) + h.PluginCache.RegisterPostHooks(ctx, s.ID, hook.SceneUpdatePost, nil, nil) } } diff --git a/pkg/session/plugin.go b/pkg/session/plugin.go new file mode 100644 index 000000000..7a57ca4b5 --- /dev/null +++ b/pkg/session/plugin.go @@ -0,0 +1,82 @@ +package session + +import ( + "context" + "encoding/gob" + "net/http" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/plugin/hook" +) + +type VisitedPluginHook struct { + PluginID string + HookType hook.TriggerEnum +} + +func init() { + gob.Register([]VisitedPluginHook{}) +} + +func (s *Store) VisitedPluginHandler() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // get the visited plugins from the cookie and set in the context + session, err := s.sessionStore.Get(r, cookieName) + + // ignore errors + if err == nil { + val := session.Values[visitedPluginHooksKey] + + visitedPlugins, _ := val.([]VisitedPluginHook) + + ctx := setVisitedPluginHooks(r.Context(), visitedPlugins) + r = r.WithContext(ctx) + } + + next.ServeHTTP(w, r) + }) + } +} + +func GetVisitedPluginHooks(ctx context.Context) []VisitedPluginHook { + ctxVal := ctx.Value(contextVisitedPlugins) + if ctxVal != nil { + return ctxVal.([]VisitedPluginHook) + } + + return nil +} + +func AddVisitedPluginHook(ctx context.Context, pluginID string, hookType hook.TriggerEnum) context.Context { + curVal := GetVisitedPluginHooks(ctx) + curVal = append(curVal, VisitedPluginHook{PluginID: pluginID, HookType: hookType}) + return setVisitedPluginHooks(ctx, curVal) +} + +func setVisitedPluginHooks(ctx context.Context, visitedPlugins []VisitedPluginHook) context.Context { + return context.WithValue(ctx, contextVisitedPlugins, visitedPlugins) +} + +func (s *Store) MakePluginCookie(ctx context.Context) *http.Cookie { + currentUser := GetCurrentUserID(ctx) + visitedPlugins := GetVisitedPluginHooks(ctx) + + session := sessions.NewSession(s.sessionStore, cookieName) + if currentUser != nil { + session.Values[userIDKey] = *currentUser + } + + session.Values[visitedPluginHooksKey] = visitedPlugins + + encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, + s.sessionStore.Codecs...) + if err != nil { + logger.Errorf("error creating session cookie: %s", err.Error()) + return nil + } + + return sessions.NewCookie(session.Name(), encoded, session.Options) +} diff --git a/pkg/session/session.go b/pkg/session/session.go index d5218155f..285c7cc3c 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -5,10 +5,8 @@ import ( "errors" "net/http" - "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "github.com/stashapp/stash/pkg/logger" - "github.com/stashapp/stash/pkg/sliceutil" ) type key int @@ -19,8 +17,8 @@ const ( ) const ( - userIDKey = "userID" - visitedPluginsKey = "visitedPlugins" + userIDKey = "userID" + visitedPluginHooksKey = "visitedPluginsHooks" ) const ( @@ -147,67 +145,6 @@ func GetCurrentUserID(ctx context.Context) *string { return nil } -func (s *Store) VisitedPluginHandler() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // get the visited plugins from the cookie and set in the context - session, err := s.sessionStore.Get(r, cookieName) - - // ignore errors - if err == nil { - val := session.Values[visitedPluginsKey] - - visitedPlugins, _ := val.([]string) - - ctx := setVisitedPlugins(r.Context(), visitedPlugins) - r = r.WithContext(ctx) - } - - next.ServeHTTP(w, r) - }) - } -} - -func GetVisitedPlugins(ctx context.Context) []string { - ctxVal := ctx.Value(contextVisitedPlugins) - if ctxVal != nil { - return ctxVal.([]string) - } - - return nil -} - -func AddVisitedPlugin(ctx context.Context, pluginID string) context.Context { - curVal := GetVisitedPlugins(ctx) - curVal = sliceutil.AppendUnique(curVal, pluginID) - return setVisitedPlugins(ctx, curVal) -} - -func setVisitedPlugins(ctx context.Context, visitedPlugins []string) context.Context { - return context.WithValue(ctx, contextVisitedPlugins, visitedPlugins) -} - -func (s *Store) MakePluginCookie(ctx context.Context) *http.Cookie { - currentUser := GetCurrentUserID(ctx) - visitedPlugins := GetVisitedPlugins(ctx) - - session := sessions.NewSession(s.sessionStore, cookieName) - if currentUser != nil { - session.Values[userIDKey] = *currentUser - } - - session.Values[visitedPluginsKey] = visitedPlugins - - encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, - s.sessionStore.Codecs...) - if err != nil { - logger.Errorf("error creating session cookie: %s", err.Error()) - return nil - } - - return sessions.NewCookie(session.Name(), encoded, session.Options) -} - func (s *Store) Authenticate(w http.ResponseWriter, r *http.Request) (userID string, err error) { c := s.config