mirror of https://github.com/stashapp/stash.git
130 lines
2.7 KiB
Go
130 lines
2.7 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"runtime/debug"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/mattn/go-sqlite3"
|
|
"github.com/stashapp/stash/pkg/logger"
|
|
"github.com/stashapp/stash/pkg/models"
|
|
)
|
|
|
|
type key int
|
|
|
|
const (
|
|
txnKey key = iota + 1
|
|
dbKey
|
|
hookManagerKey
|
|
)
|
|
|
|
func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) {
|
|
// if we are already in a transaction or have a database already, just use it
|
|
if tx, _ := getDBReader(ctx); tx != nil {
|
|
return ctx, nil
|
|
}
|
|
|
|
return context.WithValue(ctx, dbKey, db.db), nil
|
|
}
|
|
|
|
func (db *Database) Begin(ctx context.Context) (context.Context, error) {
|
|
if tx, _ := getTx(ctx); tx != nil {
|
|
// log the stack trace so we can see
|
|
logger.Error(string(debug.Stack()))
|
|
|
|
return nil, fmt.Errorf("already in transaction")
|
|
}
|
|
|
|
tx, err := db.db.BeginTxx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("beginning transaction: %w", err)
|
|
}
|
|
|
|
hookMgr := &hookManager{}
|
|
ctx = hookMgr.register(ctx)
|
|
|
|
return context.WithValue(ctx, txnKey, tx), nil
|
|
}
|
|
|
|
func (db *Database) Commit(ctx context.Context) error {
|
|
tx, err := getTx(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// execute post-commit hooks
|
|
db.executePostCommitHooks(ctx)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *Database) Rollback(ctx context.Context) error {
|
|
tx, err := getTx(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := tx.Rollback(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// execute post-rollback hooks
|
|
db.executePostRollbackHooks(ctx)
|
|
|
|
return nil
|
|
}
|
|
|
|
func getTx(ctx context.Context) (*sqlx.Tx, error) {
|
|
tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
|
|
if !ok || tx == nil {
|
|
return nil, fmt.Errorf("not in transaction")
|
|
}
|
|
return tx, nil
|
|
}
|
|
|
|
func getDBReader(ctx context.Context) (dbReader, error) {
|
|
// get transaction first if present
|
|
tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
|
|
if !ok || tx == nil {
|
|
// try to get database if present
|
|
db, ok := ctx.Value(dbKey).(*sqlx.DB)
|
|
if !ok || db == nil {
|
|
return nil, fmt.Errorf("not in transaction")
|
|
}
|
|
return db, nil
|
|
}
|
|
return tx, nil
|
|
}
|
|
|
|
func (db *Database) IsLocked(err error) bool {
|
|
var sqliteError sqlite3.Error
|
|
if errors.As(err, &sqliteError) {
|
|
return sqliteError.Code == sqlite3.ErrBusy
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (db *Database) TxnRepository() models.Repository {
|
|
return models.Repository{
|
|
TxnManager: db,
|
|
File: db.File,
|
|
Folder: db.Folder,
|
|
Gallery: db.Gallery,
|
|
Image: db.Image,
|
|
Movie: MovieReaderWriter,
|
|
Performer: PerformerReaderWriter,
|
|
Scene: db.Scene,
|
|
SceneMarker: SceneMarkerReaderWriter,
|
|
ScrapedItem: ScrapedItemReaderWriter,
|
|
Studio: StudioReaderWriter,
|
|
Tag: TagReaderWriter,
|
|
SavedFilter: SavedFilterReaderWriter,
|
|
}
|
|
}
|