mirror of https://github.com/stashapp/stash.git
365 lines
12 KiB
Go
365 lines
12 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"fmt"
|
|
"math/rand"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/stashapp/stash/pkg/models"
|
|
)
|
|
|
|
func selectAll(tableName string) string {
|
|
idColumn := getColumn(tableName, "*")
|
|
return "SELECT " + idColumn + " FROM " + tableName + " "
|
|
}
|
|
|
|
func distinctIDs(qb *queryBuilder, tableName string) {
|
|
qb.addColumn("DISTINCT " + getColumn(tableName, "id"))
|
|
qb.from = tableName
|
|
}
|
|
|
|
func selectIDs(qb *queryBuilder, tableName string) {
|
|
qb.addColumn(getColumn(tableName, "id"))
|
|
qb.from = tableName
|
|
}
|
|
|
|
func getColumn(tableName string, columnName string) string {
|
|
return tableName + "." + columnName
|
|
}
|
|
|
|
func getPagination(findFilter *models.FindFilterType) string {
|
|
if findFilter == nil {
|
|
panic("nil find filter for pagination")
|
|
}
|
|
|
|
if findFilter.IsGetAll() {
|
|
return " "
|
|
}
|
|
|
|
return getPaginationSQL(findFilter.GetPage(), findFilter.GetPageSize())
|
|
}
|
|
|
|
func getPaginationSQL(page int, perPage int) string {
|
|
page = (page - 1) * perPage
|
|
return " LIMIT " + strconv.Itoa(perPage) + " OFFSET " + strconv.Itoa(page) + " "
|
|
}
|
|
|
|
const randomSeedPrefix = "random_" // prefix for random sort
|
|
|
|
type sortOptions []string
|
|
|
|
func (o sortOptions) validateSort(sort string) error {
|
|
if strings.HasPrefix(sort, randomSeedPrefix) {
|
|
// seed as a parameter from the UI
|
|
seedStr := sort[len(randomSeedPrefix):]
|
|
_, err := strconv.ParseUint(seedStr, 10, 64)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid random seed: %s", seedStr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
for _, v := range o {
|
|
if v == sort {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("invalid sort: %s", sort)
|
|
}
|
|
|
|
func getSortDirection(direction string) string {
|
|
if direction != "ASC" && direction != "DESC" {
|
|
return "ASC"
|
|
} else {
|
|
return direction
|
|
}
|
|
}
|
|
func getSort(sort string, direction string, tableName string) string {
|
|
direction = getSortDirection(direction)
|
|
|
|
switch {
|
|
case strings.HasSuffix(sort, "_count"):
|
|
var relationTableName = strings.TrimSuffix(sort, "_count") // TODO: pluralize?
|
|
colName := getColumn(relationTableName, "id")
|
|
return " ORDER BY COUNT(distinct " + colName + ") " + direction
|
|
case strings.Compare(sort, "filesize") == 0:
|
|
colName := getColumn(tableName, "size")
|
|
return " ORDER BY " + colName + " " + direction
|
|
case strings.HasPrefix(sort, randomSeedPrefix):
|
|
// seed as a parameter from the UI
|
|
seedStr := sort[len(randomSeedPrefix):]
|
|
seed, err := strconv.ParseUint(seedStr, 10, 64)
|
|
if err != nil {
|
|
// fallback to a random seed
|
|
seed = rand.Uint64()
|
|
}
|
|
return getRandomSort(tableName, direction, seed)
|
|
case strings.Compare(sort, "random") == 0:
|
|
return getRandomSort(tableName, direction, rand.Uint64())
|
|
default:
|
|
colName := getColumn(tableName, sort)
|
|
if strings.Contains(sort, ".") {
|
|
colName = sort
|
|
}
|
|
if strings.Compare(sort, "name") == 0 {
|
|
return " ORDER BY " + colName + " COLLATE NATURAL_CI " + direction
|
|
}
|
|
if strings.Compare(sort, "title") == 0 {
|
|
return " ORDER BY " + colName + " COLLATE NATURAL_CI " + direction
|
|
}
|
|
|
|
return " ORDER BY " + colName + " " + direction
|
|
}
|
|
}
|
|
|
|
func getRandomSort(tableName string, direction string, seed uint64) string {
|
|
// cap seed at 10^8
|
|
seed %= 1e8
|
|
|
|
colName := getColumn(tableName, "id")
|
|
|
|
// https://stackoverflow.com/questions/21949795#comment33255354_21949859
|
|
// p1 := 52959209
|
|
// p2 := 1047483763
|
|
// p3 := 2147483647
|
|
// n := <colName>
|
|
// ORDER BY ((n+seed)*(n+seed)*p1 + (n+seed)*p2) % p3
|
|
// since sqlite converts overflowing numbers to reals, a custom db function that uses uints with overflow should be faster,
|
|
// however in practice the overhead of calling a custom function vastly outweighs the benefits
|
|
return fmt.Sprintf(" ORDER BY mod((%[1]s + %[2]d) * (%[1]s + %[2]d) * 52959209 + (%[1]s + %[2]d) * 1047483763, 2147483647) %[3]s", colName, seed, direction)
|
|
}
|
|
|
|
func getCountSort(primaryTable, joinTable, primaryFK, direction string) string {
|
|
return fmt.Sprintf(" ORDER BY (SELECT COUNT(*) FROM %s AS sort WHERE sort.%s = %s.id) %s", joinTable, primaryFK, primaryTable, getSortDirection(direction))
|
|
}
|
|
|
|
func getStringSearchClause(columns []string, q string, not bool) sqlClause {
|
|
var likeClauses []string
|
|
var args []interface{}
|
|
|
|
notStr := ""
|
|
binaryType := " OR "
|
|
if not {
|
|
notStr = " NOT"
|
|
binaryType = " AND "
|
|
}
|
|
q = strings.TrimSpace(q)
|
|
trimmedQuery := strings.Trim(q, "\"")
|
|
|
|
if trimmedQuery == q {
|
|
q = regexp.MustCompile(`\s+`).ReplaceAllString(q, " ")
|
|
queryWords := strings.Split(q, " ")
|
|
// Search for any word
|
|
for _, word := range queryWords {
|
|
for _, column := range columns {
|
|
likeClauses = append(likeClauses, column+notStr+" LIKE ?")
|
|
args = append(args, "%"+word+"%")
|
|
}
|
|
}
|
|
} else {
|
|
// Search the exact query
|
|
for _, column := range columns {
|
|
likeClauses = append(likeClauses, column+notStr+" LIKE ?")
|
|
args = append(args, "%"+trimmedQuery+"%")
|
|
}
|
|
}
|
|
likes := strings.Join(likeClauses, binaryType)
|
|
|
|
return makeClause("("+likes+")", args...)
|
|
}
|
|
|
|
func getEnumSearchClause(column string, enumVals []string, not bool) sqlClause {
|
|
var args []interface{}
|
|
|
|
notStr := ""
|
|
if not {
|
|
notStr = " NOT"
|
|
}
|
|
|
|
clause := fmt.Sprintf("(%s%s IN %s)", column, notStr, getInBinding(len(enumVals)))
|
|
for _, enumVal := range enumVals {
|
|
args = append(args, enumVal)
|
|
}
|
|
|
|
return makeClause(clause, args...)
|
|
}
|
|
|
|
func getInBinding(length int) string {
|
|
bindings := strings.Repeat("?, ", length)
|
|
bindings = strings.TrimRight(bindings, ", ")
|
|
return "(" + bindings + ")"
|
|
}
|
|
|
|
func getIntCriterionWhereClause(column string, input models.IntCriterionInput) (string, []interface{}) {
|
|
return getIntWhereClause(column, input.Modifier, input.Value, input.Value2)
|
|
}
|
|
|
|
func getIntWhereClause(column string, modifier models.CriterionModifier, value int, upper *int) (string, []interface{}) {
|
|
if upper == nil {
|
|
u := 0
|
|
upper = &u
|
|
}
|
|
|
|
args := []interface{}{value, *upper}
|
|
return getNumericWhereClause(column, modifier, args)
|
|
}
|
|
|
|
func getFloatCriterionWhereClause(column string, input models.FloatCriterionInput) (string, []interface{}) {
|
|
return getFloatWhereClause(column, input.Modifier, input.Value, input.Value2)
|
|
}
|
|
|
|
func getFloatWhereClause(column string, modifier models.CriterionModifier, value float64, upper *float64) (string, []interface{}) {
|
|
if upper == nil {
|
|
u := 0.0
|
|
upper = &u
|
|
}
|
|
|
|
args := []interface{}{value, *upper}
|
|
return getNumericWhereClause(column, modifier, args)
|
|
}
|
|
|
|
func getNumericWhereClause(column string, modifier models.CriterionModifier, args []interface{}) (string, []interface{}) {
|
|
singleArgs := args[0:1]
|
|
|
|
switch modifier {
|
|
case models.CriterionModifierIsNull:
|
|
return fmt.Sprintf("%s IS NULL", column), nil
|
|
case models.CriterionModifierNotNull:
|
|
return fmt.Sprintf("%s IS NOT NULL", column), nil
|
|
case models.CriterionModifierEquals:
|
|
return fmt.Sprintf("%s = ?", column), singleArgs
|
|
case models.CriterionModifierNotEquals:
|
|
return fmt.Sprintf("%s != ?", column), singleArgs
|
|
case models.CriterionModifierBetween:
|
|
return fmt.Sprintf("%s BETWEEN ? AND ?", column), args
|
|
case models.CriterionModifierNotBetween:
|
|
return fmt.Sprintf("%s NOT BETWEEN ? AND ?", column), args
|
|
case models.CriterionModifierLessThan:
|
|
return fmt.Sprintf("%s < ?", column), singleArgs
|
|
case models.CriterionModifierGreaterThan:
|
|
return fmt.Sprintf("%s > ?", column), singleArgs
|
|
}
|
|
|
|
panic("unsupported numeric modifier type " + modifier)
|
|
}
|
|
|
|
func getDateCriterionWhereClause(column string, input models.DateCriterionInput) (string, []interface{}) {
|
|
return getDateWhereClause(column, input.Modifier, input.Value, input.Value2)
|
|
}
|
|
|
|
func getDateWhereClause(column string, modifier models.CriterionModifier, value string, upper *string) (string, []interface{}) {
|
|
if upper == nil {
|
|
u := time.Now().AddDate(0, 0, 1).Format(time.RFC3339)
|
|
upper = &u
|
|
}
|
|
|
|
args := []interface{}{value}
|
|
betweenArgs := []interface{}{value, *upper}
|
|
|
|
switch modifier {
|
|
case models.CriterionModifierIsNull:
|
|
return fmt.Sprintf("(%s IS NULL OR %s = '')", column, column), nil
|
|
case models.CriterionModifierNotNull:
|
|
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", column, column), nil
|
|
case models.CriterionModifierEquals:
|
|
return fmt.Sprintf("%s = ?", column), args
|
|
case models.CriterionModifierNotEquals:
|
|
return fmt.Sprintf("%s != ?", column), args
|
|
case models.CriterionModifierBetween:
|
|
return fmt.Sprintf("%s BETWEEN ? AND ?", column), betweenArgs
|
|
case models.CriterionModifierNotBetween:
|
|
return fmt.Sprintf("%s NOT BETWEEN ? AND ?", column), betweenArgs
|
|
case models.CriterionModifierLessThan:
|
|
return fmt.Sprintf("%s < ?", column), args
|
|
case models.CriterionModifierGreaterThan:
|
|
return fmt.Sprintf("%s > ?", column), args
|
|
}
|
|
|
|
panic("unsupported date modifier type")
|
|
}
|
|
|
|
func getTimestampCriterionWhereClause(column string, input models.TimestampCriterionInput) (string, []interface{}) {
|
|
return getTimestampWhereClause(column, input.Modifier, input.Value, input.Value2)
|
|
}
|
|
|
|
func getTimestampWhereClause(column string, modifier models.CriterionModifier, value string, upper *string) (string, []interface{}) {
|
|
if upper == nil {
|
|
u := time.Now().AddDate(0, 0, 1).Format(time.RFC3339)
|
|
upper = &u
|
|
}
|
|
|
|
args := []interface{}{value}
|
|
betweenArgs := []interface{}{value, *upper}
|
|
|
|
switch modifier {
|
|
case models.CriterionModifierIsNull:
|
|
return fmt.Sprintf("%s IS NULL", column), nil
|
|
case models.CriterionModifierNotNull:
|
|
return fmt.Sprintf("%s IS NOT NULL", column), nil
|
|
case models.CriterionModifierEquals:
|
|
return fmt.Sprintf("%s = ?", column), args
|
|
case models.CriterionModifierNotEquals:
|
|
return fmt.Sprintf("%s != ?", column), args
|
|
case models.CriterionModifierBetween:
|
|
return fmt.Sprintf("%s BETWEEN ? AND ?", column), betweenArgs
|
|
case models.CriterionModifierNotBetween:
|
|
return fmt.Sprintf("%s NOT BETWEEN ? AND ?", column), betweenArgs
|
|
case models.CriterionModifierLessThan:
|
|
return fmt.Sprintf("%s < ?", column), args
|
|
case models.CriterionModifierGreaterThan:
|
|
return fmt.Sprintf("%s > ?", column), args
|
|
}
|
|
|
|
panic("unsupported date modifier type")
|
|
}
|
|
|
|
// returns where clause and having clause
|
|
func getMultiCriterionClause(primaryTable, foreignTable, joinTable, primaryFK, foreignFK string, criterion *models.MultiCriterionInput) (string, string) {
|
|
whereClause := ""
|
|
havingClause := ""
|
|
switch criterion.Modifier {
|
|
case models.CriterionModifierIncludes:
|
|
// includes any of the provided ids
|
|
if joinTable != "" {
|
|
whereClause = joinTable + "." + foreignFK + " IN " + getInBinding(len(criterion.Value))
|
|
} else {
|
|
whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value))
|
|
}
|
|
case models.CriterionModifierIncludesAll:
|
|
// includes all of the provided ids
|
|
if joinTable != "" {
|
|
whereClause = joinTable + "." + foreignFK + " IN " + getInBinding(len(criterion.Value))
|
|
havingClause = "count(distinct " + joinTable + "." + foreignFK + ") IS " + strconv.Itoa(len(criterion.Value))
|
|
} else {
|
|
whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value))
|
|
havingClause = "count(distinct " + foreignTable + ".id) IS " + strconv.Itoa(len(criterion.Value))
|
|
}
|
|
case models.CriterionModifierExcludes:
|
|
// excludes all of the provided ids
|
|
if joinTable != "" {
|
|
whereClause = primaryTable + ".id not in (select " + joinTable + "." + primaryFK + " from " + joinTable + " where " + joinTable + "." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")"
|
|
} else {
|
|
whereClause = "not exists (select s.id from " + primaryTable + " as s where s.id = " + primaryTable + ".id and s." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")"
|
|
}
|
|
}
|
|
|
|
return whereClause, havingClause
|
|
}
|
|
|
|
func getCountCriterionClause(primaryTable, joinTable, primaryFK string, criterion models.IntCriterionInput) (string, []interface{}) {
|
|
lhs := fmt.Sprintf("(SELECT COUNT(*) FROM %s s WHERE s.%s = %s.id)", joinTable, primaryFK, primaryTable)
|
|
return getIntCriterionWhereClause(lhs, criterion)
|
|
}
|
|
|
|
func coalesce(column string) string {
|
|
return fmt.Sprintf("COALESCE(%s, '')", column)
|
|
}
|
|
|
|
func like(v string) string {
|
|
return "%" + v + "%"
|
|
}
|