mirror of https://github.com/stashapp/stash.git
538 lines
13 KiB
Go
538 lines
13 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
"github.com/stashapp/stash/pkg/logger"
|
|
"github.com/stashapp/stash/pkg/models"
|
|
)
|
|
|
|
const idColumn = "id"
|
|
|
|
type objectList interface {
|
|
Append(o interface{})
|
|
New() interface{}
|
|
}
|
|
|
|
type repository struct {
|
|
tx dbi
|
|
tableName string
|
|
idColumn string
|
|
}
|
|
|
|
func (r *repository) get(id int, dest interface{}) error {
|
|
stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ? LIMIT 1", r.tableName, r.idColumn)
|
|
return r.tx.Get(dest, stmt, id)
|
|
}
|
|
|
|
func (r *repository) getAll(id int, f func(rows *sqlx.Rows) error) error {
|
|
stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", r.tableName, r.idColumn)
|
|
return r.queryFunc(stmt, []interface{}{id}, false, f)
|
|
}
|
|
|
|
func (r *repository) insert(obj interface{}) (sql.Result, error) {
|
|
stmt := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", r.tableName, listKeys(obj, false), listKeys(obj, true))
|
|
return r.tx.NamedExec(stmt, obj)
|
|
}
|
|
|
|
func (r *repository) insertObject(obj interface{}, out interface{}) error {
|
|
result, err := r.insert(obj)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return r.get(int(id), out)
|
|
}
|
|
|
|
func (r *repository) update(id int, obj interface{}, partial bool) error {
|
|
exists, err := r.exists(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !exists {
|
|
return fmt.Errorf("%s %d does not exist in %s", r.idColumn, id, r.tableName)
|
|
}
|
|
|
|
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSet(obj, partial), r.tableName, r.idColumn)
|
|
_, err = r.tx.NamedExec(stmt, obj)
|
|
|
|
return err
|
|
}
|
|
|
|
func (r *repository) updateMap(id int, m map[string]interface{}) error {
|
|
exists, err := r.exists(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !exists {
|
|
return fmt.Errorf("%s %d does not exist in %s", r.idColumn, id, r.tableName)
|
|
}
|
|
|
|
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSetMap(m), r.tableName, r.idColumn)
|
|
_, err = r.tx.NamedExec(stmt, m)
|
|
|
|
return err
|
|
}
|
|
|
|
func (r *repository) destroyExisting(ids []int) error {
|
|
for _, id := range ids {
|
|
exists, err := r.exists(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !exists {
|
|
return fmt.Errorf("%s %d does not exist in %s", r.idColumn, id, r.tableName)
|
|
}
|
|
}
|
|
|
|
return r.destroy(ids)
|
|
}
|
|
|
|
func (r *repository) destroy(ids []int) error {
|
|
for _, id := range ids {
|
|
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", r.tableName, r.idColumn)
|
|
if _, err := r.tx.Exec(stmt, id); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *repository) exists(id int) (bool, error) {
|
|
stmt := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ? LIMIT 1", r.idColumn, r.tableName, r.idColumn)
|
|
stmt = r.buildCountQuery(stmt)
|
|
|
|
c, err := r.runCountQuery(stmt, []interface{}{id})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return c == 1, nil
|
|
}
|
|
|
|
func (r *repository) buildCountQuery(query string) string {
|
|
return "SELECT COUNT(*) as count FROM (" + query + ") as temp"
|
|
}
|
|
|
|
func (r *repository) runCountQuery(query string, args []interface{}) (int, error) {
|
|
result := struct {
|
|
Int int `db:"count"`
|
|
}{0}
|
|
|
|
// Perform query and fetch result
|
|
if err := r.tx.Get(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return 0, err
|
|
}
|
|
|
|
return result.Int, nil
|
|
}
|
|
|
|
func (r *repository) runIdsQuery(query string, args []interface{}) ([]int, error) {
|
|
var result []struct {
|
|
Int int `db:"id"`
|
|
}
|
|
|
|
if err := r.tx.Select(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return []int{}, err
|
|
}
|
|
|
|
vsm := make([]int, len(result))
|
|
for i, v := range result {
|
|
vsm[i] = v.Int
|
|
}
|
|
return vsm, nil
|
|
}
|
|
|
|
func (r *repository) runSumQuery(query string, args []interface{}) (float64, error) {
|
|
// Perform query and fetch result
|
|
result := struct {
|
|
Float64 float64 `db:"sum"`
|
|
}{0}
|
|
|
|
// Perform query and fetch result
|
|
if err := r.tx.Get(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return 0, err
|
|
}
|
|
|
|
return result.Float64, nil
|
|
}
|
|
|
|
func (r *repository) queryFunc(query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error {
|
|
logger.Tracef("SQL: %s, args: %v", query, args)
|
|
|
|
rows, err := r.tx.Queryx(query, args...)
|
|
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
if err := f(rows); err != nil {
|
|
return err
|
|
}
|
|
if single {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *repository) query(query string, args []interface{}, out objectList) error {
|
|
return r.queryFunc(query, args, false, func(rows *sqlx.Rows) error {
|
|
object := out.New()
|
|
if err := rows.StructScan(object); err != nil {
|
|
return err
|
|
}
|
|
out.Append(object)
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *repository) queryStruct(query string, args []interface{}, out interface{}) error {
|
|
return r.queryFunc(query, args, true, func(rows *sqlx.Rows) error {
|
|
if err := rows.StructScan(out); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (r *repository) querySimple(query string, args []interface{}, out interface{}) error {
|
|
rows, err := r.tx.Queryx(query, args...)
|
|
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if rows.Next() {
|
|
if err := rows.Scan(out); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *repository) buildQueryBody(body string, whereClauses []string, havingClauses []string) string {
|
|
if len(whereClauses) > 0 {
|
|
body = body + " WHERE " + strings.Join(whereClauses, " AND ") // TODO handle AND or OR
|
|
}
|
|
if len(havingClauses) > 0 {
|
|
body = body + " GROUP BY " + r.tableName + ".id "
|
|
body = body + " HAVING " + strings.Join(havingClauses, " AND ") // TODO handle AND or OR
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) {
|
|
body = r.buildQueryBody(body, whereClauses, havingClauses)
|
|
|
|
withClause := ""
|
|
if len(withClauses) > 0 {
|
|
var recursive string
|
|
if recursiveWith {
|
|
recursive = " RECURSIVE "
|
|
}
|
|
withClause = "WITH " + recursive + strings.Join(withClauses, ", ") + " "
|
|
}
|
|
|
|
countQuery := withClause + r.buildCountQuery(body)
|
|
idsQuery := withClause + body + sortAndPagination
|
|
|
|
// Perform query and fetch result
|
|
logger.Tracef("SQL: %s, args: %v", idsQuery, args)
|
|
|
|
var countResult int
|
|
var countErr error
|
|
var idsResult []int
|
|
var idsErr error
|
|
|
|
countResult, countErr = r.runCountQuery(countQuery, args)
|
|
idsResult, idsErr = r.runIdsQuery(idsQuery, args)
|
|
|
|
if countErr != nil {
|
|
return nil, 0, fmt.Errorf("error executing count query with SQL: %s, args: %v, error: %s", countQuery, args, countErr.Error())
|
|
}
|
|
if idsErr != nil {
|
|
return nil, 0, fmt.Errorf("error executing find query with SQL: %s, args: %v, error: %s", idsQuery, args, idsErr.Error())
|
|
}
|
|
|
|
return idsResult, countResult, nil
|
|
}
|
|
|
|
func (r *repository) newQuery() queryBuilder {
|
|
return queryBuilder{
|
|
repository: r,
|
|
}
|
|
}
|
|
|
|
func (r *repository) join(j joiner, as string, parentIDCol string) {
|
|
t := r.tableName
|
|
if as != "" {
|
|
t = as
|
|
}
|
|
j.addLeftJoin(r.tableName, as, fmt.Sprintf("%s.%s = %s", t, r.idColumn, parentIDCol))
|
|
}
|
|
|
|
//nolint:golint,unused
|
|
func (r *repository) innerJoin(j joiner, as string, parentIDCol string) {
|
|
t := r.tableName
|
|
if as != "" {
|
|
t = as
|
|
}
|
|
j.addInnerJoin(r.tableName, as, fmt.Sprintf("%s.%s = %s", t, r.idColumn, parentIDCol))
|
|
}
|
|
|
|
type joiner interface {
|
|
addLeftJoin(table, as, onClause string)
|
|
addInnerJoin(table, as, onClause string)
|
|
}
|
|
|
|
type joinRepository struct {
|
|
repository
|
|
fkColumn string
|
|
}
|
|
|
|
func (r *joinRepository) getIDs(id int) ([]int, error) {
|
|
query := fmt.Sprintf(`SELECT %s as id from %s WHERE %s = ?`, r.fkColumn, r.tableName, r.idColumn)
|
|
return r.runIdsQuery(query, []interface{}{id})
|
|
}
|
|
|
|
func (r *joinRepository) insert(id, foreignID int) (sql.Result, error) {
|
|
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.fkColumn)
|
|
return r.tx.Exec(stmt, id, foreignID)
|
|
}
|
|
|
|
func (r *joinRepository) replace(id int, foreignIDs []int) error {
|
|
if err := r.destroy([]int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, fk := range foreignIDs {
|
|
if _, err := r.insert(id, fk); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type imageRepository struct {
|
|
repository
|
|
imageColumn string
|
|
}
|
|
|
|
func (r *imageRepository) get(id int) ([]byte, error) {
|
|
query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.imageColumn, r.tableName, r.idColumn)
|
|
var ret []byte
|
|
err := r.querySimple(query, []interface{}{id}, &ret)
|
|
return ret, err
|
|
}
|
|
|
|
func (r *imageRepository) replace(id int, image []byte) error {
|
|
if err := r.destroy([]int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.imageColumn)
|
|
_, err := r.tx.Exec(stmt, id, image)
|
|
|
|
return err
|
|
}
|
|
|
|
type captionRepository struct {
|
|
repository
|
|
}
|
|
|
|
func (r *captionRepository) get(id int) ([]*models.SceneCaption, error) {
|
|
query := fmt.Sprintf("SELECT %s, %s, %s from %s WHERE %s = ?", sceneCaptionCodeColumn, sceneCaptionFilenameColumn, sceneCaptionTypeColumn, r.tableName, r.idColumn)
|
|
var ret []*models.SceneCaption
|
|
err := r.queryFunc(query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
|
var captionCode string
|
|
var captionFilename string
|
|
var captionType string
|
|
|
|
if err := rows.Scan(&captionCode, &captionFilename, &captionType); err != nil {
|
|
return err
|
|
}
|
|
|
|
caption := &models.SceneCaption{
|
|
LanguageCode: captionCode,
|
|
Filename: captionFilename,
|
|
CaptionType: captionType,
|
|
}
|
|
ret = append(ret, caption)
|
|
return nil
|
|
})
|
|
return ret, err
|
|
}
|
|
|
|
func (r *captionRepository) insert(id int, caption *models.SceneCaption) (sql.Result, error) {
|
|
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?)", r.tableName, r.idColumn, sceneCaptionCodeColumn, sceneCaptionFilenameColumn, sceneCaptionTypeColumn)
|
|
return r.tx.Exec(stmt, id, caption.LanguageCode, caption.Filename, caption.CaptionType)
|
|
}
|
|
|
|
func (r *captionRepository) replace(id int, captions []*models.SceneCaption) error {
|
|
if err := r.destroy([]int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, caption := range captions {
|
|
if _, err := r.insert(id, caption); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type stringRepository struct {
|
|
repository
|
|
stringColumn string
|
|
}
|
|
|
|
func (r *stringRepository) get(id int) ([]string, error) {
|
|
query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.stringColumn, r.tableName, r.idColumn)
|
|
var ret []string
|
|
err := r.queryFunc(query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
|
var out string
|
|
if err := rows.Scan(&out); err != nil {
|
|
return err
|
|
}
|
|
|
|
ret = append(ret, out)
|
|
return nil
|
|
})
|
|
return ret, err
|
|
}
|
|
|
|
func (r *stringRepository) insert(id int, s string) (sql.Result, error) {
|
|
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.stringColumn)
|
|
return r.tx.Exec(stmt, id, s)
|
|
}
|
|
|
|
func (r *stringRepository) replace(id int, newStrings []string) error {
|
|
if err := r.destroy([]int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, s := range newStrings {
|
|
if _, err := r.insert(id, s); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type stashIDRepository struct {
|
|
repository
|
|
}
|
|
|
|
type stashIDs []*models.StashID
|
|
|
|
func (s *stashIDs) Append(o interface{}) {
|
|
*s = append(*s, o.(*models.StashID))
|
|
}
|
|
|
|
func (s *stashIDs) New() interface{} {
|
|
return &models.StashID{}
|
|
}
|
|
|
|
func (r *stashIDRepository) get(id int) ([]*models.StashID, error) {
|
|
query := fmt.Sprintf("SELECT stash_id, endpoint from %s WHERE %s = ?", r.tableName, r.idColumn)
|
|
var ret stashIDs
|
|
err := r.query(query, []interface{}{id}, &ret)
|
|
return []*models.StashID(ret), err
|
|
}
|
|
|
|
func (r *stashIDRepository) replace(id int, newIDs []models.StashID) error {
|
|
if err := r.destroy([]int{id}); err != nil {
|
|
return err
|
|
}
|
|
|
|
query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn)
|
|
for _, stashID := range newIDs {
|
|
_, err := r.tx.Exec(query, id, stashID.Endpoint, stashID.StashID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func listKeys(i interface{}, addPrefix bool) string {
|
|
var query []string
|
|
v := reflect.ValueOf(i)
|
|
for i := 0; i < v.NumField(); i++ {
|
|
// Get key for struct tag
|
|
rawKey := v.Type().Field(i).Tag.Get("db")
|
|
key := strings.Split(rawKey, ",")[0]
|
|
if key == "id" {
|
|
continue
|
|
}
|
|
if addPrefix {
|
|
key = ":" + key
|
|
}
|
|
query = append(query, key)
|
|
}
|
|
return strings.Join(query, ", ")
|
|
}
|
|
|
|
func updateSet(i interface{}, partial bool) string {
|
|
var query []string
|
|
v := reflect.ValueOf(i)
|
|
for i := 0; i < v.NumField(); i++ {
|
|
// Get key for struct tag
|
|
rawKey := v.Type().Field(i).Tag.Get("db")
|
|
key := strings.Split(rawKey, ",")[0]
|
|
if key == "id" {
|
|
continue
|
|
}
|
|
|
|
add := true
|
|
if partial {
|
|
reflectValue := reflect.ValueOf(v.Field(i).Interface())
|
|
add = !reflectValue.IsNil()
|
|
}
|
|
|
|
if add {
|
|
query = append(query, fmt.Sprintf("%s=:%s", key, key))
|
|
}
|
|
}
|
|
return strings.Join(query, ", ")
|
|
}
|
|
|
|
func updateSetMap(m map[string]interface{}) string {
|
|
var query []string
|
|
for k := range m {
|
|
query = append(query, fmt.Sprintf("%s=:%s", k, k))
|
|
}
|
|
return strings.Join(query, ", ")
|
|
}
|