diff --git a/cmd/go_migrate/main.go b/cmd/go_migrate/main.go index 3313d99..e57fe3f 100644 --- a/cmd/go_migrate/main.go +++ b/cmd/go_migrate/main.go @@ -2,18 +2,17 @@ package main import ( "context" - "database/sql" "sync" "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" - "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db" + "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/extractors" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/loaders" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/table_analyzers" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/transformers" - "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" ) func main() { @@ -33,11 +32,33 @@ func main() { log.Info("=== Starting migration ===") - sourceDb, targetDb, connError := connectToDatabases() - if connError != nil { - log.Fatal("Connection error: ", connError) - } + var wgConnect errgroup.Group + var sourceDb, targetDb dbwrapper.DbWrapper + wgConnect.Go(func() error { + var err error + sourceDb, err = connectWithTimeout(ctx, migrationConfig.SourceDbType, config.App.SourceDbUrl, 20*time.Second) + if err != nil { + return err + } + + return nil + }) + + wgConnect.Go(func() error { + var err error + targetDb, err = connectWithTimeout(ctx, migrationConfig.TargetDbType, config.App.TargetDbUrl, 20*time.Second) + if err != nil { + return err + } + + return nil + }) + + if err := wgConnect.Wait(); err != nil { + log.Error("Connection error: ", err) + return + } defer sourceDb.Close() defer targetDb.Close() @@ -70,8 +91,8 @@ func main() { func processMigrationJobs( ctx context.Context, - sourceDb *sql.DB, - targetDb *pgxpool.Pool, + sourceDb dbwrapper.DbWrapper, + targetDb dbwrapper.DbWrapper, jobs []config.Job, maxParallelWorkers int, ) []JobResult { @@ -94,7 +115,6 @@ func processMigrationJobs( chJobs := make(chan config.Job, len(jobs)) var wgJobs sync.WaitGroup - targetDbWrapper := db.NewPostgresDbWrapper(targetDb) sourceTableAnalyzer := table_analyzers.NewMssqlTableAnalyzer(sourceDb) targetTableAnalyzer := table_analyzers.NewPostgresTableAnalyzer(targetDb) extractor := extractors.NewMssqlExtractor(sourceDb) @@ -107,7 +127,7 @@ func processMigrationJobs( log.Infof("[worker %d] >>> Processing job: %s.%s <<<", i, job.SourceTable.Schema, job.SourceTable.Table) res := processMigrationJob( ctx, - targetDbWrapper, + targetDb, sourceTableAnalyzer, targetTableAnalyzer, extractor, @@ -138,3 +158,19 @@ func processMigrationJobs( return finalResults } + +func connectWithTimeout(ctx context.Context, dbType string, dbUrl string, timeout time.Duration) (dbwrapper.DbWrapper, error) { + localCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + sourceDb, err := dbwrapper.New(dbType) + if err != nil { + return nil, err + } + + if err = sourceDb.Connect(localCtx, dbUrl); err != nil { + return nil, err + } + + return sourceDb, nil +} diff --git a/cmd/go_migrate/process.go b/cmd/go_migrate/process.go index 0886ca2..cfc678b 100644 --- a/cmd/go_migrate/process.go +++ b/cmd/go_migrate/process.go @@ -8,7 +8,7 @@ import ( "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" - "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db" + dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl/table_analyzers" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" @@ -18,8 +18,7 @@ import ( func processMigrationJob( ctx context.Context, - // sourceDbWrapper db.DbWrapper, - targetDbWrapper db.DbWrapper, + targetDbWrapper dbwrapper.DbWrapper, sourceTableAnalyzer etl.TableAnalyzer, targetTableAnalyzer etl.TableAnalyzer, extractor etl.Extractor, diff --git a/internal/app/db-wrapper/mssql.go b/internal/app/db-wrapper/mssql.go index 2716d93..6de9611 100644 --- a/internal/app/db-wrapper/mssql.go +++ b/internal/app/db-wrapper/mssql.go @@ -15,15 +15,23 @@ func init() { } type mssqlRowResult struct { + row *sql.Row +} + +func (mr *mssqlRowResult) Scan(dest ...any) error { + return mr.row.Scan(dest...) +} + +type mssqlRowsResult struct { columns []string rows *sql.Rows } -func (mr *mssqlRowResult) Close() error { +func (mr *mssqlRowsResult) Close() error { return mr.rows.Close() } -func (mr *mssqlRowResult) Columns() ([]string, error) { +func (mr *mssqlRowsResult) Columns() ([]string, error) { if mr.columns != nil { return mr.columns, nil } @@ -31,19 +39,19 @@ func (mr *mssqlRowResult) Columns() ([]string, error) { return mr.rows.Columns() } -func (mr *mssqlRowResult) Err() error { +func (mr *mssqlRowsResult) Err() error { return mr.rows.Err() } -func (mr *mssqlRowResult) Next() bool { +func (mr *mssqlRowsResult) Next() bool { return mr.rows.Next() } -func (mr *mssqlRowResult) Scan(dest ...any) error { +func (mr *mssqlRowsResult) Scan(dest ...any) error { return mr.rows.Scan(dest...) } -func (mr *mssqlRowResult) Values() ([]any, error) { +func (mr *mssqlRowsResult) Values() ([]any, error) { columns, err := mr.Columns() if err != nil { return nil, err @@ -112,7 +120,12 @@ func (mw *mssqlDbWrapper) Query(ctx context.Context, query string, args ...any) return nil, err } - return &mssqlRowResult{columns: nil, rows: rows}, nil + return &mssqlRowsResult{columns: nil, rows: rows}, nil +} + +func (mw *mssqlDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult { + row := mw.db.QueryRowContext(ctx, query, args...) + return &mssqlRowResult{row: row} } func (mw *mssqlDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { diff --git a/internal/app/db-wrapper/postgres.go b/internal/app/db-wrapper/postgres.go index 6f2e220..e6a1f37 100644 --- a/internal/app/db-wrapper/postgres.go +++ b/internal/app/db-wrapper/postgres.go @@ -15,16 +15,24 @@ func init() { } type postgresRowResult struct { + row pgx.Row +} + +func (pr *postgresRowResult) Scan(dest ...any) error { + return pr.row.Scan(dest...) +} + +type postgresRowsResult struct { columns []string rows pgx.Rows } -func (pr *postgresRowResult) Close() error { +func (pr *postgresRowsResult) Close() error { pr.rows.Close() return nil } -func (pr *postgresRowResult) Columns() ([]string, error) { +func (pr *postgresRowsResult) Columns() ([]string, error) { if pr.columns != nil { return pr.columns, nil } @@ -42,19 +50,19 @@ func (pr *postgresRowResult) Columns() ([]string, error) { return columns, nil } -func (pr *postgresRowResult) Err() error { +func (pr *postgresRowsResult) Err() error { return pr.rows.Err() } -func (pr *postgresRowResult) Next() bool { +func (pr *postgresRowsResult) Next() bool { return pr.rows.Next() } -func (pr *postgresRowResult) Scan(dest ...any) error { +func (pr *postgresRowsResult) Scan(dest ...any) error { return pr.rows.Scan(dest...) } -func (pr *postgresRowResult) Values() ([]any, error) { +func (pr *postgresRowsResult) Values() ([]any, error) { return pr.rows.Values() } @@ -102,7 +110,12 @@ func (pw *postgresDbWrapper) Query(ctx context.Context, query string, args ...an return nil, err } - return &postgresRowResult{columns: nil, rows: rows}, nil + return &postgresRowsResult{columns: nil, rows: rows}, nil +} + +func (pw *postgresDbWrapper) QueryRow(ctx context.Context, query string, args ...any) RowResult { + row := pw.db.QueryRow(ctx, query, args...) + return &postgresRowResult{row: row} } func (pw *postgresDbWrapper) SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) { diff --git a/internal/app/db-wrapper/types.go b/internal/app/db-wrapper/types.go index 421294e..8bd0145 100644 --- a/internal/app/db-wrapper/types.go +++ b/internal/app/db-wrapper/types.go @@ -20,11 +20,16 @@ type RowsResult interface { Values() ([]any, error) } +type RowResult interface { + Scan(dest ...any) error +} + type DbWrapper interface { Close() error Connect(ctx context.Context, dbUrl string) error Exec(ctx context.Context, query string, args ...any) (ExecResult, error) GetDialect() string Query(ctx context.Context, query string, args ...any) (RowsResult, error) + QueryRow(ctx context.Context, query string, args ...any) RowResult SaveMassive(ctx context.Context, schema string, table string, columnNames []string, rows [][]any) (int64, error) } diff --git a/internal/app/db/mssql.go b/internal/app/db/mssql.go deleted file mode 100644 index 2314c65..0000000 --- a/internal/app/db/mssql.go +++ /dev/null @@ -1,30 +0,0 @@ -package db - -import ( - "context" - "database/sql" -) - -type MssqlDbWrapper struct { - db *sql.DB -} - -func NewMssqlDbWrapper(db *sql.DB) DbWrapper { - return &MssqlDbWrapper{db: db} -} - -func (wrapper *MssqlDbWrapper) Exec(ctx context.Context, query string, args ...any) (DbWrapperResult, error) { - result, execErr := wrapper.db.ExecContext(ctx, query, args...) - if execErr != nil { - return DbWrapperResult{}, execErr - } - - affectedRows, err := result.RowsAffected() - if err != nil { - return DbWrapperResult{}, err - } - - return DbWrapperResult{ - AffectedRows: affectedRows, - }, nil -} diff --git a/internal/app/db/postgres.go b/internal/app/db/postgres.go deleted file mode 100644 index 81f3c58..0000000 --- a/internal/app/db/postgres.go +++ /dev/null @@ -1,47 +0,0 @@ -package db - -import ( - "context" - "fmt" - - "github.com/jackc/pgx/v5/pgxpool" -) - -func Connect(ctx context.Context, dbURL string) (*pgxpool.Pool, error) { - pool, err := pgxpool.New(ctx, dbURL) - if err != nil { - return nil, fmt.Errorf("unable to connect to database: %w", err) - } - - if err := pool.Ping(ctx); err != nil { - pool.Close() - return nil, fmt.Errorf("unable to ping database: %w", err) - } - - return pool, nil -} - -func Close(pool *pgxpool.Pool) { - if pool != nil { - pool.Close() - } -} - -type PostgresDbWrapper struct { - db *pgxpool.Pool -} - -func NewPostgresDbWrapper(db *pgxpool.Pool) DbWrapper { - return &PostgresDbWrapper{db: db} -} - -func (wrapper *PostgresDbWrapper) Exec(ctx context.Context, query string, args ...any) (DbWrapperResult, error) { - result, err := wrapper.db.Exec(ctx, query, args...) - if err != nil { - return DbWrapperResult{}, err - } - - return DbWrapperResult{ - AffectedRows: result.RowsAffected(), - }, nil -} diff --git a/internal/app/db/types.go b/internal/app/db/types.go deleted file mode 100644 index ea72117..0000000 --- a/internal/app/db/types.go +++ /dev/null @@ -1,11 +0,0 @@ -package db - -import "context" - -type DbWrapperResult struct { - AffectedRows int64 -} - -type DbWrapper interface { - Exec(ctx context.Context, query string, args ...any) (DbWrapperResult, error) -} diff --git a/internal/app/etl/extractors/mssql.go b/internal/app/etl/extractors/mssql.go index a62ba5c..d34447a 100644 --- a/internal/app/etl/extractors/mssql.go +++ b/internal/app/etl/extractors/mssql.go @@ -13,16 +13,17 @@ import ( "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/convert" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" + dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) type MssqlExtractor struct { - db *sql.DB + db dbwrapper.DbWrapper } -func NewMssqlExtractor(db *sql.DB) etl.Extractor { +func NewMssqlExtractor(db dbwrapper.DbWrapper) etl.Extractor { return &MssqlExtractor{db: db} } @@ -118,7 +119,7 @@ func (mssqlEx *MssqlExtractor) ProcessPartition( } rowsRead := 0 - rows, err := mssqlEx.db.QueryContext(ctx, query, queryArgs...) + rows, err := mssqlEx.db.Query(ctx, query, queryArgs...) if err != nil { return rowsRead, &custom_errors.ExtractorError{Partition: partition, HasLastId: false, Msg: err.Error()} } diff --git a/internal/app/etl/extractors/postgres.go b/internal/app/etl/extractors/postgres.go index ad46d2e..8d494d3 100644 --- a/internal/app/etl/extractors/postgres.go +++ b/internal/app/etl/extractors/postgres.go @@ -9,18 +9,18 @@ import ( "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" + dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" - "github.com/jackc/pgx/v5/pgxpool" ) type PostgresExtractor struct { - db *pgxpool.Pool + db dbwrapper.DbWrapper } -func NewPostgresExtractor(pool *pgxpool.Pool) etl.Extractor { - return &PostgresExtractor{db: pool} +func NewPostgresExtractor(db dbwrapper.DbWrapper) etl.Extractor { + return &PostgresExtractor{db: db} } func buildExtractQueryPostgres(sourceDbInfo config.SourceTableInfo, columns []models.ColumnType) string { diff --git a/internal/app/etl/extractors/types.go b/internal/app/etl/extractors/types.go deleted file mode 100644 index 85defa0..0000000 --- a/internal/app/etl/extractors/types.go +++ /dev/null @@ -1 +0,0 @@ -package extractors diff --git a/internal/app/etl/loaders/postgres.go b/internal/app/etl/loaders/postgres.go index 4560a7c..f9ed76f 100644 --- a/internal/app/etl/loaders/postgres.go +++ b/internal/app/etl/loaders/postgres.go @@ -9,19 +9,18 @@ import ( "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/custom_errors" + dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgxpool" ) type PostgresLoader struct { - db *pgxpool.Pool + db dbwrapper.DbWrapper } -func NewPostgresLoader(pool *pgxpool.Pool) etl.Loader { - return &PostgresLoader{db: pool} +func NewPostgresLoader(db dbwrapper.DbWrapper) etl.Loader { + return &PostgresLoader{db: db} } func mapSlice[T any, V any](input []T, mapper func(T) V) []V { @@ -40,12 +39,12 @@ func (postgresLd *PostgresLoader) ProcessBatch( colNames []string, batch models.Batch, ) (int, error) { - tableId := pgx.Identifier{tableInfo.Schema, tableInfo.Table} - _, err := postgresLd.db.CopyFrom( + _, err := postgresLd.db.SaveMassive( ctx, - tableId, + tableInfo.Schema, + tableInfo.Table, colNames, - pgx.CopyFromRows(batch.Rows), + batch.Rows, ) if err != nil { @@ -54,7 +53,7 @@ func (postgresLd *PostgresLoader) ProcessBatch( if pgErr.Code == "23505" { return 0, &custom_errors.JobError{ ShouldCancelJob: true, - Msg: fmt.Sprintf("Fatal error in table %s", tableId.Sanitize()), + Msg: fmt.Sprintf("Fatal error in table %s.%s", tableInfo.Schema, tableInfo.Table), Prev: err, } } diff --git a/internal/app/etl/table_analyzers/mssql.go b/internal/app/etl/table_analyzers/mssql.go index 2f46aca..0ec0f12 100644 --- a/internal/app/etl/table_analyzers/mssql.go +++ b/internal/app/etl/table_analyzers/mssql.go @@ -8,16 +8,17 @@ import ( "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" + dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" "github.com/google/uuid" ) type MssqlTableAnalyzer struct { - db *sql.DB + db dbwrapper.DbWrapper } -func NewMssqlTableAnalyzer(db *sql.DB) etl.TableAnalyzer { +func NewMssqlTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer { return &MssqlTableAnalyzer{db: db} } @@ -142,7 +143,7 @@ func (ta *MssqlTableAnalyzer) QueryColumnTypes( localCtx, cancel := context.WithTimeout(ctx, 20*time.Second) defer cancel() - rows, err := ta.db.QueryContext(localCtx, mssqlColumnMetadataQuery, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)) + rows, err := ta.db.Query(localCtx, mssqlColumnMetadataQuery, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)) if err != nil { return nil, err } @@ -187,7 +188,7 @@ GROUP BY t.name` defer cancel() var rowsCount int64 - err := ta.db.QueryRowContext(ctxTimeout, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)).Scan(&rowsCount) + err := ta.db.QueryRow(ctxTimeout, query, sql.Named("schema", tableInfo.Schema), sql.Named("table", tableInfo.Table)).Scan(&rowsCount) if err != nil { return 0, err } @@ -218,7 +219,7 @@ ORDER BY batch_id`, ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*20) defer cancel() - rows, err := ta.db.QueryContext(ctxTimeout, query, sql.Named("maxPartitions", maxPartitions)) + rows, err := ta.db.Query(ctxTimeout, query, sql.Named("maxPartitions", maxPartitions)) if err != nil { return nil, err } diff --git a/internal/app/etl/table_analyzers/postgres.go b/internal/app/etl/table_analyzers/postgres.go index b07eb5e..9489c78 100644 --- a/internal/app/etl/table_analyzers/postgres.go +++ b/internal/app/etl/table_analyzers/postgres.go @@ -6,16 +6,16 @@ import ( "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" + dbwrapper "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db-wrapper" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/etl" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/models" - "github.com/jackc/pgx/v5/pgxpool" ) type PostgresTableAnalyzer struct { - db *pgxpool.Pool + db dbwrapper.DbWrapper } -func NewPostgresTableAnalyzer(db *pgxpool.Pool) etl.TableAnalyzer { +func NewPostgresTableAnalyzer(db dbwrapper.DbWrapper) etl.TableAnalyzer { return &PostgresTableAnalyzer{db: db} } diff --git a/scripts/pg-info-test/main.go b/scripts/pg-info-test/main.go index c69c726..0e34a02 100644 --- a/scripts/pg-info-test/main.go +++ b/scripts/pg-info-test/main.go @@ -8,13 +8,32 @@ import ( "time" "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/config" - "git.ksdemosapps.com/kylesoda/go-migrate/internal/app/db" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" ) +func Connect(ctx context.Context, dbURL string) (*pgxpool.Pool, error) { + pool, err := pgxpool.New(ctx, dbURL) + if err != nil { + return nil, fmt.Errorf("unable to connect to database: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil +} + +func Close(pool *pgxpool.Pool) { + if pool != nil { + pool.Close() + } +} + func main() { log.SetFormatter(&log.TextFormatter{ FullTimestamp: true, @@ -27,8 +46,8 @@ func main() { ctxSource, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - sourcePool, err := db.Connect(ctxSource, config.App.SourceDbUrl) - defer db.Close(sourcePool) + sourcePool, err := Connect(ctxSource, config.App.SourceDbUrl) + defer Close(sourcePool) if err != nil { log.Fatal(err) } @@ -37,8 +56,8 @@ func main() { ctxTarget, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - targetPool, err := db.Connect(ctxTarget, config.App.TargetDbUrl) - defer db.Close(targetPool) + targetPool, err := Connect(ctxTarget, config.App.TargetDbUrl) + defer Close(targetPool) if err != nil { log.Fatal(err) }