stash/pkg/sqlite/group_relationships.go

458 lines
13 KiB
Go

package sqlite
import (
"context"
"fmt"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models"
"gopkg.in/guregu/null.v4"
"gopkg.in/guregu/null.v4/zero"
)
type groupRelationshipRow struct {
ContainingID int `db:"containing_id"`
SubID int `db:"sub_id"`
OrderIndex int `db:"order_index"`
Description zero.String `db:"description"`
}
func (r groupRelationshipRow) resolve(useContainingID bool) models.GroupIDDescription {
id := r.ContainingID
if !useContainingID {
id = r.SubID
}
return models.GroupIDDescription{
GroupID: id,
Description: r.Description.String,
}
}
type groupRelationshipStore struct {
table *table
}
func (s *groupRelationshipStore) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
const idIsContaining = false
return s.getGroupRelationships(ctx, id, idIsContaining)
}
func (s *groupRelationshipStore) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
const idIsContaining = true
return s.getGroupRelationships(ctx, id, idIsContaining)
}
func (s *groupRelationshipStore) getGroupRelationships(ctx context.Context, id int, idIsContaining bool) ([]models.GroupIDDescription, error) {
col := "containing_id"
if !idIsContaining {
col = "sub_id"
}
table := s.table.table
q := dialect.Select(table.All()).
From(table).
Where(table.Col(col).Eq(id)).
Order(table.Col("order_index").Asc())
const single = false
var ret []models.GroupIDDescription
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
var row groupRelationshipRow
if err := rows.StructScan(&row); err != nil {
return err
}
ret = append(ret, row.resolve(!idIsContaining))
return nil
}); err != nil {
return nil, fmt.Errorf("getting group relationships from %s: %w", table.GetTable(), err)
}
return ret, nil
}
// getMaxOrderIndex gets the maximum order index for the containing group with the given id
func (s *groupRelationshipStore) getMaxOrderIndex(ctx context.Context, containingID int) (int, error) {
idColumn := s.table.table.Col("containing_id")
q := dialect.Select(goqu.MAX("order_index")).
From(s.table.table).
Where(idColumn.Eq(containingID))
var maxOrderIndex zero.Int
if err := querySimple(ctx, q, &maxOrderIndex); err != nil {
return 0, fmt.Errorf("getting max order index: %w", err)
}
return int(maxOrderIndex.Int64), nil
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error {
if d.Loaded() {
for i, v := range d.List() {
orderIndex := i + 1
r := groupRelationshipRow{
ContainingID: id,
SubID: v.GroupID,
OrderIndex: orderIndex,
Description: zero.StringFrom(v.Description),
}
if !idIsContaining {
// get the max order index of the containing groups sub groups
containingID := v.GroupID
maxOrderIndex, err := s.getMaxOrderIndex(ctx, containingID)
if err != nil {
return err
}
r.ContainingID = v.GroupID
r.SubID = id
r.OrderIndex = maxOrderIndex + 1
}
_, err := s.table.insert(ctx, r)
if err != nil {
return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err)
}
}
return nil
}
return nil
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = false
return s.createRelationships(ctx, id, d, idIsContaining)
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = true
return s.createRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error {
// always destroy the existing relationships even if the new list is empty
if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil {
return err
}
return s.createRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = false
return s.replaceRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = true
return s.replaceRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) modifyRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions, idIsContaining bool) error {
if v == nil {
return nil
}
switch v.Mode {
case models.RelationshipUpdateModeSet:
return s.replaceJoins(ctx, id, *v, idIsContaining)
case models.RelationshipUpdateModeAdd:
return s.addJoins(ctx, id, v.Groups, idIsContaining)
case models.RelationshipUpdateModeRemove:
toRemove := make([]int, len(v.Groups))
for i, vv := range v.Groups {
toRemove[i] = vv.GroupID
}
return s.destroyJoins(ctx, id, toRemove, idIsContaining)
}
return nil
}
func (s *groupRelationshipStore) modifyContainingRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error {
const idIsContaining = false
return s.modifyRelationships(ctx, id, v, idIsContaining)
}
func (s *groupRelationshipStore) modifySubRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error {
const idIsContaining = true
return s.modifyRelationships(ctx, id, v, idIsContaining)
}
func (s *groupRelationshipStore) addJoins(ctx context.Context, id int, groups []models.GroupIDDescription, idIsContaining bool) error {
// if we're adding to a containing group, get the max order index first
var maxOrderIndex int
if idIsContaining {
var err error
maxOrderIndex, err = s.getMaxOrderIndex(ctx, id)
if err != nil {
return err
}
}
for i, vv := range groups {
r := groupRelationshipRow{
Description: zero.StringFrom(vv.Description),
}
if idIsContaining {
r.ContainingID = id
r.SubID = vv.GroupID
r.OrderIndex = maxOrderIndex + (i + 1)
} else {
// get the max order index of the containing groups sub groups
containingMaxOrderIndex, err := s.getMaxOrderIndex(ctx, vv.GroupID)
if err != nil {
return err
}
r.ContainingID = vv.GroupID
r.SubID = id
r.OrderIndex = containingMaxOrderIndex + 1
}
_, err := s.table.insert(ctx, r)
if err != nil {
return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err)
}
}
return nil
}
func (s *groupRelationshipStore) destroyAllJoins(ctx context.Context, id int, idIsContaining bool) error {
table := s.table.table
idColumn := table.Col("containing_id")
if !idIsContaining {
idColumn = table.Col("sub_id")
}
q := dialect.Delete(table).Where(idColumn.Eq(id))
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("destroying %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) replaceJoins(ctx context.Context, id int, v models.UpdateGroupDescriptions, idIsContaining bool) error {
if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil {
return err
}
// convert to RelatedGroupDescriptions
rgd := models.NewRelatedGroupDescriptions(v.Groups)
return s.createRelationships(ctx, id, rgd, idIsContaining)
}
func (s *groupRelationshipStore) destroyJoins(ctx context.Context, id int, toRemove []int, idIsContaining bool) error {
table := s.table.table
idColumn := table.Col("containing_id")
fkColumn := table.Col("sub_id")
if !idIsContaining {
idColumn = table.Col("sub_id")
fkColumn = table.Col("containing_id")
}
q := dialect.Delete(table).Where(idColumn.Eq(id), fkColumn.In(toRemove))
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("destroying %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) getOrderIndexOfSubGroup(ctx context.Context, containingGroupID int, subGroupID int) (int, error) {
table := s.table.table
q := dialect.Select("order_index").
From(table).
Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("sub_id").Eq(subGroupID),
)
var orderIndex null.Int
if err := querySimple(ctx, q, &orderIndex); err != nil {
return 0, fmt.Errorf("getting order index: %w", err)
}
if !orderIndex.Valid {
return 0, fmt.Errorf("sub-group %d not found in containing group %d", subGroupID, containingGroupID)
}
return int(orderIndex.Int64), nil
}
func (s *groupRelationshipStore) getGroupIDAtOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (*int, error) {
table := s.table.table
q := dialect.Select(table.Col("sub_id")).From(table).Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("order_index").Eq(orderIndex),
)
var ret null.Int
if err := querySimple(ctx, q, &ret); err != nil {
return nil, fmt.Errorf("getting sub id for order index: %w", err)
}
if !ret.Valid {
return nil, nil
}
intRet := int(ret.Int64)
return &intRet, nil
}
func (s *groupRelationshipStore) getOrderIndexAfterOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (int, error) {
table := s.table.table
q := dialect.Select(goqu.MIN("order_index")).From(table).Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("order_index").Gt(orderIndex),
)
var ret null.Int
if err := querySimple(ctx, q, &ret); err != nil {
return 0, fmt.Errorf("getting order index: %w", err)
}
if !ret.Valid {
return orderIndex + 1, nil
}
return int(ret.Int64), nil
}
// incrementOrderIndexes increments the order_index value of all sub-groups in the containing group at or after the given index
func (s *groupRelationshipStore) incrementOrderIndexes(ctx context.Context, groupID int, indexBefore int) error {
table := s.table.table
// WORKAROUND - sqlite won't allow incrementing the value directly since it causes a
// unique constraint violation.
// Instead, we first set the order index to a negative value temporarily
// see https://stackoverflow.com/a/7703239/695786
q := dialect.Update(table).Set(exp.Record{
"order_index": goqu.L("-order_index"),
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("order_index").Gte(indexBefore),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
q = dialect.Update(table).Set(exp.Record{
"order_index": goqu.L("1-order_index"),
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("order_index").Lt(0),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) reorderSubGroup(ctx context.Context, groupID int, subGroupID int, insertPointID int, insertAfter bool) error {
insertPointIndex, err := s.getOrderIndexOfSubGroup(ctx, groupID, insertPointID)
if err != nil {
return err
}
// if we're setting before
if insertAfter {
insertPointIndex, err = s.getOrderIndexAfterOrderIndex(ctx, groupID, insertPointIndex)
if err != nil {
return err
}
}
// increment the order index of all sub-groups after and including the insertion point
if err := s.incrementOrderIndexes(ctx, groupID, int(insertPointIndex)); err != nil {
return err
}
// set the order index of the sub-group to the insertion point
table := s.table.table
q := dialect.Update(table).Set(exp.Record{
"order_index": insertPointIndex,
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("sub_id").Eq(subGroupID),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error {
const idIsContaining = true
if err := s.addJoins(ctx, groupID, subGroups, idIsContaining); err != nil {
return err
}
ids := make([]int, len(subGroups))
for i, v := range subGroups {
ids[i] = v.GroupID
}
if insertIndex != nil {
// get the id of the sub-group at the insert index
insertPointID, err := s.getGroupIDAtOrderIndex(ctx, groupID, *insertIndex)
if err != nil {
return err
}
if insertPointID == nil {
// if the insert index is out of bounds, just assume adding to the end
return nil
}
// reorder the sub-groups
const insertAfter = false
if err := s.ReorderSubGroups(ctx, groupID, ids, *insertPointID, insertAfter); err != nil {
return err
}
}
return nil
}
func (s *groupRelationshipStore) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error {
const idIsContaining = true
return s.destroyJoins(ctx, groupID, subGroupIDs, idIsContaining)
}
func (s *groupRelationshipStore) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error {
for _, id := range subGroupIDs {
if err := s.reorderSubGroup(ctx, groupID, id, insertPointID, insertAfter); err != nil {
return err
}
}
return nil
}