stash/pkg/sqlite/transaction.go

136 lines
2.7 KiB
Go

package sqlite
import (
"context"
"errors"
"fmt"
"runtime/debug"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
)
type key int
const (
txnKey key = iota + 1
dbKey
writableKey
)
func (db *Database) WithDatabase(ctx context.Context) (context.Context, error) {
// if we are already in a transaction or have a database already, just use it
if tx, _ := getDBReader(ctx); tx != nil {
return ctx, nil
}
return context.WithValue(ctx, dbKey, db.readDB), nil
}
func (db *Database) Begin(ctx context.Context, writable bool) (context.Context, error) {
if tx, _ := getTx(ctx); tx != nil {
// log the stack trace so we can see
logger.Error(string(debug.Stack()))
return nil, fmt.Errorf("already in transaction")
}
dbtx := db.readDB
if writable {
dbtx = db.writeDB
}
tx, err := dbtx.BeginTxx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("beginning transaction: %w", err)
}
ctx = context.WithValue(ctx, writableKey, writable)
return context.WithValue(ctx, txnKey, tx), nil
}
func (db *Database) Commit(ctx context.Context) error {
tx, err := getTx(ctx)
if err != nil {
return err
}
defer db.txnComplete(ctx)
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (db *Database) Rollback(ctx context.Context) error {
tx, err := getTx(ctx)
if err != nil {
return err
}
defer db.txnComplete(ctx)
if err := tx.Rollback(); err != nil {
return err
}
return nil
}
func (db *Database) txnComplete(ctx context.Context) {
}
func getTx(ctx context.Context) (*sqlx.Tx, error) {
tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
if !ok || tx == nil {
return nil, fmt.Errorf("not in transaction")
}
return tx, nil
}
func getDBReader(ctx context.Context) (dbReader, error) {
// get transaction first if present
tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
if !ok || tx == nil {
// try to get database if present
db, ok := ctx.Value(dbKey).(*sqlx.DB)
if !ok || db == nil {
return nil, fmt.Errorf("not in transaction")
}
return db, nil
}
return tx, nil
}
func (db *Database) IsLocked(err error) bool {
var sqliteError sqlite3.Error
if errors.As(err, &sqliteError) {
return sqliteError.Code == sqlite3.ErrBusy
}
return false
}
func (db *Database) Repository() models.Repository {
return models.Repository{
TxnManager: db,
Blob: db.Blobs,
File: db.File,
Folder: db.Folder,
Gallery: db.Gallery,
GalleryChapter: db.GalleryChapter,
Image: db.Image,
Group: db.Group,
Performer: db.Performer,
Scene: db.Scene,
SceneMarker: db.SceneMarker,
Studio: db.Studio,
Tag: db.Tag,
SavedFilter: db.SavedFilter,
}
}