stash/pkg/sqlite/transaction.go

69 lines
1.4 KiB
Go
Raw Normal View History

package sqlite
import (
"context"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models"
)
type key int
const (
txnKey key = iota + 1
)
func (db *Database) Begin(ctx context.Context) (context.Context, error) {
if tx, _ := getTx(ctx); tx != nil {
return nil, fmt.Errorf("already in transaction")
}
tx, err := db.db.BeginTxx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("beginning transaction: %w", err)
}
return context.WithValue(ctx, txnKey, tx), nil
}
func (db *Database) Commit(ctx context.Context) error {
tx, err := getTx(ctx)
if err != nil {
return err
}
return tx.Commit()
}
func (db *Database) Rollback(ctx context.Context) error {
tx, err := getTx(ctx)
if err != nil {
return err
}
return tx.Rollback()
}
func getTx(ctx context.Context) (*sqlx.Tx, error) {
tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
if !ok || tx == nil {
return nil, fmt.Errorf("not in transaction")
}
return tx, nil
}
func (db *Database) TxnRepository() models.Repository {
return models.Repository{
TxnManager: db,
Gallery: GalleryReaderWriter,
Image: ImageReaderWriter,
Movie: MovieReaderWriter,
Performer: PerformerReaderWriter,
Scene: SceneReaderWriter,
SceneMarker: SceneMarkerReaderWriter,
ScrapedItem: ScrapedItemReaderWriter,
Studio: StudioReaderWriter,
Tag: TagReaderWriter,
SavedFilter: SavedFilterReaderWriter,
}
}