diff options
| author | Chris Lu <chris.lu@uber.com> | 2019-03-30 23:08:29 -0700 |
|---|---|---|
| committer | Chris Lu <chris.lu@uber.com> | 2019-03-30 23:08:29 -0700 |
| commit | 97406333a5ecc5b0d2cdaa74ff9901e3100e4bf2 (patch) | |
| tree | 04cb10ddb0fb87663ba1783a7e82397aa2c9c06f /weed/filer2/abstract_sql | |
| parent | 920b4e56aa76fbf37780363d5b345c2882d311b5 (diff) | |
| download | seaweedfs-97406333a5ecc5b0d2cdaa74ff9901e3100e4bf2.tar.xz seaweedfs-97406333a5ecc5b0d2cdaa74ff9901e3100e4bf2.zip | |
support atomic renaming for mysql/postgres filer store
Diffstat (limited to 'weed/filer2/abstract_sql')
| -rw-r--r-- | weed/filer2/abstract_sql/abstract_sql_store.go | 44 |
1 files changed, 39 insertions, 5 deletions
diff --git a/weed/filer2/abstract_sql/abstract_sql_store.go b/weed/filer2/abstract_sql/abstract_sql_store.go index 95ce9cb9f..9a3ee51c3 100644 --- a/weed/filer2/abstract_sql/abstract_sql_store.go +++ b/weed/filer2/abstract_sql/abstract_sql_store.go @@ -19,6 +19,40 @@ type AbstractSqlStore struct { SqlListInclusive string } +type TxOrDB interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +func (store *AbstractSqlStore) BeginTransaction(ctx context.Context) (context.Context, error) { + tx, err := store.DB.BeginTx(ctx, nil) + if err != nil { + return ctx, err + } + + return context.WithValue(ctx, "tx", tx), nil +} +func (store *AbstractSqlStore) CommitTransaction(ctx context.Context) error { + if tx, ok := ctx.Value("tx").(*sql.Tx); ok { + return tx.Commit() + } + return nil +} +func (store *AbstractSqlStore) RollbackTransaction(ctx context.Context) error { + if tx, ok := ctx.Value("tx").(*sql.Tx); ok { + return tx.Rollback() + } + return nil +} + +func (store *AbstractSqlStore) getTxOrDB(ctx context.Context) TxOrDB { + if tx, ok := ctx.Value("tx").(*sql.Tx); ok { + return tx + } + return store.DB +} + func (store *AbstractSqlStore) InsertEntry(ctx context.Context, entry *filer2.Entry) (err error) { dir, name := entry.FullPath.DirAndName() @@ -27,7 +61,7 @@ func (store *AbstractSqlStore) InsertEntry(ctx context.Context, entry *filer2.En return fmt.Errorf("encode %s: %s", entry.FullPath, err) } - res, err := store.DB.Exec(store.SqlInsert, hashToLong(dir), name, dir, meta) + res, err := store.getTxOrDB(ctx).ExecContext(ctx, store.SqlInsert, hashToLong(dir), name, dir, meta) if err != nil { return fmt.Errorf("insert %s: %s", entry.FullPath, err) } @@ -47,7 +81,7 @@ func (store *AbstractSqlStore) UpdateEntry(ctx context.Context, entry *filer2.En return fmt.Errorf("encode %s: %s", entry.FullPath, err) } - res, err := store.DB.Exec(store.SqlUpdate, meta, hashToLong(dir), name, dir) + res, err := store.getTxOrDB(ctx).ExecContext(ctx, store.SqlUpdate, meta, hashToLong(dir), name, dir) if err != nil { return fmt.Errorf("update %s: %s", entry.FullPath, err) } @@ -62,7 +96,7 @@ func (store *AbstractSqlStore) UpdateEntry(ctx context.Context, entry *filer2.En func (store *AbstractSqlStore) FindEntry(ctx context.Context, fullpath filer2.FullPath) (*filer2.Entry, error) { dir, name := fullpath.DirAndName() - row := store.DB.QueryRow(store.SqlFind, hashToLong(dir), name, dir) + row := store.getTxOrDB(ctx).QueryRowContext(ctx, store.SqlFind, hashToLong(dir), name, dir) var data []byte if err := row.Scan(&data); err != nil { return nil, filer2.ErrNotFound @@ -82,7 +116,7 @@ func (store *AbstractSqlStore) DeleteEntry(ctx context.Context, fullpath filer2. dir, name := fullpath.DirAndName() - res, err := store.DB.Exec(store.SqlDelete, hashToLong(dir), name, dir) + res, err := store.getTxOrDB(ctx).ExecContext(ctx, store.SqlDelete, hashToLong(dir), name, dir) if err != nil { return fmt.Errorf("delete %s: %s", fullpath, err) } @@ -102,7 +136,7 @@ func (store *AbstractSqlStore) ListDirectoryEntries(ctx context.Context, fullpat sqlText = store.SqlListInclusive } - rows, err := store.DB.Query(sqlText, hashToLong(string(fullpath)), startFileName, string(fullpath), limit) + rows, err := store.getTxOrDB(ctx).QueryContext(ctx, sqlText, hashToLong(string(fullpath)), startFileName, string(fullpath), limit) if err != nil { return nil, fmt.Errorf("list %s : %v", fullpath, err) } |
