stash/pkg/sqlite/studio.go

715 lines
20 KiB
Go

package sqlite
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/jmoiron/sqlx"
"gopkg.in/guregu/null.v4"
"gopkg.in/guregu/null.v4/zero"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/studio"
)
const (
studioTable = "studios"
studioIDColumn = "studio_id"
studioAliasesTable = "studio_aliases"
studioAliasColumn = "alias"
studioParentIDColumn = "parent_id"
studioNameColumn = "name"
studioImageBlobColumn = "image_blob"
)
type studioRow struct {
ID int `db:"id" goqu:"skipinsert"`
Name zero.String `db:"name"`
URL zero.String `db:"url"`
ParentID null.Int `db:"parent_id,omitempty"`
CreatedAt Timestamp `db:"created_at"`
UpdatedAt Timestamp `db:"updated_at"`
// expressed as 1-100
Rating null.Int `db:"rating"`
Details zero.String `db:"details"`
IgnoreAutoTag bool `db:"ignore_auto_tag"`
// not used in resolutions or updates
ImageBlob zero.String `db:"image_blob"`
}
func (r *studioRow) fromStudio(o models.Studio) {
r.ID = o.ID
r.Name = zero.StringFrom(o.Name)
r.URL = zero.StringFrom(o.URL)
r.ParentID = intFromPtr(o.ParentID)
r.CreatedAt = Timestamp{Timestamp: o.CreatedAt}
r.UpdatedAt = Timestamp{Timestamp: o.UpdatedAt}
r.Rating = intFromPtr(o.Rating)
r.Details = zero.StringFrom(o.Details)
r.IgnoreAutoTag = o.IgnoreAutoTag
}
func (r *studioRow) resolve() *models.Studio {
ret := &models.Studio{
ID: r.ID,
Name: r.Name.String,
URL: r.URL.String,
ParentID: nullIntPtr(r.ParentID),
CreatedAt: r.CreatedAt.Timestamp,
UpdatedAt: r.UpdatedAt.Timestamp,
Rating: nullIntPtr(r.Rating),
Details: r.Details.String,
IgnoreAutoTag: r.IgnoreAutoTag,
}
return ret
}
type studioRowRecord struct {
updateRecord
}
func (r *studioRowRecord) fromPartial(o models.StudioPartial) {
r.setNullString("name", o.Name)
r.setNullString("url", o.URL)
r.setNullInt("parent_id", o.ParentID)
r.setTimestamp("created_at", o.CreatedAt)
r.setTimestamp("updated_at", o.UpdatedAt)
r.setNullInt("rating", o.Rating)
r.setNullString("details", o.Details)
r.setBool("ignore_auto_tag", o.IgnoreAutoTag)
}
type StudioStore struct {
repository
blobJoinQueryBuilder
tableMgr *table
}
func NewStudioStore(blobStore *BlobStore) *StudioStore {
return &StudioStore{
repository: repository{
tableName: studioTable,
idColumn: idColumn,
},
blobJoinQueryBuilder: blobJoinQueryBuilder{
blobStore: blobStore,
joinTable: studioTable,
},
tableMgr: studioTableMgr,
}
}
func (qb *StudioStore) table() exp.IdentifierExpression {
return qb.tableMgr.table
}
func (qb *StudioStore) selectDataset() *goqu.SelectDataset {
return dialect.From(qb.table()).Select(qb.table().All())
}
func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error {
var err error
var r studioRow
r.fromStudio(*newObject)
id, err := qb.tableMgr.insertID(ctx, r)
if err != nil {
return err
}
if newObject.Aliases.Loaded() {
if err := studio.EnsureAliasesUnique(ctx, id, newObject.Aliases.List(), qb); err != nil {
return err
}
if err := studiosAliasesTableMgr.insertJoins(ctx, id, newObject.Aliases.List()); err != nil {
return err
}
}
if newObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.insertJoins(ctx, id, newObject.StashIDs.List()); err != nil {
return err
}
}
updated, _ := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
}
*newObject = *updated
return nil
}
func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
r := studioRowRecord{
updateRecord{
Record: make(exp.Record),
},
}
r.fromPartial(input)
if len(r.Record) > 0 {
if err := qb.tableMgr.updateByID(ctx, input.ID, r.Record); err != nil {
return nil, err
}
}
if input.Aliases != nil {
if err := studio.EnsureAliasesUnique(ctx, input.ID, input.Aliases.Values, qb); err != nil {
return nil, err
}
if err := studiosAliasesTableMgr.modifyJoins(ctx, input.ID, input.Aliases.Values, input.Aliases.Mode); err != nil {
return nil, err
}
}
if input.StashIDs != nil {
if err := studiosStashIDsTableMgr.modifyJoins(ctx, input.ID, input.StashIDs.StashIDs, input.StashIDs.Mode); err != nil {
return nil, err
}
}
return qb.Find(ctx, input.ID)
}
// This is only used by the Import/Export functionality
func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error {
var r studioRow
r.fromStudio(*updatedObject)
if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil {
return err
}
if updatedObject.Aliases.Loaded() {
if err := studiosAliasesTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.Aliases.List()); err != nil {
return err
}
}
if updatedObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.StashIDs.List()); err != nil {
return err
}
}
return nil
}
func (qb *StudioStore) Destroy(ctx context.Context, id int) error {
// must handle image checksums manually
if err := qb.destroyImage(ctx, id); err != nil {
return err
}
return qb.destroyExisting(ctx, []int{id})
}
// returns nil, nil if not found
func (qb *StudioStore) Find(ctx context.Context, id int) (*models.Studio, error) {
ret, err := qb.find(ctx, id)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return ret, err
}
func (qb *StudioStore) FindMany(ctx context.Context, ids []int) ([]*models.Studio, error) {
ret := make([]*models.Studio, len(ids))
table := qb.table()
if err := batchExec(ids, defaultBatchSize, func(batch []int) error {
q := qb.selectDataset().Prepared(true).Where(table.Col(idColumn).In(batch))
unsorted, err := qb.getMany(ctx, q)
if err != nil {
return err
}
for _, s := range unsorted {
i := intslice.IntIndex(ids, s.ID)
ret[i] = s
}
return nil
}); err != nil {
return nil, err
}
for i := range ret {
if ret[i] == nil {
return nil, fmt.Errorf("studio with id %d not found", ids[i])
}
}
return ret, nil
}
// returns nil, sql.ErrNoRows if not found
func (qb *StudioStore) find(ctx context.Context, id int) (*models.Studio, error) {
q := qb.selectDataset().Where(qb.tableMgr.byID(id))
ret, err := qb.get(ctx, q)
if err != nil {
return nil, err
}
return ret, nil
}
// returns nil, sql.ErrNoRows if not found
func (qb *StudioStore) get(ctx context.Context, q *goqu.SelectDataset) (*models.Studio, error) {
ret, err := qb.getMany(ctx, q)
if err != nil {
return nil, err
}
if len(ret) == 0 {
return nil, sql.ErrNoRows
}
return ret[0], nil
}
func (qb *StudioStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*models.Studio, error) {
const single = false
var ret []*models.Studio
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
var f studioRow
if err := r.StructScan(&f); err != nil {
return err
}
s := f.resolve()
ret = append(ret, s)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (qb *StudioStore) findBySubquery(ctx context.Context, sq *goqu.SelectDataset) ([]*models.Studio, error) {
table := qb.table()
q := qb.selectDataset().Where(
table.Col(idColumn).Eq(
sq,
),
)
return qb.getMany(ctx, q)
}
func (qb *StudioStore) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
// SELECT studios.* FROM studios WHERE studios.parent_id = ?
table := qb.table()
sq := qb.selectDataset().Where(table.Col(studioParentIDColumn).Eq(id))
ret, err := qb.getMany(ctx, sq)
if err != nil {
return nil, err
}
return ret, nil
}
func (qb *StudioStore) FindBySceneID(ctx context.Context, sceneID int) (*models.Studio, error) {
// SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1
table := qb.table()
scenes := sceneTableMgr.table
sq := qb.selectDataset().Join(
scenes, goqu.On(table.Col(idColumn), scenes.Col(studioIDColumn)),
).Where(
scenes.Col(idColumn),
).Limit(1)
ret, err := qb.get(ctx, sq)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
return ret, nil
}
func (qb *StudioStore) FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error) {
// query := "SELECT * FROM studios WHERE name = ?"
// if nocase {
// query += " COLLATE NOCASE"
// }
// query += " LIMIT 1"
where := "name = ?"
if nocase {
where += " COLLATE NOCASE"
}
sq := qb.selectDataset().Prepared(true).Where(goqu.L(where, name)).Limit(1)
ret, err := qb.get(ctx, sq)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
return ret, nil
}
func (qb *StudioStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) {
sq := dialect.From(studiosStashIDsJoinTable).Select(studiosStashIDsJoinTable.Col(studioIDColumn)).Where(
studiosStashIDsJoinTable.Col("stash_id").Eq(stashID.StashID),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashID.Endpoint),
)
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash ID %s: %w", stashID.StashID, err)
}
return ret, nil
}
func (qb *StudioStore) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
table := qb.table()
sq := dialect.From(table).LeftJoin(
studiosStashIDsJoinTable,
goqu.On(table.Col(idColumn).Eq(studiosStashIDsJoinTable.Col(studioIDColumn))),
).Select(table.Col(idColumn))
if hasStashID {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNotNull(),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashboxEndpoint),
)
} else {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNull(),
)
}
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash-box endpoint %s: %w", stashboxEndpoint, err)
}
return ret, nil
}
func (qb *StudioStore) Count(ctx context.Context) (int, error) {
q := dialect.Select(goqu.COUNT("*")).From(qb.table())
return count(ctx, q)
}
func (qb *StudioStore) All(ctx context.Context) ([]*models.Studio, error) {
table := qb.table()
return qb.getMany(ctx, qb.selectDataset().Order(table.Col(studioNameColumn).Asc()))
}
func (qb *StudioStore) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
table := qb.table()
sq := dialect.From(table).Select(table.Col(idColumn)).LeftJoin(
studiosAliasesJoinTable,
goqu.On(studiosAliasesJoinTable.Col(studioIDColumn).Eq(table.Col(idColumn))),
)
var whereClauses []exp.Expression
for _, w := range words {
whereClauses = append(whereClauses, table.Col(studioNameColumn).Like(w+"%"))
whereClauses = append(whereClauses, studiosAliasesJoinTable.Col("alias").Like(w+"%"))
}
sq = sq.Where(
goqu.Or(whereClauses...),
table.Col("ignore_auto_tag").Eq(0),
)
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting performers for autotag: %w", err)
}
return ret, nil
}
func (qb *StudioStore) validateFilter(filter *models.StudioFilterType) error {
const and = "AND"
const or = "OR"
const not = "NOT"
if filter.And != nil {
if filter.Or != nil {
return illegalFilterCombination(and, or)
}
if filter.Not != nil {
return illegalFilterCombination(and, not)
}
return qb.validateFilter(filter.And)
}
if filter.Or != nil {
if filter.Not != nil {
return illegalFilterCombination(or, not)
}
return qb.validateFilter(filter.Or)
}
if filter.Not != nil {
return qb.validateFilter(filter.Not)
}
return nil
}
func (qb *StudioStore) makeFilter(ctx context.Context, studioFilter *models.StudioFilterType) *filterBuilder {
query := &filterBuilder{}
if studioFilter.And != nil {
query.and(qb.makeFilter(ctx, studioFilter.And))
}
if studioFilter.Or != nil {
query.or(qb.makeFilter(ctx, studioFilter.Or))
}
if studioFilter.Not != nil {
query.not(qb.makeFilter(ctx, studioFilter.Not))
}
query.handleCriterion(ctx, stringCriterionHandler(studioFilter.Name, studioTable+".name"))
query.handleCriterion(ctx, stringCriterionHandler(studioFilter.Details, studioTable+".details"))
query.handleCriterion(ctx, stringCriterionHandler(studioFilter.URL, studioTable+".url"))
query.handleCriterion(ctx, intCriterionHandler(studioFilter.Rating100, studioTable+".rating", nil))
// legacy rating handler
query.handleCriterion(ctx, rating5CriterionHandler(studioFilter.Rating, studioTable+".rating", nil))
query.handleCriterion(ctx, boolCriterionHandler(studioFilter.IgnoreAutoTag, studioTable+".ignore_auto_tag", nil))
query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) {
if studioFilter.StashID != nil {
qb.stashIDRepository().join(f, "studio_stash_ids", "studios.id")
stringCriterionHandler(studioFilter.StashID, "studio_stash_ids.stash_id")(ctx, f)
}
}))
query.handleCriterion(ctx, &stashIDCriterionHandler{
c: studioFilter.StashIDEndpoint,
stashIDRepository: qb.stashIDRepository(),
stashIDTableAs: "studio_stash_ids",
parentIDCol: "studios.id",
})
query.handleCriterion(ctx, studioIsMissingCriterionHandler(qb, studioFilter.IsMissing))
query.handleCriterion(ctx, studioSceneCountCriterionHandler(qb, studioFilter.SceneCount))
query.handleCriterion(ctx, studioImageCountCriterionHandler(qb, studioFilter.ImageCount))
query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount))
query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents))
query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, studioTable+".created_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, studioTable+".updated_at"))
return query
}
func (qb *StudioStore) makeQuery(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
if studioFilter == nil {
studioFilter = &models.StudioFilterType{}
}
if findFilter == nil {
findFilter = &models.FindFilterType{}
}
query := qb.newQuery()
distinctIDs(&query, studioTable)
if q := findFilter.Q; q != nil && *q != "" {
query.join(studioAliasesTable, "", "studio_aliases.studio_id = studios.id")
searchColumns := []string{"studios.name", "studio_aliases.alias"}
query.parseQueryString(searchColumns, *q)
}
if err := qb.validateFilter(studioFilter); err != nil {
return nil, err
}
filter := qb.makeFilter(ctx, studioFilter)
if err := query.addFilter(filter); err != nil {
return nil, err
}
query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter)
return &query, nil
}
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
query, err := qb.makeQuery(ctx, studioFilter, findFilter)
if err != nil {
return nil, 0, err
}
idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
}
studios, err := qb.FindMany(ctx, idsResult)
if err != nil {
return nil, 0, err
}
return studios, countResult, nil
}
func studioIsMissingCriterionHandler(qb *StudioStore, isMissing *string) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if isMissing != nil && *isMissing != "" {
switch *isMissing {
case "image":
f.addWhere("studios.image_blob IS NULL")
case "stash_id":
qb.stashIDRepository().join(f, "studio_stash_ids", "studios.id")
f.addWhere("studio_stash_ids.studio_id IS NULL")
default:
f.addWhere("(studios." + *isMissing + " IS NULL OR TRIM(studios." + *isMissing + ") = '')")
}
}
}
}
func studioSceneCountCriterionHandler(qb *StudioStore, sceneCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if sceneCount != nil {
f.addLeftJoin("scenes", "", "scenes.studio_id = studios.id")
clause, args := getIntCriterionWhereClause("count(distinct scenes.id)", *sceneCount)
f.addHaving(clause, args...)
}
}
}
func studioImageCountCriterionHandler(qb *StudioStore, imageCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if imageCount != nil {
f.addLeftJoin("images", "", "images.studio_id = studios.id")
clause, args := getIntCriterionWhereClause("count(distinct images.id)", *imageCount)
f.addHaving(clause, args...)
}
}
}
func studioGalleryCountCriterionHandler(qb *StudioStore, galleryCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if galleryCount != nil {
f.addLeftJoin("galleries", "", "galleries.studio_id = studios.id")
clause, args := getIntCriterionWhereClause("count(distinct galleries.id)", *galleryCount)
f.addHaving(clause, args...)
}
}
}
func studioParentCriterionHandler(qb *StudioStore, parents *models.MultiCriterionInput) criterionHandlerFunc {
addJoinsFunc := func(f *filterBuilder) {
f.addLeftJoin("studios", "parent_studio", "parent_studio.id = studios.parent_id")
}
h := multiCriterionHandlerBuilder{
primaryTable: studioTable,
foreignTable: "parent_studio",
joinTable: "",
primaryFK: studioIDColumn,
foreignFK: "parent_id",
addJoinsFunc: addJoinsFunc,
}
return h.handler(parents)
}
func studioAliasCriterionHandler(qb *StudioStore, alias *models.StringCriterionInput) criterionHandlerFunc {
h := stringListCriterionHandlerBuilder{
joinTable: studioAliasesTable,
stringColumn: studioAliasColumn,
addJoinTable: func(f *filterBuilder) {
studiosAliasesTableMgr.join(f, "", "studios.id")
},
}
return h.handler(alias)
}
func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) string {
var sort string
var direction string
if findFilter == nil {
sort = "name"
direction = "ASC"
} else {
sort = findFilter.GetSort("name")
direction = findFilter.GetDirection()
}
sortQuery := ""
switch sort {
case "scenes_count":
sortQuery += getCountSort(studioTable, sceneTable, studioIDColumn, direction)
case "images_count":
sortQuery += getCountSort(studioTable, imageTable, studioIDColumn, direction)
case "galleries_count":
sortQuery += getCountSort(studioTable, galleryTable, studioIDColumn, direction)
default:
sortQuery += getSort(sort, direction, "studios")
}
// Whatever the sorting, always use name/id as a final sort
sortQuery += ", COALESCE(studios.name, studios.id) COLLATE NATURAL_CI ASC"
return sortQuery
}
func (qb *StudioStore) GetImage(ctx context.Context, studioID int) ([]byte, error) {
return qb.blobJoinQueryBuilder.GetImage(ctx, studioID, studioImageBlobColumn)
}
func (qb *StudioStore) HasImage(ctx context.Context, studioID int) (bool, error) {
return qb.blobJoinQueryBuilder.HasImage(ctx, studioID, studioImageBlobColumn)
}
func (qb *StudioStore) UpdateImage(ctx context.Context, studioID int, image []byte) error {
return qb.blobJoinQueryBuilder.UpdateImage(ctx, studioID, studioImageBlobColumn, image)
}
func (qb *StudioStore) destroyImage(ctx context.Context, studioID int) error {
return qb.blobJoinQueryBuilder.DestroyImage(ctx, studioID, studioImageBlobColumn)
}
func (qb *StudioStore) stashIDRepository() *stashIDRepository {
return &stashIDRepository{
repository{
tx: qb.tx,
tableName: "studio_stash_ids",
idColumn: studioIDColumn,
},
}
}
func (qb *StudioStore) GetStashIDs(ctx context.Context, studioID int) ([]models.StashID, error) {
return studiosStashIDsTableMgr.get(ctx, studioID)
}
func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string, error) {
return studiosAliasesTableMgr.get(ctx, studioID)
}