stash/pkg/sqlite/transaction.go

130 lines
2.7 KiB
Go

package sqlite
import (
"context"
"errors"
"fmt"
"runtime/debug"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
)
type key int
const (
txnKey key = iota + 1
dbKey
hookManagerKey
)
func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) {
// if we are already in a transaction or have a database already, just use it
if tx, _ := getDBReader(ctx); tx != nil {
return ctx, nil
}
return context.WithValue(ctx, dbKey, db.db), nil
}
func (db *Database) Begin(ctx context.Context) (context.Context, error) {
if tx, _ := getTx(ctx); tx != nil {
// log the stack trace so we can see
logger.Error(string(debug.Stack()))
return nil, fmt.Errorf("already in transaction")
}
tx, err := db.db.BeginTxx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("beginning transaction: %w", err)
}
hookMgr := &hookManager{}
ctx = hookMgr.register(ctx)
return context.WithValue(ctx, txnKey, tx), nil
}
func (db *Database) Commit(ctx context.Context) error {
tx, err := getTx(ctx)
if err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
// execute post-commit hooks
db.executePostCommitHooks(ctx)
return nil
}
func (db *Database) Rollback(ctx context.Context) error {
tx, err := getTx(ctx)
if err != nil {
return err
}
if err := tx.Rollback(); err != nil {
return err
}
// execute post-rollback hooks
db.executePostRollbackHooks(ctx)
return nil
}
func getTx(ctx context.Context) (*sqlx.Tx, error) {
tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
if !ok || tx == nil {
return nil, fmt.Errorf("not in transaction")
}
return tx, nil
}
func getDBReader(ctx context.Context) (dbReader, error) {
// get transaction first if present
tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
if !ok || tx == nil {
// try to get database if present
db, ok := ctx.Value(dbKey).(*sqlx.DB)
if !ok || db == nil {
return nil, fmt.Errorf("not in transaction")
}
return db, nil
}
return tx, nil
}
func (db *Database) IsLocked(err error) bool {
var sqliteError sqlite3.Error
if errors.As(err, &sqliteError) {
return sqliteError.Code == sqlite3.ErrBusy
}
return false
}
func (db *Database) TxnRepository() models.Repository {
return models.Repository{
TxnManager: db,
File: db.File,
Folder: db.Folder,
Gallery: db.Gallery,
Image: db.Image,
Movie: MovieReaderWriter,
Performer: PerformerReaderWriter,
Scene: db.Scene,
SceneMarker: SceneMarkerReaderWriter,
ScrapedItem: ScrapedItemReaderWriter,
Studio: StudioReaderWriter,
Tag: TagReaderWriter,
SavedFilter: SavedFilterReaderWriter,
}
}