mirror of https://github.com/stashapp/stash.git
1211 lines
28 KiB
Go
1211 lines
28 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/doug-martin/goqu/v9"
|
|
"github.com/doug-martin/goqu/v9/exp"
|
|
"github.com/jmoiron/sqlx"
|
|
"gopkg.in/guregu/null.v4"
|
|
|
|
"github.com/stashapp/stash/pkg/logger"
|
|
"github.com/stashapp/stash/pkg/models"
|
|
"github.com/stashapp/stash/pkg/sliceutil"
|
|
)
|
|
|
|
type table struct {
|
|
table exp.IdentifierExpression
|
|
idColumn exp.IdentifierExpression
|
|
}
|
|
|
|
type NotFoundError struct {
|
|
ID int
|
|
Table string
|
|
}
|
|
|
|
func (e *NotFoundError) Error() string {
|
|
return fmt.Sprintf("id %d does not exist in %s", e.ID, e.Table)
|
|
}
|
|
|
|
func (t *table) insert(ctx context.Context, o interface{}) (sql.Result, error) {
|
|
q := dialect.Insert(t.table).Prepared(true).Rows(o)
|
|
ret, err := exec(ctx, q)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("inserting into %s: %w", t.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *table) insertID(ctx context.Context, o interface{}) (int, error) {
|
|
result, err := t.insert(ctx, o)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
ret, err := result.LastInsertId()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return int(ret), nil
|
|
}
|
|
|
|
func (t *table) updateByID(ctx context.Context, id interface{}, o interface{}) error {
|
|
q := dialect.Update(t.table).Prepared(true).Set(o).Where(t.byID(id))
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("updating %s: %w", t.table.GetTable(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *table) byID(id interface{}) exp.Expression {
|
|
return t.idColumn.Eq(id)
|
|
}
|
|
|
|
func (t *table) byIDInts(ids ...int) exp.Expression {
|
|
ii := make([]interface{}, len(ids))
|
|
for i, id := range ids {
|
|
ii[i] = id
|
|
}
|
|
return t.idColumn.In(ii...)
|
|
}
|
|
|
|
func (t *table) idExists(ctx context.Context, id interface{}) (bool, error) {
|
|
q := dialect.Select(goqu.COUNT("*")).From(t.table).Where(t.byID(id))
|
|
|
|
var count int
|
|
if err := querySimple(ctx, q, &count); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return count == 1, nil
|
|
}
|
|
|
|
func (t *table) checkIDExists(ctx context.Context, id int) error {
|
|
exists, err := t.idExists(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !exists {
|
|
return &NotFoundError{ID: id, Table: t.table.GetTable()}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *table) destroyExisting(ctx context.Context, ids []int) error {
|
|
for _, id := range ids {
|
|
exists, err := t.idExists(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !exists {
|
|
return &NotFoundError{
|
|
ID: id,
|
|
Table: t.table.GetTable(),
|
|
}
|
|
}
|
|
}
|
|
|
|
return t.destroy(ctx, ids)
|
|
}
|
|
|
|
func (t *table) destroy(ctx context.Context, ids []int) error {
|
|
q := dialect.Delete(t.table).Where(t.idColumn.In(ids))
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("destroying %s: %w", t.table.GetTable(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *table) join(j joiner, as string, parentIDCol string) {
|
|
tableName := t.table.GetTable()
|
|
tt := tableName
|
|
if as != "" {
|
|
tt = as
|
|
}
|
|
j.addLeftJoin(tableName, as, fmt.Sprintf("%s.%s = %s", tt, t.idColumn.GetCol(), parentIDCol))
|
|
}
|
|
|
|
// func (t *table) get(ctx context.Context, q *goqu.SelectDataset, dest interface{}) error {
|
|
// tx, err := getTx(ctx)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
|
|
// sql, args, err := q.ToSQL()
|
|
// if err != nil {
|
|
// return fmt.Errorf("generating sql: %w", err)
|
|
// }
|
|
|
|
// return tx.GetContext(ctx, dest, sql, args...)
|
|
// }
|
|
|
|
type joinTable struct {
|
|
table
|
|
fkColumn exp.IdentifierExpression
|
|
|
|
// required for ordering
|
|
foreignTable *table
|
|
orderBy exp.OrderedExpression
|
|
}
|
|
|
|
func (t *joinTable) invert() *joinTable {
|
|
return &joinTable{
|
|
table: table{
|
|
table: t.table.table,
|
|
idColumn: t.fkColumn,
|
|
},
|
|
fkColumn: t.table.idColumn,
|
|
foreignTable: t.foreignTable,
|
|
orderBy: t.orderBy,
|
|
}
|
|
}
|
|
|
|
func (t *joinTable) get(ctx context.Context, id int) ([]int, error) {
|
|
q := dialect.Select(t.fkColumn).From(t.table.table).Where(t.idColumn.Eq(id))
|
|
|
|
if t.orderBy != nil {
|
|
if t.foreignTable != nil {
|
|
q = q.InnerJoin(t.foreignTable.table, goqu.On(t.foreignTable.idColumn.Eq(t.fkColumn)))
|
|
}
|
|
q = q.Order(t.orderBy)
|
|
}
|
|
|
|
const single = false
|
|
var ret []int
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var fk int
|
|
if err := rows.Scan(&fk); err != nil {
|
|
return err
|
|
}
|
|
|
|
ret = append(ret, fk)
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("getting foreign keys from %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *joinTable) insertJoins(ctx context.Context, id int, foreignIDs []int) error {
|
|
// manually create SQL so that we can prepare once
|
|
// ignore duplicates
|
|
q := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?) ON CONFLICT (%[2]s, %s) DO NOTHING", t.table.table.GetTable(), t.idColumn.GetCol(), t.fkColumn.GetCol())
|
|
|
|
stmt, err := dbWrapper.Prepare(ctx, q)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
// eliminate duplicates
|
|
foreignIDs = sliceutil.AppendUniques(nil, foreignIDs)
|
|
|
|
for _, fk := range foreignIDs {
|
|
if _, err := dbWrapper.ExecStmt(ctx, stmt, id, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *joinTable) replaceJoins(ctx context.Context, id int, foreignIDs []int) error {
|
|
if err := t.destroy(ctx, []int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return t.insertJoins(ctx, id, foreignIDs)
|
|
}
|
|
|
|
func (t *joinTable) addJoins(ctx context.Context, id int, foreignIDs []int) error {
|
|
// get existing foreign keys
|
|
fks, err := t.get(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// only add foreign keys that are not already present
|
|
foreignIDs = sliceutil.Exclude(foreignIDs, fks)
|
|
return t.insertJoins(ctx, id, foreignIDs)
|
|
}
|
|
|
|
func (t *joinTable) destroyJoins(ctx context.Context, id int, foreignIDs []int) error {
|
|
q := dialect.Delete(t.table.table).Where(
|
|
t.idColumn.Eq(id),
|
|
t.fkColumn.In(foreignIDs),
|
|
)
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("destroying %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *joinTable) modifyJoins(ctx context.Context, id int, foreignIDs []int, mode models.RelationshipUpdateMode) error {
|
|
switch mode {
|
|
case models.RelationshipUpdateModeSet:
|
|
return t.replaceJoins(ctx, id, foreignIDs)
|
|
case models.RelationshipUpdateModeAdd:
|
|
return t.addJoins(ctx, id, foreignIDs)
|
|
case models.RelationshipUpdateModeRemove:
|
|
return t.destroyJoins(ctx, id, foreignIDs)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type stashIDTable struct {
|
|
table
|
|
}
|
|
|
|
type stashIDRow struct {
|
|
StashID null.String `db:"stash_id"`
|
|
Endpoint null.String `db:"endpoint"`
|
|
UpdatedAt Timestamp `db:"updated_at"`
|
|
}
|
|
|
|
func (r *stashIDRow) resolve() models.StashID {
|
|
return models.StashID{
|
|
StashID: r.StashID.String,
|
|
Endpoint: r.Endpoint.String,
|
|
UpdatedAt: r.UpdatedAt.Timestamp,
|
|
}
|
|
}
|
|
|
|
func (t *stashIDTable) get(ctx context.Context, id int) ([]models.StashID, error) {
|
|
q := dialect.Select("endpoint", "stash_id", "updated_at").From(t.table.table).Where(t.idColumn.Eq(id))
|
|
|
|
const single = false
|
|
var ret []models.StashID
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var v stashIDRow
|
|
if err := rows.StructScan(&v); err != nil {
|
|
return err
|
|
}
|
|
|
|
ret = append(ret, v.resolve())
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("getting stash ids from %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *stashIDTable) insertJoin(ctx context.Context, id int, v models.StashID) (sql.Result, error) {
|
|
var q = dialect.Insert(t.table.table).Cols(t.idColumn.GetCol(), "endpoint", "stash_id", "updated_at").Vals(
|
|
goqu.Vals{id, v.Endpoint, v.StashID, v.UpdatedAt},
|
|
)
|
|
ret, err := exec(ctx, q)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("inserting into %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *stashIDTable) insertJoins(ctx context.Context, id int, v []models.StashID) error {
|
|
for _, fk := range v {
|
|
if _, err := t.insertJoin(ctx, id, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *stashIDTable) replaceJoins(ctx context.Context, id int, v []models.StashID) error {
|
|
if err := t.destroy(ctx, []int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return t.insertJoins(ctx, id, v)
|
|
}
|
|
|
|
func (t *stashIDTable) addJoins(ctx context.Context, id int, v []models.StashID) error {
|
|
// get existing foreign keys
|
|
fks, err := t.get(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// only add values that are not already present
|
|
var filtered []models.StashID
|
|
for _, vv := range v {
|
|
for _, e := range fks {
|
|
if vv.Endpoint == e.Endpoint {
|
|
continue
|
|
}
|
|
|
|
filtered = append(filtered, vv)
|
|
}
|
|
}
|
|
return t.insertJoins(ctx, id, filtered)
|
|
}
|
|
|
|
func (t *stashIDTable) destroyJoins(ctx context.Context, id int, v []models.StashID) error {
|
|
for _, vv := range v {
|
|
q := dialect.Delete(t.table.table).Where(
|
|
t.idColumn.Eq(id),
|
|
t.table.table.Col("endpoint").Eq(vv.Endpoint),
|
|
t.table.table.Col("stash_id").Eq(vv.StashID),
|
|
)
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("destroying %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *stashIDTable) modifyJoins(ctx context.Context, id int, v []models.StashID, mode models.RelationshipUpdateMode) error {
|
|
switch mode {
|
|
case models.RelationshipUpdateModeSet:
|
|
return t.replaceJoins(ctx, id, v)
|
|
case models.RelationshipUpdateModeAdd:
|
|
return t.addJoins(ctx, id, v)
|
|
case models.RelationshipUpdateModeRemove:
|
|
return t.destroyJoins(ctx, id, v)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type stringTable struct {
|
|
table
|
|
stringColumn exp.IdentifierExpression
|
|
}
|
|
|
|
func (t *stringTable) get(ctx context.Context, id int) ([]string, error) {
|
|
q := dialect.Select(t.stringColumn).From(t.table.table).Where(t.idColumn.Eq(id))
|
|
|
|
const single = false
|
|
var ret []string
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var v string
|
|
if err := rows.Scan(&v); err != nil {
|
|
return err
|
|
}
|
|
|
|
ret = append(ret, v)
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("getting stash ids from %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *stringTable) insertJoin(ctx context.Context, id int, v string) (sql.Result, error) {
|
|
q := dialect.Insert(t.table.table).Cols(t.idColumn.GetCol(), t.stringColumn.GetCol()).Vals(
|
|
goqu.Vals{id, v},
|
|
)
|
|
ret, err := exec(ctx, q)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("inserting into %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *stringTable) insertJoins(ctx context.Context, id int, v []string) error {
|
|
for _, fk := range v {
|
|
if _, err := t.insertJoin(ctx, id, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *stringTable) replaceJoins(ctx context.Context, id int, v []string) error {
|
|
if err := t.destroy(ctx, []int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return t.insertJoins(ctx, id, v)
|
|
}
|
|
|
|
func (t *stringTable) addJoins(ctx context.Context, id int, v []string) error {
|
|
// get existing foreign keys
|
|
existing, err := t.get(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// only add values that are not already present
|
|
filtered := sliceutil.Exclude(v, existing)
|
|
return t.insertJoins(ctx, id, filtered)
|
|
}
|
|
|
|
func (t *stringTable) destroyJoins(ctx context.Context, id int, v []string) error {
|
|
for _, vv := range v {
|
|
q := dialect.Delete(t.table.table).Where(
|
|
t.idColumn.Eq(id),
|
|
t.stringColumn.Eq(vv),
|
|
)
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("destroying %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *stringTable) modifyJoins(ctx context.Context, id int, v []string, mode models.RelationshipUpdateMode) error {
|
|
switch mode {
|
|
case models.RelationshipUpdateModeSet:
|
|
return t.replaceJoins(ctx, id, v)
|
|
case models.RelationshipUpdateModeAdd:
|
|
return t.addJoins(ctx, id, v)
|
|
case models.RelationshipUpdateModeRemove:
|
|
return t.destroyJoins(ctx, id, v)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type orderedValueTable[T comparable] struct {
|
|
table
|
|
valueColumn exp.IdentifierExpression
|
|
}
|
|
|
|
func (t *orderedValueTable[T]) positionColumn() exp.IdentifierExpression {
|
|
const positionColumn = "position"
|
|
return t.table.table.Col(positionColumn)
|
|
}
|
|
|
|
func (t *orderedValueTable[T]) get(ctx context.Context, id int) ([]T, error) {
|
|
q := dialect.Select(t.valueColumn).From(t.table.table).Where(t.idColumn.Eq(id)).Order(t.positionColumn().Asc())
|
|
|
|
const single = false
|
|
var ret []T
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var v T
|
|
if err := rows.Scan(&v); err != nil {
|
|
return err
|
|
}
|
|
|
|
ret = append(ret, v)
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("getting stash ids from %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *orderedValueTable[T]) insertJoin(ctx context.Context, id int, position int, v T) (sql.Result, error) {
|
|
q := dialect.Insert(t.table.table).Cols(t.idColumn.GetCol(), t.positionColumn().GetCol(), t.valueColumn.GetCol()).Vals(
|
|
goqu.Vals{id, position, v},
|
|
)
|
|
ret, err := exec(ctx, q)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("inserting into %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *orderedValueTable[T]) insertJoins(ctx context.Context, id int, startPos int, v []T) error {
|
|
for i, fk := range v {
|
|
if _, err := t.insertJoin(ctx, id, i+startPos, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *orderedValueTable[T]) replaceJoins(ctx context.Context, id int, v []T) error {
|
|
if err := t.destroy(ctx, []int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
const startPos = 0
|
|
return t.insertJoins(ctx, id, startPos, v)
|
|
}
|
|
|
|
func (t *orderedValueTable[T]) addJoins(ctx context.Context, id int, v []T) error {
|
|
// get existing foreign keys
|
|
existing, err := t.get(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// only add values that are not already present
|
|
filtered := sliceutil.Exclude(v, existing)
|
|
|
|
if len(filtered) == 0 {
|
|
return nil
|
|
}
|
|
|
|
startPos := len(existing)
|
|
return t.insertJoins(ctx, id, startPos, filtered)
|
|
}
|
|
|
|
func (t *orderedValueTable[T]) destroyJoins(ctx context.Context, id int, v []T) error {
|
|
existing, err := t.get(ctx, id)
|
|
if err != nil {
|
|
return fmt.Errorf("getting existing %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
newValue := sliceutil.Exclude(existing, v)
|
|
if len(newValue) == len(existing) {
|
|
return nil
|
|
}
|
|
|
|
return t.replaceJoins(ctx, id, newValue)
|
|
}
|
|
|
|
func (t *orderedValueTable[T]) modifyJoins(ctx context.Context, id int, v []T, mode models.RelationshipUpdateMode) error {
|
|
switch mode {
|
|
case models.RelationshipUpdateModeSet:
|
|
return t.replaceJoins(ctx, id, v)
|
|
case models.RelationshipUpdateModeAdd:
|
|
return t.addJoins(ctx, id, v)
|
|
case models.RelationshipUpdateModeRemove:
|
|
return t.destroyJoins(ctx, id, v)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type scenesGroupsTable struct {
|
|
table
|
|
}
|
|
|
|
type groupsScenesRow struct {
|
|
SceneID null.Int `db:"scene_id"`
|
|
GroupID null.Int `db:"group_id"`
|
|
SceneIndex null.Int `db:"scene_index"`
|
|
}
|
|
|
|
func (r groupsScenesRow) resolve(sceneID int) models.GroupsScenes {
|
|
return models.GroupsScenes{
|
|
GroupID: int(r.GroupID.Int64),
|
|
SceneIndex: nullIntPtr(r.SceneIndex),
|
|
}
|
|
}
|
|
|
|
func (t *scenesGroupsTable) get(ctx context.Context, id int) ([]models.GroupsScenes, error) {
|
|
q := dialect.Select("group_id", "scene_index").From(t.table.table).Where(t.idColumn.Eq(id))
|
|
|
|
const single = false
|
|
var ret []models.GroupsScenes
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var v groupsScenesRow
|
|
if err := rows.StructScan(&v); err != nil {
|
|
return err
|
|
}
|
|
|
|
ret = append(ret, v.resolve(id))
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("getting scene groups from %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *scenesGroupsTable) insertJoin(ctx context.Context, id int, v models.GroupsScenes) (sql.Result, error) {
|
|
q := dialect.Insert(t.table.table).Cols(t.idColumn.GetCol(), "group_id", "scene_index").Vals(
|
|
goqu.Vals{id, v.GroupID, intFromPtr(v.SceneIndex)},
|
|
)
|
|
ret, err := exec(ctx, q)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("inserting into %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *scenesGroupsTable) insertJoins(ctx context.Context, id int, v []models.GroupsScenes) error {
|
|
for _, fk := range v {
|
|
if _, err := t.insertJoin(ctx, id, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *scenesGroupsTable) replaceJoins(ctx context.Context, id int, v []models.GroupsScenes) error {
|
|
if err := t.destroy(ctx, []int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
return t.insertJoins(ctx, id, v)
|
|
}
|
|
|
|
func (t *scenesGroupsTable) addJoins(ctx context.Context, id int, v []models.GroupsScenes) error {
|
|
// get existing foreign keys
|
|
fks, err := t.get(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// only add values that are not already present
|
|
var filtered []models.GroupsScenes
|
|
for _, vv := range v {
|
|
found := false
|
|
|
|
for _, e := range fks {
|
|
if vv.GroupID == e.GroupID {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
filtered = append(filtered, vv)
|
|
}
|
|
}
|
|
return t.insertJoins(ctx, id, filtered)
|
|
}
|
|
|
|
func (t *scenesGroupsTable) destroyJoins(ctx context.Context, id int, v []models.GroupsScenes) error {
|
|
for _, vv := range v {
|
|
q := dialect.Delete(t.table.table).Where(
|
|
t.idColumn.Eq(id),
|
|
t.table.table.Col("group_id").Eq(vv.GroupID),
|
|
)
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("destroying %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *scenesGroupsTable) modifyJoins(ctx context.Context, id int, v []models.GroupsScenes, mode models.RelationshipUpdateMode) error {
|
|
switch mode {
|
|
case models.RelationshipUpdateModeSet:
|
|
return t.replaceJoins(ctx, id, v)
|
|
case models.RelationshipUpdateModeAdd:
|
|
return t.addJoins(ctx, id, v)
|
|
case models.RelationshipUpdateModeRemove:
|
|
return t.destroyJoins(ctx, id, v)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type imageGalleriesTable struct {
|
|
joinTable
|
|
}
|
|
|
|
func (t *imageGalleriesTable) setCover(ctx context.Context, id int, galleryID int) error {
|
|
if err := t.resetCover(ctx, galleryID); err != nil {
|
|
return err
|
|
}
|
|
|
|
table := t.table.table
|
|
|
|
q := dialect.Update(table).Prepared(true).Set(goqu.Record{
|
|
"cover": true,
|
|
}).Where(t.idColumn.Eq(id), table.Col(galleryIDColumn).Eq(galleryID))
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("setting cover flag in %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *imageGalleriesTable) resetCover(ctx context.Context, galleryID int) error {
|
|
table := t.table.table
|
|
|
|
q := dialect.Update(table).Prepared(true).Set(goqu.Record{
|
|
"cover": false,
|
|
}).Where(
|
|
table.Col(galleryIDColumn).Eq(galleryID),
|
|
table.Col("cover").Eq(true),
|
|
)
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("unsetting cover flags in %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type relatedFilesTable struct {
|
|
table
|
|
}
|
|
|
|
// type scenesFilesRow struct {
|
|
// SceneID int `db:"scene_id"`
|
|
// Primary bool `db:"primary"`
|
|
// FileID models.FileID `db:"file_id"`
|
|
// }
|
|
|
|
// get returns the file IDs related to the provided scene ID
|
|
// the primary file is returned first
|
|
func (t *relatedFilesTable) get(ctx context.Context, id int) ([]models.FileID, error) {
|
|
q := dialect.Select("file_id").From(t.table.table).Where(t.idColumn.Eq(id)).Order(t.table.table.Col("primary").Desc())
|
|
|
|
const single = false
|
|
var ret []models.FileID
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var v models.FileID
|
|
if err := rows.Scan(&v); err != nil {
|
|
return err
|
|
}
|
|
|
|
ret = append(ret, v)
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("getting related files from %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *relatedFilesTable) insertJoin(ctx context.Context, id int, primary bool, fileID models.FileID) error {
|
|
q := dialect.Insert(t.table.table).Cols(t.idColumn.GetCol(), "primary", "file_id").Vals(
|
|
goqu.Vals{id, primary, fileID},
|
|
)
|
|
_, err := exec(ctx, q)
|
|
if err != nil {
|
|
return fmt.Errorf("inserting into %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *relatedFilesTable) insertJoins(ctx context.Context, id int, firstPrimary bool, fileIDs []models.FileID) error {
|
|
for i, fk := range fileIDs {
|
|
if err := t.insertJoin(ctx, id, firstPrimary && i == 0, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *relatedFilesTable) replaceJoins(ctx context.Context, id int, fileIDs []models.FileID) error {
|
|
if err := t.destroy(ctx, []int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
const firstPrimary = true
|
|
return t.insertJoins(ctx, id, firstPrimary, fileIDs)
|
|
}
|
|
|
|
// destroyJoins destroys all entries in the table with the provided fileIDs
|
|
func (t *relatedFilesTable) destroyJoins(ctx context.Context, fileIDs []models.FileID) error {
|
|
q := dialect.Delete(t.table.table).Where(t.table.table.Col("file_id").In(fileIDs))
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("destroying file joins in %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *relatedFilesTable) setPrimary(ctx context.Context, id int, fileID models.FileID) error {
|
|
table := t.table.table
|
|
|
|
q := dialect.Update(table).Prepared(true).Set(goqu.Record{
|
|
"primary": 0,
|
|
}).Where(t.idColumn.Eq(id), table.Col(fileIDColumn).Neq(fileID))
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("unsetting primary flags in %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
q = dialect.Update(table).Prepared(true).Set(goqu.Record{
|
|
"primary": 1,
|
|
}).Where(t.idColumn.Eq(id), table.Col(fileIDColumn).Eq(fileID))
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return fmt.Errorf("setting primary flag in %s: %w", t.table.table.GetTable(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type viewHistoryTable struct {
|
|
table
|
|
dateColumn exp.IdentifierExpression
|
|
}
|
|
|
|
func (t *viewHistoryTable) getDates(ctx context.Context, id int) ([]time.Time, error) {
|
|
table := t.table.table
|
|
|
|
q := dialect.Select(
|
|
t.dateColumn,
|
|
).From(table).Where(
|
|
t.idColumn.Eq(id),
|
|
).Order(t.dateColumn.Desc())
|
|
|
|
const single = false
|
|
var ret []time.Time
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var date Timestamp
|
|
if err := rows.Scan(&date); err != nil {
|
|
return err
|
|
}
|
|
ret = append(ret, date.Timestamp)
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *viewHistoryTable) getManyDates(ctx context.Context, ids []int) ([][]time.Time, error) {
|
|
table := t.table.table
|
|
|
|
q := dialect.Select(
|
|
t.idColumn,
|
|
t.dateColumn,
|
|
).From(table).Where(
|
|
t.idColumn.In(ids),
|
|
).Order(t.dateColumn.Desc())
|
|
|
|
ret := make([][]time.Time, len(ids))
|
|
idToIndex := idToIndexMap(ids)
|
|
|
|
const single = false
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var id int
|
|
var date Timestamp
|
|
if err := rows.Scan(&id, &date); err != nil {
|
|
return err
|
|
}
|
|
|
|
idx := idToIndex[id]
|
|
ret[idx] = append(ret[idx], date.Timestamp)
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *viewHistoryTable) getLastDate(ctx context.Context, id int) (*time.Time, error) {
|
|
table := t.table.table
|
|
q := dialect.Select(t.dateColumn).From(table).Where(
|
|
t.idColumn.Eq(id),
|
|
).Order(t.dateColumn.Desc()).Limit(1)
|
|
|
|
var date NullTimestamp
|
|
if err := querySimple(ctx, q, &date); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return date.TimePtr(), nil
|
|
}
|
|
|
|
func (t *viewHistoryTable) getManyLastDate(ctx context.Context, ids []int) ([]*time.Time, error) {
|
|
table := t.table.table
|
|
|
|
q := dialect.Select(
|
|
t.idColumn,
|
|
goqu.MAX(t.dateColumn),
|
|
).From(table).Where(
|
|
t.idColumn.In(ids),
|
|
).GroupBy(t.idColumn)
|
|
|
|
ret := make([]*time.Time, len(ids))
|
|
idToIndex := idToIndexMap(ids)
|
|
|
|
const single = false
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var id int
|
|
|
|
// MAX appears to return a string, so handle it manually
|
|
var dateString string
|
|
|
|
if err := rows.Scan(&id, &dateString); err != nil {
|
|
return err
|
|
}
|
|
|
|
t, err := time.Parse(TimestampFormat, dateString)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing date %v: %w", dateString, err)
|
|
}
|
|
|
|
idx := idToIndex[id]
|
|
ret[idx] = &t
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *viewHistoryTable) getCount(ctx context.Context, id int) (int, error) {
|
|
table := t.table.table
|
|
q := dialect.Select(goqu.COUNT("*")).From(table).Where(t.idColumn.Eq(id))
|
|
|
|
const single = true
|
|
var ret int
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
if err := rows.Scan(&ret); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *viewHistoryTable) getManyCount(ctx context.Context, ids []int) ([]int, error) {
|
|
table := t.table.table
|
|
|
|
q := dialect.Select(
|
|
t.idColumn,
|
|
goqu.COUNT(t.dateColumn),
|
|
).From(table).Where(
|
|
t.idColumn.In(ids),
|
|
).GroupBy(t.idColumn)
|
|
|
|
ret := make([]int, len(ids))
|
|
idToIndex := idToIndexMap(ids)
|
|
|
|
const single = false
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
var id int
|
|
var count int
|
|
if err := rows.Scan(&id, &count); err != nil {
|
|
return err
|
|
}
|
|
|
|
idx := idToIndex[id]
|
|
ret[idx] = count
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *viewHistoryTable) getAllCount(ctx context.Context) (int, error) {
|
|
table := t.table.table
|
|
q := dialect.Select(goqu.COUNT("*")).From(table)
|
|
|
|
const single = true
|
|
var ret int
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
if err := rows.Scan(&ret); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *viewHistoryTable) getUniqueCount(ctx context.Context) (int, error) {
|
|
table := t.table.table
|
|
q := dialect.Select(goqu.COUNT(goqu.DISTINCT(t.idColumn))).From(table)
|
|
|
|
const single = true
|
|
var ret int
|
|
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
|
if err := rows.Scan(&ret); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (t *viewHistoryTable) addDates(ctx context.Context, id int, dates []time.Time) ([]time.Time, error) {
|
|
table := t.table.table
|
|
|
|
if len(dates) == 0 {
|
|
dates = []time.Time{time.Now()}
|
|
}
|
|
|
|
for _, d := range dates {
|
|
q := dialect.Insert(table).Cols(t.idColumn.GetCol(), t.dateColumn.GetCol()).Vals(
|
|
// convert all dates to UTC
|
|
goqu.Vals{id, UTCTimestamp{Timestamp{d}}},
|
|
)
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return nil, fmt.Errorf("inserting into %s: %w", table.GetTable(), err)
|
|
}
|
|
}
|
|
|
|
return t.getDates(ctx, id)
|
|
}
|
|
|
|
func (t *viewHistoryTable) deleteDates(ctx context.Context, id int, dates []time.Time) ([]time.Time, error) {
|
|
table := t.table.table
|
|
|
|
mostRecent := false
|
|
if len(dates) == 0 {
|
|
mostRecent = true
|
|
dates = []time.Time{time.Now()}
|
|
}
|
|
|
|
for _, date := range dates {
|
|
var subquery *goqu.SelectDataset
|
|
if mostRecent {
|
|
// delete the most recent
|
|
subquery = dialect.Select("rowid").From(table).Where(
|
|
t.idColumn.Eq(id),
|
|
).Order(t.dateColumn.Desc()).Limit(1)
|
|
} else {
|
|
subquery = dialect.Select("rowid").From(table).Where(
|
|
t.idColumn.Eq(id),
|
|
t.dateColumn.Eq(UTCTimestamp{Timestamp{date}}),
|
|
).Limit(1)
|
|
}
|
|
|
|
q := dialect.Delete(table).Where(goqu.I("rowid").Eq(subquery))
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return nil, fmt.Errorf("deleting from %s: %w", table.GetTable(), err)
|
|
}
|
|
}
|
|
|
|
return t.getDates(ctx, id)
|
|
}
|
|
|
|
func (t *viewHistoryTable) deleteAllDates(ctx context.Context, id int) (int, error) {
|
|
table := t.table.table
|
|
q := dialect.Delete(table).Where(t.idColumn.Eq(id))
|
|
|
|
if _, err := exec(ctx, q); err != nil {
|
|
return 0, fmt.Errorf("resetting dates for id %v: %w", id, err)
|
|
}
|
|
|
|
return t.getCount(ctx, id)
|
|
}
|
|
|
|
type sqler interface {
|
|
ToSQL() (sql string, params []interface{}, err error)
|
|
}
|
|
|
|
func exec(ctx context.Context, stmt sqler) (sql.Result, error) {
|
|
tx, err := getTx(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sql, args, err := stmt.ToSQL()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating sql: %w", err)
|
|
}
|
|
|
|
logger.Tracef("SQL: %s [%v]", sql, args)
|
|
ret, err := tx.ExecContext(ctx, sql, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("executing `%s` [%v]: %w", sql, args, err)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func count(ctx context.Context, q *goqu.SelectDataset) (int, error) {
|
|
var count int
|
|
if err := querySimple(ctx, q, &count); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func queryFunc(ctx context.Context, query *goqu.SelectDataset, single bool, f func(rows *sqlx.Rows) error) error {
|
|
q, args, err := query.ToSQL()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := dbWrapper.QueryxContext(ctx, q, args...)
|
|
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return fmt.Errorf("querying `%s` [%v]: %w", q, args, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
if err := f(rows); err != nil {
|
|
return err
|
|
}
|
|
if single {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func querySimple(ctx context.Context, query *goqu.SelectDataset, out interface{}) error {
|
|
q, args, err := query.ToSQL()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := dbWrapper.QueryxContext(ctx, q, args...)
|
|
if err != nil {
|
|
return fmt.Errorf("querying `%s` [%v]: %w", q, args, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() {
|
|
if err := rows.Scan(out); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// func cols(table exp.IdentifierExpression, cols []string) []interface{} {
|
|
// var ret []interface{}
|
|
// for _, c := range cols {
|
|
// ret = append(ret, table.Col(c))
|
|
// }
|
|
// return ret
|
|
// }
|