stash/pkg/sqlite/group.go

704 lines
18 KiB
Go

package sqlite
import (
"context"
"database/sql"
"errors"
"fmt"
"slices"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/jmoiron/sqlx"
"gopkg.in/guregu/null.v4"
"gopkg.in/guregu/null.v4/zero"
"github.com/stashapp/stash/pkg/models"
)
const (
groupTable = "groups"
groupIDColumn = "group_id"
groupFrontImageBlobColumn = "front_image_blob"
groupBackImageBlobColumn = "back_image_blob"
groupsTagsTable = "groups_tags"
groupURLsTable = "group_urls"
groupURLColumn = "url"
groupRelationsTable = "groups_relations"
)
type groupRow struct {
ID int `db:"id" goqu:"skipinsert"`
Name zero.String `db:"name"`
Aliases zero.String `db:"aliases"`
Duration null.Int `db:"duration"`
Date NullDate `db:"date"`
// expressed as 1-100
Rating null.Int `db:"rating"`
StudioID null.Int `db:"studio_id,omitempty"`
Director zero.String `db:"director"`
Description zero.String `db:"description"`
CreatedAt Timestamp `db:"created_at"`
UpdatedAt Timestamp `db:"updated_at"`
// not used in resolutions or updates
FrontImageBlob zero.String `db:"front_image_blob"`
BackImageBlob zero.String `db:"back_image_blob"`
}
func (r *groupRow) fromGroup(o models.Group) {
r.ID = o.ID
r.Name = zero.StringFrom(o.Name)
r.Aliases = zero.StringFrom(o.Aliases)
r.Duration = intFromPtr(o.Duration)
r.Date = NullDateFromDatePtr(o.Date)
r.Rating = intFromPtr(o.Rating)
r.StudioID = intFromPtr(o.StudioID)
r.Director = zero.StringFrom(o.Director)
r.Description = zero.StringFrom(o.Synopsis)
r.CreatedAt = Timestamp{Timestamp: o.CreatedAt}
r.UpdatedAt = Timestamp{Timestamp: o.UpdatedAt}
}
func (r *groupRow) resolve() *models.Group {
ret := &models.Group{
ID: r.ID,
Name: r.Name.String,
Aliases: r.Aliases.String,
Duration: nullIntPtr(r.Duration),
Date: r.Date.DatePtr(),
Rating: nullIntPtr(r.Rating),
StudioID: nullIntPtr(r.StudioID),
Director: r.Director.String,
Synopsis: r.Description.String,
CreatedAt: r.CreatedAt.Timestamp,
UpdatedAt: r.UpdatedAt.Timestamp,
}
return ret
}
type groupRowRecord struct {
updateRecord
}
func (r *groupRowRecord) fromPartial(o models.GroupPartial) {
r.setNullString("name", o.Name)
r.setNullString("aliases", o.Aliases)
r.setNullInt("duration", o.Duration)
r.setNullDate("date", o.Date)
r.setNullInt("rating", o.Rating)
r.setNullInt("studio_id", o.StudioID)
r.setNullString("director", o.Director)
r.setNullString("description", o.Synopsis)
r.setTimestamp("created_at", o.CreatedAt)
r.setTimestamp("updated_at", o.UpdatedAt)
}
type groupRepositoryType struct {
repository
scenes repository
tags joinRepository
}
var (
groupRepository = groupRepositoryType{
repository: repository{
tableName: groupTable,
idColumn: idColumn,
},
scenes: repository{
tableName: groupsScenesTable,
idColumn: groupIDColumn,
},
tags: joinRepository{
repository: repository{
tableName: groupsTagsTable,
idColumn: groupIDColumn,
},
fkColumn: tagIDColumn,
foreignTable: tagTable,
orderBy: "tags.name ASC",
},
}
)
type GroupStore struct {
blobJoinQueryBuilder
tagRelationshipStore
groupRelationshipStore
tableMgr *table
}
func NewGroupStore(blobStore *BlobStore) *GroupStore {
return &GroupStore{
blobJoinQueryBuilder: blobJoinQueryBuilder{
blobStore: blobStore,
joinTable: groupTable,
},
tagRelationshipStore: tagRelationshipStore{
idRelationshipStore: idRelationshipStore{
joinTable: groupsTagsTableMgr,
},
},
groupRelationshipStore: groupRelationshipStore{
table: groupRelationshipTableMgr,
},
tableMgr: groupTableMgr,
}
}
func (qb *GroupStore) table() exp.IdentifierExpression {
return qb.tableMgr.table
}
func (qb *GroupStore) selectDataset() *goqu.SelectDataset {
return dialect.From(qb.table()).Select(qb.table().All())
}
func (qb *GroupStore) Create(ctx context.Context, newObject *models.Group) error {
var r groupRow
r.fromGroup(*newObject)
id, err := qb.tableMgr.insertID(ctx, r)
if err != nil {
return err
}
if newObject.URLs.Loaded() {
const startPos = 0
if err := groupsURLsTableMgr.insertJoins(ctx, id, startPos, newObject.URLs.List()); err != nil {
return err
}
}
if err := qb.tagRelationshipStore.createRelationships(ctx, id, newObject.TagIDs); err != nil {
return err
}
if err := qb.groupRelationshipStore.createContainingRelationships(ctx, id, newObject.ContainingGroups); err != nil {
return err
}
if err := qb.groupRelationshipStore.createSubRelationships(ctx, id, newObject.SubGroups); err != nil {
return err
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
}
*newObject = *updated
return nil
}
func (qb *GroupStore) UpdatePartial(ctx context.Context, id int, partial models.GroupPartial) (*models.Group, error) {
r := groupRowRecord{
updateRecord{
Record: make(exp.Record),
},
}
r.fromPartial(partial)
if len(r.Record) > 0 {
if err := qb.tableMgr.updateByID(ctx, id, r.Record); err != nil {
return nil, err
}
}
if partial.URLs != nil {
if err := groupsURLsTableMgr.modifyJoins(ctx, id, partial.URLs.Values, partial.URLs.Mode); err != nil {
return nil, err
}
}
if err := qb.tagRelationshipStore.modifyRelationships(ctx, id, partial.TagIDs); err != nil {
return nil, err
}
if err := qb.groupRelationshipStore.modifyContainingRelationships(ctx, id, partial.ContainingGroups); err != nil {
return nil, err
}
if err := qb.groupRelationshipStore.modifySubRelationships(ctx, id, partial.SubGroups); err != nil {
return nil, err
}
return qb.find(ctx, id)
}
func (qb *GroupStore) Update(ctx context.Context, updatedObject *models.Group) error {
var r groupRow
r.fromGroup(*updatedObject)
if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil {
return err
}
if updatedObject.URLs.Loaded() {
if err := groupsURLsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.URLs.List()); err != nil {
return err
}
}
if err := qb.tagRelationshipStore.replaceRelationships(ctx, updatedObject.ID, updatedObject.TagIDs); err != nil {
return err
}
if err := qb.groupRelationshipStore.replaceContainingRelationships(ctx, updatedObject.ID, updatedObject.ContainingGroups); err != nil {
return err
}
if err := qb.groupRelationshipStore.replaceSubRelationships(ctx, updatedObject.ID, updatedObject.SubGroups); err != nil {
return err
}
return nil
}
func (qb *GroupStore) Destroy(ctx context.Context, id int) error {
// must handle image checksums manually
if err := qb.destroyImages(ctx, id); err != nil {
return err
}
return groupRepository.destroyExisting(ctx, []int{id})
}
// returns nil, nil if not found
func (qb *GroupStore) Find(ctx context.Context, id int) (*models.Group, error) {
ret, err := qb.find(ctx, id)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return ret, err
}
func (qb *GroupStore) FindMany(ctx context.Context, ids []int) ([]*models.Group, error) {
ret := make([]*models.Group, len(ids))
table := qb.table()
if err := batchExec(ids, defaultBatchSize, func(batch []int) error {
q := qb.selectDataset().Prepared(true).Where(table.Col(idColumn).In(batch))
unsorted, err := qb.getMany(ctx, q)
if err != nil {
return err
}
for _, s := range unsorted {
i := slices.Index(ids, s.ID)
ret[i] = s
}
return nil
}); err != nil {
return nil, err
}
for i := range ret {
if ret[i] == nil {
return nil, fmt.Errorf("group with id %d not found", ids[i])
}
}
return ret, nil
}
// returns nil, sql.ErrNoRows if not found
func (qb *GroupStore) find(ctx context.Context, id int) (*models.Group, error) {
q := qb.selectDataset().Where(qb.tableMgr.byID(id))
ret, err := qb.get(ctx, q)
if err != nil {
return nil, err
}
return ret, nil
}
// returns nil, sql.ErrNoRows if not found
func (qb *GroupStore) get(ctx context.Context, q *goqu.SelectDataset) (*models.Group, error) {
ret, err := qb.getMany(ctx, q)
if err != nil {
return nil, err
}
if len(ret) == 0 {
return nil, sql.ErrNoRows
}
return ret[0], nil
}
func (qb *GroupStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*models.Group, error) {
const single = false
var ret []*models.Group
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
var f groupRow
if err := r.StructScan(&f); err != nil {
return err
}
s := f.resolve()
ret = append(ret, s)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (qb *GroupStore) FindByName(ctx context.Context, name string, nocase bool) (*models.Group, error) {
// query := "SELECT * FROM groups WHERE name = ?"
// if nocase {
// query += " COLLATE NOCASE"
// }
// query += " LIMIT 1"
where := "name = ?"
if nocase {
where += " COLLATE NOCASE"
}
sq := qb.selectDataset().Prepared(true).Where(goqu.L(where, name)).Limit(1)
ret, err := qb.get(ctx, sq)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
return ret, nil
}
func (qb *GroupStore) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Group, error) {
// query := "SELECT * FROM groups WHERE name"
// if nocase {
// query += " COLLATE NOCASE"
// }
// query += " IN " + getInBinding(len(names))
where := "name"
if nocase {
where += " COLLATE NOCASE"
}
where += " IN " + getInBinding(len(names))
var args []interface{}
for _, name := range names {
args = append(args, name)
}
sq := qb.selectDataset().Prepared(true).Where(goqu.L(where, args...))
ret, err := qb.getMany(ctx, sq)
if err != nil {
return nil, err
}
return ret, nil
}
func (qb *GroupStore) Count(ctx context.Context) (int, error) {
q := dialect.Select(goqu.COUNT("*")).From(qb.table())
return count(ctx, q)
}
func (qb *GroupStore) All(ctx context.Context) ([]*models.Group, error) {
table := qb.table()
return qb.getMany(ctx, qb.selectDataset().Order(
table.Col("name").Asc(),
table.Col(idColumn).Asc(),
))
}
func (qb *GroupStore) makeQuery(ctx context.Context, groupFilter *models.GroupFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
if findFilter == nil {
findFilter = &models.FindFilterType{}
}
if groupFilter == nil {
groupFilter = &models.GroupFilterType{}
}
query := groupRepository.newQuery()
distinctIDs(&query, groupTable)
if q := findFilter.Q; q != nil && *q != "" {
searchColumns := []string{"groups.name", "groups.aliases"}
query.parseQueryString(searchColumns, *q)
}
filter := filterBuilderFromHandler(ctx, &groupFilterHandler{
groupFilter: groupFilter,
})
if err := query.addFilter(filter); err != nil {
return nil, err
}
if err := qb.setGroupSort(&query, findFilter); err != nil {
return nil, err
}
query.sortAndPagination += getPagination(findFilter)
return &query, nil
}
func (qb *GroupStore) Query(ctx context.Context, groupFilter *models.GroupFilterType, findFilter *models.FindFilterType) ([]*models.Group, int, error) {
query, err := qb.makeQuery(ctx, groupFilter, findFilter)
if err != nil {
return nil, 0, err
}
idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
}
groups, err := qb.FindMany(ctx, idsResult)
if err != nil {
return nil, 0, err
}
return groups, countResult, nil
}
func (qb *GroupStore) QueryCount(ctx context.Context, groupFilter *models.GroupFilterType, findFilter *models.FindFilterType) (int, error) {
query, err := qb.makeQuery(ctx, groupFilter, findFilter)
if err != nil {
return 0, err
}
return query.executeCount(ctx)
}
var groupSortOptions = sortOptions{
"created_at",
"date",
"duration",
"id",
"name",
"random",
"rating",
"scenes_count",
"sub_group_order",
"tag_count",
"updated_at",
}
func (qb *GroupStore) setGroupSort(query *queryBuilder, findFilter *models.FindFilterType) error {
var sort string
var direction string
if findFilter == nil {
sort = "name"
direction = "ASC"
} else {
sort = findFilter.GetSort("name")
direction = findFilter.GetDirection()
}
// CVE-2024-32231 - ensure sort is in the list of allowed sorts
if err := groupSortOptions.validateSort(sort); err != nil {
return err
}
switch sort {
case "sub_group_order":
// sub_group_order is a special sort that sorts by the order_index of the subgroups
if query.hasJoin("groups_parents") {
query.sortAndPagination += getSort("order_index", direction, "groups_parents")
} else {
// this will give unexpected results if the query is not filtered by a parent group and
// the group has multiple parents and order indexes
query.join(groupRelationsTable, "", "groups.id = groups_relations.sub_id")
query.sortAndPagination += getSort("order_index", direction, groupRelationsTable)
}
case "tag_count":
query.sortAndPagination += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction)
case "scenes_count": // generic getSort won't work for this
query.sortAndPagination += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction)
default:
query.sortAndPagination += getSort(sort, direction, "groups")
}
// Whatever the sorting, always use name/id as a final sort
query.sortAndPagination += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC"
return nil
}
func (qb *GroupStore) queryGroups(ctx context.Context, query string, args []interface{}) ([]*models.Group, error) {
const single = false
var ret []*models.Group
if err := groupRepository.queryFunc(ctx, query, args, single, func(r *sqlx.Rows) error {
var f groupRow
if err := r.StructScan(&f); err != nil {
return err
}
s := f.resolve()
ret = append(ret, s)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (qb *GroupStore) UpdateFrontImage(ctx context.Context, groupID int, frontImage []byte) error {
return qb.UpdateImage(ctx, groupID, groupFrontImageBlobColumn, frontImage)
}
func (qb *GroupStore) UpdateBackImage(ctx context.Context, groupID int, backImage []byte) error {
return qb.UpdateImage(ctx, groupID, groupBackImageBlobColumn, backImage)
}
func (qb *GroupStore) destroyImages(ctx context.Context, groupID int) error {
if err := qb.DestroyImage(ctx, groupID, groupFrontImageBlobColumn); err != nil {
return err
}
if err := qb.DestroyImage(ctx, groupID, groupBackImageBlobColumn); err != nil {
return err
}
return nil
}
func (qb *GroupStore) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) {
return qb.GetImage(ctx, groupID, groupFrontImageBlobColumn)
}
func (qb *GroupStore) HasFrontImage(ctx context.Context, groupID int) (bool, error) {
return qb.HasImage(ctx, groupID, groupFrontImageBlobColumn)
}
func (qb *GroupStore) GetBackImage(ctx context.Context, groupID int) ([]byte, error) {
return qb.GetImage(ctx, groupID, groupBackImageBlobColumn)
}
func (qb *GroupStore) HasBackImage(ctx context.Context, groupID int) (bool, error) {
return qb.HasImage(ctx, groupID, groupBackImageBlobColumn)
}
func (qb *GroupStore) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Group, error) {
query := `SELECT DISTINCT groups.*
FROM groups
INNER JOIN groups_scenes ON groups.id = groups_scenes.group_id
INNER JOIN performers_scenes ON performers_scenes.scene_id = groups_scenes.scene_id
WHERE performers_scenes.performer_id = ?
`
args := []interface{}{performerID}
return qb.queryGroups(ctx, query, args)
}
func (qb *GroupStore) CountByPerformerID(ctx context.Context, performerID int) (int, error) {
query := `SELECT COUNT(DISTINCT groups_scenes.group_id) AS count
FROM groups_scenes
INNER JOIN performers_scenes ON performers_scenes.scene_id = groups_scenes.scene_id
WHERE performers_scenes.performer_id = ?
`
args := []interface{}{performerID}
return groupRepository.runCountQuery(ctx, query, args)
}
func (qb *GroupStore) FindByStudioID(ctx context.Context, studioID int) ([]*models.Group, error) {
query := `SELECT groups.*
FROM groups
WHERE groups.studio_id = ?
`
args := []interface{}{studioID}
return qb.queryGroups(ctx, query, args)
}
func (qb *GroupStore) CountByStudioID(ctx context.Context, studioID int) (int, error) {
query := `SELECT COUNT(1) AS count
FROM groups
WHERE groups.studio_id = ?
`
args := []interface{}{studioID}
return groupRepository.runCountQuery(ctx, query, args)
}
func (qb *GroupStore) GetURLs(ctx context.Context, groupID int) ([]string, error) {
return groupsURLsTableMgr.get(ctx, groupID)
}
// FindSubGroupIDs returns a list of group IDs where a group in the ids list is a sub-group of the parent group
func (qb *GroupStore) FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error) {
/*
SELECT gr.sub_id FROM groups_relations gr
WHERE gr.containing_id = :parentID AND gr.sub_id IN (:ids);
*/
table := groupRelationshipTableMgr.table
q := dialect.From(table).Prepared(true).
Select(table.Col("sub_id")).Where(
table.Col("containing_id").Eq(containingID),
table.Col("sub_id").In(ids),
)
const single = false
var ret []int
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
var id int
if err := r.Scan(&id); err != nil {
return err
}
ret = append(ret, id)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
// FindInAscestors returns a list of group IDs where a group in the ids list is an ascestor of the ancestor group IDs
func (qb *GroupStore) FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error) {
/*
WITH RECURSIVE ascestors AS (
SELECT g.id AS parent_id FROM groups g WHERE g.id IN (:ascestorIDs)
UNION
SELECT gr.containing_id FROM groups_relations gr INNER JOIN ascestors a ON a.parent_id = gr.sub_id
)
SELECT p.parent_id FROM ascestors p WHERE p.parent_id IN (:ids);
*/
table := qb.table()
const ascestors = "ancestors"
const parentID = "parent_id"
q := dialect.From(ascestors).Prepared(true).
WithRecursive(ascestors,
dialect.From(qb.table()).Select(table.Col(idColumn).As(parentID)).
Where(table.Col(idColumn).In(ascestorIDs)).
Union(
dialect.From(groupRelationsJoinTable).InnerJoin(
goqu.I(ascestors), goqu.On(goqu.I("parent_id").Eq(goqu.I("sub_id"))),
).Select("containing_id"),
),
).Select(parentID).Where(goqu.I(parentID).In(ids))
const single = false
var ret []int
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
var id int
if err := r.Scan(&id); err != nil {
return err
}
ret = append(ret, id)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}