From 14a1b1c234c562c8dbd7a1e179779e2c55bb439f Mon Sep 17 00:00:00 2001 From: Elom Gomez Date: Fri, 26 Jun 2026 14:34:44 -0500 Subject: [PATCH 1/6] Add pscale import d1 for Cloudflare D1 offline migration. Introduces the D1 import command group with lint, plan, pgloader-based load, and verify, plus slim postgres connection helpers shared by the import path. Co-authored-by: Cursor --- go.mod | 5 +- go.sum | 10 +- internal/cmd/importcmd/d1.go | 437 +++++++++++++ internal/cmd/importcmd/import.go | 22 + internal/cmd/mcp/import_d1_handlers.go | 249 ++++++++ internal/cmd/mcp/server.go | 5 +- internal/cmd/root.go | 22 +- internal/cmdutil/errors.go | 5 +- internal/migrate/d1/constraints.go | 279 +++++++++ internal/migrate/d1/constraints_test.go | 35 ++ internal/migrate/d1/convert.go | 229 +++++++ internal/migrate/d1/doctor.go | 235 +++++++ internal/migrate/d1/doctor_test.go | 70 +++ internal/migrate/d1/errors.go | 86 +++ internal/migrate/d1/errors_test.go | 14 + internal/migrate/d1/export.go | 76 +++ internal/migrate/d1/import.go | 387 ++++++++++++ internal/migrate/d1/import_test.go | 14 + internal/migrate/d1/lint.go | 139 +++++ internal/migrate/d1/lint_test.go | 206 +++++++ internal/migrate/d1/orm_metadata.go | 157 +++++ internal/migrate/d1/orm_metadata_test.go | 62 ++ internal/migrate/d1/output.go | 144 +++++ internal/migrate/d1/parse.go | 354 +++++++++++ internal/migrate/d1/path.go | 31 + internal/migrate/d1/pgloader.go | 296 +++++++++ internal/migrate/d1/pgloader_errors.go | 26 + internal/migrate/d1/pgloader_errors_test.go | 46 ++ internal/migrate/d1/pgloader_test.go | 137 +++++ internal/migrate/d1/pgloader_transforms.lisp | 16 + internal/migrate/d1/plan.go | 236 +++++++ internal/migrate/d1/postgres.go | 12 + internal/migrate/d1/prepare.go | 203 ++++++ internal/migrate/d1/prepare_test.go | 85 +++ internal/migrate/d1/schema_reset.go | 153 +++++ internal/migrate/d1/schema_reset_test.go | 63 ++ internal/migrate/d1/sqlite_load.go | 186 ++++++ internal/migrate/d1/sqlite_load_test.go | 135 ++++ internal/migrate/d1/state.go | 161 +++++ internal/migrate/d1/state_test.go | 107 ++++ .../migrate/d1/testdata/sample_d1_export.sql | 71 +++ internal/migrate/d1/types.go | 189 ++++++ internal/migrate/d1/verify.go | 191 ++++++ internal/migrate/d1/verify_checks.go | 578 ++++++++++++++++++ internal/migrate/d1/verify_checks_test.go | 111 ++++ internal/postgres/postgres.go | 231 +++++++ internal/postgres/postgres_test.go | 210 +++++++ internal/postgres/psql.go | 67 ++ script/d1-import-test/README.md | 160 +++++ script/d1-import-test/bench-watch-imports.sh | 200 ++++++ script/d1-import-test/build-local-export.sh | 104 ++++ .../collect-benchmark-results.sh | 79 +++ script/d1-import-test/generate_seed.py | 448 ++++++++++++++ .../launch-benchmark-detached.sh | 93 +++ .../launch-storage-benchmark-detached.sh | 45 ++ script/d1-import-test/load-bulk.sh | 59 ++ script/d1-import-test/load.sh | 64 ++ script/d1-import-test/merge_seed_chunks.py | 173 ++++++ script/d1-import-test/prepare-demo-100mb.sh | 66 ++ script/d1-import-test/provision-database.sh | 66 ++ script/d1-import-test/reset.sql | 39 ++ script/d1-import-test/resume-chunks.sh | 65 ++ script/d1-import-test/run-9gb-benchmark.sh | 50 ++ script/d1-import-test/run-cli-import.sh | 233 +++++++ script/d1-import-test/run-local-import.sh | 187 ++++++ script/d1-import-test/run-size-benchmark.sh | 107 ++++ .../d1-import-test/run-storage-benchmark.sh | 118 ++++ script/d1-import-test/schema.sql | 362 +++++++++++ script/d1-import-test/time-export.sh | 39 ++ test_import_d1.sh | 79 +++ 70 files changed, 9607 insertions(+), 12 deletions(-) create mode 100644 internal/cmd/importcmd/d1.go create mode 100644 internal/cmd/importcmd/import.go create mode 100644 internal/cmd/mcp/import_d1_handlers.go create mode 100644 internal/migrate/d1/constraints.go create mode 100644 internal/migrate/d1/constraints_test.go create mode 100644 internal/migrate/d1/convert.go create mode 100644 internal/migrate/d1/doctor.go create mode 100644 internal/migrate/d1/doctor_test.go create mode 100644 internal/migrate/d1/errors.go create mode 100644 internal/migrate/d1/errors_test.go create mode 100644 internal/migrate/d1/export.go create mode 100644 internal/migrate/d1/import.go create mode 100644 internal/migrate/d1/import_test.go create mode 100644 internal/migrate/d1/lint.go create mode 100644 internal/migrate/d1/lint_test.go create mode 100644 internal/migrate/d1/orm_metadata.go create mode 100644 internal/migrate/d1/orm_metadata_test.go create mode 100644 internal/migrate/d1/output.go create mode 100644 internal/migrate/d1/parse.go create mode 100644 internal/migrate/d1/path.go create mode 100644 internal/migrate/d1/pgloader.go create mode 100644 internal/migrate/d1/pgloader_errors.go create mode 100644 internal/migrate/d1/pgloader_errors_test.go create mode 100644 internal/migrate/d1/pgloader_test.go create mode 100644 internal/migrate/d1/pgloader_transforms.lisp create mode 100644 internal/migrate/d1/plan.go create mode 100644 internal/migrate/d1/postgres.go create mode 100644 internal/migrate/d1/prepare.go create mode 100644 internal/migrate/d1/prepare_test.go create mode 100644 internal/migrate/d1/schema_reset.go create mode 100644 internal/migrate/d1/schema_reset_test.go create mode 100644 internal/migrate/d1/sqlite_load.go create mode 100644 internal/migrate/d1/sqlite_load_test.go create mode 100644 internal/migrate/d1/state.go create mode 100644 internal/migrate/d1/state_test.go create mode 100644 internal/migrate/d1/testdata/sample_d1_export.sql create mode 100644 internal/migrate/d1/types.go create mode 100644 internal/migrate/d1/verify.go create mode 100644 internal/migrate/d1/verify_checks.go create mode 100644 internal/migrate/d1/verify_checks_test.go create mode 100644 internal/postgres/postgres.go create mode 100644 internal/postgres/postgres_test.go create mode 100644 internal/postgres/psql.go create mode 100644 script/d1-import-test/README.md create mode 100755 script/d1-import-test/bench-watch-imports.sh create mode 100755 script/d1-import-test/build-local-export.sh create mode 100755 script/d1-import-test/collect-benchmark-results.sh create mode 100755 script/d1-import-test/generate_seed.py create mode 100755 script/d1-import-test/launch-benchmark-detached.sh create mode 100755 script/d1-import-test/launch-storage-benchmark-detached.sh create mode 100755 script/d1-import-test/load-bulk.sh create mode 100755 script/d1-import-test/load.sh create mode 100755 script/d1-import-test/merge_seed_chunks.py create mode 100755 script/d1-import-test/prepare-demo-100mb.sh create mode 100755 script/d1-import-test/provision-database.sh create mode 100644 script/d1-import-test/reset.sql create mode 100755 script/d1-import-test/resume-chunks.sh create mode 100755 script/d1-import-test/run-9gb-benchmark.sh create mode 100755 script/d1-import-test/run-cli-import.sh create mode 100755 script/d1-import-test/run-local-import.sh create mode 100755 script/d1-import-test/run-size-benchmark.sh create mode 100755 script/d1-import-test/run-storage-benchmark.sh create mode 100644 script/d1-import-test/schema.sql create mode 100755 script/d1-import-test/time-export.sh create mode 100755 test_import_d1.sh diff --git a/go.mod b/go.mod index 9b6bcbf8a..93d39667a 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/google/go-cmp v0.7.0 github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-version v1.8.0 + github.com/jackc/pgx/v5 v5.8.0 github.com/lensesio/tableprinter v0.0.0-20201125135848-89e81fc956e7 github.com/lib/pq v1.12.0 github.com/mark3labs/mcp-go v0.46.0 @@ -78,6 +79,9 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/kataras/tablewriter v0.0.0-20180708051242-e063d29b7c23 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/klauspost/compress v1.18.2 // indirect @@ -115,7 +119,6 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/grpc v1.79.3 // indirect google.golang.org/protobuf v1.36.10 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 144c8c65d..dda906e80 100644 --- a/go.sum +++ b/go.sum @@ -112,6 +112,14 @@ github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kataras/tablewriter v0.0.0-20180708051242-e063d29b7c23 h1:M8exrBzuhWcU6aoHJlHWPe4qFjVKzkMGRal78f5jRRU= github.com/kataras/tablewriter v0.0.0-20180708051242-e063d29b7c23/go.mod h1:kBSna6b0/RzsOcOZf515vAXwSsXYusl2U7SA0XP09yI= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= @@ -120,7 +128,6 @@ github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uq github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/connect-compress/v2 v2.1.0 h1:8fM8QrVeHT69e5VVSh4yjDaQASYIvOp2uMZq7nVLj2U= github.com/klauspost/connect-compress/v2 v2.1.0/go.mod h1:Ayurh2wscMMx3AwdGGVL+ylSR5316WfApREDgsqHyH8= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -218,6 +225,7 @@ github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjb github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/internal/cmd/importcmd/d1.go b/internal/cmd/importcmd/d1.go new file mode 100644 index 000000000..622638567 --- /dev/null +++ b/internal/cmd/importcmd/d1.go @@ -0,0 +1,437 @@ +package importcmd + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/migrate/d1" + "github.com/planetscale/cli/internal/printer" +) + +func writeD1(ch *cmdutil.Helper, resp d1.Response) error { + if resp.Status == "error" { + switch ch.Printer.Format() { + case printer.JSON: + if err := ch.Printer.PrintJSON(resp); err != nil { + return err + } + return &cmdutil.Error{ + ExitCode: cmdutil.ActionRequestedExitCode, + Printed: true, + } + case printer.Human: + return humanD1Error(resp) + default: + return fmt.Errorf(`import d1 does not support output format %q (use human or json)`, ch.Printer.Format()) + } + } + + switch ch.Printer.Format() { + case printer.JSON: + return ch.Printer.PrintJSON(resp) + case printer.Human: + d1.PrintHumanResponse(ch.Printer, resp) + return nil + default: + return fmt.Errorf(`import d1 does not support output format %q (use human or json)`, ch.Printer.Format()) + } +} + +func humanD1Error(resp d1.Response) error { + if resp.Error == nil { + return fmt.Errorf("import d1 command failed") + } + if resp.Error.Remediation != "" { + return fmt.Errorf("%s\n%s", resp.Error.Message, resp.Error.Remediation) + } + return fmt.Errorf("%s", resp.Error.Message) +} + +// D1Cmd returns the import d1 subcommand group. +func D1Cmd(ch *cmdutil.Helper) *cobra.Command { + cmd := &cobra.Command{ + Use: "d1 ", + Short: "Import Cloudflare D1 into PlanetScale Postgres", + Long: `Offline import from Cloudflare D1 (SQLite) to PlanetScale Postgres. + +Export with wrangler, lint the dump, then start the import (use --dry-run to preview). +All commands support --format json for machine-readable output.`, + } + + cmd.AddCommand(d1DoctorCmd(ch)) + cmd.AddCommand(d1ExportCmd(ch)) + cmd.AddCommand(d1LintCmd(ch)) + cmd.AddCommand(d1ConvertSchemaCmd(ch)) + cmd.AddCommand(d1StartCmd(ch)) + cmd.AddCommand(d1VerifyCmd(ch)) + cmd.AddCommand(d1StatusCmd(ch)) + cmd.AddCommand(d1CompleteCmd(ch)) + + return cmd +} + +func d1DoctorCmd(ch *cmdutil.Helper) *cobra.Command { + cmd := &cobra.Command{ + Use: "doctor", + Short: "Check prerequisites for D1 migration", + RunE: func(cmd *cobra.Command, args []string) error { + result, err := d1.Doctor(cmd.Context()) + if err != nil { + return writeD1(ch, d1.ErrorResponse("doctor", err)) + } + if !result.Ready { + return writeD1(ch, d1.ErrorResponse("doctor", d1.DoctorReadinessError(result))) + } + return writeD1(ch, d1.OKResponse("doctor", result, d1.DoctorNextSteps(result))) + }, + } + + return cmd +} + +func d1ExportCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + d1Database string + output string + remote bool + table string + noData bool + } + + cmd := &cobra.Command{ + Use: "export", + Short: "Export a D1 database using wrangler", + Example: ` pscale import d1 export --d1-database my-app-db --remote --output ./d1-export.sql --format json`, + RunE: func(cmd *cobra.Command, args []string) error { + result, err := d1.Export(cmd.Context(), d1.ExportOptions{ + D1Database: flags.d1Database, + Output: flags.output, + Remote: flags.remote, + Table: flags.table, + NoData: flags.noData, + }) + if err != nil { + return writeD1(ch, d1.ErrorResponse("export", err)) + } + resp := d1.OKResponse("export", result, []d1.NextStep{ + {Tool: "import_d1_lint", Command: "pscale import d1 lint --input " + result.OutputPath, Reason: "Analyze export before import"}, + }) + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.d1Database, "d1-database", "", "Cloudflare D1 database name") + cmd.Flags().StringVar(&flags.output, "output", "", "Output SQL file path") + cmd.Flags().BoolVar(&flags.remote, "remote", false, "Export from remote D1 (not local dev)") + cmd.Flags().StringVar(&flags.table, "table", "", "Export a single table") + cmd.Flags().BoolVar(&flags.noData, "no-data", false, "Schema only export") + cmd.MarkFlagRequired("d1-database") + return cmd +} + +func d1LintCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + input string + } + + cmd := &cobra.Command{ + Use: "lint", + Short: "Analyze a D1 SQL export for migration issues", + Example: ` pscale import d1 lint --input ./d1-export.sql --format json`, + RunE: func(cmd *cobra.Command, args []string) error { + result, err := d1.Lint(flags.input) + if err != nil { + return writeD1(ch, d1.ErrorResponse("lint", err)) + } + resp := d1.OKResponse("lint", result, d1.LintNextSteps(result)) + resp.Issues = result.Issues + if result.ErrorCount > 0 { + resp.Status = "error" + } else if result.WarningCount > 0 { + resp.Status = "warning" + } + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.input, "input", "", "Path to D1 SQL export") + cmd.MarkFlagRequired("input") + return cmd +} + +func d1ConvertSchemaCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + input string + output string + } + + cmd := &cobra.Command{ + Use: "convert-schema", + Short: "Convert SQLite schema in a D1 export to PostgreSQL DDL", + RunE: func(cmd *cobra.Command, args []string) error { + if flags.output == "" { + flags.output = flags.input + ".postgres.sql" + } + count, err := d1.ConvertSchema(flags.input, flags.output) + if err != nil { + return writeD1(ch, d1.ErrorResponse("convert-schema", err)) + } + resp := d1.OKResponse("convert-schema", map[string]any{ + "input": flags.input, + "output": flags.output, + "table_count": count, + }, nil) + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.input, "input", "", "Path to D1 SQL export") + cmd.Flags().StringVar(&flags.output, "output", "", "Output PostgreSQL schema file") + cmd.MarkFlagRequired("input") + return cmd +} + +func d1StartCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + org string + database string + branch string + input string + method string + migrationID string + dbName string + dryRun bool + force bool + } + + cmd := &cobra.Command{ + Use: "start", + Short: "Start importing a D1 export (lint + plan, then load)", + Long: `Runs lint and builds an import plan, then loads data into PlanetScale Postgres. +Requires pgloader on PATH — run import d1 doctor to verify prerequisites. + +Use --dry-run to lint and save migration state without touching Postgres.`, + Example: ` # Preview lint + plan and get a migration ID + pscale import d1 start --org acme --database mydb --input ./d1-export.sql --dry-run --force --format json + + # Run the import + pscale import d1 start --org acme --database mydb --input ./d1-export.sql --method pgloader --force --format json`, + RunE: func(cmd *cobra.Command, args []string) error { + org := flags.org + if org == "" { + org = ch.Config.Organization + } + + importOpts := d1.ImportOptions{ + Org: org, + Database: flags.database, + Branch: flags.branch, + InputPath: flags.input, + Method: flags.method, + MigrationID: flags.migrationID, + DBName: flags.dbName, + DryRun: flags.dryRun, + } + + prepared, err := d1.PrepareImport(importOpts) + if err != nil { + return writeD1(ch, d1.ErrorResponse("start", err)) + } + + if !prepared.CanProceed { + return writeD1(ch, d1.BlockedStartResponse(prepared, flags.dryRun)) + } + + if !flags.force && !flags.dryRun && ch.Printer.Format() == printer.Human { + d1.PrintStartPreview(ch.Printer, prepared) + if err := ch.Printer.ConfirmCommand(prepared.MigrationID, "import d1 start", "start"); err != nil { + return err + } + } + + client, err := ch.Client() + if err != nil { + return err + } + + result, err := d1.Import(cmd.Context(), client, &d1.DefaultImportClient{Client: client}, importOpts, prepared) + if err != nil { + resp := d1.ErrorResponse("start", err) + if result != nil { + resp.Data = result + resp.Issues = result.Lint.Issues + } + resp.MigrationID = prepared.MigrationID + return writeD1(ch, resp) + } + resp := d1.OKResponse("start", result, d1.StartNextSteps(result.MigrationID, flags.database, result.Method, flags.dryRun)) + resp.MigrationID = result.MigrationID + resp.Issues = result.Lint.Issues + if flags.dryRun { + resp.Status = "dry_run" + } + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.org, "org", "", "PlanetScale organization") + cmd.Flags().StringVar(&flags.database, "database", "", "PlanetScale database name") + cmd.Flags().StringVar(&flags.branch, "branch", "main", "PlanetScale branch name") + cmd.Flags().StringVar(&flags.input, "input", "", "Path to D1 SQL export") + cmd.Flags().StringVar(&flags.method, "method", "", "Import method: pgloader (≥1GB) or psql (<1GB; schema via psql, data via pgloader)") + cmd.Flags().StringVar(&flags.migrationID, "migration-id", "", "Existing migration ID from a prior start --dry-run") + cmd.Flags().StringVar(&flags.dbName, "dbname", "postgres", "Destination PostgreSQL database name") + cmd.Flags().BoolVar(&flags.dryRun, "dry-run", false, "Lint and build import plan without loading Postgres") + cmd.Flags().BoolVar(&flags.force, "force", false, "Skip confirmation prompt") + cmd.MarkFlagRequired("database") + cmd.MarkFlagRequired("input") + return cmd +} + +func d1VerifyCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + org string + database string + branch string + migrationID string + input string + sqlite string + } + + cmd := &cobra.Command{ + Use: "verify", + Short: "Verify D1 import (row counts, sequences, coercion, content checks)", + RunE: func(cmd *cobra.Command, args []string) error { + org := flags.org + if org == "" { + org = ch.Config.Organization + } + + verifyOpts := d1.VerifyOptions{ + Org: org, + Database: flags.database, + Branch: flags.branch, + MigrationID: flags.migrationID, + InputPath: flags.input, + SQLitePath: flags.sqlite, + } + + client, err := ch.Client() + if err != nil { + return err + } + destURI, cleanup, err := d1.ResolveDestURI(cmd.Context(), client, d1.ImportOptions{ + Org: org, + Database: flags.database, + Branch: flags.branch, + }) + if err != nil { + return err + } + defer func() { _ = cleanup() }() + verifyOpts.DestURI = destURI + + result, err := d1.Verify(cmd.Context(), verifyOpts) + if err != nil { + resp := d1.ErrorResponse("verify", err) + if result != nil { + resp.Data = result + } + return writeD1(ch, resp) + } + resp := d1.OKResponse("verify", result, nil) + resp.MigrationID = flags.migrationID + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.org, "org", "", "PlanetScale organization") + cmd.Flags().StringVar(&flags.database, "database", "", "PlanetScale database name") + cmd.Flags().StringVar(&flags.branch, "branch", "main", "PlanetScale branch name") + cmd.Flags().StringVar(&flags.migrationID, "migration-id", "", "Migration ID from plan/import") + cmd.Flags().StringVar(&flags.input, "input", "", "Path to original D1 SQL export") + cmd.Flags().StringVar(&flags.sqlite, "sqlite", "", "Path to local SQLite file for source counts") + cmd.MarkFlagRequired("database") + return cmd +} + +func d1StatusCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + org string + database string + branch string + migrationID string + } + + cmd := &cobra.Command{ + Use: "status", + Short: "Show local migration state", + RunE: func(cmd *cobra.Command, args []string) error { + org := flags.org + if org == "" { + org = ch.Config.Organization + } + state, err := d1.Status(org, flags.database, flags.branch, flags.migrationID) + if err != nil { + return writeD1(ch, d1.ErrorResponse("status", err)) + } + resp := d1.OKResponse("status", state, nil) + resp.MigrationID = state.MigrationID + return writeD1(ch, resp) + }, + } + + cmd.Flags().StringVar(&flags.org, "org", "", "PlanetScale organization") + cmd.Flags().StringVar(&flags.database, "database", "", "PlanetScale database name") + cmd.Flags().StringVar(&flags.branch, "branch", "main", "PlanetScale branch name") + cmd.Flags().StringVar(&flags.migrationID, "migration-id", "", "Migration ID") + cmd.MarkFlagRequired("database") + cmd.MarkFlagRequired("migration-id") + return cmd +} + +func d1CompleteCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + org string + database string + branch string + migrationID string + force bool + } + + cmd := &cobra.Command{ + Use: "complete", + Aliases: []string{"teardown"}, + Short: "Mark a D1 migration as complete in local state", + RunE: func(cmd *cobra.Command, args []string) error { + org := flags.org + if org == "" { + org = ch.Config.Organization + } + if !flags.force { + if err := ch.Printer.ConfirmCommand(flags.migrationID, "import d1 complete", "complete"); err != nil { + return err + } + } + err := d1.Complete(org, flags.database, flags.branch, flags.migrationID) + if err != nil { + return writeD1(ch, d1.ErrorResponse("complete", err)) + } + return writeD1(ch, d1.OKResponse("complete", map[string]string{ + "migration_id": flags.migrationID, + "status": d1.PhaseComplete, + }, nil)) + }, + } + + cmd.Flags().StringVar(&flags.org, "org", "", "PlanetScale organization") + cmd.Flags().StringVar(&flags.database, "database", "", "PlanetScale database name") + cmd.Flags().StringVar(&flags.branch, "branch", "main", "PlanetScale branch name") + cmd.Flags().StringVar(&flags.migrationID, "migration-id", "", "Migration ID") + cmd.Flags().BoolVar(&flags.force, "force", false, "Skip confirmation prompt") + cmd.MarkFlagRequired("database") + cmd.MarkFlagRequired("migration-id") + return cmd +} diff --git a/internal/cmd/importcmd/import.go b/internal/cmd/importcmd/import.go new file mode 100644 index 000000000..0c6389bd7 --- /dev/null +++ b/internal/cmd/importcmd/import.go @@ -0,0 +1,22 @@ +package importcmd + +import ( + "github.com/planetscale/cli/internal/cmdutil" + "github.com/spf13/cobra" +) + +// ImportCmd returns the import command group. +func ImportCmd(ch *cmdutil.Helper) *cobra.Command { + cmd := &cobra.Command{ + Use: "import", + Short: "Import external databases into PlanetScale Postgres", + Long: `Import databases from external sources into PlanetScale Postgres. + +Available sources: + d1 Import from Cloudflare D1 using an offline SQLite export`, + } + + cmd.AddCommand(D1Cmd(ch)) + + return cmd +} diff --git a/internal/cmd/mcp/import_d1_handlers.go b/internal/cmd/mcp/import_d1_handlers.go new file mode 100644 index 000000000..bc4e25f05 --- /dev/null +++ b/internal/cmd/mcp/import_d1_handlers.go @@ -0,0 +1,249 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/migrate/d1" +) + +func importD1ToolDefs() []ToolDef { + return []ToolDef{ + { + tool: mcp.NewTool("import_d1_doctor", + mcp.WithDescription("Check prerequisites for Cloudflare D1 to PlanetScale Postgres import"), + ), + handler: handleImportD1Doctor, + }, + { + tool: mcp.NewTool("import_d1_lint", + mcp.WithDescription("Analyze a D1 SQL export for import issues"), + mcp.WithString("input", mcp.Description("Path to D1 SQL export file"), mcp.Required()), + ), + handler: handleImportD1Lint, + }, + { + tool: mcp.NewTool("import_d1_start", + mcp.WithDescription("Start importing a D1 SQL export into PlanetScale Postgres (runs lint/plan first; dry_run previews without loading Postgres)"), + mcp.WithString("input", mcp.Description("Path to D1 SQL export file"), mcp.Required()), + mcp.WithString("org", mcp.Description("PlanetScale organization"), mcp.Required()), + mcp.WithString("database", mcp.Description("PlanetScale database name"), mcp.Required()), + mcp.WithString("branch", mcp.Description("PlanetScale branch name")), + mcp.WithString("method", mcp.Description("Import method: pgloader (≥1GB) or psql (<1GB; schema via psql, data via pgloader)")), + mcp.WithString("migration_id", mcp.Description("Migration ID from a prior start --dry-run")), + mcp.WithBoolean("dry_run", mcp.Description("Lint and build import plan without loading Postgres")), + mcp.WithBoolean("force", mcp.Description("Skip confirmations")), + ), + handler: handleImportD1Start, + }, + { + tool: mcp.NewTool("import_d1_status", + mcp.WithDescription("Show local D1 import state"), + mcp.WithString("org", mcp.Description("PlanetScale organization"), mcp.Required()), + mcp.WithString("database", mcp.Description("PlanetScale database name"), mcp.Required()), + mcp.WithString("branch", mcp.Description("PlanetScale branch name")), + mcp.WithString("migration_id", mcp.Description("Migration ID"), mcp.Required()), + ), + handler: handleImportD1Status, + }, + { + tool: mcp.NewTool("import_d1_verify", + mcp.WithDescription("Verify D1 import (row counts, sequences, coercion, content checks)"), + mcp.WithString("org", mcp.Description("PlanetScale organization"), mcp.Required()), + mcp.WithString("database", mcp.Description("PlanetScale database name"), mcp.Required()), + mcp.WithString("branch", mcp.Description("PlanetScale branch name")), + mcp.WithString("migration_id", mcp.Description("Migration ID")), + mcp.WithString("input", mcp.Description("Path to D1 SQL export")), + mcp.WithString("sqlite", mcp.Description("Path to local SQLite file")), + ), + handler: handleImportD1Verify, + }, + } +} + +func handleImportD1Doctor(ctx context.Context, request mcp.CallToolRequest, ch *cmdutil.Helper) (*mcp.CallToolResult, error) { + result, err := d1.Doctor(ctx) + if err != nil { + return importD1Error("doctor", err) + } + resp := d1.OKResponse("doctor", result, d1.DoctorNextSteps(result)) + if !result.Ready { + return importD1Error("doctor", d1.DoctorReadinessError(result)) + } + return importD1Result(resp) +} + +func handleImportD1Lint(ctx context.Context, request mcp.CallToolRequest, ch *cmdutil.Helper) (*mcp.CallToolResult, error) { + input, err := request.RequireString("input") + if err != nil { + return nil, err + } + result, err := d1.Lint(input) + if err != nil { + return importD1Error("lint", err) + } + resp := d1.OKResponse("lint", result, d1.LintNextSteps(result)) + resp.Issues = result.Issues + if result.ErrorCount > 0 { + resp.Status = "error" + } else if result.WarningCount > 0 { + resp.Status = "warning" + } + return importD1Result(resp) +} + + +func handleImportD1Start(ctx context.Context, request mcp.CallToolRequest, ch *cmdutil.Helper) (*mcp.CallToolResult, error) { + input, err := request.RequireString("input") + if err != nil { + return nil, err + } + org, err := request.RequireString("org") + if err != nil { + return nil, err + } + database, err := request.RequireString("database") + if err != nil { + return nil, err + } + branch := request.GetString("branch", "main") + method := request.GetString("method", "") + migrationID := request.GetString("migration_id", "") + dryRun := request.GetBool("dry_run", false) + + importOpts := d1.ImportOptions{ + Org: org, + Database: database, + Branch: branch, + InputPath: input, + Method: method, + MigrationID: migrationID, + DryRun: dryRun, + } + + prepared, err := d1.PrepareImport(importOpts) + if err != nil { + return importD1Error("start", err) + } + + if !prepared.CanProceed { + return importD1Result(d1.BlockedStartResponse(prepared, dryRun)) + } + + client, err := ch.Client() + if err != nil { + return nil, err + } + + result, err := d1.Import(ctx, client, &d1.DefaultImportClient{Client: client}, importOpts, prepared) + if err != nil { + resp := d1.ErrorResponse("start", err) + if result != nil { + resp.Data = result + if result.Lint != nil { + resp.Issues = result.Lint.Issues + } + } + resp.MigrationID = prepared.MigrationID + return importD1Result(resp) + } + resp := d1.OKResponse("start", result, d1.StartNextSteps(result.MigrationID, database, result.Method, dryRun)) + resp.MigrationID = result.MigrationID + if result.Lint != nil { + resp.Issues = result.Lint.Issues + } + if dryRun { + resp.Status = "dry_run" + } + return importD1Result(resp) +} + +func handleImportD1Status(ctx context.Context, request mcp.CallToolRequest, ch *cmdutil.Helper) (*mcp.CallToolResult, error) { + org, err := request.RequireString("org") + if err != nil { + return nil, err + } + database, err := request.RequireString("database") + if err != nil { + return nil, err + } + branch := request.GetString("branch", "main") + migrationID, err := request.RequireString("migration_id") + if err != nil { + return nil, err + } + + state, err := d1.Status(org, database, branch, migrationID) + if err != nil { + return importD1Error("status", err) + } + resp := d1.OKResponse("status", state, nil) + resp.MigrationID = state.MigrationID + return importD1Result(resp) +} + +func handleImportD1Verify(ctx context.Context, request mcp.CallToolRequest, ch *cmdutil.Helper) (*mcp.CallToolResult, error) { + org, err := request.RequireString("org") + if err != nil { + return nil, err + } + database, err := request.RequireString("database") + if err != nil { + return nil, err + } + branch := request.GetString("branch", "main") + migrationID := request.GetString("migration_id", "") + input := request.GetString("input", "") + sqlitePath := request.GetString("sqlite", "") + + client, err := ch.Client() + if err != nil { + return nil, err + } + destURI, cleanup, err := d1.ResolveDestURI(ctx, client, d1.ImportOptions{ + Org: org, + Database: database, + Branch: branch, + }) + if err != nil { + return nil, err + } + defer func() { _ = cleanup() }() + + result, err := d1.Verify(ctx, d1.VerifyOptions{ + Org: org, + Database: database, + Branch: branch, + MigrationID: migrationID, + InputPath: input, + SQLitePath: sqlitePath, + DestURI: destURI, + }) + if err != nil { + resp := d1.ErrorResponse("verify", err) + if result != nil { + resp.Data = result + } + return importD1Result(resp) + } + resp := d1.OKResponse("verify", result, nil) + resp.MigrationID = migrationID + return importD1Result(resp) +} + +func importD1Result(resp d1.Response) (*mcp.CallToolResult, error) { + b, err := json.Marshal(resp) + if err != nil { + return nil, fmt.Errorf("marshal response: %w", err) + } + return mcp.NewToolResultText(string(b)), nil +} + +func importD1Error(phase string, err error) (*mcp.CallToolResult, error) { + resp := d1.ErrorResponse(phase, err) + return importD1Result(resp) +} diff --git a/internal/cmd/mcp/server.go b/internal/cmd/mcp/server.go index 33e8c7711..bef5d2730 100644 --- a/internal/cmd/mcp/server.go +++ b/internal/cmd/mcp/server.go @@ -23,7 +23,7 @@ type ToolDef struct { // getToolDefinitions returns the list of all available MCP tools func getToolDefinitions() []ToolDef { namingBlurb := ". Two common naming conventions for PlanetScale databases are / and //. When the user provides a database identifier in either of these formats, automatically parse and use the org, database, and branch parameters directly - do not perform discovery steps like list_orgs or list_databases. Examples: `acme/widgets` -> org=acme, database=widgets. `acme/widgets/main` -> org=acme, database=widgets, branch=main. If the user provides an identifier like 'org/database' or 'org/database/branch', parse these components directly and skip organizational/database discovery steps." - return []ToolDef{ + tools := []ToolDef{ { tool: mcp.NewTool("list_orgs", mcp.WithDescription("List all available organizations"), @@ -165,6 +165,9 @@ func getToolDefinitions() []ToolDef { handler: HandleGetInsights, }, } + + tools = append(tools, importD1ToolDefs()...) + return tools } // ServerCmd returns a new cobra.Command for the mcp server command. diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 38f03e090..cc229d734 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -41,6 +41,7 @@ import ( "github.com/planetscale/cli/internal/cmd/database" "github.com/planetscale/cli/internal/cmd/dataimports" "github.com/planetscale/cli/internal/cmd/deployrequest" + "github.com/planetscale/cli/internal/cmd/importcmd" "github.com/planetscale/cli/internal/cmd/keyspace" "github.com/planetscale/cli/internal/cmd/org" "github.com/planetscale/cli/internal/cmd/password" @@ -118,16 +119,17 @@ func Execute(ctx context.Context, sigc chan os.Signal, signals []os.Signal, ver, return 0 } - // print any user specific messages first - switch format { - case printer.JSON: - fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err) - default: - fmt.Fprintf(os.Stderr, "Error: %s\n", err) + var cmdErr *cmdutil.Error + printed := errors.As(err, &cmdErr) && cmdErr.Printed + if !printed { + switch format { + case printer.JSON: + fmt.Fprintf(os.Stderr, `{"error": "%s"}`, err) + default: + fmt.Fprintf(os.Stderr, "Error: %s\n", err) + } } - // check if a sub command wants to return a specific exit code - var cmdErr *cmdutil.Error if errors.As(err, &cmdErr) { return cmdErr.ExitCode } @@ -312,6 +314,10 @@ func runCmd(ctx context.Context, ver, commit, buildDate string, format *printer. shellCmd.GroupID = "database" rootCmd.AddCommand(shellCmd) + importCmd := importcmd.ImportCmd(ch) + importCmd.GroupID = "postgres" + rootCmd.AddCommand(importCmd) + workflowCmd := workflow.WorkflowCmd(ch) workflowCmd.GroupID = "vitess" rootCmd.AddCommand(workflowCmd) diff --git a/internal/cmdutil/errors.go b/internal/cmdutil/errors.go index 2dd48bafd..28f32c44f 100644 --- a/internal/cmdutil/errors.go +++ b/internal/cmdutil/errors.go @@ -18,8 +18,11 @@ var errExpiredAuthMessage = errors.New("the access token has expired. Please run // Error can be used by a command to change the exit status of the CLI. type Error struct { Msg string - // Status + // ExitCode is returned to the shell when the command fails. ExitCode int + // Printed indicates the error output was already written (e.g. to stdout); + // root should not print Msg to stderr again. + Printed bool } func (e *Error) Error() string { return e.Msg } diff --git a/internal/migrate/d1/constraints.go b/internal/migrate/d1/constraints.go new file mode 100644 index 000000000..a91dfc742 --- /dev/null +++ b/internal/migrate/d1/constraints.go @@ -0,0 +1,279 @@ +package d1 + +import ( + "regexp" + "strings" +) + +var ( + referencesClauseRe = regexp.MustCompile(`(?is)^REFERENCES\s+(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(\s*([^)]+)\)\s*(.*)$`) + foreignKeyConstraintRe = regexp.MustCompile(`(?is)^FOREIGN\s+KEY\s*\(\s*([^)]+)\)\s*(REFERENCES\s+.+)$`) + primaryKeyConstraintRe = regexp.MustCompile(`(?is)^PRIMARY\s+KEY\s*\(\s*([^)]+)\)\s*$`) + uniqueConstraintRe = regexp.MustCompile(`(?is)^UNIQUE\s*\(\s*([^)]+)\)\s*$`) + createIndexRe = regexp.MustCompile(`(?is)^CREATE\s+(UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s+ON\s+(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(\s*([^)]+)\)\s*;?\s*$`) +) + +// IndexSchema holds a parsed CREATE INDEX statement from a dump. +type IndexSchema struct { + Name string + Table string + Unique bool + Columns string + RawDDL string +} + +func isTableConstraint(part string) bool { + upper := strings.ToUpper(strings.TrimSpace(part)) + return strings.HasPrefix(upper, "PRIMARY KEY") || + strings.HasPrefix(upper, "FOREIGN KEY") || + strings.HasPrefix(upper, "UNIQUE(") || + strings.HasPrefix(upper, "UNIQUE (") || + strings.HasPrefix(upper, "CHECK(") || + strings.HasPrefix(upper, "CHECK (") || + strings.HasPrefix(upper, "CONSTRAINT ") +} + +func referencesClause(colDef string) string { + idx := indexOfIgnoreCase(colDef, "REFERENCES") + if idx < 0 { + return "" + } + return strings.TrimSpace(colDef[idx:]) +} + +func convertTableConstraint(clause string) string { + clause = strings.TrimSpace(clause) + clause = strings.TrimSuffix(clause, ",") + if clause == "" { + return "" + } + + upper := strings.ToUpper(clause) + switch { + case strings.HasPrefix(upper, "FOREIGN KEY"): + return convertForeignKeyConstraint(clause) + case strings.HasPrefix(upper, "PRIMARY KEY"): + return convertPrimaryKeyConstraint(clause) + case strings.HasPrefix(upper, "UNIQUE"): + return convertUniqueConstraint(clause) + default: + return clause + } +} + +func convertForeignKeyConstraint(clause string) string { + m := foreignKeyConstraintRe.FindStringSubmatch(clause) + if m == nil { + return clause + } + cols := quoteColumnList(m[1]) + refs := convertReferencesClause(strings.TrimSpace(m[2])) + return "FOREIGN KEY (" + cols + ") " + refs +} + +func convertPrimaryKeyConstraint(clause string) string { + m := primaryKeyConstraintRe.FindStringSubmatch(clause) + if m == nil { + return clause + } + return "PRIMARY KEY (" + quoteColumnList(m[1]) + ")" +} + +func convertUniqueConstraint(clause string) string { + m := uniqueConstraintRe.FindStringSubmatch(clause) + if m == nil { + return clause + } + return "UNIQUE (" + quoteColumnList(m[1]) + ")" +} + +func convertReferencesClause(refs string) string { + m := referencesClauseRe.FindStringSubmatch(refs) + if m == nil { + return refs + } + table := quoteIdent(firstNonEmpty(m[1], m[2], m[3], m[4])) + refCols := quoteColumnList(m[5]) + tail := strings.TrimSpace(m[6]) + if tail != "" { + return "REFERENCES " + table + " (" + refCols + ") " + tail + } + return "REFERENCES " + table + " (" + refCols + ")" +} + +func quoteColumnList(list string) string { + parts := splitCommaList(list) + quoted := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + quoted = append(quoted, quoteIdent(strings.Trim(part, "`\"'"))) + } + return strings.Join(quoted, ", ") +} + +func splitCommaList(list string) []string { + var parts []string + var current strings.Builder + depth := 0 + inSingle := false + inDouble := false + + for _, r := range list { + switch r { + case '\'': + if !inDouble { + inSingle = !inSingle + } + current.WriteRune(r) + case '"': + if !inSingle { + inDouble = !inDouble + } + current.WriteRune(r) + case '(': + if !inSingle && !inDouble { + depth++ + } + current.WriteRune(r) + case ')': + if !inSingle && !inDouble { + depth-- + } + current.WriteRune(r) + case ',': + if depth == 0 && !inSingle && !inDouble { + parts = append(parts, current.String()) + current.Reset() + continue + } + current.WriteRune(r) + default: + current.WriteRune(r) + } + } + if current.Len() > 0 { + parts = append(parts, current.String()) + } + return parts +} + +func convertIndexDDL(raw string) string { + m := createIndexRe.FindStringSubmatch(raw) + if m == nil { + return raw + } + unique := strings.TrimSpace(m[1]) != "" + name := quoteIdent(firstNonEmpty(m[2], m[3], m[4], m[5])) + table := quoteIdent(firstNonEmpty(m[6], m[7], m[8], m[9])) + cols := quoteColumnList(m[10]) + prefix := "CREATE INDEX IF NOT EXISTS " + if unique { + prefix = "CREATE UNIQUE INDEX IF NOT EXISTS " + } + return prefix + name + " ON " + table + " (" + cols + ");" +} + +func isUUIDColumn(col ColumnSchema, table TableSchema, all []TableSchema) bool { + return isExplicitUUIDColumn(col) || columnReferencesUUIDKey(col, table, all) +} + +func isExplicitUUIDColumn(col ColumnSchema) bool { + name := strings.ToLower(col.Name) + t := strings.ToUpper(col.Type) + + if !isTextLikeType(t) { + return false + } + + if col.PrimaryKey && (name == "id" || name == "uuid") { + return true + } + if strings.HasSuffix(name, "_uuid") { + return true + } + return false +} + +func columnReferencesUUIDKey(col ColumnSchema, table TableSchema, all []TableSchema) bool { + refTable, refCol := columnFKTarget(col, table) + if refTable == "" { + return false + } + ref := tableByName(all, refTable) + if ref == nil { + return false + } + refColSchema := columnByName(*ref, refCol) + return isExplicitUUIDColumn(refColSchema) +} + +func columnFKTarget(col ColumnSchema, table TableSchema) (string, string) { + if col.ForeignKey != "" { + return parseReferencesTarget(col.ForeignKey) + } + for _, constraint := range table.Constraints { + cols, refs := parseTableLevelForeignKey(constraint) + for _, name := range cols { + if name == col.Name { + return parseReferencesTarget(refs) + } + } + } + return "", "" +} + +func parseTableLevelForeignKey(constraint string) ([]string, string) { + m := foreignKeyConstraintRe.FindStringSubmatch(constraint) + if m == nil { + return nil, "" + } + cols := make([]string, 0) + for _, part := range splitCommaList(m[1]) { + part = strings.Trim(strings.TrimSpace(part), "`\"'") + if part != "" { + cols = append(cols, part) + } + } + return cols, strings.TrimSpace(m[2]) +} + +func parseReferencesTarget(refs string) (string, string) { + m := referencesClauseRe.FindStringSubmatch(strings.TrimSpace(refs)) + if m == nil { + return "", "" + } + table := firstNonEmpty(m[1], m[2], m[3], m[4]) + refCols := splitCommaList(m[5]) + refCol := "" + if len(refCols) > 0 { + refCol = strings.Trim(strings.TrimSpace(refCols[0]), "`\"'") + } + return table, refCol +} + +func tableByName(all []TableSchema, name string) *TableSchema { + lower := strings.ToLower(name) + for i := range all { + if strings.ToLower(all[i].Name) == lower { + return &all[i] + } + } + return nil +} + +func columnByName(table TableSchema, name string) ColumnSchema { + lower := strings.ToLower(name) + for _, col := range table.Columns { + if strings.ToLower(col.Name) == lower { + return col + } + } + return ColumnSchema{} +} + +func isTextLikeType(t string) bool { + return t == "" || strings.Contains(t, "CHAR") || strings.Contains(t, "CLOB") || strings.Contains(t, "TEXT") +} diff --git a/internal/migrate/d1/constraints_test.go b/internal/migrate/d1/constraints_test.go new file mode 100644 index 000000000..35d4fc9c6 --- /dev/null +++ b/internal/migrate/d1/constraints_test.go @@ -0,0 +1,35 @@ +package d1 + +import "testing" + +func TestParseTableLevelForeignKey(t *testing.T) { + cols, refs := parseTableLevelForeignKey(`FOREIGN KEY (entity_id) REFERENCES external_entities(id)`) + if len(cols) != 1 || cols[0] != "entity_id" { + t.Fatalf("unexpected columns: %#v", cols) + } + refTable, refCol := parseReferencesTarget(refs) + if refTable != "external_entities" || refCol != "id" { + t.Fatalf("unexpected ref target: %s.%s", refTable, refCol) + } +} + +func TestColumnFKTargetUsesTableConstraint(t *testing.T) { + table := TableSchema{ + Name: "entity_links", + Columns: []ColumnSchema{ + {Name: "entity_id", Type: "TEXT", NotNull: true}, + {Name: "post_id", Type: "INTEGER", NotNull: true}, + }, + Constraints: []string{ + `PRIMARY KEY (entity_id, post_id)`, + `FOREIGN KEY (entity_id) REFERENCES external_entities(id)`, + `FOREIGN KEY (post_id) REFERENCES posts(id)`, + }, + } + col := table.Columns[0] + + refTable, refCol := columnFKTarget(col, table) + if refTable != "external_entities" || refCol != "id" { + t.Fatalf("got %s.%s", refTable, refCol) + } +} diff --git a/internal/migrate/d1/convert.go b/internal/migrate/d1/convert.go new file mode 100644 index 000000000..51dd1ffbe --- /dev/null +++ b/internal/migrate/d1/convert.go @@ -0,0 +1,229 @@ +package d1 + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +var ( + sqliteTypeCleanup = regexp.MustCompile(`(?i)\s+PRIMARY\s+KEY\s+AUTOINCREMENT`) +) + +// SchemaParts holds table DDL and secondary index DDL separately so imports can +// load data before building indexes (much faster than maintaining indexes per row). +type SchemaParts struct { + Tables string + Indexes string +} + +// ConvertSchemaParts converts SQLite DDL into Postgres table and index SQL. +func ConvertSchemaParts(inputPath string) (SchemaParts, int, error) { + tables, err := ParseDump(inputPath) + if err != nil { + return SchemaParts{}, 0, err + } + + indexes, err := ParseIndexes(inputPath) + if err != nil { + return SchemaParts{}, 0, err + } + + var tableBuf strings.Builder + tableBuf.WriteString("-- Generated by pscale import d1 convert-schema (tables)\n") + tableBuf.WriteString("-- Source: " + inputPath + "\n\n") + + converted := 0 + tableByName := make(map[string]TableSchema, len(tables)) + for _, table := range tables { + tableByName[table.Name] = table + } + for _, name := range topologicalLoadOrder(tables) { + table, ok := tableByName[name] + if !ok { + continue + } + if IsORMMetadataTable(table.Name) { + continue + } + tableBuf.WriteString(convertTableDDL(table, tables)) + tableBuf.WriteString("\n\n") + converted++ + } + + var indexBuf strings.Builder + if len(indexes) > 0 { + indexBuf.WriteString("-- Generated by pscale import d1 convert-schema (indexes)\n") + indexBuf.WriteString("-- Source: " + inputPath + "\n\n") + indexBuf.WriteString("-- Indexes\n") + for _, idx := range indexes { + if IsORMMetadataTable(idx.Table) { + continue + } + indexBuf.WriteString(convertIndexDDL(idx.RawDDL)) + indexBuf.WriteString("\n") + } + indexBuf.WriteString("\n") + } + + return SchemaParts{ + Tables: tableBuf.String(), + Indexes: indexBuf.String(), + }, converted, nil +} + +// ConvertSchema converts SQLite CREATE TABLE statements to PostgreSQL DDL. +func ConvertSchema(inputPath, outputPath string) (int, error) { + parts, converted, err := ConvertSchemaParts(inputPath) + if err != nil { + return 0, err + } + + var b strings.Builder + b.WriteString(parts.Tables) + if parts.Indexes != "" { + b.WriteString(parts.Indexes) + } + + if err := os.WriteFile(outputPath, []byte(b.String()), 0o600); err != nil { + return 0, fmt.Errorf("write schema: %w", err) + } + + return converted, nil +} + +func convertTableDDL(table TableSchema, all []TableSchema) string { + var b strings.Builder + b.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", quoteIdent(table.Name))) + + var lines []string + for _, col := range table.Columns { + lines = append(lines, " "+convertColumn(col, table, all)) + } + for _, constraint := range table.Constraints { + if converted := convertTableConstraint(constraint); converted != "" { + lines = append(lines, " "+converted) + } + } + b.WriteString(strings.Join(lines, ",\n")) + b.WriteString("\n);\n") + + return b.String() +} + +func convertColumn(col ColumnSchema, table TableSchema, all []TableSchema) string { + pgType := sqliteTypeToPostgres(col, table, all) + + var parts []string + parts = append(parts, quoteIdent(col.Name), pgType) + + if col.AutoIncrement { + parts = append(parts, "GENERATED BY DEFAULT AS IDENTITY") + if col.PrimaryKey { + parts = append(parts, "PRIMARY KEY") + } + } else if col.PrimaryKey { + parts = append(parts, "PRIMARY KEY") + } + + if col.NotNull && !col.AutoIncrement { + parts = append(parts, "NOT NULL") + } + + if col.Unique { + parts = append(parts, "UNIQUE") + } + + if col.DefaultValue != "" && !col.AutoIncrement { + parts = append(parts, "DEFAULT", convertDefault(col.DefaultValue, pgType)) + } + + if col.ForeignKey != "" { + parts = append(parts, convertReferencesClause(col.ForeignKey)) + } + + return strings.Join(parts, " ") +} + +func sqliteTypeToPostgres(col ColumnSchema, table TableSchema, all []TableSchema) string { + if isUUIDColumn(col, table, all) { + return "UUID" + } + + t := strings.ToUpper(col.Type) + + if col.AutoIncrement { + if strings.Contains(t, "BIG") { + return "BIGINT" + } + return "INTEGER" + } + + switch { + case t == "" || t == "NUMERIC": + return "TEXT" + case strings.Contains(t, "INT"): + if isBooleanColumn(col) { + return "BOOLEAN" + } + return "BIGINT" + case strings.Contains(t, "CHAR") || strings.Contains(t, "CLOB") || strings.Contains(t, "TEXT"): + if isJSONText(col) { + return "JSONB" + } + if isTimestampText(col) { + return "TIMESTAMPTZ" + } + return "TEXT" + case strings.Contains(t, "BLOB"): + return "BYTEA" + case strings.Contains(t, "REAL") || strings.Contains(t, "FLOA") || strings.Contains(t, "DOUB"): + return "DOUBLE PRECISION" + case strings.Contains(t, "BOOL"): + return "BOOLEAN" + default: + return "TEXT" + } +} + +func convertDefault(def, pgType string) string { + def = strings.TrimSpace(def) + upper := strings.ToUpper(def) + if upper == "NULL" { + return "NULL" + } + if pgType == "BOOLEAN" && (def == "0" || def == "1") { + if def == "1" { + return "TRUE" + } + return "FALSE" + } + if pgType == "UUID" { + def = strings.Trim(def, "'\"") + return "'" + def + "'" + } + if strings.HasPrefix(def, "'") || strings.HasPrefix(def, `"`) { + return def + } + if pgType == "TEXT" || pgType == "TIMESTAMPTZ" { + return "'" + strings.Trim(def, "'\"") + "'" + } + return def +} + +func quoteIdent(name string) string { + escaped := strings.ReplaceAll(name, `"`, `""`) + return `"` + escaped + `"` +} + +// ConvertCreateStatement converts a raw SQLite CREATE TABLE line to Postgres (for tests). +func ConvertCreateStatement(sqliteDDL string) string { + ddl := sqliteTypeCleanup.ReplaceAllString(sqliteDDL, "") + ddl = regexp.MustCompile(`(?i)\bAUTOINCREMENT\b`).ReplaceAllString(ddl, "") + ddl = regexp.MustCompile(`(?i)\bINTEGER\b`).ReplaceAllStringFunc(ddl, func(s string) string { + return "BIGINT" + }) + ddl = regexp.MustCompile(`(?i)\bREAL\b`).ReplaceAllString(ddl, "DOUBLE PRECISION") + return ddl +} diff --git a/internal/migrate/d1/doctor.go b/internal/migrate/d1/doctor.go new file mode 100644 index 000000000..4a0980459 --- /dev/null +++ b/internal/migrate/d1/doctor.go @@ -0,0 +1,235 @@ +package d1 + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/planetscale/cli/internal/postgres" + execabs "golang.org/x/sys/execabs" +) + +const ( + checkOK = "ok" + checkWarn = "warn" + checkFail = "fail" + checkSkip = "skip" +) + +// Doctor runs prerequisite checks for D1 migration. +func Doctor(ctx context.Context) (*DoctorResult, error) { + checks := []DoctorCheck{ + checkWrangler(ctx), + checkPgloader(ctx), + checkPsql(), + checkSQLite3(ctx), + checkCloudflareEnv(), + } + + result := &DoctorResult{Checks: checks, Ready: true} + for _, c := range checks { + if c.Status == checkFail { + result.Ready = false + } + } + return result, nil +} + +func checkWrangler(ctx context.Context) DoctorCheck { + for _, cmd := range []string{"wrangler", "npx"} { + path, err := execabs.LookPath(cmd) + if err != nil { + continue + } + if cmd == "npx" { + c := execabs.CommandContext(ctx, path, "wrangler", "--version") + out, err := c.CombinedOutput() + if err == nil { + return DoctorCheck{ + Name: "wrangler", + Status: checkOK, + Version: strings.TrimSpace(string(out)), + } + } + continue + } + c := execabs.CommandContext(ctx, path, "--version") + out, err := c.CombinedOutput() + if err == nil { + return DoctorCheck{ + Name: "wrangler", + Status: checkOK, + Version: strings.TrimSpace(string(out)), + } + } + } + + return DoctorCheck{ + Name: "wrangler", + Status: checkWarn, + Message: "wrangler not found", + Remediation: wranglerMissingRemediation, + } +} + +func checkPgloader(ctx context.Context) DoctorCheck { + path, err := execabs.LookPath("pgloader") + if err != nil { + return DoctorCheck{ + Name: "pgloader", + Status: checkFail, + Message: "pgloader not found", + Remediation: pgloaderInstallRemediation, + } + } + c := execabs.CommandContext(ctx, path, "--version") + out, err := c.CombinedOutput() + if err != nil { + return DoctorCheck{ + Name: "pgloader", + Status: checkFail, + Message: "pgloader found but --version failed", + Remediation: "Reinstall pgloader", + } + } + return DoctorCheck{ + Name: "pgloader", + Status: checkOK, + Version: strings.TrimSpace(string(out)), + } +} + +func checkPsql() DoctorCheck { + major, minor, err := postgres.CheckPsqlVersion(10) + if err != nil { + return DoctorCheck{ + Name: "psql", + Status: checkFail, + Message: err.Error(), + Remediation: "Install PostgreSQL client tools: brew install postgresql@18", + } + } + return DoctorCheck{ + Name: "psql", + Status: checkOK, + Version: fmt.Sprintf("%d.%d", major, minor), + } +} + +func checkSQLite3(ctx context.Context) DoctorCheck { + path, err := execabs.LookPath("sqlite3") + if err != nil { + return DoctorCheck{ + Name: "sqlite3", + Status: checkSkip, + Message: "sqlite3 CLI not found", + Remediation: "Optional: install sqlite3 for verify and pgloader prep (brew install sqlite)", + } + } + c := execabs.CommandContext(ctx, path, "--version") + out, err := c.CombinedOutput() + if err != nil { + return DoctorCheck{ + Name: "sqlite3", + Status: checkSkip, + } + } + return DoctorCheck{ + Name: "sqlite3", + Status: checkOK, + Version: strings.TrimSpace(string(out)), + } +} + +func checkCloudflareEnv() DoctorCheck { + token := os.Getenv("CLOUDFLARE_API_TOKEN") + account := os.Getenv("CLOUDFLARE_ACCOUNT_ID") + if token != "" && account != "" { + return DoctorCheck{ + Name: "cloudflare_auth", + Status: checkOK, + } + } + return DoctorCheck{ + Name: "cloudflare_auth", + Status: checkWarn, + Message: "CLOUDFLARE_API_TOKEN and/or CLOUDFLARE_ACCOUNT_ID not set", + Remediation: "Set Cloudflare env vars for remote export, or pass --input with an existing dump", + } +} + +// DoctorReadinessError summarizes failed prerequisite checks for doctor/start. +func DoctorReadinessError(result *DoctorResult) error { + if result == nil || result.Ready { + return nil + } + + var parts []string + var remediations []string + for _, c := range result.Checks { + if c.Status != checkFail { + continue + } + msg := c.Name + if c.Message != "" { + msg += ": " + c.Message + } + parts = append(parts, msg) + if c.Remediation != "" { + remediations = append(remediations, c.Remediation) + } + } + + message := "prerequisites not met" + if len(parts) > 0 { + message = strings.Join(parts, "; ") + } + remediation := strings.Join(remediations, "; ") + if remediation == "" { + remediation = "Run `pscale import d1 doctor` and fix failed checks" + } + return newMigrationError(ErrCodePrereqFailed, message, remediation) +} + +// DoctorNextSteps suggests next actions after doctor. +func DoctorNextSteps(result *DoctorResult) []NextStep { + if !result.Ready { + return []NextStep{ + {Command: "pscale import d1 doctor", Reason: "Fix failed checks and re-run doctor"}, + } + } + return []NextStep{ + {Tool: "import_d1_export", Command: "pscale import d1 export --d1-database --remote", Reason: "Export D1 database remotely"}, + {Tool: "import_d1_lint", Command: "pscale import d1 lint --input ./d1-export.sql", Reason: "Lint an existing export file"}, + } +} + +// FindWrangler returns the wrangler command to execute. +func FindWrangler() (string, []string, error) { + if path, err := execabs.LookPath("wrangler"); err == nil { + return path, nil, nil + } + if path, err := execabs.LookPath("npx"); err == nil { + return path, []string{"wrangler"}, nil + } + return "", nil, errMissingTool("wrangler", wranglerMissingRemediation) +} + +// FindPgloader returns pgloader path. +func FindPgloader() (string, error) { + path, err := execabs.LookPath("pgloader") + if err != nil { + return "", errMissingTool("pgloader", pgloaderInstallRemediation) + } + return path, nil +} + +// FindSQLite3 returns sqlite3 path. +func FindSQLite3() (string, error) { + path, err := execabs.LookPath("sqlite3") + if err != nil { + return "", errMissingTool("sqlite3", "Install with: brew install sqlite") + } + return path, nil +} diff --git a/internal/migrate/d1/doctor_test.go b/internal/migrate/d1/doctor_test.go new file mode 100644 index 000000000..f0216d0f7 --- /dev/null +++ b/internal/migrate/d1/doctor_test.go @@ -0,0 +1,70 @@ +package d1 + +import ( + "context" + "testing" +) + +func TestDoctor_RequiresPgloader(t *testing.T) { + if _, err := FindPgloader(); err == nil { + t.Skip("pgloader installed") + } + + result, err := Doctor(context.Background()) + if err != nil { + t.Fatalf("Doctor: %v", err) + } + if result.Ready { + t.Fatal("expected doctor not ready without pgloader") + } + + var pgloaderCheck DoctorCheck + for _, c := range result.Checks { + if c.Name == "pgloader" { + pgloaderCheck = c + break + } + } + if pgloaderCheck.Status != checkFail { + t.Fatalf("pgloader check status = %q, want %q", pgloaderCheck.Status, checkFail) + } + + if err := DoctorReadinessError(result); err == nil { + t.Fatal("expected readiness error") + } else { + requireMigrationErr(t, err, ErrCodePrereqFailed) + } +} + +func TestPrepareImport_RequiresPgloader(t *testing.T) { + if _, err := FindPgloader(); err == nil { + t.Skip("pgloader installed") + } + + _, err := PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + }) + if err == nil { + t.Fatal("expected missing pgloader error") + } + requireMigrationErr(t, err, ErrCodeMissingTool) +} + +func TestImport_RequiresPgloader(t *testing.T) { + if _, err := FindPgloader(); err == nil { + t.Skip("pgloader installed") + } + + _, err := Import(context.Background(), nil, nil, ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + DryRun: true, + }, nil) + if err == nil { + t.Fatal("expected missing pgloader error") + } + requireMigrationErr(t, err, ErrCodeMissingTool) +} diff --git a/internal/migrate/d1/errors.go b/internal/migrate/d1/errors.go new file mode 100644 index 000000000..b8a584513 --- /dev/null +++ b/internal/migrate/d1/errors.go @@ -0,0 +1,86 @@ +package d1 + +import ( + "errors" + "fmt" + "strings" +) + +// ErrCode constants for structured errors. +const ( + ErrCodeVirtualTable = "VIRTUAL_TABLE" + ErrCodeMissingInput = "MISSING_INPUT" + ErrCodeMissingTool = "MISSING_TOOL" + ErrCodeInvalidInput = "INVALID_INPUT" + ErrCodeImportFailed = "IMPORT_FAILED" + ErrCodeVerifyFailed = "VERIFY_FAILED" + ErrCodeNotFound = "NOT_FOUND" + ErrCodePrereqFailed = "PREREQ_FAILED" + ErrCodeLintBlocked = "LINT_BLOCKED" + ErrCodeDestinationConflict = "DESTINATION_CONFLICT" +) + +const ( + wranglerMissingRemediation = "Install wrangler, use npx wrangler d1 export, or pass --input if you already have a dump." + pgloaderInstallRemediation = "Install pgloader (brew install pgloader on macOS; see https://pgloader.readthedocs.io/en/latest/install.html for other platforms)" + lintBlockedRemediation = "Fix lint errors or run `pscale import d1 lint` for details; use `import d1 start --dry-run` for a read-only preview" +) + +type MigrationError struct { + Info ErrorInfo +} + +func (e *MigrationError) Error() string { + return e.Info.Message +} + +func migrationErr(err error) (*MigrationError, bool) { + var me *MigrationError + if errors.As(err, &me) { + return me, true + } + return nil, false +} + +func newMigrationError(code, message, remediation string) *MigrationError { + return &MigrationError{ + Info: ErrorInfo{ + Code: code, + Message: message, + Remediation: remediation, + }, + } +} + +func lintBlockedReason(errorCount int) string { + return fmt.Sprintf("lint reported %d error(s); fix or use import d1 lint for details", errorCount) +} + +func errMissingInput(path string) error { + return newMigrationError( + ErrCodeMissingInput, + fmt.Sprintf("input file not found: %s", path), + "Run `pscale import d1 export`, export with wrangler/npx, or pass an existing dump with --input", + ) +} + +func errMissingTool(name, remediation string) error { + return newMigrationError( + ErrCodeMissingTool, + fmt.Sprintf("required tool not found: %s", name), + remediation, + ) +} + +func errExistingImportTables(tables []string) error { + return newMigrationError( + ErrCodeDestinationConflict, + fmt.Sprintf("destination already has tables from this import: %s", strings.Join(tables, ", ")), + "Use a new branch, drop the conflicting tables, or choose a database without overlapping table names before importing", + ) +} + +// ErrLintBlocked returns a structured error when lint errors block import. +func ErrLintBlocked(reason string) error { + return newMigrationError(ErrCodeLintBlocked, reason, lintBlockedRemediation) +} diff --git a/internal/migrate/d1/errors_test.go b/internal/migrate/d1/errors_test.go new file mode 100644 index 000000000..b2c9a149f --- /dev/null +++ b/internal/migrate/d1/errors_test.go @@ -0,0 +1,14 @@ +package d1 + +import "testing" + +func requireMigrationErr(t *testing.T, err error, code string) { + t.Helper() + me, ok := migrationErr(err) + if !ok { + t.Fatalf("expected MigrationError, got %T: %v", err, err) + } + if me.Info.Code != code { + t.Fatalf("code = %q, want %q", me.Info.Code, code) + } +} diff --git a/internal/migrate/d1/export.go b/internal/migrate/d1/export.go new file mode 100644 index 000000000..67a4e6dc6 --- /dev/null +++ b/internal/migrate/d1/export.go @@ -0,0 +1,76 @@ +package d1 + +import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + execabs "golang.org/x/sys/execabs" +) + +// ExportOptions configures D1 export via wrangler. +type ExportOptions struct { + D1Database string + Output string + Remote bool + Table string + NoData bool +} + +// Export runs wrangler d1 export. +func Export(ctx context.Context, opts ExportOptions) (*ExportResult, error) { + if opts.D1Database == "" { + return nil, newMigrationError(ErrCodeInvalidInput, "d1 database name is required", "Pass --d1-database") + } + if opts.Output == "" { + opts.Output = fmt.Sprintf("d1-export-%s.sql", opts.D1Database) + } + + bin, prefix, err := FindWrangler() + if err != nil { + return nil, err + } + + args := append(prefix, "d1", "export", opts.D1Database, "--output", opts.Output) + if opts.Remote { + args = append(args, "--remote") + } + if opts.Table != "" { + args = append(args, "--table="+opts.Table) + } + if opts.NoData { + args = append(args, "--no-data=true") + } + + cmd := execabs.CommandContext(ctx, bin, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = os.Environ() + + if err := cmd.Run(); err != nil { + return nil, newMigrationError( + ErrCodeImportFailed, + fmt.Sprintf("wrangler export failed: %v", err), + "Ensure CLOUDFLARE_API_TOKEN and CLOUDFLARE_ACCOUNT_ID are set, or "+wranglerMissingRemediation, + ) + } + + size, _ := FileSize(opts.Output) + return &ExportResult{ + OutputPath: opts.Output, + Remote: opts.Remote, + Database: opts.D1Database, + ExportedAt: time.Now().UTC(), + SizeBytes: size, + }, nil +} + +// DefaultSQLitePath returns a sqlite path adjacent to the dump. +func DefaultSQLitePath(dumpPath string) string { + base := filepath.Base(dumpPath) + ext := filepath.Ext(base) + name := base[:len(base)-len(ext)] + return filepath.Join(filepath.Dir(dumpPath), name+".sqlite") +} diff --git a/internal/migrate/d1/import.go b/internal/migrate/d1/import.go new file mode 100644 index 000000000..dc3991616 --- /dev/null +++ b/internal/migrate/d1/import.go @@ -0,0 +1,387 @@ +package d1 + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + ps "github.com/planetscale/planetscale-go/planetscale" + + "github.com/planetscale/cli/internal/postgres" + "github.com/planetscale/cli/internal/roleutil" + execabs "golang.org/x/sys/execabs" +) + +// ImportOptions configures D1 import into PlanetScale Postgres. +type ImportOptions struct { + Org string + Database string + Branch string + InputPath string + Method string + MigrationID string + DBName string + DryRun bool + DestURI string // optional override for testing +} + +// ImportClient abstracts PlanetScale API access for import. +type ImportClient interface { + GetDatabase(ctx context.Context, org, database string) (*ps.Database, error) +} + +// DefaultImportClient wraps planetscale client. +type DefaultImportClient struct { + Client *ps.Client +} + +func (c *DefaultImportClient) GetDatabase(ctx context.Context, org, database string) (*ps.Database, error) { + return c.Client.Databases.Get(ctx, &ps.GetDatabaseRequest{ + Organization: org, + Database: database, + }) +} + +// Import loads a D1 SQLite dump into PlanetScale Postgres. +// Pass prepared when the caller already ran PrepareImport (e.g. human confirm flow). +func Import(ctx context.Context, psClient *ps.Client, client ImportClient, opts ImportOptions, prepared *ImportPrepareResult) (result *ImportResult, err error) { + if prepared == nil { + prepared, err = PrepareImport(opts) + if err != nil { + return nil, err + } + } + + opts.MigrationID = prepared.MigrationID + opts.Method = prepared.Method + + result = importResultFromPrepare(prepared, opts.DryRun) + + if !prepared.CanProceed { + return result, ErrLintBlocked(prepared.BlockedReason) + } + + if opts.DryRun { + return result, nil + } + + importStarted := false + defer func() { + if err != nil && importStarted { + _ = saveImportMigrationState(opts, PhaseFailed, "") + } + }() + + importStart := time.Now() + timings := &ImportTimings{} + + db, err := client.GetDatabase(ctx, opts.Org, opts.Database) + if err != nil { + return nil, fmt.Errorf("get database: %w", err) + } + if db.Kind != "postgresql" { + return nil, newMigrationError( + ErrCodeInvalidInput, + fmt.Sprintf("database %s is not PostgreSQL", opts.Database), + "Create a PostgreSQL database branch for D1 migration", + ) + } + + sqlitePath := DefaultSQLitePath(opts.InputPath) + if state, stateErr := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID); stateErr == nil { + if state.InputPath != "" && state.InputPath != opts.InputPath { + return nil, newMigrationError( + ErrCodeInvalidInput, + fmt.Sprintf("input path %q does not match migration state %q", opts.InputPath, state.InputPath), + "Use the same --input as the original import or omit --migration-id to start fresh", + ) + } + if state.SQLitePath != "" { + sqlitePath = state.SQLitePath + } + } + + importStarted = true + if err := saveImportMigrationState(opts, PhaseImporting, ""); err != nil { + return nil, err + } + + sqliteStart := time.Now() + if err := EnsureSQLiteFromDump(ctx, opts.InputPath, sqlitePath); err != nil { + return nil, err + } + timings.SQLiteStagingMs = time.Since(sqliteStart).Milliseconds() + + destURI, cleanup, err := ResolveDestURI(ctx, psClient, opts) + if err != nil { + return nil, err + } + if cleanup != nil { + defer cleanup() + } + + currentUser, err := usernameFromDestURI(destURI) + if err != nil { + return nil, err + } + if err := reassignStaleImportRoleObjects(ctx, psClient, opts, currentUser); err != nil { + return nil, err + } + + switch opts.Method { + case MethodPgloader: + if err := importWithPgloader(ctx, opts, destURI, sqlitePath, timings); err != nil { + return nil, err + } + case MethodPsql: + if err := importSmall(ctx, opts, destURI, sqlitePath); err != nil { + return nil, err + } + default: + return nil, newMigrationError(ErrCodeInvalidInput, "unknown import method: "+opts.Method, "Use pgloader (large dumps) or psql (small dumps; data loaded via pgloader)") + } + + tables, err := ParseDump(opts.InputPath) + if err == nil { + for _, table := range tables { + if !IsORMMetadataTable(table.Name) { + result.TablesLoaded++ + } + } + } + + timings.TotalMs = time.Since(importStart).Milliseconds() + result.Timings = timings + + state := &MigrationState{ + MigrationID: opts.MigrationID, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + InputPath: opts.InputPath, + SQLitePath: sqlitePath, + Method: opts.Method, + Phase: PhaseImported, + } + if err := SaveState(state); err != nil { + return nil, err + } + + return result, nil +} + +func importWithPgloader(ctx context.Context, opts ImportOptions, destURI, sqlitePath string, timings *ImportTimings) error { + schemaStart := time.Now() + if err := applyPostgresSchema(ctx, opts, destURI); err != nil { + return err + } + timings.SchemaMs = time.Since(schemaStart).Milliseconds() + return loadTablesAndFinalize(ctx, opts, destURI, sqlitePath, timings) +} + +// importSmall loads dumps under 1GB: schema via psql, data via pgloader. +func importSmall(ctx context.Context, opts ImportOptions, destURI, sqlitePath string) error { + if err := applyPostgresSchema(ctx, opts, destURI); err != nil { + return err + } + return loadTablesAndFinalize(ctx, opts, destURI, sqlitePath, nil) +} + +func loadTablesAndFinalize(ctx context.Context, opts ImportOptions, destURI, sqlitePath string, timings *ImportTimings) error { + loadTables, err := PgloaderLoadTables(opts.InputPath) + if err != nil { + return err + } + + pgTimings, err := RunPgloader(ctx, PgloaderOptions{ + SQLitePath: sqlitePath, + DestURI: destURI, + InputPath: opts.InputPath, + DataOnly: true, + Tables: loadTables, + }) + if err != nil { + return err + } + if timings != nil { + timings.PgloaderMs = pgTimings.PgloaderMs + timings.TableLoads = pgTimings.TableLoads + } + + indexStart := time.Now() + if err := applyPostgresIndexes(ctx, opts, destURI); err != nil { + return err + } + if timings != nil { + timings.IndexBuildMs = time.Since(indexStart).Milliseconds() + } + + seqStart := time.Now() + if err := ResetImportedSequences(ctx, destURI, opts.InputPath); err != nil { + return err + } + if timings != nil { + timings.SequenceResetMs = time.Since(seqStart).Milliseconds() + } + return nil +} + +// ResolveDestURI creates a short-lived Postgres role and returns a connection string. +func ResolveDestURI(ctx context.Context, psClient *ps.Client, opts ImportOptions) (string, func() error, error) { + if opts.DestURI != "" { + return opts.DestURI, func() error { return nil }, nil + } + if psClient == nil { + return "", nil, fmt.Errorf("planetscale client required for import") + } + + roleName := fmt.Sprintf("d1-import-%d", time.Now().Unix()) + role, err := roleutil.New(ctx, psClient, roleutil.Options{ + Organization: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + Name: roleName, + TTL: 2 * time.Hour, + InheritedRoles: []string{"postgres"}, + }) + if err != nil { + return "", nil, fmt.Errorf("create destination role: %w", err) + } + + dbName := opts.DBName + if dbName == "" { + dbName = "postgres" + } + + uri := postgres.BuildConnectionString(&postgres.Config{ + Host: role.Role.AccessHostURL, + Port: 5432, + User: role.Role.Username, + Password: role.Role.Password, + Database: dbName, + SSLMode: "require", + Options: map[string]string{}, + }) + + return uri, func() error { return role.Cleanup(ctx, "postgres") }, nil +} + +// ResetImportedSequences aligns identity sequences with MAX(column) after pgloader import. +// Per-table pgloader runs may leave sequences at their initial value; setval is idempotent. +func ResetImportedSequences(ctx context.Context, destURI, inputPath string) error { + tables, err := ParseDump(inputPath) + if err != nil { + return err + } + + db, err := OpenPostgres(destURI) + if err != nil { + return err + } + defer db.Close() + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + for _, col := range table.Columns { + if !col.AutoIncrement { + continue + } + query := fmt.Sprintf( + `SELECT setval(pg_get_serial_sequence($1, $2), GREATEST(COALESCE((SELECT MAX(%s) FROM %s), 1), 1), true)`, + quoteIdent(col.Name), + quoteIdent(table.Name), + ) + if _, err := db.ExecContext(ctx, query, "public."+table.Name, col.Name); err != nil { + return fmt.Errorf("reset sequence %s.%s: %w", table.Name, col.Name, err) + } + } + } + return nil +} + +func applyPostgresSchema(ctx context.Context, opts ImportOptions, destURI string) error { + tables, err := ParseDump(opts.InputPath) + if err != nil { + return err + } + + importNames := importTableNames(tables) + existing, err := existingPublicTables(ctx, destURI, importNames) + if err != nil { + return err + } + if conflicts := conflictingImportTables(importNames, existing); len(conflicts) > 0 { + return errExistingImportTables(conflicts) + } + + workDir, err := os.MkdirTemp("", "pscale-d1-schema-*") + if err != nil { + return err + } + defer os.RemoveAll(workDir) + + var b strings.Builder + b.WriteString("-- Generated by pscale import d1\n") + b.WriteString("-- Source: ") + b.WriteString(opts.InputPath) + b.WriteString("\n\n") + b.WriteString(buildImportTablesSQL(tables)) + + combinedPath := filepath.Join(workDir, fmt.Sprintf("postgres-tables-%s.sql", opts.MigrationID)) + if err := os.WriteFile(combinedPath, []byte(b.String()), 0o600); err != nil { + return err + } + + return runPsqlFile(ctx, destURI, combinedPath) +} + +func applyPostgresIndexes(ctx context.Context, opts ImportOptions, destURI string) error { + parts, _, err := ConvertSchemaParts(opts.InputPath) + if err != nil { + return err + } + if strings.TrimSpace(parts.Indexes) == "" { + return nil + } + + workDir, err := os.MkdirTemp("", "pscale-d1-indexes-*") + if err != nil { + return err + } + defer os.RemoveAll(workDir) + + var b strings.Builder + b.WriteString("-- Generated by pscale import d1 (post-load indexes)\n") + b.WriteString(fmt.Sprintf("SET maintenance_work_mem TO '%s';\n", pgloaderIndexMaintenanceWorkMem)) + b.WriteString(parts.Indexes) + + indexPath := filepath.Join(workDir, fmt.Sprintf("postgres-indexes-%s.sql", opts.MigrationID)) + if err := os.WriteFile(indexPath, []byte(b.String()), 0o600); err != nil { + return err + } + + return runPsqlFile(ctx, destURI, indexPath) +} + +func runPsqlFile(ctx context.Context, destURI, path string) error { + psqlPath, err := postgres.FindPsqlPath() + if err != nil { + return err + } + + cmd := execabs.CommandContext(ctx, psqlPath, destURI, "-v", "ON_ERROR_STOP=1", "-f", path) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("psql %s: %w: %s", filepath.Base(path), err, string(out)) + } + return nil +} + +// Status returns migration state for status polling. +func Status(org, database, branch, migrationID string) (*MigrationState, error) { + return LoadState(org, database, branch, migrationID) +} diff --git a/internal/migrate/d1/import_test.go b/internal/migrate/d1/import_test.go new file mode 100644 index 000000000..8c25b3965 --- /dev/null +++ b/internal/migrate/d1/import_test.go @@ -0,0 +1,14 @@ +package d1 + +import ( + "testing" +) + +func TestIsEphemeralImportRole(t *testing.T) { + if !isEphemeralImportRole("pscale_api_abc123") { + t.Fatal("expected pscale_api role to match") + } + if isEphemeralImportRole("postgres") { + t.Fatal("postgres should not match") + } +} diff --git a/internal/migrate/d1/lint.go b/internal/migrate/d1/lint.go new file mode 100644 index 000000000..cfd8a1f5e --- /dev/null +++ b/internal/migrate/d1/lint.go @@ -0,0 +1,139 @@ +package d1 + +import ( + "path/filepath" + "strings" +) + +// Lint analyzes a SQLite dump for migration issues. +func Lint(inputPath string) (*LintResult, error) { + tables, err := ParseDump(inputPath) + if err != nil { + return nil, err + } + + result := &LintResult{ + InputPath: inputPath, + TableCount: len(tables), + Issues: []Issue{}, + Tables: make([]string, 0, len(tables)), + } + + for _, table := range tables { + result.Tables = append(result.Tables, table.Name) + result.Issues = append(result.Issues, lintTable(table, tables)...) + } + + for _, issue := range result.Issues { + switch issue.Severity { + case SeverityError: + result.ErrorCount++ + case SeverityWarning: + result.WarningCount++ + } + } + + return result, nil +} + +func lintTable(table TableSchema, all []TableSchema) []Issue { + var issues []Issue + + issues = append(issues, lintORMMetadata(table)...) + + for _, col := range table.Columns { + if col.AutoIncrement { + issues = append(issues, Issue{ + Code: "AUTOINCREMENT", + Severity: SeverityWarning, + Table: table.Name, + Column: col.Name, + Remediation: "Will map to GENERATED BY DEFAULT AS IDENTITY", + }) + } + + if isBooleanColumn(col) { + issues = append(issues, Issue{ + Code: "BOOLEAN_AS_INTEGER", + Severity: SeverityWarning, + Table: table.Name, + Column: col.Name, + Remediation: "0/1 integer values will cast to boolean on import", + }) + } + + if isTimestampText(col) { + issues = append(issues, Issue{ + Code: "TEXT_TIMESTAMP", + Severity: SeverityWarning, + Table: table.Name, + Column: col.Name, + Remediation: "TEXT timestamps will cast to timestamptz using pgloader rules", + }) + } + + if isJSONText(col) { + issues = append(issues, Issue{ + Code: "JSON_IN_TEXT", + Severity: SeverityInfo, + Table: table.Name, + Column: col.Name, + Remediation: "TEXT JSON payloads can map to jsonb where detected", + }) + } + + if isUUIDColumn(col, table, all) { + issues = append(issues, Issue{ + Code: "TEXT_UUID", + Severity: SeverityInfo, + Table: table.Name, + Column: col.Name, + Remediation: "TEXT/UUID-style keys will map to UUID in Postgres", + }) + } + } + + return issues +} + +func isBooleanColumn(col ColumnSchema) bool { + name := strings.ToLower(col.Name) + if strings.Contains(name, "is_") || strings.HasSuffix(name, "_flag") || + name == "active" || name == "enabled" || name == "published" { + return col.Type == "INTEGER" || col.Type == "INT" + } + return false +} + +func isTimestampText(col ColumnSchema) bool { + name := strings.ToLower(col.Name) + if strings.Contains(name, "_at") || strings.Contains(name, "timestamp") || strings.Contains(name, "date") { + return col.Type == "TEXT" + } + return false +} + +func isJSONText(col ColumnSchema) bool { + name := strings.ToLower(col.Name) + return (strings.Contains(name, "json") || strings.Contains(name, "metadata") || strings.Contains(name, "payload")) && + col.Type == "TEXT" +} + +// LintNextSteps returns agent next steps based on lint results. +func LintNextSteps(result *LintResult) []NextStep { + steps := []NextStep{ + { + Tool: "import_d1_start", + Command: "pscale import d1 start --input " + filepath.Base(result.InputPath) + " --database --dry-run", + Reason: "Preview import plan and get a migration ID", + }, + } + if result.ErrorCount == 0 { + steps = append(steps, NextStep{ + Tool: "import_d1_start", + Command: "pscale import d1 start --input " + filepath.Base(result.InputPath) + " --database ", + Reason: "Run import after lint passes", + }) + } + return steps +} diff --git a/internal/migrate/d1/lint_test.go b/internal/migrate/d1/lint_test.go new file mode 100644 index 000000000..cd6e5895f --- /dev/null +++ b/internal/migrate/d1/lint_test.go @@ -0,0 +1,206 @@ +package d1 + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func testFixture(t *testing.T) string { + t.Helper() + return filepath.Join("testdata", "sample_d1_export.sql") +} + +func TestParseDump(t *testing.T) { + tables, err := ParseDump(testFixture(t)) + if err != nil { + t.Fatalf("ParseDump: %v", err) + } + if len(tables) != 6 { + t.Fatalf("expected 6 tables, got %d", len(tables)) + } + if tables[0].Name != "users" { + t.Fatalf("expected users table first, got %s", tables[0].Name) + } + + var teamMembers *TableSchema + for i := range tables { + if tables[i].Name == "entity_links" { + teamMembers = &tables[i] + break + } + } + if teamMembers == nil { + t.Fatal("expected entity_links table") + } + if len(teamMembers.Constraints) < 2 { + t.Fatalf("expected composite PK and FK constraints, got %v", teamMembers.Constraints) + } +} + +func TestParseIndexes(t *testing.T) { + indexes, err := ParseIndexes(testFixture(t)) + if err != nil { + t.Fatalf("ParseIndexes: %v", err) + } + if len(indexes) != 2 { + t.Fatalf("expected 2 indexes, got %d", len(indexes)) + } + if indexes[0].Name != "idx_users_email" { + t.Fatalf("unexpected first index: %s", indexes[0].Name) + } +} + +func TestLint(t *testing.T) { + result, err := Lint(testFixture(t)) + if err != nil { + t.Fatalf("Lint: %v", err) + } + if result.TableCount != 6 { + t.Fatalf("expected 6 tables, got %d", result.TableCount) + } + if result.ErrorCount != 0 { + t.Fatalf("expected no errors, got %d", result.ErrorCount) + } + if result.WarningCount == 0 { + t.Fatal("expected warnings for autoincrement/boolean columns") + } + + foundAutoincrement := false + foundDrizzle := false + for _, issue := range result.Issues { + if issue.Code == "AUTOINCREMENT" { + foundAutoincrement = true + } + if issue.Code == "DRIZZLE_MIGRATIONS" { + foundDrizzle = true + } + } + if !foundAutoincrement { + t.Fatal("expected AUTOINCREMENT issue") + } + if !foundDrizzle { + t.Fatal("expected DRIZZLE_MIGRATIONS issue") + } +} + +func TestPlan(t *testing.T) { + plan, err := Plan(PlanOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + Branch: "main", + }) + if err != nil { + t.Fatalf("Plan: %v", err) + } + if plan.MigrationID == "" { + t.Fatal("expected migration id") + } + if len(plan.LoadOrder) != 6 { + t.Fatalf("expected load order length 6, got %d", len(plan.LoadOrder)) + } + if plan.RecommendedMethod == "" { + t.Fatal("expected recommended method") + } +} + +func TestConvertSchema(t *testing.T) { + out := t.TempDir() + "/schema.sql" + count, err := ConvertSchema(testFixture(t), out) + if err != nil { + t.Fatalf("ConvertSchema: %v", err) + } + if count != 4 { + t.Fatalf("expected 4 tables converted, got %d", count) + } + data, err := os.ReadFile(out) + if err != nil { + t.Fatal(err) + } + content := string(data) + checks := []string{ + "GENERATED BY DEFAULT AS IDENTITY", + "BOOLEAN", + "TIMESTAMPTZ", + `FOREIGN KEY ("user_id") REFERENCES "users" ("id")`, + `PRIMARY KEY ("entity_id", "post_id")`, + `"id" UUID PRIMARY KEY`, + `"entity_id" UUID NOT NULL`, + `CREATE INDEX IF NOT EXISTS "idx_users_email"`, + `CREATE UNIQUE INDEX IF NOT EXISTS "idx_entity_links_post"`, + `UNIQUE`, + `"id" INTEGER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY`, + } + for _, check := range checks { + if !strings.Contains(content, check) { + t.Fatalf("expected schema to contain %q\n%s", check, content) + } + } + if strings.Contains(content, "__drizzle_migrations") || strings.Contains(content, "_prisma_migrations") { + t.Fatal("ORM metadata tables should be skipped in schema output") + } +} + +func TestCountInsertRows(t *testing.T) { + counts, err := CountInsertRows(testFixture(t)) + if err != nil { + t.Fatalf("CountInsertRows: %v", err) + } + if counts["users"] != 2 { + t.Fatalf("expected 2 user rows, got %d", counts["users"]) + } + if counts["posts"] != 2 { + t.Fatalf("expected 2 post rows, got %d", counts["posts"]) + } +} + +func TestStateStore(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + store, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + + state := &MigrationState{ + MigrationID: "test123", + Org: "acme", + Database: "mydb", + Branch: "main", + InputPath: testFixture(t), + Phase: PhasePlanned, + } + if err := store.Save(state); err != nil { + t.Fatal(err) + } + + loaded, err := store.Load("acme", "mydb", "main", "test123") + if err != nil { + t.Fatal(err) + } + if loaded.MigrationID != "test123" { + t.Fatalf("expected test123, got %s", loaded.MigrationID) + } + + if err := store.Delete("acme", "mydb", "main", "test123"); err != nil { + t.Fatal(err) + } +} + +func TestConvertTableConstraint(t *testing.T) { + got := convertTableConstraint("FOREIGN KEY (team_id, user_id) REFERENCES teams(id)") + want := `FOREIGN KEY ("team_id", "user_id") REFERENCES "teams" ("id")` + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestQuoteColumnList(t *testing.T) { + got := quoteColumnList("org_id, slug") + want := `"org_id", "slug"` + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} diff --git a/internal/migrate/d1/orm_metadata.go b/internal/migrate/d1/orm_metadata.go new file mode 100644 index 000000000..a30b0019b --- /dev/null +++ b/internal/migrate/d1/orm_metadata.go @@ -0,0 +1,157 @@ +package d1 + +import ( + "strings" +) + +type ormMetadataRule struct { + code string + orm string + remediation string + match func(table string) bool +} + +var ormMetadataRules = []ormMetadataRule{ + { + code: "DRIZZLE_MIGRATIONS", + orm: "Drizzle", + remediation: "After import, baseline Drizzle on Postgres (e.g. drizzle-kit push or a fresh migrations folder); " + + "do not rely on SQLite __drizzle_migrations history", + match: func(table string) bool { + return strings.HasPrefix(strings.ToLower(table), "__drizzle") + }, + }, + { + code: "PRISMA_MIGRATIONS", + orm: "Prisma", + remediation: "After import, baseline Prisma on Postgres (e.g. prisma db pull then prisma migrate resolve / new initial migration); " + + "do not import _prisma_migrations from SQLite", + match: matchTableName("_prisma_migrations"), + }, + { + code: "KNEX_MIGRATIONS", + orm: "Knex", + remediation: "After import, re-baseline Knex migration history on Postgres; knex_migrations from SQLite is not valid on Postgres", + match: matchAnyTableName("knex_migrations", "knex_migrations_lock"), + }, + { + code: "SEQUELIZE_META", + orm: "Sequelize", + remediation: "After import, re-baseline Sequelize migration history on Postgres; SequelizeMeta from SQLite is not valid on Postgres", + match: matchTableName("sequelizemeta"), + }, + { + code: "RAILS_MIGRATIONS", + orm: "Rails ActiveRecord", + remediation: "After import, re-baseline Rails schema_migrations on Postgres; SQLite migration versions do not transfer cleanly", + match: matchAnyTableName("schema_migrations", "ar_internal_metadata"), + }, + { + code: "FLYWAY_MIGRATIONS", + orm: "Flyway", + remediation: "After import, baseline Flyway on Postgres; flyway_schema_history from SQLite must not be reused", + match: matchTableName("flyway_schema_history"), + }, + { + code: "LIQUIBASE_MIGRATIONS", + orm: "Liquibase", + remediation: "After import, baseline Liquibase on Postgres; databasechangelog tables from SQLite must not be reused", + match: matchAnyTableName("databasechangelog", "databasechangeloglock"), + }, + { + code: "DJANGO_MIGRATIONS", + orm: "Django", + remediation: "After import, run django migrate --fake-initial or otherwise baseline django_migrations on Postgres", + match: matchTableName("django_migrations"), + }, + { + code: "ALEMBIC_VERSION", + orm: "Alembic", + remediation: "After import, stamp Alembic to the correct Postgres revision; alembic_version from SQLite is not portable", + match: matchTableName("alembic_version"), + }, + { + code: "TYPEORM_METADATA", + orm: "TypeORM", + remediation: "After import, baseline TypeORM migrations on Postgres; typeorm_metadata from SQLite is not valid on Postgres", + match: matchTableName("typeorm_metadata"), + }, + { + code: "GOOSE_MIGRATIONS", + orm: "Goose", + remediation: "After import, re-baseline Goose version table on Postgres; goose_db_version from SQLite is not portable", + match: matchTableName("goose_db_version"), + }, +} + +func matchTableName(name string) func(string) bool { + lower := strings.ToLower(name) + return func(table string) bool { + return strings.ToLower(table) == lower + } +} + +func matchAnyTableName(names ...string) func(string) bool { + set := make(map[string]struct{}, len(names)) + for _, name := range names { + set[strings.ToLower(name)] = struct{}{} + } + return func(table string) bool { + _, ok := set[strings.ToLower(table)] + return ok + } +} + +// IsORMMetadataTable reports whether a table holds ORM/framework migration bookkeeping +// that should not be imported into Postgres. See PRE_DEPLOY.md for pre-ship UX work +// (post-import ORM baseline guidance for users). +func IsORMMetadataTable(name string) bool { + return ORMMetadataRule(name) != nil +} + +// ORMMetadataRule returns the matching ORM metadata rule, if any. +func ORMMetadataRule(name string) *ormMetadataRule { + for i := range ormMetadataRules { + if ormMetadataRules[i].match(name) { + return &ormMetadataRules[i] + } + } + return nil +} + +func lintORMMetadata(table TableSchema) []Issue { + rule := ORMMetadataRule(table.Name) + if rule == nil { + return nil + } + issues := []Issue{{ + Code: rule.code, + Severity: SeverityInfo, + Table: table.Name, + Message: rule.orm + " migration metadata table detected", + Remediation: rule.remediation, + }} + if strings.EqualFold(table.Name, "schema_migrations") && !looksLikeRailsSchemaMigrations(table) { + issues = append(issues, Issue{ + Code: "SCHEMA_MIGRATIONS_NAME_COLLISION", + Severity: SeverityWarning, + Table: table.Name, + Message: "table name matches Rails schema_migrations but column layout does not", + Remediation: "If this is application data, rename the table before import; ORM metadata skip will exclude it from Postgres", + }) + } + return issues +} + +func looksLikeRailsSchemaMigrations(table TableSchema) bool { + if len(table.Columns) != 1 { + return false + } + col := table.Columns[0] + name := strings.ToLower(col.Name) + if name != "version" { + return false + } + t := strings.ToUpper(col.Type) + return strings.Contains(t, "CHAR") || strings.Contains(t, "TEXT") || t == "" +} diff --git a/internal/migrate/d1/orm_metadata_test.go b/internal/migrate/d1/orm_metadata_test.go new file mode 100644 index 000000000..e9496fac7 --- /dev/null +++ b/internal/migrate/d1/orm_metadata_test.go @@ -0,0 +1,62 @@ +package d1 + +import "testing" + +func TestIsORMMetadataTable(t *testing.T) { + tests := []struct { + table string + want bool + code string + }{ + {"__drizzle_migrations", true, "DRIZZLE_MIGRATIONS"}, + {"__drizzle_migrations_journal", true, "DRIZZLE_MIGRATIONS"}, + {"_prisma_migrations", true, "PRISMA_MIGRATIONS"}, + {"knex_migrations", true, "KNEX_MIGRATIONS"}, + {"knex_migrations_lock", true, "KNEX_MIGRATIONS"}, + {"SequelizeMeta", true, "SEQUELIZE_META"}, + {"schema_migrations", true, "RAILS_MIGRATIONS"}, + {"ar_internal_metadata", true, "RAILS_MIGRATIONS"}, + {"flyway_schema_history", true, "FLYWAY_MIGRATIONS"}, + {"databasechangelog", true, "LIQUIBASE_MIGRATIONS"}, + {"django_migrations", true, "DJANGO_MIGRATIONS"}, + {"alembic_version", true, "ALEMBIC_VERSION"}, + {"typeorm_metadata", true, "TYPEORM_METADATA"}, + {"goose_db_version", true, "GOOSE_MIGRATIONS"}, + {"users", false, ""}, + {"migrations", false, ""}, + {"organizations", false, ""}, + } + + for _, tc := range tests { + got := IsORMMetadataTable(tc.table) + if got != tc.want { + t.Fatalf("IsORMMetadataTable(%q) = %v, want %v", tc.table, got, tc.want) + } + if tc.want { + rule := ORMMetadataRule(tc.table) + if rule == nil || rule.code != tc.code { + t.Fatalf("ORMMetadataRule(%q) = %v, want code %q", tc.table, rule, tc.code) + } + } + } +} + +func TestLintORMMetadataTables(t *testing.T) { + result, err := Lint(testFixture(t)) + if err != nil { + t.Fatalf("Lint: %v", err) + } + + found := map[string]bool{} + for _, issue := range result.Issues { + if issue.Code == "DRIZZLE_MIGRATIONS" || issue.Code == "PRISMA_MIGRATIONS" { + found[issue.Code] = true + } + } + if !found["DRIZZLE_MIGRATIONS"] { + t.Fatal("expected DRIZZLE_MIGRATIONS lint issue") + } + if !found["PRISMA_MIGRATIONS"] { + t.Fatal("expected PRISMA_MIGRATIONS lint issue") + } +} diff --git a/internal/migrate/d1/output.go b/internal/migrate/d1/output.go new file mode 100644 index 000000000..54ec13927 --- /dev/null +++ b/internal/migrate/d1/output.go @@ -0,0 +1,144 @@ +package d1 + +import ( + "fmt" + "time" + + "github.com/planetscale/cli/internal/printer" +) + +// PrintHumanResponse writes a human-readable success response via the shared printer. +func PrintHumanResponse(p *printer.Printer, resp Response) { + p.Printf("Status: %s", resp.Status) + if resp.Phase != "" { + p.Printf(" (%s)", resp.Phase) + } + p.Println() + + if resp.MigrationID != "" { + p.Printf("Migration ID: %s\n", resp.MigrationID) + } + + printHumanData(p, resp.Phase, resp.Data) + + if len(resp.Issues) > 0 { + p.Printf("\nIssues (%d):\n", len(resp.Issues)) + for _, issue := range resp.Issues { + loc := issue.Table + if issue.Column != "" { + loc += "." + issue.Column + } + p.Printf(" [%s] %s %s: %s\n", issue.Severity, issue.Code, loc, issue.Remediation) + } + } + + if len(resp.NextSteps) > 0 { + p.Println("\nNext steps:") + for _, step := range resp.NextSteps { + if step.Command != "" { + p.Printf(" - %s (%s)\n", step.Command, step.Reason) + } else { + p.Printf(" - %s (%s)\n", step.Tool, step.Reason) + } + } + } +} + +func printHumanData(p *printer.Printer, phase string, data any) { + if data == nil { + return + } + + switch phase { + case "doctor": + if r, ok := data.(DoctorResult); ok { + p.Println("\nChecks:") + for _, c := range r.Checks { + line := fmt.Sprintf(" %s: %s", c.Name, c.Status) + if c.Version != "" { + line += fmt.Sprintf(" (%s)", c.Version) + } + p.Println(line) + } + p.Printf("Ready: %v\n", r.Ready) + } + case "export": + if r, ok := data.(ExportResult); ok { + p.Printf("\nExported to %s (%d bytes)\n", r.OutputPath, r.SizeBytes) + } + case "lint": + if r, ok := data.(LintResult); ok { + p.Printf("\nTables: %d | Errors: %d | Warnings: %d\n", r.TableCount, r.ErrorCount, r.WarningCount) + } + case "start": + if r, ok := data.(ImportResult); ok { + p.Printf("\nMethod: %s", r.Method) + if r.DryRun { + p.Print(" (dry run)") + } + p.Println() + if r.Plan != nil { + sizeMB := float64(r.Plan.EstimatedSizeBytes) / (1024 * 1024) + p.Printf("Plan: %d tables, %.1f MB estimated\n", len(r.Plan.Tables), sizeMB) + } + if r.TablesLoaded > 0 { + p.Printf("Tables loaded: %d\n", r.TablesLoaded) + } + if r.Timings != nil && r.Timings.TotalMs > 0 { + p.Printf("Total time: %.1fs\n", float64(r.Timings.TotalMs)/1000) + } + } + case "verify": + if r, ok := data.(VerifyResult); ok { + matched := "no" + if r.Matched { + matched = "yes" + } + p.Printf("\nMatched: %s\n", matched) + } + case "status": + if r, ok := data.(MigrationState); ok { + p.Printf("\nPhase: %s | Updated: %s\n", r.Phase, r.UpdatedAt.Format(time.RFC3339)) + } + case "convert-schema": + if m, ok := data.(map[string]any); ok { + p.Println() + p.Printf(" Input: %v\n", m["input"]) + p.Printf(" Output: %v\n", m["output"]) + p.Printf(" Tables: %v\n", m["table_count"]) + } + case "complete": + if m, ok := data.(map[string]string); ok { + p.Println() + p.Printf(" Migration ID: %s\n", m["migration_id"]) + p.Printf(" Status: %s\n", m["status"]) + } + } +} + +// OKResponse builds a success response. +func OKResponse(phase string, data any, next []NextStep) Response { + return Response{ + Status: "ok", + Phase: phase, + Data: data, + NextSteps: next, + } +} + +// ErrorResponse builds an error response from an error. +func ErrorResponse(phase string, err error) Response { + resp := Response{ + Status: "error", + Phase: phase, + } + if me, ok := migrationErr(err); ok { + resp.Error = &me.Info + } else { + resp.Error = &ErrorInfo{ + Code: ErrCodeImportFailed, + Message: err.Error(), + } + } + return resp +} diff --git a/internal/migrate/d1/parse.go b/internal/migrate/d1/parse.go new file mode 100644 index 000000000..3dce7d6ac --- /dev/null +++ b/internal/migrate/d1/parse.go @@ -0,0 +1,354 @@ +package d1 + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" +) + +var ( + createTableRe = regexp.MustCompile(`(?is)^CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(`) + virtualTableRe = regexp.MustCompile(`(?is)^CREATE\s+VIRTUAL\s+TABLE`) + autoincrementRe = regexp.MustCompile(`(?i)AUTOINCREMENT`) + insertRe = regexp.MustCompile(`(?is)^INSERT\s+INTO\s+(?:` + "`" + `([^` + "`" + `]+)` + "`" + `|"([^"]+)"|'([^']+)'|([a-zA-Z_][\w]*))`) + valueTupleSepRe = regexp.MustCompile(`\)\s*,\s*\(`) +) + +// TableSchema holds parsed SQLite table metadata from a dump file. +type TableSchema struct { + Name string + Columns []ColumnSchema + Constraints []string + RawDDL string +} + +// ColumnSchema holds parsed column metadata. +type ColumnSchema struct { + Name string + Type string + PrimaryKey bool + AutoIncrement bool + NotNull bool + Unique bool + DefaultValue string + ForeignKey string +} + +// ParseDump reads a SQLite SQL dump and extracts table definitions. +func ParseDump(path string) ([]TableSchema, error) { + clean, err := ValidateInputPath(path) + if err != nil { + return nil, err + } + f, err := os.Open(clean) + if err != nil { + return nil, err + } + defer f.Close() + + var tables []TableSchema + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) + + var current *TableSchema + var ddlLines []string + parenDepth := 0 + + flush := func() { + if current == nil { + return + } + current.RawDDL = strings.Join(ddlLines, "\n") + current.Columns, current.Constraints = parseTableBody(current.RawDDL) + tables = append(tables, *current) + current = nil + ddlLines = nil + parenDepth = 0 + } + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "--") { + continue + } + + if virtualTableRe.MatchString(line) { + return nil, newMigrationError( + ErrCodeVirtualTable, + "dump contains CREATE VIRTUAL TABLE statements", + "Remove or recreate FTS5/virtual tables manually in Postgres after migration", + ) + } + + if current == nil { + m := createTableRe.FindStringSubmatch(line) + if m == nil { + continue + } + name := firstNonEmpty(m[1], m[2], m[3], m[4]) + current = &TableSchema{Name: name} + ddlLines = append(ddlLines, line) + parenDepth += strings.Count(line, "(") - strings.Count(line, ")") + if parenDepth <= 0 && strings.HasSuffix(line, ";") { + flush() + } + continue + } + + ddlLines = append(ddlLines, line) + parenDepth += strings.Count(line, "(") - strings.Count(line, ")") + if parenDepth <= 0 && strings.HasSuffix(line, ";") { + flush() + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read dump: %w", err) + } + flush() + + if len(tables) == 0 { + return nil, newMigrationError( + ErrCodeInvalidInput, + "no CREATE TABLE statements found in dump", + "Ensure the input is a wrangler d1 export SQL file with schema definitions", + ) + } + + return tables, nil +} + +func parseTableBody(ddl string) ([]ColumnSchema, []string) { + start := strings.Index(ddl, "(") + end := strings.LastIndex(ddl, ")") + if start < 0 || end <= start { + return nil, nil + } + body := ddl[start+1 : end] + parts := splitColumnDefs(body) + cols := make([]ColumnSchema, 0, len(parts)) + var constraints []string + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if isTableConstraint(part) { + constraints = append(constraints, part) + continue + } + col := parseColumn(part) + if col.Name != "" { + cols = append(cols, col) + } + } + return cols, constraints +} + +func parseColumn(def string) ColumnSchema { + def = strings.TrimSpace(def) + if def == "" { + return ColumnSchema{} + } + + // Strip trailing comma + def = strings.TrimSuffix(def, ",") + + tokens := strings.Fields(def) + if len(tokens) == 0 { + return ColumnSchema{} + } + + name := strings.Trim(tokens[0], "`\"'") + colType := "" + if len(tokens) > 1 { + colType = strings.ToUpper(tokens[1]) + } + + col := ColumnSchema{ + Name: name, + Type: colType, + } + + upper := strings.ToUpper(def) + if strings.Contains(upper, "NOT NULL") { + col.NotNull = true + } + if strings.Contains(upper, "PRIMARY KEY") { + col.PrimaryKey = true + } + if strings.Contains(upper, "UNIQUE") && !strings.HasPrefix(upper, "UNIQUE(") && !strings.HasPrefix(upper, "UNIQUE (") { + col.Unique = true + } + if autoincrementRe.MatchString(def) { + col.AutoIncrement = true + } + if idx := strings.Index(strings.ToUpper(def), "DEFAULT"); idx >= 0 { + col.DefaultValue = strings.TrimSpace(def[idx+7:]) + col.DefaultValue = strings.TrimSuffix(col.DefaultValue, ",") + if refIdx := indexOfIgnoreCase(col.DefaultValue, "REFERENCES"); refIdx >= 0 { + col.DefaultValue = strings.TrimSpace(col.DefaultValue[:refIdx]) + col.DefaultValue = strings.TrimSuffix(col.DefaultValue, ",") + } + } + if strings.Contains(upper, "REFERENCES") { + col.ForeignKey = referencesClause(def) + } + + return col +} + +// ParseIndexes extracts CREATE INDEX statements from a SQLite dump. +func ParseIndexes(path string) ([]IndexSchema, error) { + clean, err := ValidateInputPath(path) + if err != nil { + return nil, err + } + f, err := os.Open(clean) + if err != nil { + return nil, err + } + defer f.Close() + + var indexes []IndexSchema + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "--") { + continue + } + if !strings.HasPrefix(strings.ToUpper(line), "CREATE") { + continue + } + upper := strings.ToUpper(line) + if !strings.Contains(upper, " INDEX ") { + continue + } + m := createIndexRe.FindStringSubmatch(line) + if m == nil { + continue + } + indexes = append(indexes, IndexSchema{ + Name: firstNonEmpty(m[2], m[3], m[4], m[5]), + Table: firstNonEmpty(m[6], m[7], m[8], m[9]), + Unique: strings.TrimSpace(m[1]) != "", + Columns: m[10], + RawDDL: line, + }) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read dump indexes: %w", err) + } + return indexes, nil +} + +func splitColumnDefs(body string) []string { + var parts []string + var current strings.Builder + depth := 0 + for _, r := range body { + switch r { + case '(': + depth++ + current.WriteRune(r) + case ')': + depth-- + current.WriteRune(r) + case ',': + if depth == 0 { + parts = append(parts, current.String()) + current.Reset() + continue + } + current.WriteRune(r) + default: + current.WriteRune(r) + } + } + if current.Len() > 0 { + parts = append(parts, current.String()) + } + return parts +} + +// CountInsertRows estimates row counts per table from INSERT statements. +func CountInsertRows(path string) (map[string]int, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + counts := make(map[string]int) + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) + + var pendingTable string + var pendingSQL strings.Builder + + flush := func() { + if pendingTable == "" { + return + } + sql := pendingSQL.String() + rows := len(valueTupleSepRe.FindAllString(sql, -1)) + 1 + counts[pendingTable] += rows + pendingTable = "" + pendingSQL.Reset() + } + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "--") { + continue + } + + m := insertRe.FindStringSubmatch(line) + if m != nil { + flush() + pendingTable = firstNonEmpty(m[1], m[2], m[3], m[4]) + pendingSQL.WriteString(line) + if strings.HasSuffix(line, ";") { + flush() + } + continue + } + + if pendingTable != "" { + pendingSQL.WriteString(" ") + pendingSQL.WriteString(line) + if strings.HasSuffix(line, ";") { + flush() + } + } + } + flush() + + if err := scanner.Err(); err != nil { + return nil, err + } + return counts, nil +} + +// FileSize returns the size of a file in bytes. +func FileSize(path string) (int64, error) { + info, err := os.Stat(path) + if err != nil { + return 0, err + } + return info.Size(), nil +} + +func firstNonEmpty(vals ...string) string { + for _, v := range vals { + if v != "" { + return v + } + } + return "" +} diff --git a/internal/migrate/d1/path.go b/internal/migrate/d1/path.go new file mode 100644 index 000000000..f8ce06952 --- /dev/null +++ b/internal/migrate/d1/path.go @@ -0,0 +1,31 @@ +package d1 + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// ValidateInputPath ensures a user-supplied path is safe to read. +func ValidateInputPath(path string) (string, error) { + if path == "" { + return "", newMigrationError(ErrCodeMissingInput, "input path is required", "Pass --input with a D1 SQL export file") + } + if strings.ContainsAny(path, "\x00\n\r;") { + return "", newMigrationError(ErrCodeInvalidInput, "invalid characters in input path", "Use a simple file path without newlines or semicolons") + } + + clean := filepath.Clean(path) + info, err := os.Stat(clean) + if err != nil { + if os.IsNotExist(err) { + return "", errMissingInput(clean) + } + return "", fmt.Errorf("stat input: %w", err) + } + if info.IsDir() { + return "", newMigrationError(ErrCodeInvalidInput, "input path is a directory", "Pass a .sql export file path") + } + return clean, nil +} diff --git a/internal/migrate/d1/pgloader.go b/internal/migrate/d1/pgloader.go new file mode 100644 index 000000000..7f1d2f74e --- /dev/null +++ b/internal/migrate/d1/pgloader.go @@ -0,0 +1,296 @@ +package d1 + +import ( + "context" + _ "embed" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/planetscale/cli/internal/postgres" + execabs "golang.org/x/sys/execabs" +) + +//go:embed pgloader_transforms.lisp +var pgloaderTransformsLisp string + +const ( + pgloaderBatchSize = "20 MB" + pgloaderDynamicSpace = "4096" // MB per pgloader process (SBCL heap cap) + + pgloaderLargeTableRowThreshold = 100_000 + + // Fast profile: small/medium tables after indexes are deferred. + pgloaderFastPrefetchRows = 25000 + pgloaderFastBatchRows = 25000 + pgloaderFastWorkers = 8 + pgloaderFastConcurrency = 2 + + // Conservative profile: wide rows / large tables (e.g. attachments). + pgloaderSlowPrefetchRows = 5000 + pgloaderSlowBatchRows = 10000 + pgloaderSlowWorkers = 2 + pgloaderSlowConcurrency = 1 + + pgloaderLoadWorkMem = "256MB" + pgloaderLoadMaintenanceWorkMem = "512MB" + pgloaderIndexMaintenanceWorkMem = "2GB" +) + +// PgloaderOptions configures pgloader execution. +type PgloaderOptions struct { + SQLitePath string + DestURI string + InputPath string // dump path for column-level CAST rules + WorkDir string + DryRun bool + DataOnly bool + // Tables loads one table per pgloader invocation when set (recommended for + // large databases — avoids SBCL heap exhaustion from whole-catalog planning). + Tables []string +} + +type pgloaderMemoryProfile struct { + prefetchRows int + batchRows int + workers int + concurrency int +} + +func pgloaderProfileForTable(rowCount int) pgloaderMemoryProfile { + if rowCount >= pgloaderLargeTableRowThreshold { + return pgloaderMemoryProfile{ + prefetchRows: pgloaderSlowPrefetchRows, + batchRows: pgloaderSlowBatchRows, + workers: pgloaderSlowWorkers, + concurrency: pgloaderSlowConcurrency, + } + } + return pgloaderMemoryProfile{ + prefetchRows: pgloaderFastPrefetchRows, + batchRows: pgloaderFastBatchRows, + workers: pgloaderFastWorkers, + concurrency: pgloaderFastConcurrency, + } +} + +// PgloaderLoadTables returns non-ORM tables in FK-safe load order. +func PgloaderLoadTables(inputPath string) ([]string, error) { + tables, err := ParseDump(inputPath) + if err != nil { + return nil, err + } + ordered := topologicalLoadOrder(tables) + out := make([]string, 0, len(ordered)) + for _, name := range ordered { + if !IsORMMetadataTable(name) { + out = append(out, name) + } + } + return out, nil +} + +// RunPgloader loads SQLite into PostgreSQL using pgloader. +func RunPgloader(ctx context.Context, opts PgloaderOptions) (ImportTimings, error) { + var timings ImportTimings + pgloader, err := FindPgloader() + if err != nil { + return timings, err + } + + if opts.WorkDir == "" { + opts.WorkDir, err = os.MkdirTemp("", "pscale-d1-pgloader-*") + if err != nil { + return timings, err + } + defer os.RemoveAll(opts.WorkDir) + } + + tables := opts.Tables + tableSchemas, err := ParseDump(opts.InputPath) + if err != nil { + return timings, err + } + tableByName := make(map[string]TableSchema, len(tableSchemas)) + for _, t := range tableSchemas { + tableByName[t.Name] = t + } + + rowCounts, _ := CountInsertRows(opts.InputPath) + + if len(tables) == 0 { + pgStart := time.Now() + if err := runPgloaderScript(ctx, pgloader, opts, pgloaderScriptConfig{ + dataOnly: opts.DataOnly, + resetSequences: true, + profile: pgloaderProfileForTable(0), + }, TableSchema{}, tableSchemas); err != nil { + return timings, err + } + timings.PgloaderMs = time.Since(pgStart).Milliseconds() + return timings, nil + } + + pgStart := time.Now() + for i, name := range tables { + table, ok := tableByName[name] + if !ok { + return timings, fmt.Errorf("pgloader table %s: not found in dump schema", name) + } + profile := pgloaderProfileForTable(rowCounts[name]) + fmt.Fprintf(os.Stderr, "pgloader: loading table %d/%d %s (workers=%d prefetch=%d)\n", + i+1, len(tables), name, profile.workers, profile.prefetchRows) + tableStart := time.Now() + if err := runPgloaderScript(ctx, pgloader, opts, pgloaderScriptConfig{ + dataOnly: opts.DataOnly, + tableName: name, + resetSequences: true, + profile: profile, + }, table, tableSchemas); err != nil { + return timings, fmt.Errorf("pgloader table %s: %w", name, err) + } + timings.TableLoads = append(timings.TableLoads, TableLoadTiming{ + Table: name, + Ms: time.Since(tableStart).Milliseconds(), + }) + } + timings.PgloaderMs = time.Since(pgStart).Milliseconds() + return timings, nil +} + +type pgloaderScriptConfig struct { + dataOnly bool + tableName string + resetSequences bool + profile pgloaderMemoryProfile +} + +func runPgloaderScript(ctx context.Context, pgloader string, opts PgloaderOptions, cfg pgloaderScriptConfig, table TableSchema, allTables []TableSchema) error { + loadFile := filepath.Join(opts.WorkDir, "load.load") + if cfg.tableName != "" { + loadFile = filepath.Join(opts.WorkDir, "load-"+cfg.tableName+".load") + } + castTables := allTables + if table.Name != "" { + castTables = []TableSchema{table} + } + content := buildPgloaderScript(opts.SQLitePath, opts.DestURI, cfg, castTables, allTables) + if err := os.WriteFile(loadFile, []byte(content), 0o600); err != nil { + return err + } + + if opts.DryRun { + return nil + } + + transformsFile := filepath.Join(opts.WorkDir, "transforms.lisp") + if err := os.WriteFile(transformsFile, []byte(pgloaderTransformsLisp), 0o600); err != nil { + return err + } + + cmd := execabs.CommandContext(ctx, pgloader, "--load-lisp-file", transformsFile, loadFile) + cmd.Env = append(os.Environ(), + "SBCL_OPTIONS=--dynamic-space-size "+pgloaderDynamicSpace, + ) + out, err := cmd.CombinedOutput() + output := string(out) + if err != nil { + return fmt.Errorf("pgloader failed: %w: %s", err, output) + } + if strings.Contains(output, "FATAL") || strings.Contains(output, "KABOOM") || + strings.Contains(output, "ERROR Error while formatting") || + strings.Contains(output, "ERROR The value") || + strings.Contains(output, "Heap exhausted") || + pgloaderHadErrors(output) { + return fmt.Errorf("pgloader failed: %s", output) + } + fmt.Fprint(os.Stderr, output) + return nil +} + +func buildPgloaderScript(sqlitePath, destURI string, cfg pgloaderScriptConfig, castTables, allTables []TableSchema) string { + absSQLite, _ := filepath.Abs(sqlitePath) + src := "sqlite:///" + strings.ReplaceAll(absSQLite, " ", "%20") + target := destURI + if parsed, err := postgres.ParseConnectionURI(destURI); err == nil { + target = postgres.BuildConnectionURI(parsed) + } + + profile := cfg.profile + if profile.workers == 0 { + profile = pgloaderProfileForTable(0) + } + + var b strings.Builder + b.WriteString("LOAD DATABASE\n") + b.WriteString(" FROM " + src + "\n") + b.WriteString(" INTO " + target + "\n") + b.WriteString("\n") + + if cfg.dataOnly { + b.WriteString(" WITH data only, create no tables, create no indexes, truncate, disable triggers,\n") + if cfg.resetSequences { + b.WriteString(" reset sequences,\n") + } else { + b.WriteString(" reset no sequences,\n") + } + b.WriteString(fmt.Sprintf(" workers = %d, concurrency = %d,\n", profile.workers, profile.concurrency)) + b.WriteString(fmt.Sprintf(" batch rows = %d,\n", profile.batchRows)) + b.WriteString(" batch size = " + pgloaderBatchSize + ",\n") + b.WriteString(fmt.Sprintf(" prefetch rows = %d\n", profile.prefetchRows)) + } else { + b.WriteString(" WITH include drop, create tables, create indexes, reset sequences,\n") + b.WriteString(fmt.Sprintf(" workers = %d, concurrency = %d,\n", profile.workers, profile.concurrency)) + b.WriteString(fmt.Sprintf(" batch rows = %d,\n", profile.batchRows)) + b.WriteString(" batch size = " + pgloaderBatchSize + ",\n") + b.WriteString(fmt.Sprintf(" prefetch rows = %d\n", profile.prefetchRows)) + } + + if cfg.tableName != "" { + b.WriteString("\n") + b.WriteString(" INCLUDING ONLY TABLE NAMES LIKE " + pgloaderQuotePattern(cfg.tableName) + "\n") + } + + appendPgloaderCasts(&b, castTables, allTables) + + b.WriteString("\n") + b.WriteString(fmt.Sprintf(" SET work_mem to '%s', maintenance_work_mem to '%s', synchronous_commit to 'off';\n", + pgloaderLoadWorkMem, pgloaderLoadMaintenanceWorkMem)) + return b.String() +} + +func appendPgloaderCasts(b *strings.Builder, castTables, allTables []TableSchema) { + var rules []string + for _, table := range castTables { + for _, col := range table.Columns { + pgType := sqliteTypeToPostgres(col, table, allTables) + ref := fmt.Sprintf("column %s.%s", table.Name, col.Name) + switch pgType { + case "BOOLEAN": + rules = append(rules, ref+" to boolean using sqlite-int-to-boolean") + case "TIMESTAMPTZ": + rules = append(rules, ref+" to timestamptz using sqlite-timestamp-to-timestamp") + case "JSONB": + rules = append(rules, ref+" to jsonb using sqlite-text-to-jsonb") + } + } + } + if len(rules) == 0 { + return + } + b.WriteString("\n CAST ") + for i, rule := range rules { + if i > 0 { + b.WriteString(",\n ") + } else { + b.WriteString("\n ") + } + b.WriteString(rule) + } +} + +func pgloaderQuotePattern(name string) string { + return "'" + strings.ReplaceAll(name, "'", "''") + "'" +} diff --git a/internal/migrate/d1/pgloader_errors.go b/internal/migrate/d1/pgloader_errors.go new file mode 100644 index 000000000..6fc70aca0 --- /dev/null +++ b/internal/migrate/d1/pgloader_errors.go @@ -0,0 +1,26 @@ +package d1 + +import ( + "regexp" + "strings" +) + +var pgloaderSummaryErrorRe = regexp.MustCompile(`(?m)^\|\s+(\d+)\s+\|`) + +// pgloaderHadErrors inspects pgloader output for failures that do not set exit code. +func pgloaderHadErrors(output string) bool { + if strings.Contains(output, "Database error") || + strings.Contains(output, "INSUFFICIENT-PRIVILEGE") || + strings.Contains(output, "must be owner of table") { + return true + } + for _, match := range pgloaderSummaryErrorRe.FindAllStringSubmatch(output, -1) { + if len(match) < 2 { + continue + } + if match[1] != "0" { + return true + } + } + return false +} diff --git a/internal/migrate/d1/pgloader_errors_test.go b/internal/migrate/d1/pgloader_errors_test.go new file mode 100644 index 000000000..106a268a8 --- /dev/null +++ b/internal/migrate/d1/pgloader_errors_test.go @@ -0,0 +1,46 @@ +package d1 + +import "testing" + +func TestPgloaderHadErrors(t *testing.T) { + tests := []struct { + name string + output string + want bool + }{ + { + name: "clean summary", + output: ` +| errors | rows | bytes | total time +| 0 | 100 | 1 kB | 1.000 s +`, + want: false, + }, + { + name: "summary with errors", + output: ` +| errors | rows | bytes | total time +| 3 | 97 | 1 kB | 1.000 s +`, + want: true, + }, + { + name: "database error", + output: "Database error 42501: must be owner of table users", + want: true, + }, + { + name: "insufficient privilege", + output: "INSUFFICIENT-PRIVILEGE disable triggers", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := pgloaderHadErrors(tt.output); got != tt.want { + t.Fatalf("pgloaderHadErrors() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/migrate/d1/pgloader_test.go b/internal/migrate/d1/pgloader_test.go new file mode 100644 index 000000000..809690788 --- /dev/null +++ b/internal/migrate/d1/pgloader_test.go @@ -0,0 +1,137 @@ +package d1 + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestBuildPgloaderScriptDataOnlyPerTable(t *testing.T) { + table := TableSchema{ + Name: "organizations", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER", PrimaryKey: true, AutoIncrement: true}, + {Name: "slug", Type: "TEXT", NotNull: true}, + {Name: "is_active", Type: "INTEGER", NotNull: true}, + {Name: "created_at", Type: "TEXT", NotNull: true}, + }, + } + script := buildPgloaderScript("/tmp/test.sqlite", "postgresql://u:p@host/db", pgloaderScriptConfig{ + dataOnly: true, + tableName: "organizations", + resetSequences: false, + profile: pgloaderProfileForTable(0), + }, []TableSchema{table}, []TableSchema{table}) + + checks := []string{ + "WITH data only, create no tables, create no indexes, truncate, disable triggers,", + "reset no sequences,", + "workers = 8, concurrency = 2,", + "batch rows = 25000,", + "batch size = 20 MB,", + "prefetch rows = 25000", + "INCLUDING ONLY TABLE NAMES LIKE 'organizations'", + "column organizations.is_active to boolean using sqlite-int-to-boolean", + "column organizations.created_at to timestamptz using sqlite-timestamp-to-timestamp", + "SET work_mem to '256MB'", + "synchronous_commit to 'off'", + } + for _, want := range checks { + if !strings.Contains(script, want) { + t.Fatalf("script missing %q\n%s", want, script) + } + } + for _, bad := range []string{ + "column organizations.id to boolean", + "column organizations.slug to timestamptz", + "type integer to boolean", + "type text to timestamptz", + } { + if strings.Contains(script, bad) { + t.Fatalf("script should not contain %q\n%s", bad, script) + } + } +} + +func TestBuildPgloaderScriptLargeTableProfile(t *testing.T) { + script := buildPgloaderScript("/tmp/test.sqlite", "postgresql://u:p@host/db", pgloaderScriptConfig{ + dataOnly: true, + tableName: "attachments", + resetSequences: true, + profile: pgloaderProfileForTable(pgloaderLargeTableRowThreshold), + }, nil, nil) + + for _, want := range []string{ + "workers = 2, concurrency = 1,", + "batch rows = 10000,", + "prefetch rows = 5000", + } { + if !strings.Contains(script, want) { + t.Fatalf("script missing %q\n%s", want, script) + } + } +} + +func TestBuildPgloaderScriptFullLoadResetsSequences(t *testing.T) { + script := buildPgloaderScript("/tmp/test.sqlite", "postgresql://u:p@host/db", pgloaderScriptConfig{ + dataOnly: true, + resetSequences: true, + profile: pgloaderProfileForTable(0), + }, nil, nil) + if !strings.Contains(script, "reset sequences,") { + t.Fatalf("expected reset sequences in final table script:\n%s", script) + } + if strings.Contains(script, "INCLUDING ONLY") { + t.Fatalf("did not expect table filter for full load:\n%s", script) + } +} + +func TestPgloaderLoadTablesSkipsORMMetadata(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "dump.sql") + if err := os.WriteFile(path, []byte(` +CREATE TABLE organizations (id INTEGER PRIMARY KEY); +CREATE TABLE __drizzle_migrations (id INTEGER PRIMARY KEY); +CREATE TABLE users (id INTEGER PRIMARY KEY, org_id INTEGER); +`), 0o600); err != nil { + t.Fatal(err) + } + + tables, err := PgloaderLoadTables(path) + if err != nil { + t.Fatalf("PgloaderLoadTables: %v", err) + } + if len(tables) != 2 { + t.Fatalf("tables = %v, want [organizations users]", tables) + } + if tables[0] != "organizations" || tables[1] != "users" { + t.Fatalf("load order = %v", tables) + } +} + +func TestPgloaderQuotePatternEscapesQuotes(t *testing.T) { + got := pgloaderQuotePattern("foo'bar") + if got != "'foo''bar'" { + t.Fatalf("pgloaderQuotePattern() = %q", got) + } +} + +func TestConvertSchemaPartsSplitsIndexes(t *testing.T) { + parts, count, err := ConvertSchemaParts(testFixture(t)) + if err != nil { + t.Fatalf("ConvertSchemaParts: %v", err) + } + if count != 4 { + t.Fatalf("expected 4 tables, got %d", count) + } + if !strings.Contains(parts.Tables, `CREATE TABLE IF NOT EXISTS "users"`) { + t.Fatalf("expected users table DDL") + } + if strings.Contains(parts.Tables, "CREATE INDEX") { + t.Fatalf("tables section should not contain indexes") + } + if !strings.Contains(parts.Indexes, `CREATE INDEX IF NOT EXISTS "idx_users_email"`) { + t.Fatalf("expected index DDL in indexes section:\n%s", parts.Indexes) + } +} diff --git a/internal/migrate/d1/pgloader_transforms.lisp b/internal/migrate/d1/pgloader_transforms.lisp new file mode 100644 index 000000000..80c85c620 --- /dev/null +++ b/internal/migrate/d1/pgloader_transforms.lisp @@ -0,0 +1,16 @@ +(in-package #:pgloader.transforms) + +(defun sqlite-int-to-boolean (val) + "SQLite stores booleans as INTEGER 0/1; PostgreSQL COPY expects boolean." + (cond + ((null val) :null) + ((and (integerp val) (zerop val)) "false") + ((and (stringp val) (string= val "0")) "false") + (t "true"))) + +(defun sqlite-text-to-jsonb (val) + "SQLite JSON lives in TEXT; pass valid JSON through to PostgreSQL JSONB." + (cond + ((null val) :null) + ((stringp val) val) + (t (format nil "~a" val)))) diff --git a/internal/migrate/d1/plan.go b/internal/migrate/d1/plan.go new file mode 100644 index 000000000..3e1c606bd --- /dev/null +++ b/internal/migrate/d1/plan.go @@ -0,0 +1,236 @@ +package d1 + +import ( + "fmt" + "regexp" + "sort" + "strings" + + gonanoid "github.com/matoous/go-nanoid/v2" +) + +const ( + MethodPgloader = "pgloader" + MethodPsql = "psql" // schema via psql; data via pgloader (dumps under 1GB) +) + +// PlanOptions configures migration planning. +type PlanOptions struct { + InputPath string + Org string + Database string + Branch string + Method string + MigrationID string // optional: reuse an existing migration ID from plan/start + Lint *LintResult // optional: skip re-lint when already computed +} + +// Plan builds a migration plan from a SQLite dump. +func Plan(opts PlanOptions) (*PlanResult, error) { + tables, err := ParseDump(opts.InputPath) + if err != nil { + return nil, err + } + + lintResult := opts.Lint + if lintResult == nil { + lintResult, err = Lint(opts.InputPath) + if err != nil { + return nil, err + } + } + + rowCounts, err := CountInsertRows(opts.InputPath) + if err != nil { + return nil, err + } + size, err := FileSize(opts.InputPath) + if err != nil { + return nil, err + } + + method := opts.Method + if method == "" { + method = recommendMethod(size) + } + + plan := &PlanResult{ + MigrationID: opts.planMigrationID(), + InputPath: opts.InputPath, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + RecommendedMethod: method, + EstimatedSizeBytes: size, + Tables: make([]TablePlan, 0, len(tables)), + CastRules: defaultCastRules(), + LoadOrder: topologicalLoadOrder(tables), + Issues: lintResult.Issues, + } + + for _, table := range tables { + tp := TablePlan{ + Name: table.Name, + RowEstimate: rowCounts[table.Name], + } + for _, col := range table.Columns { + if col.ForeignKey != "" { + tp.HasFK = true + break + } + } + if !tp.HasFK { + for _, ref := range parseTableFKReferences(table.RawDDL) { + if ref != "" { + tp.HasFK = true + break + } + } + } + plan.Tables = append(plan.Tables, tp) + } + + return plan, nil +} + +func (opts PlanOptions) planMigrationID() string { + if opts.MigrationID != "" { + return opts.MigrationID + } + return gonanoid.MustGenerate("0123456789abcdefghijklmnopqrstuvwxyz", 12) +} + +func recommendMethod(sizeBytes int64) string { + const oneGB = 1024 * 1024 * 1024 + if sizeBytes > 0 && sizeBytes < oneGB { + return MethodPsql + } + return MethodPgloader +} + +func defaultCastRules() []CastRule { + return []CastRule{ + {SourceType: "integer", TargetType: "boolean", Using: "(= 1)", Tables: "match-columns-like '%active%'"}, + {SourceType: "text", TargetType: "timestamptz", Using: "sqlite-timestamp-to-timestamp"}, + {SourceType: "text", TargetType: "jsonb", Using: "sqlite-text-to-jsonb"}, + } +} + +func topologicalLoadOrder(tables []TableSchema) []string { + names := make([]string, 0, len(tables)) + deps := make(map[string][]string) + nameSet := make(map[string]bool) + + for _, t := range tables { + names = append(names, t.Name) + nameSet[t.Name] = true + for _, col := range t.Columns { + if ref := parseFKReference(col.ForeignKey); ref != "" && nameSet[ref] { + deps[t.Name] = append(deps[t.Name], ref) + } + } + for _, ref := range parseTableFKReferences(t.RawDDL) { + if nameSet[ref] { + deps[t.Name] = append(deps[t.Name], ref) + } + } + } + + sort.Strings(names) + + visited := make(map[string]bool) + var order []string + + var visit func(string) + visit = func(name string) { + if visited[name] { + return + } + visited[name] = true + for _, dep := range deps[name] { + if dep != "" { + visit(dep) + } + } + order = append(order, name) + } + + for _, name := range names { + visit(name) + } + + return order +} + +func parseFKReference(fk string) string { + if fk == "" { + return "" + } + idx := indexOfIgnoreCase(fk, "REFERENCES") + if idx < 0 { + return "" + } + rest := strings.TrimSpace(fk[idx+len("REFERENCES"):]) + parts := strings.Fields(rest) + if len(parts) == 0 { + return "" + } + return strings.Trim(parts[0], "`\"'") +} + +var tableFKRe = regexp.MustCompile(`(?i)FOREIGN\s+KEY[^)]*\)\s*REFERENCES\s+(?:` + "`" + `([^` + "`" + `]+)` + "`" + `|"([^"]+)"|'([^']+)'|([a-zA-Z_][\w]*))`) + +func parseTableFKReferences(ddl string) []string { + matches := tableFKRe.FindAllStringSubmatch(ddl, -1) + var refs []string + for _, m := range matches { + ref := firstNonEmpty(m[1], m[2], m[3], m[4]) + if ref != "" { + refs = append(refs, ref) + } + } + return refs +} + +func indexOfIgnoreCase(s, sub string) int { + return strings.Index(strings.ToUpper(s), strings.ToUpper(sub)) +} + +// SavePlan persists plan state for later import/verify. +func SavePlan(plan *PlanResult) error { + state := &MigrationState{ + MigrationID: plan.MigrationID, + Org: plan.Org, + Database: plan.Database, + Branch: plan.Branch, + InputPath: plan.InputPath, + Method: plan.RecommendedMethod, + Phase: PhasePlanned, + } + return SaveState(state) +} + +// StartNextSteps returns agent next steps after start or start --dry-run. +func StartNextSteps(migrationID, database, method string, dryRun bool) []NextStep { + if dryRun { + cmd := fmt.Sprintf("pscale import d1 start --migration-id %s --database %s", migrationID, database) + if method != "" { + cmd += fmt.Sprintf(" --method %s", method) + } + cmd += " --force" + return []NextStep{ + { + Tool: "import_d1_start", + Command: cmd, + Reason: "Run the import after preview", + }, + } + } + return []NextStep{ + { + Tool: "import_d1_verify", + Command: fmt.Sprintf("pscale import d1 verify --migration-id %s --database %s", migrationID, database), + Reason: "Verify row counts, sequences, and content after import", + }, + } +} diff --git a/internal/migrate/d1/postgres.go b/internal/migrate/d1/postgres.go new file mode 100644 index 000000000..6edc08bbf --- /dev/null +++ b/internal/migrate/d1/postgres.go @@ -0,0 +1,12 @@ +package d1 + +import ( + "database/sql" + + "github.com/planetscale/cli/internal/postgres" +) + +// OpenPostgres opens a PostgreSQL connection. +func OpenPostgres(uri string) (*sql.DB, error) { + return postgres.OpenConnection(uri) +} diff --git a/internal/migrate/d1/prepare.go b/internal/migrate/d1/prepare.go new file mode 100644 index 000000000..691e49ee5 --- /dev/null +++ b/internal/migrate/d1/prepare.go @@ -0,0 +1,203 @@ +package d1 + +import ( + "fmt" + + "github.com/planetscale/cli/internal/printer" +) + +// ImportPrepareResult is lint + plan output used before and during import. +type ImportPrepareResult struct { + MigrationID string `json:"migration_id"` + Method string `json:"method"` + Lint *LintResult `json:"lint"` + Plan *PlanResult `json:"plan"` + CanProceed bool `json:"can_proceed"` + BlockedReason string `json:"blocked_reason,omitempty"` +} + +// PrepareImport runs lint and resolves or creates a migration plan without touching Postgres. +func PrepareImport(opts ImportOptions) (*ImportPrepareResult, error) { + if _, err := ValidateInputPath(opts.InputPath); err != nil { + return nil, err + } + if _, err := FindPgloader(); err != nil { + return nil, err + } + + lintResult, err := Lint(opts.InputPath) + if err != nil { + return nil, err + } + + method := opts.Method + if method == "" { + size, err := FileSize(opts.InputPath) + if err != nil { + return nil, err + } + method = recommendMethod(size) + } + + plan, err := resolvePlan(opts, method, lintResult) + if err != nil { + return nil, err + } + + if opts.Method != "" { + plan.RecommendedMethod = opts.Method + } + method = plan.RecommendedMethod + + out := &ImportPrepareResult{ + MigrationID: plan.MigrationID, + Method: method, + Lint: lintResult, + Plan: plan, + CanProceed: lintResult.ErrorCount == 0, + } + if !out.CanProceed { + out.BlockedReason = lintBlockedReason(lintResult.ErrorCount) + } + return out, nil +} + +func resolvePlan(opts ImportOptions, method string, lint *LintResult) (*PlanResult, error) { + if opts.MigrationID == "" { + return createAndSavePlan(PlanOptions{ + InputPath: opts.InputPath, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + Method: method, + Lint: lint, + }) + } + + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + return createAndSavePlan(PlanOptions{ + InputPath: opts.InputPath, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + Method: method, + MigrationID: opts.MigrationID, + Lint: lint, + }) + } + + if opts.InputPath != "" && state.InputPath != "" && state.InputPath != opts.InputPath { + return nil, newMigrationError( + ErrCodeInvalidInput, + fmt.Sprintf("input path %q does not match planned import %q", opts.InputPath, state.InputPath), + "Use the same --input as a prior start preview or omit --migration-id to start fresh", + ) + } + + inputPath := opts.InputPath + if inputPath == "" { + inputPath = state.InputPath + } + + plan, err := Plan(PlanOptions{ + InputPath: inputPath, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + Method: method, + MigrationID: state.MigrationID, + Lint: lint, + }) + if err != nil { + return nil, err + } + if state.Method != "" { + plan.RecommendedMethod = state.Method + } + return plan, nil +} + +func createAndSavePlan(opts PlanOptions) (*PlanResult, error) { + plan, err := Plan(opts) + if err != nil { + return nil, err + } + if err := SavePlan(plan); err != nil { + return nil, err + } + return plan, nil +} + +func importResultFromPrepare(prepared *ImportPrepareResult, dryRun bool) *ImportResult { + return &ImportResult{ + MigrationID: prepared.MigrationID, + Method: prepared.Method, + DryRun: dryRun, + Lint: prepared.Lint, + Plan: prepared.Plan, + CanProceed: prepared.CanProceed, + } +} + +// BlockedStartResponse builds the start error envelope when lint blocks import. +func BlockedStartResponse(prepared *ImportPrepareResult, dryRun bool) Response { + resp := ErrorResponse("start", ErrLintBlocked(prepared.BlockedReason)) + if prepared.Lint != nil { + resp.Issues = prepared.Lint.Issues + } + resp.Data = ImportResult{ + MigrationID: prepared.MigrationID, + Method: prepared.Method, + DryRun: dryRun, + Lint: prepared.Lint, + Plan: prepared.Plan, + CanProceed: false, + } + resp.MigrationID = prepared.MigrationID + return resp +} + +// PrintStartPreview writes a human-readable lint/plan summary before import confirmation. +func PrintStartPreview(p *printer.Printer, prepared *ImportPrepareResult) { + if prepared == nil { + return + } + p.Println("\nImport preview") + if prepared.Lint != nil { + p.Printf(" Lint: %d error(s), %d warning(s)\n", prepared.Lint.ErrorCount, prepared.Lint.WarningCount) + for _, issue := range prepared.Lint.Issues { + if issue.Severity != SeverityError && issue.Severity != SeverityWarning { + continue + } + loc := issue.Table + if issue.Column != "" { + loc += "." + issue.Column + } + if loc != "" { + loc = " " + loc + } + p.Printf(" [%s] %s%s: %s\n", issue.Severity, issue.Code, loc, previewMessage(issue)) + } + } + if prepared.Plan != nil { + sizeMB := float64(prepared.Plan.EstimatedSizeBytes) / (1024 * 1024) + p.Printf(" Plan: migration_id %s, method %s, %.1f MB, %d tables\n", + prepared.Plan.MigrationID, + prepared.Plan.RecommendedMethod, + sizeMB, + len(prepared.Plan.Tables), + ) + } + if prepared.BlockedReason != "" { + p.Printf(" Blocked: %s\n", prepared.BlockedReason) + } + p.Println() +} + +func previewMessage(issue Issue) string { + if issue.Message != "" { + return issue.Message + } + return issue.Remediation +} diff --git a/internal/migrate/d1/prepare_test.go b/internal/migrate/d1/prepare_test.go new file mode 100644 index 000000000..5ed22f08a --- /dev/null +++ b/internal/migrate/d1/prepare_test.go @@ -0,0 +1,85 @@ +package d1 + +import ( + "bytes" + "context" + "strings" + "testing" + + "github.com/planetscale/cli/internal/printer" +) + +func TestPrepareImport(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + prepared, err := PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + Branch: "main", + }) + if err != nil { + t.Fatalf("PrepareImport: %v", err) + } + if !prepared.CanProceed { + t.Fatalf("expected can proceed, blocked: %s", prepared.BlockedReason) + } + if prepared.MigrationID == "" { + t.Fatal("expected migration id") + } + if prepared.Lint == nil || prepared.Plan == nil { + t.Fatal("expected lint and plan in prepare result") + } + if prepared.Method != prepared.Plan.RecommendedMethod { + t.Fatalf("method mismatch: %q vs %q", prepared.Method, prepared.Plan.RecommendedMethod) + } +} + +func TestImport_BlocksOnLintErrors(t *testing.T) { + prepared := &ImportPrepareResult{ + MigrationID: "mig-test", + Method: MethodPgloader, + CanProceed: false, + BlockedReason: "lint reported 1 error(s); fix or use import d1 lint for details", + Lint: &LintResult{ + ErrorCount: 1, + Issues: []Issue{{ + Code: "TEST", + Severity: SeverityError, + Message: "blocked for test", + }}, + }, + } + + result, err := Import(context.Background(), nil, nil, ImportOptions{DryRun: true}, prepared) + if err == nil { + t.Fatal("expected lint blocked error") + } + requireMigrationErr(t, err, ErrCodeLintBlocked) + if result == nil || result.CanProceed { + t.Fatal("expected result with can_proceed false") + } +} + +func TestPrintStartPreview(t *testing.T) { + prepared, err := PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + }) + if err != nil { + t.Fatalf("PrepareImport: %v", err) + } + + var buf bytes.Buffer + format := printer.Human + p := printer.NewPrinter(&format) + p.SetHumanOutput(&buf) + PrintStartPreview(p, prepared) + out := buf.String() + for _, want := range []string{"Import preview", "Lint:", "Plan:", prepared.MigrationID} { + if !strings.Contains(out, want) { + t.Fatalf("preview missing %q:\n%s", want, out) + } + } +} diff --git a/internal/migrate/d1/schema_reset.go b/internal/migrate/d1/schema_reset.go new file mode 100644 index 000000000..d93b13ede --- /dev/null +++ b/internal/migrate/d1/schema_reset.go @@ -0,0 +1,153 @@ +package d1 + +import ( + "context" + "fmt" + "strings" + + ps "github.com/planetscale/planetscale-go/planetscale" + + "github.com/planetscale/cli/internal/postgres" +) + +const ( + postgresRoleName = "postgres" + publicSchemaName = "public" +) + +func reassignStaleImportRoleObjects(ctx context.Context, psClient *ps.Client, opts ImportOptions, currentUsername string) error { + if psClient == nil || opts.DestURI != "" { + return nil + } + + roles, err := psClient.PostgresRoles.List(ctx, &ps.ListPostgresRolesRequest{ + Organization: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + }) + if err != nil { + return fmt.Errorf("list postgres roles: %w", err) + } + + var firstErr error + for _, role := range roles { + if role == nil || role.Username == currentUsername { + continue + } + if !isEphemeralImportRole(role.Username) { + continue + } + err := psClient.PostgresRoles.ReassignObjects(ctx, &ps.ReassignPostgresRoleObjectsRequest{ + Organization: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + RoleId: role.ID, + Successor: postgresRoleName, + }) + if err != nil && firstErr == nil { + firstErr = fmt.Errorf("reassign objects from role %q to %s: %w", role.Username, postgresRoleName, err) + } + } + return firstErr +} + +func isEphemeralImportRole(username string) bool { + return strings.HasPrefix(username, "pscale_api_") +} + +func usernameFromDestURI(destURI string) (string, error) { + cfg, err := postgres.ParseConnectionURI(destURI) + if err != nil { + return "", err + } + if cfg.User == "" { + return "", fmt.Errorf("destination URI missing user") + } + return cfg.User, nil +} + +func importTableNames(tables []TableSchema) []string { + names := make([]string, 0, len(tables)) + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + names = append(names, table.Name) + } + return names +} + +func existingPublicTables(ctx context.Context, destURI string, names []string) (map[string]struct{}, error) { + existing := make(map[string]struct{}) + if len(names) == 0 { + return existing, nil + } + + db, err := OpenPostgres(destURI) + if err != nil { + return nil, err + } + defer db.Close() + + placeholders := make([]string, len(names)) + args := make([]any, len(names)) + for i, name := range names { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = name + } + query := fmt.Sprintf( + `SELECT table_name FROM information_schema.tables WHERE table_schema = '%s' AND table_name IN (%s)`, + publicSchemaName, + strings.Join(placeholders, ", "), + ) + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list existing tables: %w", err) + } + defer rows.Close() + + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, fmt.Errorf("scan table name: %w", err) + } + existing[name] = struct{}{} + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("list existing tables: %w", err) + } + + return existing, nil +} + +func conflictingImportTables(importNames []string, existing map[string]struct{}) []string { + if len(existing) == 0 { + return nil + } + conflicts := make([]string, 0, len(importNames)) + for _, name := range importNames { + if _, found := existing[name]; found { + conflicts = append(conflicts, name) + } + } + return conflicts +} + +func buildImportTablesSQL(tables []TableSchema) string { + tableByName := make(map[string]TableSchema, len(tables)) + for _, table := range tables { + tableByName[table.Name] = table + } + + var b strings.Builder + for _, name := range topologicalLoadOrder(tables) { + table, ok := tableByName[name] + if !ok || IsORMMetadataTable(table.Name) { + continue + } + b.WriteString(convertTableDDL(table, tables)) + b.WriteString("\n\n") + } + return b.String() +} diff --git a/internal/migrate/d1/schema_reset_test.go b/internal/migrate/d1/schema_reset_test.go new file mode 100644 index 000000000..387dade78 --- /dev/null +++ b/internal/migrate/d1/schema_reset_test.go @@ -0,0 +1,63 @@ +package d1 + +import ( + "strings" + "testing" +) + +func TestConflictingImportTables(t *testing.T) { + existing := map[string]struct{}{ + "organizations": {}, + "posts": {}, + "other_app": {}, + } + conflicts := conflictingImportTables([]string{"organizations", "users", "posts"}, existing) + if len(conflicts) != 2 || conflicts[0] != "organizations" || conflicts[1] != "posts" { + t.Fatalf("conflicts = %v", conflicts) + } +} + +func TestErrExistingImportTables(t *testing.T) { + err := errExistingImportTables([]string{"users", "posts"}) + requireMigrationErr(t, err, ErrCodeDestinationConflict) + me, _ := migrationErr(err) + if !strings.Contains(me.Info.Message, "users, posts") { + t.Fatalf("message = %q", me.Info.Message) + } +} + +func TestBuildImportTablesSQLCreatesAllImportTables(t *testing.T) { + tables := []TableSchema{ + { + Name: "organizations", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER", PrimaryKey: true, AutoIncrement: true}, + }, + }, + { + Name: "users", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER", PrimaryKey: true, AutoIncrement: true}, + }, + }, + } + + sql := buildImportTablesSQL(tables) + if !strings.Contains(sql, `CREATE TABLE IF NOT EXISTS "organizations"`) { + t.Fatalf("expected organizations table DDL:\n%s", sql) + } + if !strings.Contains(sql, `CREATE TABLE IF NOT EXISTS "users"`) { + t.Fatalf("expected users table DDL:\n%s", sql) + } +} + +func TestImportTableNamesSkipsORMMetadata(t *testing.T) { + names := importTableNames([]TableSchema{ + {Name: "users"}, + {Name: "__drizzle_migrations"}, + {Name: "posts"}, + }) + if len(names) != 2 || names[0] != "users" || names[1] != "posts" { + t.Fatalf("names = %v", names) + } +} diff --git a/internal/migrate/d1/sqlite_load.go b/internal/migrate/d1/sqlite_load.go new file mode 100644 index 000000000..5e762d841 --- /dev/null +++ b/internal/migrate/d1/sqlite_load.go @@ -0,0 +1,186 @@ +package d1 + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + execabs "golang.org/x/sys/execabs" +) + +const defaultSQLiteChunkBytes = 64 << 20 // 64 MiB of SQL per .read batch + +// EnsureSQLiteFromDump loads dump SQL into sqlite unless a fresh-enough database already exists. +func EnsureSQLiteFromDump(ctx context.Context, dumpPath, sqlitePath string) error { + if canReuseSQLite(dumpPath, sqlitePath) { + return nil + } + return buildSQLiteFromDump(ctx, dumpPath, sqlitePath) +} + +// BuildSQLiteFromDump always rebuilds sqlite from the dump (tests and forced refresh). +func BuildSQLiteFromDump(ctx context.Context, dumpPath, sqlitePath string) error { + return buildSQLiteFromDump(ctx, dumpPath, sqlitePath) +} + +func buildSQLiteFromDump(ctx context.Context, dumpPath, sqlitePath string) error { + dumpPath, err := ValidateInputPath(dumpPath) + if err != nil { + return err + } + + sqlite3, err := FindSQLite3() + if err != nil { + return err + } + + if err := os.RemoveAll(sqlitePath); err != nil && !os.IsNotExist(err) { + return err + } + + dir := filepath.Dir(sqlitePath) + if err := os.MkdirAll(dir, 0o700); err != nil { + return err + } + + return loadSQLiteDumpChunked(ctx, sqlite3, dumpPath, sqlitePath, defaultSQLiteChunkBytes) +} + +func canReuseSQLite(dumpPath, sqlitePath string) bool { + dumpInfo, err := os.Stat(dumpPath) + if err != nil { + return false + } + sqliteInfo, err := os.Stat(sqlitePath) + if err != nil || sqliteInfo.Size() == 0 { + return false + } + // Reuse when sqlite is at least as new as the dump (same import input). + if sqliteInfo.ModTime().Before(dumpInfo.ModTime()) { + return false + } + return sqliteHasTables(sqlitePath) +} + +func sqliteHasTables(sqlitePath string) bool { + sqlite3, err := FindSQLite3() + if err != nil { + return false + } + out, err := execabs.Command(sqlite3, sqlitePath, "SELECT 1 FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' LIMIT 1;").CombinedOutput() + if err != nil { + return false + } + return strings.TrimSpace(string(out)) == "1" +} + +func loadSQLiteDumpChunked(ctx context.Context, sqlite3, dumpPath, sqlitePath string, chunkBytes int64) error { + dump, err := os.Open(dumpPath) + if err != nil { + return err + } + defer dump.Close() + + chunkDir, err := os.MkdirTemp("", "pscale-d1-sqlite-chunk-*") + if err != nil { + return err + } + defer os.RemoveAll(chunkDir) + + reader := bufio.NewReader(dump) + var ( + chunkIdx int + chunkFile *os.File + chunkPath string + chunkSize int64 + lineNo int + totalLines int + ) + + flushChunk := func() error { + if chunkFile == nil { + return nil + } + if err := chunkFile.Close(); err != nil { + return err + } + chunkFile = nil + + readPath := strings.ReplaceAll(chunkPath, "'", "''") + cmd := execabs.CommandContext(ctx, sqlite3, sqlitePath, fmt.Sprintf(".read %s", readPath)) + var stderr bytes.Buffer + cmd.Stdout = io.Discard + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf( + "sqlite3 chunk %d (through line %d): %w: %s", + chunkIdx, + lineNo, + err, + truncateLoadError(stderr.String(), 2048), + ) + } + return os.Remove(chunkPath) + } + + startChunk := func() error { + chunkIdx++ + chunkPath = filepath.Join(chunkDir, fmt.Sprintf("chunk-%04d.sql", chunkIdx)) + f, err := os.OpenFile(chunkPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + if err != nil { + return err + } + chunkFile = f + chunkSize = 0 + return nil + } + + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + lineNo++ + totalLines++ + if chunkFile == nil { + if err := startChunk(); err != nil { + return err + } + } + if _, werr := chunkFile.Write(line); werr != nil { + return werr + } + chunkSize += int64(len(line)) + if chunkSize >= chunkBytes { + if err := flushChunk(); err != nil { + return err + } + } + } + if err == io.EOF { + break + } + if err != nil { + return err + } + } + + if err := flushChunk(); err != nil { + return fmt.Errorf("sqlite3 load failed: %w", err) + } + if totalLines == 0 { + return fmt.Errorf("sqlite3 load failed: dump is empty") + } + return nil +} + +func truncateLoadError(msg string, max int) string { + msg = strings.TrimSpace(msg) + if len(msg) <= max { + return msg + } + return msg[:max] + "..." +} diff --git a/internal/migrate/d1/sqlite_load_test.go b/internal/migrate/d1/sqlite_load_test.go new file mode 100644 index 000000000..ce0552be5 --- /dev/null +++ b/internal/migrate/d1/sqlite_load_test.go @@ -0,0 +1,135 @@ +package d1 + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestBuildSQLiteFromDump(t *testing.T) { + dir := t.TempDir() + dumpPath := filepath.Join(dir, "dump.sql") + sqlitePath := filepath.Join(dir, "load.sqlite") + + var b strings.Builder + b.WriteString("PRAGMA defer_foreign_keys=TRUE;\n") + b.WriteString("CREATE TABLE attachments (\n") + b.WriteString(" id INTEGER PRIMARY KEY,\n") + b.WriteString(" payload BLOB\n") + b.WriteString(");\n") + hex := strings.Repeat("41", 48000) // ~96 KiB blob, similar to wrangler export lines + fmt.Fprintf(&b, "INSERT INTO attachments (id, payload) VALUES(1,X'%s');\n", hex) + b.WriteString("INSERT INTO attachments (id, payload) VALUES(2,NULL);\n") + + if err := os.WriteFile(dumpPath, []byte(b.String()), 0o600); err != nil { + t.Fatal(err) + } + + if err := BuildSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatalf("BuildSQLiteFromDump: %v", err) + } + + counts, err := CountSQLiteRows(context.Background(), sqlitePath, []string{"attachments"}) + if err != nil { + t.Fatal(err) + } + if counts["attachments"] != 2 { + t.Fatalf("expected 2 rows, got %d", counts["attachments"]) + } +} + +func TestEnsureSQLiteFromDumpReusesExisting(t *testing.T) { + dir := t.TempDir() + dumpPath := filepath.Join(dir, "dump.sql") + sqlitePath := filepath.Join(dir, "load.sqlite") + + content := "PRAGMA defer_foreign_keys=TRUE;\nCREATE TABLE t (id INTEGER PRIMARY KEY);\nINSERT INTO t VALUES(1);\n" + if err := os.WriteFile(dumpPath, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + if err := BuildSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatal(err) + } + info1, err := os.Stat(sqlitePath) + if err != nil { + t.Fatal(err) + } + + time.Sleep(10 * time.Millisecond) + if err := os.WriteFile(dumpPath, []byte(content), 0o600); err != nil { + t.Fatal(err) + } + // Touch dump to be newer than sqlite — should rebuild. + dumpInfo, _ := os.Stat(dumpPath) + if err := os.Chtimes(dumpPath, dumpInfo.ModTime().Add(time.Second), dumpInfo.ModTime().Add(time.Second)); err != nil { + t.Fatal(err) + } + + if err := EnsureSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatal(err) + } + info2, err := os.Stat(sqlitePath) + if err != nil { + t.Fatal(err) + } + if !info2.ModTime().After(info1.ModTime()) { + t.Fatal("expected rebuild when dump is newer than sqlite") + } + + // Dump must not be newer than sqlite for reuse. + now := time.Now() + if err := os.Chtimes(dumpPath, now.Add(-2*time.Minute), now.Add(-2*time.Minute)); err != nil { + t.Fatal(err) + } + if err := os.Chtimes(sqlitePath, now.Add(-time.Minute), now.Add(-time.Minute)); err != nil { + t.Fatal(err) + } + + if err := EnsureSQLiteFromDump(context.Background(), dumpPath, sqlitePath); err != nil { + t.Fatal(err) + } + info3, err := os.Stat(sqlitePath) + if err != nil { + t.Fatal(err) + } + if info3.ModTime().After(info2.ModTime()) { + t.Fatal("expected sqlite reuse without rebuild when dump is not newer") + } +} + +func TestLoadSQLiteDumpChunked(t *testing.T) { + dir := t.TempDir() + dumpPath := filepath.Join(dir, "multi.sql") + sqlitePath := filepath.Join(dir, "multi.sqlite") + + var b strings.Builder + b.WriteString("PRAGMA defer_foreign_keys=TRUE;\n") + b.WriteString("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT);\n") + for i := 0; i < 200; i++ { + fmt.Fprintf(&b, "INSERT INTO t (id, v) VALUES(%d,'row');\n", i) + } + if err := os.WriteFile(dumpPath, []byte(b.String()), 0o600); err != nil { + t.Fatal(err) + } + + sqlite3, err := FindSQLite3() + if err != nil { + t.Fatal(err) + } + // Force many small chunks to exercise batching. + if err := loadSQLiteDumpChunked(context.Background(), sqlite3, dumpPath, sqlitePath, 256); err != nil { + t.Fatalf("loadSQLiteDumpChunked: %v", err) + } + + counts, err := CountSQLiteRows(context.Background(), sqlitePath, []string{"t"}) + if err != nil { + t.Fatal(err) + } + if counts["t"] != 200 { + t.Fatalf("expected 200 rows, got %d", counts["t"]) + } +} diff --git a/internal/migrate/d1/state.go b/internal/migrate/d1/state.go new file mode 100644 index 000000000..1dcd1ec83 --- /dev/null +++ b/internal/migrate/d1/state.go @@ -0,0 +1,161 @@ +package d1 + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/adrg/xdg" +) + +// StateStore manages local migration state. +type StateStore struct { + dir string +} + +// NewStateStore returns the default state store location. +func NewStateStore() (*StateStore, error) { + dir, err := xdg.ConfigFile("planetscale/import-d1") + if err != nil { + return nil, fmt.Errorf("state dir: %w", err) + } + if os.Getenv("PSCALE_TEST_MODE") == "1" { + dir = filepath.Join(os.TempDir(), "pscale-import-d1-test") + } + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, err + } + return &StateStore{dir: dir}, nil +} + +func (s *StateStore) statePath(org, database, branch, migrationID string) string { + key := fmt.Sprintf("%s_%s_%s_%s.json", sanitize(org), sanitize(database), sanitize(branch), sanitize(migrationID)) + return filepath.Join(s.dir, key) +} + +func sanitize(s string) string { + out := make([]byte, 0, len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' { + out = append(out, c) + } else { + out = append(out, '_') + } + } + return string(out) +} + +// Save persists migration state. +func (s *StateStore) Save(state *MigrationState) error { + if state.CreatedAt.IsZero() { + state.CreatedAt = time.Now().UTC() + } + state.UpdatedAt = time.Now().UTC() + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err + } + path := s.statePath(state.Org, state.Database, state.Branch, state.MigrationID) + return os.WriteFile(path, data, 0o600) +} + +// Load retrieves migration state by ID. +func (s *StateStore) Load(org, database, branch, migrationID string) (*MigrationState, error) { + path := s.statePath(org, database, branch, migrationID) + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, newMigrationError(ErrCodeNotFound, "migration state not found", "Run `import d1 start --dry-run` or `import d1 start` to create migration state") + } + return nil, err + } + var state MigrationState + if err := json.Unmarshal(data, &state); err != nil { + return nil, err + } + return &state, nil +} + +// Delete removes migration state. +func (s *StateStore) Delete(org, database, branch, migrationID string) error { + path := s.statePath(org, database, branch, migrationID) + err := os.Remove(path) + if os.IsNotExist(err) { + return nil + } + return err +} + +// SaveState is a package-level helper using the default store. +func SaveState(state *MigrationState) error { + store, err := NewStateStore() + if err != nil { + return err + } + return store.Save(state) +} + +// LoadState loads state using the default store. +func LoadState(org, database, branch, migrationID string) (*MigrationState, error) { + store, err := NewStateStore() + if err != nil { + return nil, err + } + return store.Load(org, database, branch, migrationID) +} + +// SetMigrationPhase updates the phase on existing migration state. +func SetMigrationPhase(org, database, branch, migrationID, phase string) error { + return updateMigrationState(org, database, branch, migrationID, func(state *MigrationState) { + state.Phase = phase + }) +} + +func updateMigrationState(org, database, branch, migrationID string, update func(*MigrationState)) error { + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + return err + } + update(state) + return SaveState(state) +} + +func saveImportMigrationState(opts ImportOptions, phase, sqlitePath string) error { + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + if me, ok := migrationErr(err); ok && me.Info.Code == ErrCodeNotFound { + state = &MigrationState{ + MigrationID: opts.MigrationID, + Org: opts.Org, + Database: opts.Database, + Branch: opts.Branch, + } + } else { + return err + } + } + state.Phase = phase + if opts.InputPath != "" { + state.InputPath = opts.InputPath + } + if opts.Method != "" { + state.Method = opts.Method + } + if sqlitePath != "" { + state.SQLitePath = sqlitePath + } + return SaveState(state) +} + +// Complete marks a migration as finished in local state. +func Complete(org, database, branch, migrationID string) error { + return SetMigrationPhase(org, database, branch, migrationID, PhaseComplete) +} + +// Teardown is deprecated; use Complete. +func Teardown(org, database, branch, migrationID string) error { + return Complete(org, database, branch, migrationID) +} diff --git a/internal/migrate/d1/state_test.go b/internal/migrate/d1/state_test.go new file mode 100644 index 000000000..c953223bb --- /dev/null +++ b/internal/migrate/d1/state_test.go @@ -0,0 +1,107 @@ +package d1 + +import ( + "testing" +) + +func TestMigrationPhaseTransitions(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + migrationID := "testphase123" + + plan := &PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: testFixture(t), + } + if err := SavePlan(plan); err != nil { + t.Fatalf("SavePlan: %v", err) + } + + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState: %v", err) + } + if state.Phase != PhasePlanned { + t.Fatalf("phase = %q, want %q", state.Phase, PhasePlanned) + } + + opts := ImportOptions{ + Org: org, + Database: database, + Branch: branch, + MigrationID: migrationID, + InputPath: plan.InputPath, + Method: MethodPgloader, + } + if err := saveImportMigrationState(opts, PhaseImporting, ""); err != nil { + t.Fatalf("saveImportMigrationState importing: %v", err) + } + state, err = LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState importing: %v", err) + } + if state.Phase != PhaseImporting { + t.Fatalf("phase = %q, want %q", state.Phase, PhaseImporting) + } + + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseImported); err != nil { + t.Fatalf("SetMigrationPhase imported: %v", err) + } + if err := SetMigrationPhase(org, database, branch, migrationID, PhaseVerified); err != nil { + t.Fatalf("SetMigrationPhase verified: %v", err) + } + if err := Complete(org, database, branch, migrationID); err != nil { + t.Fatalf("Complete: %v", err) + } + + state, err = LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState complete: %v", err) + } + if state.Phase != PhaseComplete { + t.Fatalf("phase = %q, want %q", state.Phase, PhaseComplete) + } +} + +func TestSaveImportMigrationStateFailed(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + org, database, branch := "acme", "mydb", "main" + migrationID := "testfailed456" + if err := SavePlan(&PlanResult{ + MigrationID: migrationID, + Org: org, + Database: database, + Branch: branch, + InputPath: testFixture(t), + }); err != nil { + t.Fatalf("SavePlan: %v", err) + } + + opts := ImportOptions{ + Org: org, + Database: database, + Branch: branch, + MigrationID: migrationID, + InputPath: testFixture(t), + Method: MethodPgloader, + } + if err := saveImportMigrationState(opts, PhaseFailed, "/tmp/test.sqlite"); err != nil { + t.Fatalf("saveImportMigrationState failed: %v", err) + } + + state, err := LoadState(org, database, branch, migrationID) + if err != nil { + t.Fatalf("LoadState: %v", err) + } + if state.Phase != PhaseFailed { + t.Fatalf("phase = %q, want %q", state.Phase, PhaseFailed) + } + if state.SQLitePath != "/tmp/test.sqlite" { + t.Fatalf("sqlite path = %q", state.SQLitePath) + } +} diff --git a/internal/migrate/d1/testdata/sample_d1_export.sql b/internal/migrate/d1/testdata/sample_d1_export.sql new file mode 100644 index 000000000..15c292f39 --- /dev/null +++ b/internal/migrate/d1/testdata/sample_d1_export.sql @@ -0,0 +1,71 @@ +-- Sample D1 export fixture for migrate d1 tests +PRAGMA foreign_keys=OFF; + +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT NOT NULL UNIQUE, + active INTEGER DEFAULT 1, + created_at TEXT NOT NULL +); + +CREATE TABLE posts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + title TEXT NOT NULL, + body TEXT, + published INTEGER DEFAULT 0, + metadata TEXT, + FOREIGN KEY (user_id) REFERENCES users(id) +); + +CREATE TABLE external_entities ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + created_at TEXT NOT NULL +); + +CREATE TABLE entity_links ( + entity_id TEXT NOT NULL, + post_id INTEGER NOT NULL, + linked_at TEXT NOT NULL, + PRIMARY KEY (entity_id, post_id), + FOREIGN KEY (entity_id) REFERENCES external_entities(id), + FOREIGN KEY (post_id) REFERENCES posts(id) +); + +CREATE TABLE __drizzle_migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + hash TEXT NOT NULL, + created_at INTEGER +); + +CREATE TABLE _prisma_migrations ( + id TEXT PRIMARY KEY, + checksum TEXT NOT NULL, + finished_at TEXT, + migration_name TEXT NOT NULL, + logs TEXT, + rolled_back_at TEXT, + started_at TEXT NOT NULL, + applied_steps_count INTEGER NOT NULL +); + +INSERT INTO users (id, email, active, created_at) VALUES + (1, 'alice@example.com', 1, '2024-01-01T00:00:00Z'), + (2, 'bob@example.com', 0, '2024-01-02T00:00:00Z'); + +INSERT INTO posts (id, user_id, title, body, published, metadata) VALUES + (1, 1, 'Hello', 'World', 1, '{"tags":["intro"]}'), + (2, 1, 'Draft', 'Work in progress', 0, NULL); + +INSERT INTO external_entities (id, name, created_at) VALUES + ('550e8400-e29b-41d4-a716-446655440000', 'Webhook A', '2024-01-01T00:00:00Z'); + +INSERT INTO entity_links (entity_id, post_id, linked_at) VALUES + ('550e8400-e29b-41d4-a716-446655440000', 1, '2024-01-02T00:00:00Z'); + +INSERT INTO __drizzle_migrations (id, hash, created_at) VALUES + (1, 'abc123', 1700000000); + +CREATE INDEX idx_users_email ON users(email); +CREATE UNIQUE INDEX idx_entity_links_post ON entity_links(post_id); diff --git a/internal/migrate/d1/types.go b/internal/migrate/d1/types.go new file mode 100644 index 000000000..a2470cc4d --- /dev/null +++ b/internal/migrate/d1/types.go @@ -0,0 +1,189 @@ +package d1 + +import "time" + +// Severity levels for lint/plan issues. +const ( + SeverityError = "error" + SeverityWarning = "warning" + SeverityInfo = "info" +) + +// Issue describes a migration concern with agent-friendly remediation. +type Issue struct { + Code string `json:"code"` + Severity string `json:"severity"` + Table string `json:"table,omitempty"` + Column string `json:"column,omitempty"` + Message string `json:"message,omitempty"` + Remediation string `json:"remediation"` +} + +// NextStep guides agents to the next tool or command. +type NextStep struct { + Tool string `json:"tool,omitempty"` + Command string `json:"command,omitempty"` + Reason string `json:"reason"` +} + +// Response is the common JSON envelope for migrate d1 commands. +type Response struct { + Status string `json:"status"` + Phase string `json:"phase"` + MigrationID string `json:"migration_id,omitempty"` + Issues []Issue `json:"issues,omitempty"` + NextSteps []NextStep `json:"next_steps,omitempty"` + Data any `json:"data,omitempty"` + Error *ErrorInfo `json:"error,omitempty"` +} + +// ErrorInfo is a structured CLI/MCP error. +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + Remediation string `json:"remediation,omitempty"` +} + +// DoctorResult lists prerequisite checks. +type DoctorResult struct { + Checks []DoctorCheck `json:"checks"` + Ready bool `json:"ready"` +} + +// DoctorCheck is a single prerequisite check. +type DoctorCheck struct { + Name string `json:"name"` + Status string `json:"status"` + Version string `json:"version,omitempty"` + Message string `json:"message,omitempty"` + Remediation string `json:"remediation,omitempty"` +} + +// LintResult summarizes lint output. +type LintResult struct { + InputPath string `json:"input_path"` + TableCount int `json:"table_count"` + ErrorCount int `json:"error_count"` + WarningCount int `json:"warning_count"` + Issues []Issue `json:"issues"` + Tables []string `json:"tables"` +} + +// PlanResult is the migration plan JSON. +type PlanResult struct { + MigrationID string `json:"migration_id"` + InputPath string `json:"input_path"` + Org string `json:"org"` + Database string `json:"database"` + Branch string `json:"branch"` + RecommendedMethod string `json:"recommended_method"` + EstimatedSizeBytes int64 `json:"estimated_size_bytes,omitempty"` + Tables []TablePlan `json:"tables"` + CastRules []CastRule `json:"cast_rules"` + LoadOrder []string `json:"load_order"` + Issues []Issue `json:"issues"` +} + +// TablePlan describes a table in the migration plan. +type TablePlan struct { + Name string `json:"name"` + RowEstimate int `json:"row_estimate,omitempty"` + HasFK bool `json:"has_foreign_keys"` +} + +// CastRule maps SQLite types to Postgres casts for pgloader. +type CastRule struct { + SourceType string `json:"source_type"` + TargetType string `json:"target_type"` + Using string `json:"using,omitempty"` + Tables string `json:"tables,omitempty"` +} + +// ExportResult describes a D1 export. +type ExportResult struct { + OutputPath string `json:"output_path"` + Remote bool `json:"remote"` + Database string `json:"d1_database"` + ExportedAt time.Time `json:"exported_at"` + SizeBytes int64 `json:"size_bytes"` +} + +// ImportResult describes an import run. +type ImportResult struct { + MigrationID string `json:"migration_id"` + Method string `json:"method"` + DryRun bool `json:"dry_run"` + TablesLoaded int `json:"tables_loaded,omitempty"` + Timings *ImportTimings `json:"timings,omitempty"` + Lint *LintResult `json:"lint,omitempty"` + Plan *PlanResult `json:"plan,omitempty"` + CanProceed bool `json:"can_proceed"` +} + +// ImportTimings breaks down import wall-clock time by phase. +type ImportTimings struct { + TotalMs int64 `json:"total_ms"` + SQLiteStagingMs int64 `json:"sqlite_staging_ms,omitempty"` + SchemaMs int64 `json:"schema_ms,omitempty"` + PgloaderMs int64 `json:"pgloader_ms,omitempty"` + IndexBuildMs int64 `json:"index_build_ms,omitempty"` + SequenceResetMs int64 `json:"sequence_reset_ms,omitempty"` + TableLoads []TableLoadTiming `json:"table_loads,omitempty"` +} + +// TableLoadTiming is per-table pgloader duration. +type TableLoadTiming struct { + Table string `json:"table"` + Ms int64 `json:"ms"` +} + +// VerifyOptions configures post-import verification. +type VerifyOptions struct { + Org string + Database string + Branch string + MigrationID string + InputPath string + SQLitePath string + DestURI string +} + +// VerifyResult compares source and destination after import. +type VerifyResult struct { + MigrationID string `json:"migration_id"` + Matched bool `json:"matched"` + Tables []TableVerifyResult `json:"tables"` + Checks []VerifyCheckResult `json:"checks,omitempty"` +} + +// TableVerifyResult is per-table verification. +type TableVerifyResult struct { + Table string `json:"table"` + SourceRows int64 `json:"source_rows"` + DestRows int64 `json:"dest_rows"` + Match bool `json:"match"` +} + +// Migration phases persisted in local state. +const ( + PhasePlanned = "planned" + PhaseImporting = "importing" + PhaseImported = "imported" + PhaseVerified = "verified" + PhaseFailed = "failed" + PhaseComplete = "complete" +) + +// MigrationState is persisted local migration metadata. +type MigrationState struct { + MigrationID string `json:"migration_id"` + Org string `json:"org"` + Database string `json:"database"` + Branch string `json:"branch"` + InputPath string `json:"input_path"` + SQLitePath string `json:"sqlite_path,omitempty"` + Method string `json:"method,omitempty"` + Phase string `json:"phase"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/internal/migrate/d1/verify.go b/internal/migrate/d1/verify.go new file mode 100644 index 000000000..0fd7c30cb --- /dev/null +++ b/internal/migrate/d1/verify.go @@ -0,0 +1,191 @@ +package d1 + +import ( + "context" + "fmt" + + execabs "golang.org/x/sys/execabs" +) + +// Verify compares SQLite source data with PlanetScale Postgres after import. +func Verify(ctx context.Context, opts VerifyOptions) (*VerifyResult, error) { + if opts.DestURI == "" { + return nil, newMigrationError( + ErrCodeInvalidInput, + "destination database connection required for verify", + "Pass --database (and --org/--branch) so verify can compare against PlanetScale Postgres", + ) + } + + sqlitePath := opts.SQLitePath + if sqlitePath == "" { + state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) + if err != nil { + if opts.InputPath != "" { + sqlitePath = DefaultSQLitePath(opts.InputPath) + } else { + return nil, err + } + } else { + sqlitePath = state.SQLitePath + if opts.InputPath == "" { + opts.InputPath = state.InputPath + } + } + } + if opts.InputPath == "" { + return nil, newMigrationError( + ErrCodeMissingInput, + "input dump path required for verify", + "Pass --input or run verify with a migration-id from a prior import", + ) + } + + tables, err := ParseDump(opts.InputPath) + if err != nil { + return nil, err + } + + tableNames := make([]string, 0, len(tables)) + dataTables := make([]TableSchema, 0, len(tables)) + for _, t := range tables { + if IsORMMetadataTable(t.Name) { + continue + } + tableNames = append(tableNames, t.Name) + dataTables = append(dataTables, t) + } + + sourceCounts, err := CountSQLiteRows(ctx, sqlitePath, tableNames) + if err != nil { + insertCounts, insertErr := CountInsertRows(opts.InputPath) + if insertErr != nil { + return nil, err + } + sourceCounts = mapStringIntToInt64(insertCounts) + } + + destCounts, err := CountPostgresRows(ctx, opts.DestURI, tableNames) + if err != nil { + return nil, err + } + + result := &VerifyResult{ + MigrationID: opts.MigrationID, + Matched: true, + Checks: []VerifyCheckResult{}, + } + + var rowCountsOK bool + result.Tables, rowCountsOK = verifyRowCounts(tableNames, sourceCounts, destCounts) + if !rowCountsOK { + result.Matched = false + } + + db, err := OpenPostgres(opts.DestURI) + if err != nil { + return nil, err + } + defer db.Close() + + seqChecks, ok := verifyIdentitySequences(ctx, db, dataTables) + result.Checks = append(result.Checks, seqChecks...) + if !ok { + result.Matched = false + } + + boolChecks, ok, err := verifyBooleanColumns(ctx, db, sqlitePath, dataTables) + if err != nil { + return nil, err + } + result.Checks = append(result.Checks, boolChecks...) + if !ok { + result.Matched = false + } + + fpChecks, ok, err := verifyTableFingerprints(ctx, db, sqlitePath, dataTables) + if err != nil { + return nil, err + } + result.Checks = append(result.Checks, fpChecks...) + if !ok { + result.Matched = false + } + + sampleChecks, ok, err := verifySampleRows(ctx, db, sqlitePath, dataTables, 8, 3) + if err != nil { + return nil, err + } + result.Checks = append(result.Checks, sampleChecks...) + if !ok { + result.Matched = false + } + + if !result.Matched { + return result, newMigrationError( + ErrCodeVerifyFailed, + "import verification failed (row counts, sequences, coercion, or content checks)", + "Re-run import or inspect failing checks in verify JSON output", + ) + } + + if opts.MigrationID != "" { + if err := SetMigrationPhase(opts.Org, opts.Database, opts.Branch, opts.MigrationID, PhaseVerified); err != nil { + return result, err + } + } + + return result, nil +} + +func mapStringIntToInt64(in map[string]int) map[string]int64 { + out := make(map[string]int64, len(in)) + for k, v := range in { + out[k] = int64(v) + } + return out +} + +// CountSQLiteRows counts rows using sqlite3 CLI. +func CountSQLiteRows(ctx context.Context, sqlitePath string, tables []string) (map[string]int64, error) { + sqlite3, err := FindSQLite3() + if err != nil { + return nil, err + } + + counts := make(map[string]int64, len(tables)) + for _, table := range tables { + query := fmt.Sprintf("SELECT COUNT(*) FROM %q;", table) + cmd := execabs.CommandContext(ctx, sqlite3, sqlitePath, query) + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("sqlite count %s: %w", table, err) + } + var count int64 + if _, err := fmt.Sscanf(string(out), "%d", &count); err != nil { + return nil, err + } + counts[table] = count + } + return counts, nil +} + +// CountPostgresRows counts rows in public schema tables. +func CountPostgresRows(ctx context.Context, destURI string, tables []string) (map[string]int64, error) { + db, err := OpenPostgres(destURI) + if err != nil { + return nil, err + } + defer db.Close() + + counts := make(map[string]int64, len(tables)) + for _, table := range tables { + var count int64 + query := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, quoteIdent(table)) + if err := db.QueryRowContext(ctx, query).Scan(&count); err != nil { + return nil, fmt.Errorf("count %s: %w", table, err) + } + counts[table] = count + } + return counts, nil +} diff --git a/internal/migrate/d1/verify_checks.go b/internal/migrate/d1/verify_checks.go new file mode 100644 index 000000000..59ce6a8cc --- /dev/null +++ b/internal/migrate/d1/verify_checks.go @@ -0,0 +1,578 @@ +package d1 + +import ( + "context" + "database/sql" + "encoding/hex" + "encoding/json" + "fmt" + "strings" + + execabs "golang.org/x/sys/execabs" +) + +// VerifyCheckResult is a single post-import verification check. +type VerifyCheckResult struct { + Name string `json:"name"` + Table string `json:"table,omitempty"` + Column string `json:"column,omitempty"` + Matched bool `json:"matched"` + Message string `json:"message,omitempty"` + Source string `json:"source,omitempty"` + Dest string `json:"dest,omitempty"` +} + +type tableFingerprint struct { + RowCount int64 + IDSum int64 +} + +type booleanDistribution struct { + TrueCount int64 + FalseCount int64 + NullCount int64 +} + +func verifyRowCounts(tableNames []string, sourceCounts, destCounts map[string]int64) ([]TableVerifyResult, bool) { + results := make([]TableVerifyResult, 0, len(tableNames)) + matched := true + for _, name := range tableNames { + ok := sourceCounts[name] == destCounts[name] + if !ok { + matched = false + } + results = append(results, TableVerifyResult{ + Table: name, + SourceRows: sourceCounts[name], + DestRows: destCounts[name], + Match: ok, + }) + } + return results, matched +} + +func verifyIdentitySequences(ctx context.Context, db *sql.DB, tables []TableSchema) ([]VerifyCheckResult, bool) { + var checks []VerifyCheckResult + matched := true + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + for _, col := range table.Columns { + if !col.AutoIncrement { + continue + } + check, ok, err := verifyTableSequence(ctx, db, table.Name, col.Name) + if err != nil { + checks = append(checks, VerifyCheckResult{ + Name: "sequences", + Table: table.Name, + Column: col.Name, + Matched: false, + Message: err.Error(), + }) + matched = false + continue + } + checks = append(checks, check) + if !ok { + matched = false + } + } + } + return checks, matched +} + +func verifyTableSequence(ctx context.Context, db *sql.DB, table, column string) (VerifyCheckResult, bool, error) { + check := VerifyCheckResult{ + Name: "sequences", + Table: table, + Column: column, + } + + var maxID sql.NullInt64 + maxQuery := fmt.Sprintf(`SELECT MAX(%s) FROM %s`, quoteIdent(column), quoteIdent(table)) + if err := db.QueryRowContext(ctx, maxQuery).Scan(&maxID); err != nil { + return check, false, fmt.Errorf("max %s.%s: %w", table, column, err) + } + if !maxID.Valid { + check.Matched = true + check.Message = "empty table" + return check, true, nil + } + + var seqName sql.NullString + if err := db.QueryRowContext(ctx, + `SELECT pg_get_serial_sequence($1, $2)`, + "public."+table, + column, + ).Scan(&seqName); err != nil { + return check, false, fmt.Errorf("sequence lookup %s.%s: %w", table, column, err) + } + if !seqName.Valid || seqName.String == "" { + check.Matched = true + check.Message = "no sequence attached (non-identity column)" + return check, true, nil + } + + var lastValue int64 + var isCalled bool + seqQuery := fmt.Sprintf(`SELECT last_value, is_called FROM %s`, seqName.String) + if err := db.QueryRowContext(ctx, seqQuery).Scan(&lastValue, &isCalled); err != nil { + return check, false, fmt.Errorf("read sequence %s: %w", seqName.String, err) + } + + nextValue := lastValue + if isCalled { + nextValue = lastValue + 1 + } + ok := maxID.Int64 < nextValue + check.Matched = ok + check.Source = fmt.Sprintf("max=%d", maxID.Int64) + check.Dest = fmt.Sprintf("next=%d (last_value=%d is_called=%t)", nextValue, lastValue, isCalled) + if !ok { + check.Message = "sequence next value would collide with existing rows" + } else { + check.Message = "sequence ready for new inserts" + } + return check, ok, nil +} + +func verifyBooleanColumns(ctx context.Context, db *sql.DB, sqlitePath string, tables []TableSchema) ([]VerifyCheckResult, bool, error) { + var checks []VerifyCheckResult + matched := true + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + for _, col := range table.Columns { + if !isBooleanColumn(col) { + continue + } + src, err := sqliteBooleanDistribution(ctx, sqlitePath, table.Name, col.Name) + if err != nil { + return checks, false, err + } + dest, err := postgresBooleanDistribution(ctx, db, table.Name, col.Name) + if err != nil { + return checks, false, err + } + ok := src.TrueCount == dest.TrueCount && src.FalseCount == dest.FalseCount && src.NullCount == dest.NullCount + check := VerifyCheckResult{ + Name: "boolean_columns", + Table: table.Name, + Column: col.Name, + Matched: ok, + Source: fmt.Sprintf("true=%d false=%d null=%d", src.TrueCount, src.FalseCount, src.NullCount), + Dest: fmt.Sprintf("true=%d false=%d null=%d", dest.TrueCount, dest.FalseCount, dest.NullCount), + } + if !ok { + check.Message = "boolean value distribution mismatch after import" + matched = false + } else { + check.Message = "boolean coercion matches source 0/1 distribution" + } + checks = append(checks, check) + } + } + return checks, matched, nil +} + +func sqliteBooleanDistribution(ctx context.Context, sqlitePath, table, column string) (booleanDistribution, error) { + query := fmt.Sprintf( + `SELECT SUM(CASE WHEN %q = 1 THEN 1 ELSE 0 END), SUM(CASE WHEN %q = 0 THEN 1 ELSE 0 END), SUM(CASE WHEN %q IS NULL THEN 1 ELSE 0 END) FROM %q;`, + column, column, column, table, + ) + return querySQLiteDistribution(ctx, sqlitePath, query) +} + +func postgresBooleanDistribution(ctx context.Context, db *sql.DB, table, column string) (booleanDistribution, error) { + query := fmt.Sprintf( + `SELECT COUNT(*) FILTER (WHERE %s = TRUE), COUNT(*) FILTER (WHERE %s = FALSE), COUNT(*) FILTER (WHERE %s IS NULL) FROM %s`, + quoteIdent(column), quoteIdent(column), quoteIdent(column), quoteIdent(table), + ) + var dist booleanDistribution + if err := db.QueryRowContext(ctx, query).Scan(&dist.TrueCount, &dist.FalseCount, &dist.NullCount); err != nil { + return dist, err + } + return dist, nil +} + +func querySQLiteDistribution(ctx context.Context, sqlitePath, query string) (booleanDistribution, error) { + sqlite3, err := FindSQLite3() + if err != nil { + return booleanDistribution{}, err + } + out, err := runSQLiteQuery(ctx, sqlite3, sqlitePath, query) + if err != nil { + return booleanDistribution{}, err + } + parts := parseSQLiteCLIFields(out) + if len(parts) < 3 { + return booleanDistribution{}, fmt.Errorf("unexpected boolean count output: %q", string(out)) + } + var dist booleanDistribution + for i, ptr := range []*int64{&dist.TrueCount, &dist.FalseCount, &dist.NullCount} { + if parts[i] == "" || parts[i] == "NULL" { + continue + } + if _, err := fmt.Sscanf(parts[i], "%d", ptr); err != nil { + return booleanDistribution{}, err + } + } + return dist, nil +} + +// parseSQLiteCLIFields splits sqlite3 CLI output. Multi-column results use '|'. +func parseSQLiteCLIFields(out []byte) []string { + s := strings.TrimSpace(string(out)) + if s == "" { + return nil + } + if strings.Contains(s, "|") { + parts := strings.Split(s, "|") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts + } + return strings.Fields(s) +} + +func verifyTableFingerprints(ctx context.Context, db *sql.DB, sqlitePath string, tables []TableSchema) ([]VerifyCheckResult, bool, error) { + var checks []VerifyCheckResult + matched := true + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + pkCol := identityColumn(table) + src, err := tableFingerprintFromSQLite(ctx, sqlitePath, table, pkCol, tables) + if err != nil { + return checks, false, err + } + dest, err := tableFingerprintFromPostgres(ctx, db, table, pkCol, tables) + if err != nil { + return checks, false, err + } + ok := src.RowCount == dest.RowCount && src.IDSum == dest.IDSum + check := VerifyCheckResult{ + Name: "table_fingerprint", + Table: table.Name, + Matched: ok, + Source: fmt.Sprintf("rows=%d id_sum=%d", src.RowCount, src.IDSum), + Dest: fmt.Sprintf("rows=%d id_sum=%d", dest.RowCount, dest.IDSum), + } + if !ok { + check.Message = "aggregate fingerprint mismatch" + matched = false + } else if shouldFingerprintPKSum(table, pkCol, tables) { + check.Message = "row count and integer PK sum match" + } else { + check.Message = "row count match" + } + checks = append(checks, check) + } + return checks, matched, nil +} + +func identityColumn(table TableSchema) string { + for _, col := range table.Columns { + if col.AutoIncrement { + return col.Name + } + } + for _, col := range table.Columns { + if col.PrimaryKey { + return col.Name + } + } + return "" +} + +func shouldFingerprintPKSum(table TableSchema, pkCol string, all []TableSchema) bool { + if pkCol == "" { + return false + } + col := columnByName(table, pkCol) + if col.Name == "" { + return false + } + if isUUIDColumn(col, table, all) { + return false + } + upper := strings.ToUpper(col.Type) + return col.AutoIncrement || strings.Contains(upper, "INT") +} + +func tableFingerprintFromSQLite(ctx context.Context, sqlitePath string, table TableSchema, pkCol string, all []TableSchema) (tableFingerprint, error) { + var query string + if pkCol != "" && shouldFingerprintPKSum(table, pkCol, all) { + query = fmt.Sprintf(`SELECT COUNT(*), COALESCE(SUM(CAST(%q AS INTEGER)), 0) FROM %q;`, pkCol, table.Name) + } else { + query = fmt.Sprintf(`SELECT COUNT(*), 0 FROM %q;`, table.Name) + } + sqlite3, err := FindSQLite3() + if err != nil { + return tableFingerprint{}, err + } + out, err := runSQLiteQuery(ctx, sqlite3, sqlitePath, query) + if err != nil { + return tableFingerprint{}, fmt.Errorf("sqlite fingerprint %s: %w", table.Name, err) + } + var fp tableFingerprint + fields := parseSQLiteCLIFields(out) + if len(fields) < 2 { + return tableFingerprint{}, fmt.Errorf("sqlite fingerprint %s: unexpected output %q", table.Name, string(out)) + } + if _, err := fmt.Sscanf(fields[0], "%d", &fp.RowCount); err != nil { + return tableFingerprint{}, fmt.Errorf("sqlite fingerprint %s row count: %w", table.Name, err) + } + if _, err := fmt.Sscanf(fields[1], "%d", &fp.IDSum); err != nil { + return tableFingerprint{}, fmt.Errorf("sqlite fingerprint %s id sum: %w", table.Name, err) + } + return fp, nil +} + +func tableFingerprintFromPostgres(ctx context.Context, db *sql.DB, table TableSchema, pkCol string, all []TableSchema) (tableFingerprint, error) { + var fp tableFingerprint + var query string + if pkCol != "" && shouldFingerprintPKSum(table, pkCol, all) { + query = fmt.Sprintf(`SELECT COUNT(*), COALESCE(SUM(%s::bigint), 0) FROM %s`, quoteIdent(pkCol), quoteIdent(table.Name)) + } else { + query = fmt.Sprintf(`SELECT COUNT(*), 0 FROM %s`, quoteIdent(table.Name)) + } + if err := db.QueryRowContext(ctx, query).Scan(&fp.RowCount, &fp.IDSum); err != nil { + return fp, fmt.Errorf("postgres fingerprint %s: %w", table.Name, err) + } + return fp, nil +} + +func verifySampleRows(ctx context.Context, db *sql.DB, sqlitePath string, tables []TableSchema, maxTables, samplesPerTable int) ([]VerifyCheckResult, bool, error) { + var checks []VerifyCheckResult + matched := true + checked := 0 + + for _, table := range tables { + if IsORMMetadataTable(table.Name) { + continue + } + if checked >= maxTables { + break + } + pkCol := identityColumn(table) + if pkCol == "" { + continue + } + ids, err := samplePrimaryKeys(ctx, sqlitePath, table.Name, pkCol, samplesPerTable) + if err != nil { + return checks, false, err + } + if len(ids) == 0 { + continue + } + checked++ + + for _, id := range ids { + src, err := sqliteRowSignature(ctx, sqlitePath, table, pkCol, id) + if err != nil { + return checks, false, err + } + dest, err := postgresRowSignature(ctx, db, table, pkCol, id, tables) + if err != nil { + return checks, false, err + } + ok := rowSignaturesMatch(src, dest, table, tables) + check := VerifyCheckResult{ + Name: "sample_rows", + Table: table.Name, + Column: pkCol, + Matched: ok, + Source: src, + Dest: dest, + } + if !ok { + check.Message = fmt.Sprintf("row signature mismatch for %s=%s", pkCol, id) + matched = false + } else { + check.Message = fmt.Sprintf("row signature match for %s=%s", pkCol, id) + } + checks = append(checks, check) + } + } + return checks, matched, nil +} + +func samplePrimaryKeys(ctx context.Context, sqlitePath, table, pkCol string, limit int) ([]string, error) { + sqlite3, err := FindSQLite3() + if err != nil { + return nil, err + } + query := fmt.Sprintf(`SELECT %q FROM %q ORDER BY %q LIMIT %d;`, pkCol, table, pkCol, limit) + out, err := runSQLiteQuery(ctx, sqlite3, sqlitePath, query) + if err != nil { + return nil, err + } + lines := strings.Split(strings.TrimSpace(string(out)), "\n") + var ids []string + for _, line := range lines { + line = strings.TrimSpace(line) + if line != "" { + ids = append(ids, line) + } + } + return ids, nil +} + +func sqliteSignatureColumnExpr(col ColumnSchema) string { + if isBooleanColumn(col) { + return fmt.Sprintf(`CASE WHEN %q IN (1, '1') THEN '1' WHEN %q IN (0, '0') THEN '0' ELSE '' END`, col.Name, col.Name) + } + if isJSONText(col) { + return fmt.Sprintf(`COALESCE(json(%q), CAST(%q AS TEXT), '')`, col.Name, col.Name) + } + return fmt.Sprintf(`COALESCE(CAST(%q AS TEXT), '')`, col.Name) +} + +func postgresSignatureColumnExpr(col ColumnSchema, table TableSchema, all []TableSchema) string { + pgType := sqliteTypeToPostgres(col, table, all) + switch pgType { + case "BOOLEAN": + name := quoteIdent(col.Name) + return fmt.Sprintf(`CASE WHEN %s IS TRUE THEN '1' WHEN %s IS FALSE THEN '0' ELSE '' END`, name, name) + case "TIMESTAMPTZ": + name := quoteIdent(col.Name) + return fmt.Sprintf(`COALESCE(to_char(%s AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS"Z"'), '')`, name) + case "JSONB": + name := quoteIdent(col.Name) + return fmt.Sprintf(`COALESCE(%s::jsonb::text, '')`, name) + case "BYTEA": + name := quoteIdent(col.Name) + return fmt.Sprintf(`COALESCE(convert_from(%s, 'UTF8'), '')`, name) + default: + return fmt.Sprintf(`COALESCE(%s::text, '')`, quoteIdent(col.Name)) + } +} + +func rowSignaturesMatch(src, dest string, table TableSchema, all []TableSchema) bool { + srcParts := strings.Split(src, "|") + destParts := strings.Split(dest, "|") + if len(srcParts) != len(destParts) || len(srcParts) != len(table.Columns) { + return src == dest + } + for i, col := range table.Columns { + pgType := sqliteTypeToPostgres(col, table, all) + switch pgType { + case "JSONB": + if !jsonValuesEqual(srcParts[i], destParts[i]) { + return false + } + case "BYTEA": + if !byteaValuesEqual(srcParts[i], destParts[i]) { + return false + } + default: + if srcParts[i] != destParts[i] { + return false + } + } + } + return true +} + +func jsonValuesEqual(a, b string) bool { + ca, errA := canonicalJSON(a) + cb, errB := canonicalJSON(b) + if errA != nil || errB != nil { + return a == b + } + return ca == cb +} + +func canonicalJSON(s string) (string, error) { + if s == "" { + return "", nil + } + var v any + if err := json.Unmarshal([]byte(s), &v); err != nil { + return "", err + } + b, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(b), nil +} + +func byteaValuesEqual(sqliteText, pgText string) bool { + if sqliteText == pgText { + return true + } + if strings.HasPrefix(pgText, `\x`) { + decoded, err := hex.DecodeString(strings.TrimPrefix(pgText, `\x`)) + if err == nil && sqliteText == string(decoded) { + return true + } + } + return false +} + +func sqliteRowSignature(ctx context.Context, sqlitePath string, table TableSchema, pkCol, pkVal string) (string, error) { + cols := make([]string, 0, len(table.Columns)) + for _, col := range table.Columns { + cols = append(cols, sqliteSignatureColumnExpr(col)) + } + query := fmt.Sprintf( + `SELECT %s FROM %q WHERE %q = %s LIMIT 1;`, + strings.Join(cols, " || '|' || "), + table.Name, + pkCol, + sqliteLiteral(pkVal), + ) + sqlite3, err := FindSQLite3() + if err != nil { + return "", err + } + out, err := runSQLiteQuery(ctx, sqlite3, sqlitePath, query) + if err != nil { + return "", err + } + return strings.TrimSpace(string(out)), nil +} + +func postgresRowSignature(ctx context.Context, db *sql.DB, table TableSchema, pkCol, pkVal string, all []TableSchema) (string, error) { + cols := make([]string, 0, len(table.Columns)) + for _, col := range table.Columns { + cols = append(cols, postgresSignatureColumnExpr(col, table, all)) + } + query := fmt.Sprintf( + `SELECT %s FROM %s WHERE %s = $1 LIMIT 1`, + strings.Join(cols, " || '|' || "), + quoteIdent(table.Name), + quoteIdent(pkCol), + ) + var sig sql.NullString + if err := db.QueryRowContext(ctx, query, pkVal).Scan(&sig); err != nil { + return "", err + } + if !sig.Valid { + return "", fmt.Errorf("row not found in %s where %s = %s", table.Name, pkCol, pkVal) + } + return sig.String, nil +} + +func sqliteLiteral(val string) string { + var n int64 + if _, err := fmt.Sscanf(val, "%d", &n); err == nil { + return val + } + return "'" + strings.ReplaceAll(val, "'", "''") + "'" +} + +func runSQLiteQuery(ctx context.Context, sqlite3, sqlitePath, query string) ([]byte, error) { + cmd := execabs.CommandContext(ctx, sqlite3, sqlitePath, query) + return cmd.Output() +} diff --git a/internal/migrate/d1/verify_checks_test.go b/internal/migrate/d1/verify_checks_test.go new file mode 100644 index 000000000..6b694bc24 --- /dev/null +++ b/internal/migrate/d1/verify_checks_test.go @@ -0,0 +1,111 @@ +package d1 + +import ( + "encoding/hex" + "testing" +) + +func TestVerifyRowCounts(t *testing.T) { + source := map[string]int64{"users": 2, "posts": 2} + dest := map[string]int64{"users": 2, "posts": 1} + + results, ok := verifyRowCounts([]string{"users", "posts"}, source, dest) + if ok { + t.Fatal("expected mismatch") + } + if len(results) != 2 { + t.Fatalf("expected 2 table results, got %d", len(results)) + } + if !results[0].Match || results[1].Match { + t.Fatalf("unexpected match flags: %+v", results) + } +} + +func TestColumnReferencesUUIDKey(t *testing.T) { + tables, err := ParseDump(testFixture(t)) + if err != nil { + t.Fatalf("ParseDump: %v", err) + } + + var entityLinks TableSchema + for _, table := range tables { + if table.Name == "entity_links" { + entityLinks = table + break + } + } + if entityLinks.Name == "" { + t.Fatal("missing entity_links table") + } + + var entityID ColumnSchema + for _, col := range entityLinks.Columns { + if col.Name == "entity_id" { + entityID = col + break + } + } + if entityID.Name == "" { + t.Fatal("missing entity_id column") + } + if !columnReferencesUUIDKey(entityID, entityLinks, tables) { + t.Fatal("expected entity_id to reference UUID primary key") + } + if isExplicitUUIDColumn(entityID) { + t.Fatal("entity_id should not be treated as explicit UUID column") + } +} + +func TestLooksLikeRailsSchemaMigrations(t *testing.T) { + rails := TableSchema{ + Name: "schema_migrations", + Columns: []ColumnSchema{{ + Name: "version", + Type: "VARCHAR(255)", + }}, + } + if !looksLikeRailsSchemaMigrations(rails) { + t.Fatal("expected rails-like schema_migrations") + } + + appTable := TableSchema{ + Name: "schema_migrations", + Columns: []ColumnSchema{ + {Name: "id", Type: "INTEGER", PrimaryKey: true}, + {Name: "name", Type: "TEXT"}, + }, + } + if looksLikeRailsSchemaMigrations(appTable) { + t.Fatal("expected app schema_migrations to differ from rails layout") + } +} + +func TestParseSQLiteCLIFields(t *testing.T) { + got := parseSQLiteCLIFields([]byte("120|0|0\n")) + if len(got) != 3 || got[0] != "120" || got[1] != "0" || got[2] != "0" { + t.Fatalf("parseSQLiteCLIFields() = %v", got) + } + got = parseSQLiteCLIFields([]byte("94400 123456\n")) + if len(got) != 2 || got[0] != "94400" { + t.Fatalf("parseSQLiteCLIFields() = %v", got) + } +} + +func TestJSONValuesEqual(t *testing.T) { + a := `{"priority": 0, "labels": ["seed"]}` + b := `{"labels": ["seed"], "priority": 0}` + if !jsonValuesEqual(a, b) { + t.Fatal("expected equivalent JSON objects to match") + } + if jsonValuesEqual(a, `{"priority": 1}`) { + t.Fatal("expected different JSON objects to mismatch") + } +} + +func TestByteaValuesEqual(t *testing.T) { + text := "attachment-1-payload" + hex := `\x` + hex.EncodeToString([]byte(text)) + if !byteaValuesEqual(text, hex) { + t.Fatalf("expected bytea hex %q to match text %q", hex, text) + } +} diff --git a/internal/postgres/postgres.go b/internal/postgres/postgres.go new file mode 100644 index 000000000..f3c7a86c4 --- /dev/null +++ b/internal/postgres/postgres.go @@ -0,0 +1,231 @@ +// Package postgres provides PostgreSQL connection utilities. +package postgres + +import ( + "database/sql" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + _ "github.com/jackc/pgx/v5/stdlib" +) + +type Config struct { + Host string + Port int + User string + Password string + Database string + SSLMode string + Options map[string]string +} + +// ParseConnectionURI supports both URI and keyword/value formats. +func ParseConnectionURI(uri string) (*Config, error) { + // Handle postgresql:// or postgres:// URIs + if strings.HasPrefix(uri, "postgresql://") || strings.HasPrefix(uri, "postgres://") { + return parseURIFormat(uri) + } + + // Handle keyword/value format (host=localhost port=5432 ...) + return parseKeyValueFormat(uri) +} + +func parseURIFormat(uri string) (*Config, error) { + u, err := url.Parse(uri) + if err != nil { + return nil, fmt.Errorf("invalid connection URI: %w", err) + } + + cfg := &Config{ + Host: u.Hostname(), + Port: 5432, + Options: make(map[string]string), + } + + if portStr := u.Port(); portStr != "" { + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port: %w", err) + } + cfg.Port = port + } + + if u.User != nil { + cfg.User = u.User.Username() + cfg.Password, _ = u.User.Password() + } + + // Database name from path without leading / + cfg.Database = strings.TrimPrefix(u.Path, "/") + + for key, values := range u.Query() { + if len(values) > 0 { + switch key { + case "sslmode": + cfg.SSLMode = values[0] + default: + cfg.Options[key] = values[0] + } + } + } + + if cfg.SSLMode == "" { + cfg.SSLMode = "require" + } + + return cfg, nil +} + +func parseKeyValueFormat(connStr string) (*Config, error) { + cfg := &Config{ + Port: 5432, + SSLMode: "require", + Options: make(map[string]string), + } + + for _, pair := range strings.Fields(connStr) { + parts := strings.SplitN(pair, "=", 2) + if len(parts) != 2 { + continue + } + key := parts[0] + value := strings.Trim(parts[1], "'\"") + + switch key { + case "host": + cfg.Host = value + case "port": + port, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("invalid port: %w", err) + } + cfg.Port = port + case "user": + cfg.User = value + case "password": + cfg.Password = value + case "dbname": + cfg.Database = value + case "sslmode": + cfg.SSLMode = value + default: + cfg.Options[key] = value + } + } + + return cfg, nil +} + +func BuildConnectionString(cfg *Config) string { + var parts []string + + if cfg.Host != "" { + parts = append(parts, fmt.Sprintf("host=%s", cfg.Host)) + } + if cfg.Port != 0 { + parts = append(parts, fmt.Sprintf("port=%d", cfg.Port)) + } + if cfg.User != "" { + parts = append(parts, fmt.Sprintf("user=%s", cfg.User)) + } + if cfg.Password != "" { + parts = append(parts, fmt.Sprintf("password=%s", quoteValue(cfg.Password))) + } + if cfg.Database != "" { + parts = append(parts, fmt.Sprintf("dbname=%s", cfg.Database)) + } + if cfg.SSLMode != "" { + parts = append(parts, fmt.Sprintf("sslmode=%s", cfg.SSLMode)) + } + + for key, value := range cfg.Options { + parts = append(parts, fmt.Sprintf("%s=%s", key, quoteValue(value))) + } + + return strings.Join(parts, " ") +} + +// BuildConnectionURI returns a postgresql:// URI suitable for pgloader. +func BuildConnectionURI(cfg *Config) string { + host := cfg.Host + if cfg.Port != 0 { + host = fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + } + + u := &url.URL{ + Scheme: "postgresql", + Host: host, + Path: "/" + cfg.Database, + } + + if cfg.User != "" { + if cfg.Password != "" { + u.User = url.UserPassword(cfg.User, cfg.Password) + } else { + u.User = url.User(cfg.User) + } + } + + q := url.Values{} + if cfg.SSLMode != "" { + q.Set("sslmode", cfg.SSLMode) + } + for key, value := range cfg.Options { + q.Set(key, value) + } + u.RawQuery = q.Encode() + + return u.String() +} + +func quoteValue(s string) string { + if strings.ContainsAny(s, " '\"\\") { + return "'" + strings.ReplaceAll(s, "'", "\\'") + "'" + } + return s +} + +// OpenConnection opens a PostgreSQL connection with sensible defaults. +func OpenConnection(connStr string) (*sql.DB, error) { + db, err := sql.Open("pgx", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open connection: %w", err) + } + + db.SetMaxOpenConns(5) + db.SetMaxIdleConns(2) + db.SetConnMaxLifetime(5 * time.Minute) + + return db, nil +} + +// QuoteIdentifier escapes a PostgreSQL identifier. +func QuoteIdentifier(name string) string { + return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` +} + +func RedactPassword(connStr string) string { + if strings.HasPrefix(connStr, "postgresql://") || strings.HasPrefix(connStr, "postgres://") { + u, err := url.Parse(connStr) + if err == nil && u.User != nil { + if _, hasPass := u.User.Password(); hasPass { + u.User = url.UserPassword(u.User.Username(), "****") + return u.String() + } + } + return connStr + } + + var result []string + for _, pair := range strings.Fields(connStr) { + if strings.HasPrefix(pair, "password=") { + result = append(result, "password=****") + } else { + result = append(result, pair) + } + } + return strings.Join(result, " ") +} diff --git a/internal/postgres/postgres_test.go b/internal/postgres/postgres_test.go new file mode 100644 index 000000000..8f5803e92 --- /dev/null +++ b/internal/postgres/postgres_test.go @@ -0,0 +1,210 @@ +package postgres + +import ( + "testing" +) + +func TestParseConnectionURI(t *testing.T) { + tests := []struct { + name string + uri string + want *Config + wantErr bool + }{ + { + name: "basic uri", + uri: "postgresql://user:pass@localhost:5432/mydb", + want: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "pass", + Database: "mydb", + SSLMode: "require", + Options: make(map[string]string), + }, + }, + { + name: "uri with sslmode", + uri: "postgresql://user:pass@localhost:5432/mydb?sslmode=disable", + want: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "pass", + Database: "mydb", + SSLMode: "disable", + Options: make(map[string]string), + }, + }, + { + name: "uri without password", + uri: "postgresql://user@localhost:5432/mydb", + want: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "", + Database: "mydb", + SSLMode: "require", + Options: make(map[string]string), + }, + }, + { + name: "key-value format", + uri: "host=localhost port=5432 dbname=mydb", + want: &Config{ + Host: "localhost", + Port: 5432, + Database: "mydb", + SSLMode: "require", + Options: make(map[string]string), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseConnectionURI(tt.uri) + if (err != nil) != tt.wantErr { + t.Errorf("ParseConnectionURI() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if got.Host != tt.want.Host { + t.Errorf("Host = %v, want %v", got.Host, tt.want.Host) + } + if got.Port != tt.want.Port { + t.Errorf("Port = %v, want %v", got.Port, tt.want.Port) + } + if got.User != tt.want.User { + t.Errorf("User = %v, want %v", got.User, tt.want.User) + } + if got.Password != tt.want.Password { + t.Errorf("Password = %v, want %v", got.Password, tt.want.Password) + } + if got.Database != tt.want.Database { + t.Errorf("Database = %v, want %v", got.Database, tt.want.Database) + } + if got.SSLMode != tt.want.SSLMode { + t.Errorf("SSLMode = %v, want %v", got.SSLMode, tt.want.SSLMode) + } + }) + } +} + +func TestBuildConnectionString(t *testing.T) { + tests := []struct { + name string + cfg *Config + want string + }{ + { + name: "basic config", + cfg: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Password: "pass", + Database: "mydb", + SSLMode: "require", + Options: make(map[string]string), + }, + want: "host=localhost port=5432 user=user password=pass dbname=mydb sslmode=require", + }, + { + name: "config without password", + cfg: &Config{ + Host: "localhost", + Port: 5432, + User: "user", + Database: "mydb", + SSLMode: "disable", + Options: make(map[string]string), + }, + want: "host=localhost port=5432 user=user dbname=mydb sslmode=disable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := BuildConnectionString(tt.cfg) + if got != tt.want { + t.Errorf("BuildConnectionString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRedactPassword(t *testing.T) { + tests := []struct { + name string + connStr string + want string + }{ + { + name: "with password", + connStr: "host=localhost port=5432 user=user password=secret dbname=mydb", + want: "host=localhost port=5432 user=user password=**** dbname=mydb", + }, + { + name: "without password", + connStr: "host=localhost port=5432 user=user dbname=mydb", + want: "host=localhost port=5432 user=user dbname=mydb", + }, + { + name: "empty string", + connStr: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := RedactPassword(tt.connStr) + if got != tt.want { + t.Errorf("RedactPassword() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestQuoteIdentifier(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + { + name: "simple identifier", + id: "mytable", + want: `"mytable"`, + }, + { + name: "identifier with quotes", + id: `table"name`, + want: `"table""name"`, + }, + { + name: "identifier with multiple quotes", + id: `my"table"name`, + want: `"my""table""name"`, + }, + { + name: "empty string", + id: "", + want: `""`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := QuoteIdentifier(tt.id) + if got != tt.want { + t.Errorf("QuoteIdentifier() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/postgres/psql.go b/internal/postgres/psql.go new file mode 100644 index 000000000..3149841be --- /dev/null +++ b/internal/postgres/psql.go @@ -0,0 +1,67 @@ +package postgres + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + exec "golang.org/x/sys/execabs" +) + +var psqlVersionRegex = regexp.MustCompile(`psql \(PostgreSQL\) (\d+)\.?(\d*)`) + +// FindPsqlPath locates a PostgreSQL psql client on PATH. +func FindPsqlPath() (string, error) { + for _, cmd := range []string{"psql-18", "psql-17", "psql-16", "psql-15", "psql"} { + path, err := exec.LookPath(cmd) + if err != nil { + continue + } + c := exec.Command(path, "--version") + out, err := c.Output() + if err != nil { + continue + } + if strings.Contains(string(out), "PostgreSQL") { + return path, nil + } + } + + return "", fmt.Errorf("couldn't find the 'psql' command-line tool required for PostgreSQL imports.\n" + + "To install, run: brew install postgresql@18") +} + +// CheckPsqlVersion verifies psql meets a minimum major version. +func CheckPsqlVersion(minMajor int) (major, minor int, err error) { + path, err := FindPsqlPath() + if err != nil { + return 0, 0, err + } + + c := exec.Command(path, "--version") + out, err := c.Output() + if err != nil { + return 0, 0, fmt.Errorf("failed to get psql version: %w", err) + } + + matches := psqlVersionRegex.FindStringSubmatch(string(out)) + if len(matches) < 2 { + return 0, 0, fmt.Errorf("could not parse psql version from: %s", string(out)) + } + + major, err = strconv.Atoi(matches[1]) + if err != nil { + return 0, 0, fmt.Errorf("could not parse psql major version: %w", err) + } + + if len(matches) > 2 && matches[2] != "" { + minor, _ = strconv.Atoi(matches[2]) + } + + if major < minMajor { + return major, minor, fmt.Errorf("psql version %d.%d is too old, minimum required is %d", major, minor, minMajor) + } + + return major, minor, nil +} diff --git a/script/d1-import-test/README.md b/script/d1-import-test/README.md new file mode 100644 index 000000000..f4126b11b --- /dev/null +++ b/script/d1-import-test/README.md @@ -0,0 +1,160 @@ +# D1 `import-test` stress database + +Synthetic schema + seed data for exercising `pscale import d1`. + +## CLI import test (recommended) + +Use **`run-cli-import.sh`** to exercise the full `pscale import d1` pipeline against an export file that is already on disk. This is the script to use for import timing and verification. + +**Smoke (~2 MB export):** + +```bash +wrangler d1 export import-test --remote --output /tmp/import-test-export.sql +IMPORT_PROFILE=smoke SKIP_DB_CREATE=true ./script/d1-import-test/run-cli-import.sh +``` + +**9 GB export (fresh database each run — default):** + +```bash +IMPORT_PROFILE=9gb ./script/d1-import-test/run-local-import.sh +``` + +**Reuse same DB name but wipe it first:** + +```bash +FRESH_DB=recreate PSCALE_DB=cf-d1-import-9gb ./script/d1-import-test/run-local-import.sh +``` + +**Provision DB + CLI import:** + +```bash +IMPORT_PROFILE=smoke ./script/d1-import-test/run-local-import.sh +``` + +| Variable | Default | Purpose | +|----------|---------|---------| +| `IMPORT_PROFILE` | `smoke` | `smoke` or `9gb` (sets export path + DB name) | +| `D1_EXPORT` | profile default | Path to wrangler SQL export | +| `PSCALE_DB` | profile default | Target PlanetScale Postgres database | +| `PSCALE_ORG` | `bb` | Organization | +| `IMPORT_METHOD` | `pgloader` | `pgloader` or `psql` | +| `IMPORT_RUN_DIR` | `/tmp/d1-cli-import-…` | JSON artifacts (preview/start/verify) | +| `FRESH_DB` | `new` | `new` (timestamped DB), `recreate` (delete+create same name), or `reuse` (reset public schema in place) | +| `SKIP_DB_CREATE` | — | Deprecated; use `FRESH_DB` instead | + +The start JSON includes `data.timings` (total, schema, pgloader per-table) when built from current `pscale-test`. + +## D1 dataset prep (separate from CLI import) + +Load remote D1 and/or time wrangler export — **not** the PlanetScale import: + +```bash +./script/d1-import-test/load-bulk.sh +./script/d1-import-test/time-export.sh + +# Optional: D1 prep then CLI import in one go +RUN_CLI_IMPORT=true SKIP_DB_CREATE=true ./script/d1-import-test/run-9gb-benchmark.sh +``` + +## Load (remote) + +**Smoke / small seeds** — one `wrangler d1 execute` per batch file: + +```bash +./script/d1-import-test/load.sh +``` + +**Multi‑MB / ~9 GB seeds** — merge batches into ~50 MB chunks first (Option B): + +```bash +./script/d1-import-test/load-bulk.sh +# or fresh 9 GB target: +D1_FRESH=true SEED_TARGET_GB=9 ./script/d1-import-test/load-bulk.sh +``` + +`load-bulk.sh` runs `generate_seed.py`, then `merge_seed_chunks.py` (concatenates complete SQL statements into `seed/chunks/chunk_*.sql`), then one `wrangler d1 execute --file` per chunk. D1 is single-threaded per database, so chunks upload sequentially, but ~180 wrangler calls for 9 GB beats ~188k per-row files. + +| Variable | Default | Purpose | +|----------|---------|---------| +| `CHUNK_TARGET_MB` | `50` | Target size per merged chunk file | +| `SEED_MAX_STATEMENT_BYTES` | `100000` | D1 max per SQL statement (do not exceed) | + +Set `D1_DATABASE=import-test` and `D1_REMOTE=true` (default). + +### Volume + +By default the seed generator targets **~9 GB** of data (under D1's 10 GB cap), mostly via 1 MiB attachment blobs. + +Quick local smoke test (~2 MB): + +```bash +SEED_TARGET_GB=0.002 ./script/d1-import-test/load.sh +``` + +Moderate import e2e (~30 MB, D1-safe blob batches): + +```bash +SEED_TARGET_GB=0.03 ./script/d1-import-test/load.sh +``` + +Blob INSERT batches are capped at ~100 KB per statement (`SEED_MAX_STATEMENT_BYTES`) to avoid D1 `SQLITE_TOOBIG`. + +Tune with: + +| Variable | Default | Purpose | +|----------|---------|---------| +| `SEED_TARGET_GB` | `9` | Approximate total DB size | +| `SEED_PAYLOAD_BYTES` | `1048576` | Attachment blob size (bytes) | +| `SEED_RESERVED_BYTES` | `536870912` | Headroom for non-blob rows | +| `SEED_TASKS_PER_PROJECT` | `16` | Tasks per project | +| `SEED_PROJECTS_PER_ORG` | `30` | Projects per org | + +A full ~9 GB remote load: use `load-bulk.sh` (~200 chunk uploads). Avoid `load.sh` at that scale (one wrangler call per batch file). + +## Reload from scratch + +```bash +wrangler d1 execute import-test --remote --file=script/d1-import-test/reset.sql +./script/d1-import-test/load.sh +``` + +## Export + lint + preview + +```bash +wrangler d1 export import-test --remote --output /tmp/import-test-export.sql +pscale import d1 lint --input /tmp/import-test-export.sql --format json +pscale import d1 start --input /tmp/import-test-export.sql --org ... --database ... --branch ... --dry-run --force --format json +``` + +The dry-run returns a `migration_id` and full import plan without loading Postgres. Run `start` again without `--dry-run` to import. + +## Schema coverage (31 tables) + +| Feature | Tables / columns | +|---------|------------------| +| Autoincrement PKs | most application tables | +| TEXT/UUID primary keys | `external_entities.id` | +| UUID foreign keys | `entity_links.entity_id` (table-level FK) | +| 0/1 booleans | `is_active`, `is_admin`, `is_public`, `is_read`, … | +| TEXT timestamps | `created_at`, `updated_at`, `due_at`, … | +| JSON in TEXT | `settings`, `metadata`, `profile_json`, `payload`, … | +| Multi-level FKs | org → user → project → task → comment → attachment | +| Junction tables | `team_members`, `project_tags`, `task_dependencies`, `entity_links` | +| Table-level FKs | composite PK tables, `entity_links`, … | +| Self-referential FKs | `tasks.parent_task_id`, `comments.parent_comment_id` | +| REAL columns | `invoices.subtotal`, `invoice_line_items.unit_price` | +| BLOB columns | `attachments.payload` (primary bulk data for large targets) | +| CHECK constraints | `tasks.status`, `tasks.priority` | +| UNIQUE constraints | column + composite (`projects`, `tags`, `external_entities`) | +| Secondary indexes | 12 indexes including composite + unique | +| ORM migration metadata | `__drizzle_*`, `_prisma_migrations`, Knex, Sequelize, Rails, Flyway, Liquibase, Django, Alembic, TypeORM, Goose (**skipped on import**) | + +Default relational seed (before blob budget): 25 users/org, 40 projects/org, 20 tasks/project, 2 comments/task. + +FTS5 virtual tables are **not** included — they cause `ParseDump` to fail on export. + +## Default seed volume (~9 GB) + +~8.5k attachments × 1 MiB blobs + relational rows. See `seed/SUMMARY.json` after generation. + +Edit `SEED_TARGET_GB` or constants in `generate_seed.py` to scale (respect D1 10 GB cap). diff --git a/script/d1-import-test/bench-watch-imports.sh b/script/d1-import-test/bench-watch-imports.sh new file mode 100755 index 000000000..97f96f97e --- /dev/null +++ b/script/d1-import-test/bench-watch-imports.sh @@ -0,0 +1,200 @@ +#!/usr/bin/env bash +# Watch a benchmark state dir and start imports as soon as export + DB are ready. +# +# Usage: +# BENCH_STATE_DIR=/tmp/d1-bench-20260625-204823 ./script/d1-import-test/bench-watch-imports.sh +# SIZES="5 9" BENCH_STATE_DIR=... ./script/d1-import-test/bench-watch-imports.sh +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLI="$(cd "$ROOT/../.." && pwd)" +STATE_DIR="${BENCH_STATE_DIR:-}" +POLL_SEC="${POLL_SEC:-15}" +SIZES=(${SIZES:-1 5 9}) +LOG="${BENCH_WATCH_LOG:-${STATE_DIR}/watch.log}" + +if [[ -z "$STATE_DIR" || ! -d "$STATE_DIR" ]]; then + echo "ERROR: BENCH_STATE_DIR must point to an existing state directory" >&2 + exit 1 +fi + +export PSCALE_DISABLE_DEV_WARNING=true +export PSCALE_TEST_MODE=1 +ORG="${PSCALE_ORG:-bb}" +export PSCALE_ORG="$ORG" + +PSCALE="${CLI}/pscale-test" +API_URL="${PSCALE_API_URL:-http://api.pscaledev.com:3000/v1}" +BRANCH="${PSCALE_BRANCH:-main}" + +mkdir -p "$STATE_DIR" +exec >> "$LOG" 2>&1 + +echo "==============================================" +echo "Import watcher started $(date -u +%Y-%m-%dT%H:%M:%SZ)" +echo "State: $STATE_DIR" +echo "Sizes: ${SIZES[*]}" +echo "Poll: ${POLL_SEC}s" +echo "==============================================" + +export_file_for() { + local size="$1" + case "$size" in + 9) echo "${D1_EXPORT_9GB:-/tmp/import-test-9gb-export.sql}" ;; + *) echo "/tmp/import-test-${size}gb-export.sql" ;; + esac +} + +db_name_for() { + local size="$1" + cat "$STATE_DIR/db-${size}gb.name" +} + +branch_ready() { + local db="$1" + local ready + ready="$("$PSCALE" --api-url "$API_URL" branch show "$db" "$BRANCH" --format json --org "$ORG" 2>/dev/null | python3 -c "import json,sys; print(json.load(sys.stdin).get('ready', False))" 2>/dev/null || echo False)" + [[ "$ready" == "True" ]] +} + +export_ready() { + local size="$1" + [[ -f "$STATE_DIR/export-${size}gb.ready" ]] +} + +import_done() { + local size="$1" + [[ -f "$STATE_DIR/import-${size}gb.done" ]] +} + +import_failed() { + local size="$1" + [[ -f "$STATE_DIR/import-${size}gb.failed" ]] +} + +import_running() { + local size="$1" + local pid_file="$STATE_DIR/import-${size}gb.pid" + if [[ ! -f "$pid_file" ]]; then + return 1 + fi + local pid + pid="$(cat "$pid_file")" + kill -0 "$pid" 2>/dev/null +} + +start_import() { + local size="$1" + local db export run_dir profile log pid + + db="$(db_name_for "$size")" + export="$(export_file_for "$size")" + profile="${size}gb" + run_dir="$STATE_DIR/import-${size}gb-$(date +%Y%m%d-%H%M%S)" + log="$STATE_DIR/import-${size}gb.log" + + if [[ ! -f "$export" ]]; then + echo "==> [watch ${size}gb] export file missing: $export" + return 1 + fi + + echo "==> [watch ${size}gb] Starting import db=$db export=$export" + touch "$STATE_DIR/import-${size}gb.started" + + ( + set -euo pipefail + start=$(date +%s) + if ! IMPORT_PROFILE="$profile" \ + D1_EXPORT="$export" \ + PSCALE_DB="$db" \ + IMPORT_RUN_DIR="$run_dir" \ + "$ROOT/run-cli-import.sh"; then + touch "$STATE_DIR/import-${size}gb.failed" + echo "==> [watch ${size}gb] Import FAILED" + exit 1 + fi + end=$(date +%s) + wall=$((end - start)) + echo "${size}gb:${wall}s:${run_dir}" >> "$STATE_DIR/results.txt" + touch "$STATE_DIR/import-${size}gb.done" + echo "==> [watch ${size}gb] Import complete: ${wall}s" + ) > "$log" 2>&1 & + pid=$! + echo "$pid" > "$STATE_DIR/import-${size}gb.pid" +} + +remaining=0 +for size in "${SIZES[@]}"; do + if ! import_done "$size" && ! import_failed "$size"; then + remaining=$((remaining + 1)) + fi +done + +while (( remaining > 0 )); do + for size in "${SIZES[@]}"; do + if import_done "$size" || import_failed "$size" || import_running "$size"; then + continue + fi + + db="$(db_name_for "$size")" + if export_ready "$size" && branch_ready "$db"; then + if start_import "$size"; then + : + else + touch "$STATE_DIR/import-${size}gb.failed" + remaining=$((remaining - 1)) + fi + else + exp="no"; br="no" + export_ready "$size" && exp="yes" + branch_ready "$db" && br="yes" + echo "==> [watch ${size}gb] waiting (export=$exp branch=$br db=$db)" + fi + done + + remaining=0 + for size in "${SIZES[@]}"; do + if import_done "$size"; then + continue + fi + if import_failed "$size"; then + continue + fi + if import_running "$size"; then + remaining=$((remaining + 1)) + continue + fi + remaining=$((remaining + 1)) + done + + if (( remaining > 0 )); then + sleep "$POLL_SEC" + fi +done + +echo "" +echo "==> Waiting for running imports to finish" +wait_fail=0 +for size in "${SIZES[@]}"; do + pid_file="$STATE_DIR/import-${size}gb.pid" + [[ -f "$pid_file" ]] || continue + pid="$(cat "$pid_file")" + if kill -0 "$pid" 2>/dev/null; then + if ! wait "$pid"; then + touch "$STATE_DIR/import-${size}gb.failed" + wait_fail=1 + fi + fi +done + +echo "" +echo "==============================================" +echo "Import watcher finished $(date -u +%Y-%m-%dT%H:%M:%SZ)" +if [[ -f "$STATE_DIR/results.txt" ]]; then + cat "$STATE_DIR/results.txt" +fi +if [[ "$wait_fail" -ne 0 ]]; then + echo "ERROR: one or more imports failed" >&2 + exit 1 +fi +echo "==============================================" diff --git a/script/d1-import-test/build-local-export.sh b/script/d1-import-test/build-local-export.sh new file mode 100755 index 000000000..1d3065d3f --- /dev/null +++ b/script/d1-import-test/build-local-export.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash +# Build a wrangler-style SQL export locally (schema + generated seed). No D1/wrangler. +# +# Usage: +# ./script/d1-import-test/build-local-export.sh 1 +# SEED_DIR=/tmp/seed-5gb SEED_TARGET_GB=5 ./script/d1-import-test/build-local-export.sh 5 +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SIZE_GB="${1:-${SEED_TARGET_GB:-1}}" +OUT="${D1_EXPORT:-/tmp/import-test-${SIZE_GB}gb-export.sql}" +SEED_DIR="${SEED_DIR:-$ROOT/seed/local-${SIZE_GB}gb}" +REGENERATE="${REGENERATE_SEED:-true}" + +if [[ ! "$SIZE_GB" =~ ^[0-9]+([.][0-9]+)?$ ]]; then + echo "ERROR: size must be a number (GB), got: $SIZE_GB" >&2 + exit 1 +fi + +seed_order=( + organizations + users + teams + team_members + projects + tags + project_tags + tasks + task_dependencies + comments + attachments + audit_log + api_keys + sessions + notifications + invoices + line_items + external_entities + entity_links +) + +echo "==> [export ${SIZE_GB}gb] Building local SQL export (~${SIZE_GB} GB target)" +echo " output: $OUT" +echo " seed dir: $SEED_DIR" + +mkdir -p "$SEED_DIR" + +if [[ "$REGENERATE" == "true" ]]; then + echo "==> [export ${SIZE_GB}gb] Generating seed" + gen_start=$(date +%s) + find "$SEED_DIR" -maxdepth 1 -type f -name '*.sql' -delete + rm -f "$SEED_DIR/SUMMARY.json" + SEED_DIR="$SEED_DIR" SEED_TARGET_GB="$SIZE_GB" python3 "$ROOT/generate_seed.py" + gen_end=$(date +%s) + echo "==> [export ${SIZE_GB}gb] Seed generation: $((gen_end - gen_start))s" +fi + +if [[ ! -f "$SEED_DIR/SUMMARY.json" ]]; then + echo "ERROR: seed generation did not produce $SEED_DIR/SUMMARY.json" >&2 + exit 1 +fi + +python3 - "$SEED_DIR/SUMMARY.json" "$SIZE_GB" <<'PY' +import json, sys +summary = json.load(open(sys.argv[1])) +target_gb = float(sys.argv[2]) +target_bytes = int(target_gb * 1024**3) +actual = int(summary.get("estimated_blob_bytes", 0)) +min_bytes = int(target_bytes * 0.99) +if actual < min_bytes: + print( + f"ERROR: seed blob storage {actual} bytes ({actual/1024**3:.3f} GB) " + f"below target {target_bytes} bytes ({target_gb} GB)", + file=sys.stderr, + ) + sys.exit(1) +print(f"seed ok: {actual/1024**3:.3f} GB blob payload (target {target_gb} GB)") +PY + +tmp="${OUT}.tmp.$$" +{ + echo "PRAGMA foreign_keys=OFF;" + echo "-- Local D1-style export generated $(date -u +%Y-%m-%dT%H:%M:%SZ)" + echo "-- Target size: ~${SIZE_GB} GB" + echo + grep -v '^PRAGMA foreign_keys' "$ROOT/schema.sql" + echo + for prefix in "${seed_order[@]}"; do + mapfile -t files < <(find "$SEED_DIR" -maxdepth 1 -name "${prefix}_*.sql" | sort) + for file in "${files[@]}"; do + [[ -f "$file" ]] || continue + cat "$file" + echo + done + done +} > "$tmp" + +mv "$tmp" "$OUT" +bytes=$(stat -f%z "$OUT" 2>/dev/null || stat -c%s "$OUT") +gb=$(python3 -c "print(round($bytes / (1024**3), 3))") +echo "==> [export ${SIZE_GB}gb] Wrote $OUT (${gb} GB)" +if [[ -n "${EXPORT_READY_FILE:-}" ]]; then + touch "$EXPORT_READY_FILE" +fi diff --git a/script/d1-import-test/collect-benchmark-results.sh b/script/d1-import-test/collect-benchmark-results.sh new file mode 100755 index 000000000..7475c50fa --- /dev/null +++ b/script/d1-import-test/collect-benchmark-results.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +# Collect import timings and Postgres storage sizes into a report. +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLI="$(cd "$ROOT/../.." && pwd)" +STATE_DIR="${1:-}" +REPORT="${2:-${STATE_DIR}/report.txt}" + +if [[ -z "$STATE_DIR" || ! -d "$STATE_DIR" ]]; then + echo "ERROR: state dir required" >&2 + exit 1 +fi + +export PSCALE_DISABLE_DEV_WARNING=true +export PSCALE_TEST_MODE=1 +export PSCALE_ALLOW_NONINTERACTIVE_SHELL=1 +PSCALE="${CLI}/pscale-test" +API_URL="${PSCALE_API_URL:-http://api.pscaledev.com:3000/v1}" +ORG="${PSCALE_ORG:-bb}" + +{ + echo "D1 import storage benchmark report" + echo "Generated: $(date -u +%Y-%m-%dT%H:%M:%SZ)" + echo "State dir: $STATE_DIR" + echo "" + + for size in 1 5 9; do + db_file="$STATE_DIR/db-${size}gb.name" + [[ -f "$db_file" ]] || continue + db="$(cat "$db_file")" + echo "=== ${size} GB ===" + echo "database: $db" + + if [[ -f "$STATE_DIR/import-${size}gb.done" ]]; then + echo "import: SUCCESS" + elif [[ -f "$STATE_DIR/import-${size}gb.failed" ]]; then + echo "import: FAILED" + else + echo "import: incomplete" + fi + + grep "^${size}gb:" "$STATE_DIR/results.txt" 2>/dev/null || true + + run_dir="$(grep "^${size}gb:" "$STATE_DIR/results.txt" 2>/dev/null | cut -d: -f3- || true)" + if [[ -n "$run_dir" && -f "$run_dir/start.json" ]]; then + python3 - "$run_dir/start.json" <<'PY' +import json, sys +raw = open(sys.argv[1]).read() +idx = raw.rfind('{"status"') +if idx < 0: + sys.exit(0) +d = json.loads(raw[idx:]) +t = (d.get("data") or {}).get("timings") or {} +for k in ["total_ms", "schema_ms", "pgloader_ms", "index_build_ms", "sequence_reset_ms"]: + v = t.get(k) + if v is not None: + print(f" {k}: {v/1000:.1f}s") +loads = t.get("table_loads") or [] +if loads: + att = next((x for x in loads if x.get("table") == "attachments"), None) + if att: + print(f" attachments_pgloader: {att['ms']/1000:.1f}s") +PY + fi + + if echo "SELECT 1" | "$PSCALE" --api-url "$API_URL" shell "$db" main --org "$ORG" >/dev/null 2>&1; then + echo " postgres storage:" + echo "SELECT pg_size_pretty(sum(octet_length(payload))) AS attachments_payload, + pg_size_pretty(sum(pg_total_relation_size(format('%I.%I', schemaname, tablename)::regclass))) AS public_tables_on_disk, + pg_size_pretty(pg_database_size(current_database())) AS pg_database_size, + (SELECT count(*) FROM attachments) AS attachment_rows;" | \ + "$PSCALE" --api-url "$API_URL" shell "$db" main --org "$ORG" 2>/dev/null | grep -v "^$" | tail -4 | sed 's/^/ /' + fi + echo "" + done +} > "$REPORT" + +cat "$REPORT" diff --git a/script/d1-import-test/generate_seed.py b/script/d1-import-test/generate_seed.py new file mode 100755 index 000000000..29613273e --- /dev/null +++ b/script/d1-import-test/generate_seed.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +"""Generate batched seed SQL for the import-test D1 database.""" + +from __future__ import annotations + +import json +import os +import textwrap +import uuid +from datetime import datetime, timedelta, timezone +from pathlib import Path + +OUT_DIR = Path(os.environ.get("SEED_DIR", str(Path(__file__).resolve().parent / "seed"))) +DB_NAME = "import-test" + +# Target Postgres logical storage (bytes in bytea payloads + relational overhead). +# SEED_TARGET_GB=1 means ~1 GiB of attachment payload bytes land in Postgres. +TARGET_GB = float(os.environ.get("SEED_TARGET_GB", "9")) +TARGET_BYTES = int(TARGET_GB * 1024**3) +RESERVED_BYTES = int(os.environ.get("SEED_RESERVED_BYTES", str(64 * 1024**2))) # relational row headroom +# D1 rejects very large single statements (SQLITE_TOOBIG); hex blobs are ~2x raw size. +D1_MAX_STATEMENT_BYTES = int(os.environ.get("SEED_MAX_STATEMENT_BYTES", "100000")) + + +def payload_size() -> int: + if "SEED_PAYLOAD_BYTES" in os.environ: + requested = int(os.environ["SEED_PAYLOAD_BYTES"]) + elif TARGET_BYTES < 32 * 1024 * 1024: + requested = 16 * 1024 + else: + requested = 1024 * 1024 + # Hex literals are ~2x raw bytes; keep statements under D1_MAX_STATEMENT_BYTES. + max_raw = max((D1_MAX_STATEMENT_BYTES - 4096) // 2, 4096) + return min(requested, max_raw) + + +PAYLOAD_BYTES = payload_size() + +BATCH_SIZE = 50 +MAX_STATEMENT_BYTES = 90_000 +LARGE_ROW_BYTES = 32_000 + +USERS_PER_ORG = int(os.environ.get("SEED_USERS_PER_ORG", "25")) +TEAMS_PER_ORG = int(os.environ.get("SEED_TEAMS_PER_ORG", "5")) +TASKS_PER_PROJECT = int(os.environ.get("SEED_TASKS_PER_PROJECT", "20")) +COMMENTS_PER_TASK = int(os.environ.get("SEED_COMMENTS_PER_TASK", "2")) +PROJECTS_PER_ORG = int(os.environ.get("SEED_PROJECTS_PER_ORG", "40")) +INVOICES_PER_ORG = int(os.environ.get("SEED_INVOICES_PER_ORG", "5")) +NOTIFICATIONS_PER_USER = int(os.environ.get("SEED_NOTIFICATIONS_PER_USER", "3")) + + +def esc(value: str | None) -> str: + if value is None: + return "NULL" + return "'" + value.replace("'", "''") + "'" + + +def ts(base: datetime, minutes: int) -> str: + return (base + timedelta(minutes=minutes)).replace(tzinfo=timezone.utc).isoformat().replace("+00:00", "Z") + + +def payload_bytes(attachment_id: int) -> bytes: + prefix = f"attachment-{attachment_id}-".encode() + fill = max(PAYLOAD_BYTES - len(prefix), 0) + return prefix + (b"x" * fill) + + +def blob_literal(data: bytes) -> str: + return "X'" + data.hex() + "'" + + +def write_batches(name: str, header: str, rows: list[str], *, max_bytes: int = MAX_STATEMENT_BYTES) -> None: + OUT_DIR.mkdir(parents=True, exist_ok=True) + batch: list[str] = [] + batch_bytes = len(header.encode()) + file_idx = 0 + + def flush() -> None: + nonlocal batch, batch_bytes, file_idx + if not batch: + return + path = OUT_DIR / f"{name}_{file_idx:03d}.sql" + path.write_text(header + ",\n".join(batch) + ";\n", encoding="utf-8") + file_idx += 1 + batch = [] + batch_bytes = len(header.encode()) + + for row in rows: + row_bytes = len(row.encode()) + 2 + limit = max_bytes + if row_bytes > LARGE_ROW_BYTES: + limit = max(row_bytes + 2, limit) + if batch and batch_bytes + row_bytes > limit: + flush() + batch.append(row) + batch_bytes += row_bytes + flush() + + +def derive_volume() -> dict[str, int]: + # Full target goes to blob payloads — this is what shows up as Postgres storage. + blob_budget = TARGET_BYTES + attachment_target = max(blob_budget // max(PAYLOAD_BYTES, 1), 1) + + return { + "users_per_org": USERS_PER_ORG, + "teams_per_org": TEAMS_PER_ORG, + "projects_per_org": PROJECTS_PER_ORG, + "tasks_per_project": TASKS_PER_PROJECT, + "comments_per_task": COMMENTS_PER_TASK, + "attachment_target": attachment_target, + "target_bytes": TARGET_BYTES, + "payload_bytes": PAYLOAD_BYTES, + "reserved_bytes": RESERVED_BYTES, + } + + +def main() -> None: + if OUT_DIR.exists(): + for child in OUT_DIR.glob("*.sql"): + child.unlink() + else: + OUT_DIR.mkdir(parents=True) + + volume = derive_volume() + attachment_target = volume["attachment_target"] + + base = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc) + + bootstrap = OUT_DIR / "000_bootstrap.sql" + bootstrap.write_text( + textwrap.dedent( + f""" + INSERT INTO __drizzle_migrations (id, hash, created_at) VALUES + (1, '001_initial', 1710000000), + (2, '002_projects', 1710500000), + (3, '003_external_entities', 1711000000); + + INSERT INTO _prisma_migrations (id, checksum, finished_at, migration_name, logs, rolled_back_at, started_at, applied_steps_count) VALUES + ('seed-001', 'abc123', '{ts(base, 1)}', '20240101000000_init', NULL, NULL, '{ts(base, 0)}', 1); + + INSERT INTO knex_migrations (id, name, batch, migration_time) VALUES + (1, '001_initial.js', 1, 1710000000000); + + INSERT INTO knex_migrations_lock ("index", is_locked) VALUES + (1, 0); + + INSERT INTO sequelizemeta (name) VALUES + ('20240101000000-init.js'); + + INSERT INTO schema_migrations (version) VALUES + ('20240101000000'); + + INSERT INTO ar_internal_metadata (key, value, created_at, updated_at) VALUES + ('environment', 'production', '{ts(base, 0)}', '{ts(base, 0)}'); + + INSERT INTO flyway_schema_history (installed_rank, version, description, type, script, checksum, installed_by, installed_on, execution_time, success) VALUES + (1, '1', 'initial', 'SQL', 'V1__initial.sql', 12345, 'seed', '{ts(base, 0)}', 42, 1); + + INSERT INTO databasechangelog (id, author, filename, dateexecuted, orderexecuted, exectype, md5sum, description, comments, tag, liquibase, contexts, labels, deployment_id) VALUES + ('seed-001', 'seed', 'db/changelog/001.xml', '{ts(base, 0)}', 1, 'EXECUTED', 'abc', 'initial', NULL, NULL, '4.0', NULL, NULL, 'seed'); + + INSERT INTO databasechangeloglock (id, locked, lockgranted, lockedby) VALUES + (1, 0, NULL, NULL); + + INSERT INTO django_migrations (id, app, name, applied) VALUES + (1, 'core', '0001_initial', '{ts(base, 0)}'); + + INSERT INTO alembic_version (version_num) VALUES + ('001_initial'); + + INSERT INTO typeorm_metadata (type, "database", "schema", "table", name, value) VALUES + ('seed', 'import-test', 'main', 'organizations', 'version', '1'); + + INSERT INTO goose_db_version (id, version_id, is_applied, tstamp) VALUES + (1, 1, 1, '{ts(base, 0)}'); + + INSERT INTO organizations (id, slug, name, plan, is_active, settings, created_at, updated_at, deleted_at) VALUES + (1, 'acme', 'Acme Corp', 'enterprise', 1, '{json.dumps({"timezone": "UTC", "features": ["audit", "sso"]})}', '{ts(base, 0)}', '{ts(base, 5)}', NULL), + (2, 'globex', 'Globex Industries', 'pro', 1, '{json.dumps({"timezone": "America/New_York"})}', '{ts(base, 10)}', '{ts(base, 15)}', NULL); + + INSERT INTO users (id, org_id, email, display_name, role, is_active, is_admin, profile_json, last_login_at, created_at, updated_at) VALUES + (1, 1, 'alice@acme.test', 'Alice Admin', 'admin', 1, 1, '{json.dumps({"title": "Platform Lead"})}', '{ts(base, 60)}', '{ts(base, 20)}', '{ts(base, 60)}'), + (2, 1, 'bob@acme.test', 'Bob Builder', 'member', 1, 0, '{json.dumps({"title": "Engineer"})}', '{ts(base, 120)}', '{ts(base, 25)}', '{ts(base, 120)}'), + (3, 2, 'carol@globex.test', 'Carol CFO', 'admin', 1, 1, NULL, '{ts(base, 180)}', '{ts(base, 30)}', '{ts(base, 180)}'); + + INSERT INTO external_entities (id, org_id, name, kind, metadata, created_at) VALUES + ('550e8400-e29b-41d4-a716-446655440000', 1, 'Bootstrap Webhook', 'webhook', '{json.dumps({"seed": True})}', '{ts(base, 40)}'); + """ + ).strip() + + "\n", + encoding="utf-8", + ) + + org_rows: list[str] = [] + user_rows: list[str] = [] + team_rows: list[str] = [] + team_member_rows: list[str] = [] + project_rows: list[str] = [] + tag_rows: list[str] = [] + project_tag_rows: list[str] = [] + task_rows: list[str] = [] + task_dep_rows: list[str] = [] + comment_rows: list[str] = [] + attachment_rows: list[str] = [] + audit_rows: list[str] = [] + api_key_rows: list[str] = [] + session_rows: list[str] = [] + notification_rows: list[str] = [] + invoice_rows: list[str] = [] + line_item_rows: list[str] = [] + external_entity_rows: list[str] = [] + entity_link_rows: list[str] = [] + + org_id = 3 + user_id = 4 + team_id = 1 + project_id = 1 + tag_id = 1 + task_id = 1 + comment_id = 1 + attachment_id = 1 + audit_id = 1 + api_key_id = 1 + session_id = 1 + notification_id = 1 + invoice_id = 1 + line_item_id = 1 + + while attachment_id <= attachment_target: + slug = f"org-{org_id}" + org_rows.append( + f"({org_id}, {esc(slug)}, {esc(slug.replace('-', ' ').title())}, 'pro', 1, " + f"{esc(json.dumps({'seed': True, 'index': org_id}))}, {esc(ts(base, org_id))}, {esc(ts(base, org_id + 1))}, NULL)" + ) + + entity_uuid = str(uuid.uuid4()) + external_entity_rows.append( + f"({esc(entity_uuid)}, {org_id}, {esc(f'Entity {org_id}')}, 'integration', " + f"{esc(json.dumps({'org': org_id, 'seed': True}))}, {esc(ts(base, org_id + 2))})" + ) + + org_user_ids: list[int] = [] + for u in range(USERS_PER_ORG): + email = f"user{user_id}@{slug}.test" + org_user_ids.append(user_id) + user_rows.append( + f"({user_id}, {org_id}, {esc(email)}, {esc(f'User {user_id}')}, 'member', 1, {1 if u == 0 else 0}, " + f"{esc(json.dumps({'team': u % TEAMS_PER_ORG}))}, {esc(ts(base, user_id))}, {esc(ts(base, user_id))}, {esc(ts(base, user_id + 1))})" + ) + for n in range(NOTIFICATIONS_PER_USER): + notification_rows.append( + f"({notification_id}, {user_id}, 'mention', {esc(f'Mention {notification_id}')}, " + f"{esc('You were mentioned in a comment')}, {esc(json.dumps({'task_id': task_id}))}, " + f"{n % 2}, NULL, {esc(ts(base, notification_id))})" + ) + notification_id += 1 + api_key_rows.append( + f"({api_key_id}, {user_id}, {esc(f'key-{api_key_id}')}, {esc(f'pk_{api_key_id:04d}')}, 1, " + f"{esc(json.dumps(['read', 'write']))}, {esc(ts(base, api_key_id + 720))}, {esc(ts(base, api_key_id))})" + ) + api_key_id += 1 + session_rows.append( + f"({session_id}, {user_id}, {esc(f'tok_{session_id:08x}')}, '127.0.0.1', 'seed-script', " + f"{esc(ts(base, session_id))}, {esc(ts(base, session_id + 1440))}, {esc(ts(base, session_id))})" + ) + session_id += 1 + user_id += 1 + + org_team_ids: list[int] = [] + for t in range(TEAMS_PER_ORG): + org_team_ids.append(team_id) + team_rows.append( + f"({team_id}, {org_id}, {esc(f'Team {team_id}')}, 0, {esc(ts(base, team_id))})" + ) + for member_idx, member_user in enumerate(org_user_ids[:5]): + team_member_rows.append( + f"({team_id}, {member_user}, {esc('lead' if member_idx == 0 else 'member')}, {esc(ts(base, team_id + member_idx))})" + ) + team_id += 1 + + org_tag_ids: list[int] = [] + for label in ("backend", "frontend", "infra", "design"): + org_tag_ids.append(tag_id) + tag_rows.append(f"({tag_id}, {org_id}, {esc(label)}, {esc('#' + format(tag_id * 111111 % 0xFFFFFF, '06x'))})") + tag_id += 1 + + org_project_ids: list[int] = [] + for p in range(PROJECTS_PER_ORG): + if attachment_id > attachment_target: + break + owner = org_user_ids[p % len(org_user_ids)] + team = org_team_ids[p % len(org_team_ids)] + org_project_ids.append(project_id) + project_rows.append( + f"({project_id}, {org_id}, {owner}, {team}, {esc(f'project-{project_id}')}, {esc(f'Project {project_id}')}, " + f"{esc(f'Description for project {project_id}')}, {p % 3 == 0}, 0, " + f"{esc(json.dumps({'priority': p % 5, 'labels': ['seed']}))}, {esc(ts(base, project_id))}, {esc(ts(base, project_id + 2))})" + ) + for tg in org_tag_ids[:2]: + project_tag_rows.append(f"({project_id}, {tg})") + + prev_task_in_project: int | None = None + for tk in range(TASKS_PER_PROJECT): + if attachment_id > attachment_target: + break + assignee = org_user_ids[(p + tk) % len(org_user_ids)] + parent = prev_task_in_project if tk > 0 and tk % 4 == 0 else None + status = ("open", "in_progress", "done", "cancelled")[tk % 4] + task_rows.append( + f"({task_id}, {project_id}, {assignee}, {parent if parent else 'NULL'}, " + f"{esc(f'Task {task_id}')}, {esc(f'Body for task {task_id}')}, {esc(status)}, {(tk % 5) + 1}, " + f"{round(1.5 + (task_id % 7), 2)}, {1 if tk % 6 == 0 else 0}, " + f"{esc(json.dumps({'tags': ['seed', status]}))}, {esc(ts(base, task_id + 1000))}, " + f"{esc(ts(base, task_id + 2000)) if status == 'done' else 'NULL'}, " + f"{esc(ts(base, task_id))}, {esc(ts(base, task_id + 1))})" + ) + if prev_task_in_project is not None and tk % 3 == 0: + task_dep_rows.append(f"({task_id}, {prev_task_in_project})") + prev_task_in_project = task_id + + if tk == 0: + entity_link_rows.append( + f"({esc(entity_uuid)}, {task_id}, {esc(ts(base, task_id + 3))})" + ) + + for c in range(COMMENTS_PER_TASK): + if attachment_id > attachment_target: + break + author = org_user_ids[(c + tk) % len(org_user_ids)] + parent_comment = comment_id - 1 if c > 0 else None + comment_rows.append( + f"({comment_id}, {task_id}, {author}, {parent_comment if parent_comment else 'NULL'}, " + f"{esc(f'Comment {comment_id} on task {task_id}')}, {1 if c > 0 else 0}, " + f"{esc(ts(base, comment_id))}, {esc(ts(base, comment_id + 1))})" + ) + blob = payload_bytes(attachment_id) + attachment_rows.append( + f"({attachment_id}, {comment_id}, {esc(f'file-{attachment_id}.bin')}, 'application/octet-stream', " + f"{len(blob)}, {esc(f'sha256:{attachment_id:08x}')}, {blob_literal(blob)}, {esc(ts(base, attachment_id))})" + ) + attachment_id += 1 + comment_id += 1 + + task_id += 1 + project_id += 1 + + for inv in range(INVOICES_PER_ORG): + subtotal = round(1000 + inv * 250.5, 2) + tax = round(subtotal * 0.08, 2) + total = round(subtotal + tax, 2) + invoice_rows.append( + f"({invoice_id}, {org_id}, {esc(f'INV-{invoice_id:05d}')}, {subtotal}, {tax}, {total}, 'sent', " + f"{esc(ts(base, invoice_id + 5000))}, {esc(ts(base, invoice_id + 6000))}, NULL, {esc(ts(base, invoice_id))})" + ) + for line in range(3): + qty = line + 1 + unit = round(50.25 + line, 2) + amount = round(qty * unit, 2) + related_task = max(task_id - 1 - line, 1) + line_item_rows.append( + f"({line_item_id}, {invoice_id}, {related_task}, {esc(f'Line {line_item_id}')}, " + f"{qty}, {unit}, {amount})" + ) + line_item_id += 1 + invoice_id += 1 + + for a in range(10): + if not org_project_ids: + break + actor = org_user_ids[a % len(org_user_ids)] + audit_rows.append( + f"({audit_id}, {org_id}, {actor}, 'update', 'project', {org_project_ids[a % len(org_project_ids)]}, " + f"{esc(json.dumps({'field': 'name', 'seed': True}))}, {esc(ts(base, audit_id + 8000))})" + ) + audit_id += 1 + + org_id += 1 + + actual_attachments = attachment_id - 1 + estimated_blob_bytes = actual_attachments * PAYLOAD_BYTES + if estimated_blob_bytes < int(TARGET_BYTES * 0.99): + raise SystemExit( + f"seed under storage target: generated {estimated_blob_bytes} bytes, " + f"need >= {int(TARGET_BYTES * 0.99)} ({TARGET_GB} GiB Postgres payload target)" + ) + + write_batches("organizations", "INSERT INTO organizations (id, slug, name, plan, is_active, settings, created_at, updated_at, deleted_at) VALUES\n", org_rows) + write_batches("users", "INSERT INTO users (id, org_id, email, display_name, role, is_active, is_admin, profile_json, last_login_at, created_at, updated_at) VALUES\n", user_rows) + write_batches("teams", "INSERT INTO teams (id, org_id, name, is_archived, created_at) VALUES\n", team_rows) + write_batches("team_members", "INSERT INTO team_members (team_id, user_id, role, joined_at) VALUES\n", team_member_rows) + write_batches("projects", "INSERT INTO projects (id, org_id, owner_user_id, team_id, slug, name, description, is_public, is_archived, metadata, created_at, updated_at) VALUES\n", project_rows) + write_batches("tags", "INSERT INTO tags (id, org_id, label, color) VALUES\n", tag_rows) + write_batches("project_tags", "INSERT INTO project_tags (project_id, tag_id) VALUES\n", project_tag_rows) + write_batches("tasks", "INSERT INTO tasks (id, project_id, assignee_user_id, parent_task_id, title, body, status, priority, estimate_hours, is_blocked, labels_json, due_at, completed_at, created_at, updated_at) VALUES\n", task_rows) + write_batches("task_dependencies", "INSERT INTO task_dependencies (task_id, depends_on_task_id) VALUES\n", task_dep_rows) + write_batches("comments", "INSERT INTO comments (id, task_id, author_user_id, parent_comment_id, body, is_edited, created_at, updated_at) VALUES\n", comment_rows) + write_batches( + "attachments", + "INSERT INTO attachments (id, comment_id, filename, content_type, byte_size, checksum, payload, uploaded_at) VALUES\n", + attachment_rows, + max_bytes=min(D1_MAX_STATEMENT_BYTES, max(PAYLOAD_BYTES * 2 + 4096, MAX_STATEMENT_BYTES)), + ) + write_batches("audit_log", "INSERT INTO audit_log (id, org_id, actor_user_id, action, entity_type, entity_id, payload, created_at) VALUES\n", audit_rows) + write_batches("api_keys", "INSERT INTO api_keys (id, user_id, name, key_prefix, is_active, scopes_json, expires_at, created_at) VALUES\n", api_key_rows) + write_batches("sessions", "INSERT INTO sessions (id, user_id, token_hash, ip_address, user_agent, last_seen_at, expires_at, created_at) VALUES\n", session_rows) + write_batches("notifications", "INSERT INTO notifications (id, user_id, kind, title, body, payload, is_read, read_at, created_at) VALUES\n", notification_rows) + write_batches("invoices", "INSERT INTO invoices (id, org_id, invoice_number, subtotal, tax, total, status, issued_at, due_at, paid_at, created_at) VALUES\n", invoice_rows) + write_batches("line_items", "INSERT INTO invoice_line_items (id, invoice_id, task_id, description, quantity, unit_price, amount) VALUES\n", line_item_rows) + write_batches( + "external_entities", + "INSERT INTO external_entities (id, org_id, name, kind, metadata, created_at) VALUES\n", + external_entity_rows, + ) + write_batches("entity_links", "INSERT INTO entity_links (entity_id, task_id, linked_at) VALUES\n", entity_link_rows) + + estimated_blob_bytes = (attachment_id - 1) * PAYLOAD_BYTES + if attachment_id - 1 < attachment_target: + raise SystemExit( + f"ERROR: generated {attachment_id - 1} attachments, expected {attachment_target} " + f"({estimated_blob_bytes} bytes, target {TARGET_BYTES})" + ) + min_bytes = int(TARGET_BYTES * 0.99) + if estimated_blob_bytes < min_bytes: + raise SystemExit( + f"ERROR: blob bytes {estimated_blob_bytes} below target minimum {min_bytes}" + ) + summary = { + **volume, + "organizations": len(org_rows) + 2, + "users": len(user_rows) + 3, + "teams": len(team_rows), + "projects": len(project_rows), + "tasks": len(task_rows), + "comments": len(comment_rows), + "attachments": attachment_id - 1, + "external_entities": len(external_entity_rows) + 1, + "entity_links": len(entity_link_rows), + "estimated_blob_bytes": estimated_blob_bytes, + "estimated_total_bytes": estimated_blob_bytes + RESERVED_BYTES, + "seed_files": len(list(OUT_DIR.glob("*.sql"))), + } + (OUT_DIR / "SUMMARY.json").write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8") + print(json.dumps(summary, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/script/d1-import-test/launch-benchmark-detached.sh b/script/d1-import-test/launch-benchmark-detached.sh new file mode 100755 index 000000000..5c094463f --- /dev/null +++ b/script/d1-import-test/launch-benchmark-detached.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +# Launch benchmark jobs in a new session so they survive Cursor/agent shell teardown. +# +# Usage: +# BENCH_STATE_DIR=/tmp/d1-bench-20260625-204823 ./script/d1-import-test/launch-benchmark-detached.sh +# START_5GB_EXPORT=true ./script/d1-import-test/launch-benchmark-detached.sh +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLI="$(cd "$ROOT/../.." && pwd)" +STATE_DIR="${BENCH_STATE_DIR:-/tmp/d1-bench-$(date +%Y%m%d-%H%M%S)}" +SIZES="${SIZES:-1 5 9}" +START_5GB_EXPORT="${START_5GB_EXPORT:-true}" + +mkdir -p "$STATE_DIR" + +if [[ ! -x "$CLI/pscale-test" ]]; then + (cd "$CLI" && go build -o pscale-test ./cmd/pscale) +fi + +rm -f "$STATE_DIR/import-"{1,5,9}gb.{pid,started,failed} + +python3 - "$ROOT" "$CLI" "$STATE_DIR" "$SIZES" "$START_5GB_EXPORT" <<'PY' +import os, subprocess, sys, time +from pathlib import Path + +root, cli, state_dir, sizes, start_5gb = sys.argv[1:6] +state = Path(state_dir) +state.mkdir(parents=True, exist_ok=True) + +env = os.environ.copy() +env.update({ + "PSCALE_DISABLE_DEV_WARNING": "true", + "PSCALE_TEST_MODE": "1", + "PSCALE_ORG": env.get("PSCALE_ORG", "bb"), +}) + +def spawn(name, cmd, extra_env=None, log_name=None): + e = env.copy() + if extra_env: + e.update(extra_env) + log = state / (log_name or f"{name}.log") + log.parent.mkdir(parents=True, exist_ok=True) + fh = open(log, "a", buffering=1) + fh.write(f"\n==> detached launch {name} pid-pending {time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())}\n") + fh.flush() + proc = subprocess.Popen( + cmd, + stdin=subprocess.DEVNULL, + stdout=fh, + stderr=subprocess.STDOUT, + env=e, + cwd=cli, + start_new_session=True, + close_fds=True, + ) + (state / f"{name}.pid").write_text(str(proc.pid)) + print(f"started {name}: pid={proc.pid} log={log}") + return proc.pid + +pids = [] + +if start_5gb.lower() == "true" and not (state / "export-5gb.ready").exists(): + pids.append(spawn( + "export-5gb", + ["bash", f"{root}/build-local-export.sh", "5"], + { + "SEED_DIR": f"{root}/seed/local-5gb", + "D1_EXPORT": "/tmp/import-test-5gb-export.sql", + "EXPORT_READY_FILE": str(state / "export-5gb.ready"), + }, + )) +else: + print("skip 5gb export (ready or disabled)") + +pids.append(spawn( + "watcher", + ["bash", f"{root}/bench-watch-imports.sh"], + { + "BENCH_STATE_DIR": state_dir, + "BENCH_WATCH_LOG": str(state / "watch.log"), + "SIZES": sizes, + "POLL_SEC": env.get("POLL_SEC", "15"), + }, + log_name="watch.log", +)) + +(state / "launcher.pids").write_text("\n".join(str(p) for p in pids) + "\n") +print(f"state_dir={state_dir}") +print("detached — safe to close Cursor agent shell") +PY + +chmod +x "$ROOT/bench-watch-imports.sh" "$ROOT/build-local-export.sh" diff --git a/script/d1-import-test/launch-storage-benchmark-detached.sh b/script/d1-import-test/launch-storage-benchmark-detached.sh new file mode 100755 index 000000000..ef6c5bdc6 --- /dev/null +++ b/script/d1-import-test/launch-storage-benchmark-detached.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# Run full storage benchmark in a detached session (survives Cursor shell teardown). +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLI="$(cd "$ROOT/../.." && pwd)" +STATE_DIR="${BENCH_STATE_DIR:-/tmp/d1-bench-$(date +%Y%m%d-%H%M%S)}" +LOG="${D1_BENCHMARK_LOG:-/tmp/d1-storage-benchmark.log}" + +mkdir -p "$STATE_DIR" + +python3 - "$ROOT" "$CLI" "$STATE_DIR" "$LOG" <<'PY' +import os, subprocess, sys, time +from pathlib import Path + +root, cli, state_dir, log = sys.argv[1:5] +state = Path(state_dir) +env = os.environ.copy() +env["BENCH_STATE_DIR"] = state_dir +env["D1_BENCHMARK_LOG"] = log +env["PSCALE_DISABLE_DEV_WARNING"] = "true" +env["PSCALE_TEST_MODE"] = "1" + +fh = open(log, "a", buffering=1) +fh.write(f"\n==> detached storage benchmark {time.strftime('%Y-%m-%dT%H:%M:%SZ')} state={state_dir}\n") +fh.flush() + +proc = subprocess.Popen( + ["bash", f"{root}/run-storage-benchmark.sh"], + stdin=subprocess.DEVNULL, + stdout=fh, + stderr=subprocess.STDOUT, + env=env, + cwd=cli, + start_new_session=True, + close_fds=True, +) +(state / "benchmark.pid").write_text(str(proc.pid)) +print(f"benchmark pid={proc.pid}") +print(f"state_dir={state_dir}") +print(f"log={log}") +print(f"report={state_dir}/report.txt (when complete)") +PY + +chmod +x "$ROOT/run-storage-benchmark.sh" "$ROOT/collect-benchmark-results.sh" "$ROOT/bench-watch-imports.sh" diff --git a/script/d1-import-test/load-bulk.sh b/script/d1-import-test/load-bulk.sh new file mode 100755 index 000000000..ee70c8385 --- /dev/null +++ b/script/d1-import-test/load-bulk.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash +# Load import-test D1 via merged SQL chunks (~50 MB each) instead of one wrangler call +# per seed batch. Use for multi-GB seeds; keep load.sh for quick smoke tests. +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DB="${D1_DATABASE:-import-test}" +REMOTE="${D1_REMOTE:-true}" +FRESH="${D1_FRESH:-false}" +CHUNK_TARGET_MB="${CHUNK_TARGET_MB:-50}" + +remote_flag=() +if [[ "$REMOTE" == "true" ]]; then + remote_flag=(--remote) +fi + +if [[ "$FRESH" == "true" ]]; then + echo "==> Resetting existing tables" + wrangler d1 execute "$DB" "${remote_flag[@]}" --file="$ROOT/reset.sql" + echo "==> Applying schema to D1 database: $DB" + wrangler d1 execute "$DB" "${remote_flag[@]}" --file="$ROOT/schema.sql" +else + echo "==> Skipping reset/schema (set D1_FRESH=true for clean load)" +fi + +if [[ "${SKIP_SEED_GENERATE:-false}" != "true" ]]; then + echo "==> Generating seed SQL batches" + python3 "$ROOT/generate_seed.py" +else + echo "==> Skipping seed generation (SKIP_SEED_GENERATE=true)" +fi + +if [[ "${SKIP_MERGE:-false}" != "true" ]]; then + echo "==> Merging seed batches into ~${CHUNK_TARGET_MB} MB chunks" + CHUNK_TARGET_MB="$CHUNK_TARGET_MB" python3 "$ROOT/merge_seed_chunks.py" +else + echo "==> Skipping merge (SKIP_MERGE=true)" +fi + +mapfile -t chunks < <(find "$ROOT/seed/chunks" -name 'chunk_*.sql' | sort) +total_chunks="${#chunks[@]}" +if [[ "$total_chunks" -eq 0 ]]; then + echo "ERROR: no chunk files under $ROOT/seed/chunks" >&2 + exit 1 +fi + +echo "==> Loading ${total_chunks} chunk(s) to D1 (sequential — D1 is single-threaded per DB)" +chunk_num=0 +for chunk in "${chunks[@]}"; do + chunk_num=$((chunk_num + 1)) + size_bytes=$(stat -f%z "$chunk" 2>/dev/null || stat -c%s "$chunk") + size_mb=$(python3 -c "print(round($size_bytes / (1024*1024), 1))") + echo " -> [${chunk_num}/${total_chunks}] $(basename "$chunk") (${size_mb} MB)" + wrangler d1 execute "$DB" "${remote_flag[@]}" --file="$chunk" +done + +echo "==> Done. Export with:" +echo " wrangler d1 export $DB --remote --output ./import-test-export.sql" +echo " pscale import d1 lint --input ./import-test-export.sql --format json" diff --git a/script/d1-import-test/load.sh b/script/d1-import-test/load.sh new file mode 100755 index 000000000..adab4756f --- /dev/null +++ b/script/d1-import-test/load.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DB="${D1_DATABASE:-import-test}" +REMOTE="${D1_REMOTE:-true}" +FRESH="${D1_FRESH:-false}" + +remote_flag=() +if [[ "$REMOTE" == "true" ]]; then + remote_flag=(--remote) +fi + +if [[ "$FRESH" == "true" ]]; then + echo "==> Resetting existing tables" + wrangler d1 execute "$DB" "${remote_flag[@]}" --file="$ROOT/reset.sql" +fi + +echo "==> Applying schema to D1 database: $DB" +wrangler d1 execute "$DB" "${remote_flag[@]}" --file="$ROOT/schema.sql" + +echo "==> Generating seed SQL batches" +python3 "$ROOT/generate_seed.py" + +echo "==> Loading seed data" +seed_order=( + 000_bootstrap.sql + organizations + users + teams + team_members + projects + tags + project_tags + tasks + task_dependencies + comments + attachments + audit_log + api_keys + sessions + notifications + invoices + line_items + external_entities + entity_links +) + +for prefix in "${seed_order[@]}"; do + if [[ "$prefix" == *.sql ]]; then + files=("$ROOT/seed/$prefix") + else + mapfile -t files < <(find "$ROOT/seed" -name "${prefix}_*.sql" | sort) + fi + for file in "${files[@]}"; do + [[ -f "$file" ]] || continue + echo " -> $(basename "$file")" + wrangler d1 execute "$DB" "${remote_flag[@]}" --file="$file" + done +done + +echo "==> Done. Export with:" +echo " wrangler d1 export $DB --remote --output ./import-test-export.sql" +echo " pscale import d1 lint --input ./import-test-export.sql --format json" diff --git a/script/d1-import-test/merge_seed_chunks.py b/script/d1-import-test/merge_seed_chunks.py new file mode 100755 index 000000000..0b5133ab8 --- /dev/null +++ b/script/d1-import-test/merge_seed_chunks.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +"""Merge per-batch seed SQL files into large chunks for wrangler d1 execute --file. + +D1 limits (see Cloudflare docs): + - Max SQL statement length: 100 KB (each statement in a file) + - Max import file size via wrangler: 5 GB + +Each small seed file is kept intact (never split). Chunks are concatenations of +complete statements, target ~50 MB per chunk by default. +""" + +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parent +SEED_DIR = ROOT / "seed" +CHUNKS_DIR = SEED_DIR / "chunks" + +# Keep in sync with load.sh / load-bulk.sh seed order. +SEED_ORDER: list[str] = [ + "000_bootstrap.sql", + "organizations", + "users", + "teams", + "team_members", + "projects", + "tags", + "project_tags", + "tasks", + "task_dependencies", + "comments", + "attachments", + "audit_log", + "api_keys", + "sessions", + "notifications", + "invoices", + "line_items", + "external_entities", + "entity_links", +] + +D1_MAX_STATEMENT_BYTES = int(os.environ.get("SEED_MAX_STATEMENT_BYTES", "100000")) +CHUNK_TARGET_BYTES = int(float(os.environ.get("CHUNK_TARGET_MB", "50")) * 1024 * 1024) + + +def ordered_seed_files() -> list[Path]: + files: list[Path] = [] + for prefix in SEED_ORDER: + if prefix.endswith(".sql"): + path = SEED_DIR / prefix + if path.is_file(): + files.append(path) + continue + for path in sorted(SEED_DIR.glob(f"{prefix}_*.sql")): + files.append(path) + return files + + +def validate_statements(files: list[Path]) -> None: + oversize: list[tuple[str, int]] = [] + for path in files: + size = path.stat().st_size + if size > D1_MAX_STATEMENT_BYTES: + oversize.append((path.name, size)) + if oversize: + sample = oversize[:5] + details = ", ".join(f"{name} ({size} bytes)" for name, size in sample) + extra = f" (+{len(oversize) - 5} more)" if len(oversize) > 5 else "" + raise SystemExit( + f"ERROR: {len(oversize)} seed file(s) exceed D1 max statement size " + f"({D1_MAX_STATEMENT_BYTES} bytes): {details}{extra}\n" + "Regenerate seed with a smaller SEED_PAYLOAD_BYTES or lower SEED_MAX_STATEMENT_BYTES." + ) + + +def merge_chunks(files: list[Path], *, dry_run: bool = False) -> dict: + if not files: + raise SystemExit(f"no seed SQL files found under {SEED_DIR}") + + validate_statements(files) + + if not dry_run: + if CHUNKS_DIR.exists(): + for child in CHUNKS_DIR.glob("chunk_*.sql"): + child.unlink() + else: + CHUNKS_DIR.mkdir(parents=True) + + chunk_idx = 0 + current_files: list[Path] = [] + current_bytes = 0 + chunk_manifest: list[dict] = [] + + def flush() -> None: + nonlocal chunk_idx, current_files, current_bytes + if not current_files: + return + chunk_idx += 1 + out_path = CHUNKS_DIR / f"chunk_{chunk_idx:04d}.sql" + file_count = len(current_files) + total_bytes = sum(f.stat().st_size for f in current_files) + entry = { + "chunk": chunk_idx, + "file": out_path.name, + "source_files": file_count, + "bytes": total_bytes, + } + chunk_manifest.append(entry) + if not dry_run: + with out_path.open("wb") as out: + for src in current_files: + data = src.read_bytes() + out.write(data) + if data and not data.endswith(b"\n"): + out.write(b"\n") + current_files = [] + current_bytes = 0 + + for path in files: + size = path.stat().st_size + if size > CHUNK_TARGET_BYTES: + flush() + chunk_idx += 1 + out_path = CHUNKS_DIR / f"chunk_{chunk_idx:04d}.sql" + chunk_manifest.append( + { + "chunk": chunk_idx, + "file": out_path.name, + "source_files": 1, + "bytes": size, + "note": "oversized single file kept alone", + } + ) + if not dry_run: + out_path.write_bytes(path.read_bytes()) + continue + if current_files and current_bytes + size > CHUNK_TARGET_BYTES: + flush() + current_files.append(path) + current_bytes += size + flush() + + total_bytes = sum(e["bytes"] for e in chunk_manifest) + return { + "chunk_target_bytes": CHUNK_TARGET_BYTES, + "d1_max_statement_bytes": D1_MAX_STATEMENT_BYTES, + "source_files": len(files), + "chunks": len(chunk_manifest), + "total_bytes": total_bytes, + "manifest": chunk_manifest, + } + + +def main() -> None: + dry_run = "--dry-run" in sys.argv + files = ordered_seed_files() + summary = merge_chunks(files, dry_run=dry_run) + summary_path = CHUNKS_DIR / "MANIFEST.json" + if not dry_run: + CHUNKS_DIR.mkdir(parents=True, exist_ok=True) + summary_path.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8") + print(json.dumps(summary, indent=2)) + if dry_run: + print(f"(dry run — no files written under {CHUNKS_DIR})", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/script/d1-import-test/prepare-demo-100mb.sh b/script/d1-import-test/prepare-demo-100mb.sh new file mode 100755 index 000000000..0df01b005 --- /dev/null +++ b/script/d1-import-test/prepare-demo-100mb.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash +# Prepare ~100 MiB D1 export + SQLite + fresh Postgres DB for demos. +# +# Usage: +# ./script/d1-import-test/prepare-demo-100mb.sh +# PSCALE_DB=cf-d1-import-demo ./script/d1-import-test/prepare-demo-100mb.sh +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLI="$(cd "$ROOT/../.." && pwd)" +DEMO_DIR="${DEMO_DIR:-/tmp/d1-demo-100mb}" +DB="${PSCALE_DB:-cf-d1-import-demo-$(date +%Y%m%d)}" +TARGET_GB="$(python3 -c "print(100*1024*1024/(1024**3))")" + +mkdir -p "$DEMO_DIR" + +export PSCALE_DISABLE_DEV_WARNING=true +export PSCALE_TEST_MODE=1 + +echo "==> Building ~100 MiB export (target ${TARGET_GB} GB payload)" +SEED_DIR="$ROOT/seed/demo-100mb" \ + D1_EXPORT="$DEMO_DIR/import-test-100mb-export.sql" \ + SEED_TARGET_GB="$TARGET_GB" \ + REGENERATE_SEED=true \ + "$ROOT/build-local-export.sh" "$TARGET_GB" + +echo "==> Building SQLite file" +SQLITE="$DEMO_DIR/import-test-100mb.sqlite" +rm -f "$SQLITE" +grep -v '^PRAGMA foreign_keys' "$DEMO_DIR/import-test-100mb-export.sql" | sqlite3 "$SQLITE" + +echo "==> Provisioning database $DB" +PROVISION_READY_FILE="$DEMO_DIR/provision.ready" \ + "$ROOT/provision-database.sh" "$DB" > "$DEMO_DIR/provision.log" 2>&1 + +echo "$DB" > "$DEMO_DIR/database.name" +cat > "$DEMO_DIR/demo.env" <&2 + exit 1 +fi + +PSCALE="${CLI}/pscale-test" +if [[ ! -x "$PSCALE" ]]; then + (cd "$CLI" && go build -o pscale-test ./cmd/pscale) +fi + +db_cmd() { + "$PSCALE" --api-url "$API_URL" database "$@" --org "$ORG" +} + +echo "==> [provision] Creating Postgres database $NAME" +db_cmd create "$NAME" \ + --engine postgresql \ + --region "$REGION" \ + --cluster-size "$CLUSTER_SIZE" \ + --replicas "$REPLICAS" \ + --major-version "$PG_MAJOR_VERSION" + +echo "==> [provision] Waiting for branch $BRANCH to be ready" +PSCALE_CMD=("$PSCALE" --api-url "$API_URL") +deadline=$((SECONDS + ${PROVISION_TIMEOUT_SEC:-900})) +while (( SECONDS < deadline )); do + ready="$("${PSCALE_CMD[@]}" branch show "$NAME" "$BRANCH" --format json --org "$ORG" 2>/dev/null | python3 -c "import json,sys; print(json.load(sys.stdin).get('ready', False))" 2>/dev/null || echo False)" + if [[ "$ready" == "True" ]]; then + break + fi + sleep 5 +done + +if [[ "$ready" != "True" ]]; then + echo "ERROR: branch not ready after ${PROVISION_TIMEOUT_SEC:-900}s" >&2 + exit 1 +fi + +echo "==> [provision] Database ready: $NAME (branch ready)" +if [[ -n "${PROVISION_READY_FILE:-}" ]]; then + touch "$PROVISION_READY_FILE" +fi +if [[ -n "${PSCALE_DB_FILE:-}" ]]; then + echo "$NAME" > "$PSCALE_DB_FILE" +fi +echo "$NAME" diff --git a/script/d1-import-test/reset.sql b/script/d1-import-test/reset.sql new file mode 100644 index 000000000..5625627dd --- /dev/null +++ b/script/d1-import-test/reset.sql @@ -0,0 +1,39 @@ +-- Drop all import-test tables for a clean reload (reverse dependency order). +PRAGMA foreign_keys = OFF; + +DROP TABLE IF EXISTS entity_links; +DROP TABLE IF EXISTS external_entities; +DROP TABLE IF EXISTS invoice_line_items; +DROP TABLE IF EXISTS invoices; +DROP TABLE IF EXISTS notifications; +DROP TABLE IF EXISTS sessions; +DROP TABLE IF EXISTS api_keys; +DROP TABLE IF EXISTS audit_log; +DROP TABLE IF EXISTS attachments; +DROP TABLE IF EXISTS comments; +DROP TABLE IF EXISTS task_dependencies; +DROP TABLE IF EXISTS tasks; +DROP TABLE IF EXISTS project_tags; +DROP TABLE IF EXISTS tags; +DROP TABLE IF EXISTS projects; +DROP TABLE IF EXISTS team_members; +DROP TABLE IF EXISTS teams; +DROP TABLE IF EXISTS users; +DROP TABLE IF EXISTS organizations; + +DROP TABLE IF EXISTS goose_db_version; +DROP TABLE IF EXISTS typeorm_metadata; +DROP TABLE IF EXISTS alembic_version; +DROP TABLE IF EXISTS django_migrations; +DROP TABLE IF EXISTS databasechangeloglock; +DROP TABLE IF EXISTS databasechangelog; +DROP TABLE IF EXISTS flyway_schema_history; +DROP TABLE IF EXISTS ar_internal_metadata; +DROP TABLE IF EXISTS schema_migrations; +DROP TABLE IF EXISTS sequelizemeta; +DROP TABLE IF EXISTS knex_migrations_lock; +DROP TABLE IF EXISTS knex_migrations; +DROP TABLE IF EXISTS _prisma_migrations; +DROP TABLE IF EXISTS __drizzle_migrations; + +PRAGMA foreign_keys = ON; diff --git a/script/d1-import-test/resume-chunks.sh b/script/d1-import-test/resume-chunks.sh new file mode 100755 index 000000000..f922cfed9 --- /dev/null +++ b/script/d1-import-test/resume-chunks.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +# Resume bulk chunk upload from a given chunk number (inclusive). +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DB="${D1_DATABASE:-import-test}" +START="${CHUNK_START:-24}" +LOG="${D1_RESUME_LOG:-/tmp/d1-chunk-resume.log}" +EXPORT="${D1_EXPORT_OUTPUT:-/tmp/import-test-9gb-export.sql}" + +remote_flag=(--remote) + +exec > >(tee -a "$LOG") 2>&1 + +echo "==> Resume from chunk_${START} at $(date -u +%Y-%m-%dT%H:%M:%SZ)" +wrangler d1 list | grep "$DB" || wrangler d1 list + +mapfile -t chunks < <(find "$ROOT/seed/chunks" -name 'chunk_*.sql' | sort) +total="${#chunks[@]}" +done_count=0 +failures=0 +max_retries=3 + +for chunk in "${chunks[@]}"; do + num=$(basename "$chunk" .sql | sed 's/chunk_//') + num=$((10#$num)) + if (( num < START )); then + continue + fi + done_count=$((done_count + 1)) + remaining=$((total - num + 1)) + size_mb=$(python3 -c "import os; print(round(os.path.getsize('$chunk') / (1024*1024), 1))") + + attempt=0 + while true; do + attempt=$((attempt + 1)) + echo " -> [${num}/${total}] $(basename "$chunk") (${size_mb} MB) attempt ${attempt}" + if wrangler d1 execute "$DB" "${remote_flag[@]}" --file="$chunk"; then + failures=0 + break + fi + failures=$((failures + 1)) + if (( attempt >= max_retries )); then + echo "ERROR: failed chunk ${num} after ${max_retries} attempts" >&2 + exit 1 + fi + echo " !! retrying in 10s..." + sleep 10 + done + + if (( num % 10 == 0 )); then + wrangler d1 list | grep "$DB" || true + fi +done + +echo "" +echo "==> All chunks loaded at $(date -u +%Y-%m-%dT%H:%M:%SZ)" +wrangler d1 list | grep "$DB" || wrangler d1 list + +echo "" +echo "==> Timed export" +D1_EXPORT_OUTPUT="$EXPORT" "$ROOT/time-export.sh" + +echo "" +echo "==> Complete. Export: $EXPORT" diff --git a/script/d1-import-test/run-9gb-benchmark.sh b/script/d1-import-test/run-9gb-benchmark.sh new file mode 100755 index 000000000..c16768b5a --- /dev/null +++ b/script/d1-import-test/run-9gb-benchmark.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +# D1 9 GB dataset prep (load + export). CLI import is run separately via run-cli-import.sh. +# +# Usage: +# ./script/d1-import-test/run-9gb-benchmark.sh # load + export only +# RUN_CLI_IMPORT=true ./script/d1-import-test/run-9gb-benchmark.sh # then CLI import +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXPORT="${D1_EXPORT:-/tmp/import-test-9gb-export.sql}" +LOG="${D1_BENCHMARK_LOG:-/tmp/d1-9gb-benchmark.log}" + +exec > >(tee -a "$LOG") 2>&1 + +echo "==============================================" +echo "D1 9GB benchmark started $(date -u +%Y-%m-%dT%H:%M:%SZ)" +echo "Log: $LOG" +echo "==============================================" + +load_start=$(date +%s) +D1_FRESH=true SEED_TARGET_GB=9 SKIP_SEED_GENERATE=true SKIP_MERGE=true \ + "$ROOT/load-bulk.sh" +load_end=$(date +%s) +load_sec=$((load_end - load_start)) + +echo "" +echo "==> Load phase complete: ${load_sec}s ($(python3 -c "print(round($load_sec/60,1))") min)" +wrangler d1 list | grep import-test || wrangler d1 list + +echo "" +echo "==> Export benchmark" +export_start=$(date +%s) +D1_EXPORT_OUTPUT="$EXPORT" "$ROOT/time-export.sh" +export_end=$(date +%s) +export_sec=$((export_end - export_start)) + +echo "" +echo "==============================================" +echo "D1 prep summary $(date -u +%Y-%m-%dT%H:%M:%SZ)" +echo " Load (upload chunks): ${load_sec}s" +echo " Export (remote→local): ${export_sec}s" +echo " Export file: $EXPORT" +echo "==============================================" + +if [[ "${RUN_CLI_IMPORT:-false}" == "true" ]]; then + echo "" + echo "==> Running CLI import test" + IMPORT_PROFILE=9gb D1_EXPORT="$EXPORT" SKIP_DB_CREATE="${SKIP_DB_CREATE:-true}" \ + "$ROOT/run-local-import.sh" +fi diff --git a/script/d1-import-test/run-cli-import.sh b/script/d1-import-test/run-cli-import.sh new file mode 100755 index 000000000..4dfa028f9 --- /dev/null +++ b/script/d1-import-test/run-cli-import.sh @@ -0,0 +1,233 @@ +#!/usr/bin/env bash +# End-to-end pscale import d1 CLI test against an existing wrangler SQL export. +# +# Prerequisites: pscale auth, local dev stack (singularity), Postgres database, +# and export file on disk. Does not load D1 or run wrangler export. +# +# Usage: +# ./script/d1-import-test/run-cli-import.sh +# IMPORT_PROFILE=9gb PSCALE_DB=cf-d1-import-9gb ./script/d1-import-test/run-cli-import.sh +# D1_EXPORT=/tmp/my-export.sql PSCALE_DB=my-db ./script/d1-import-test/run-cli-import.sh +set -euo pipefail + +CLI="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +ORG="${PSCALE_ORG:-bb}" +BRANCH="${PSCALE_BRANCH:-main}" +API_URL="${PSCALE_API_URL:-http://api.pscaledev.com:3000/v1}" + +case "${IMPORT_PROFILE:-smoke}" in + smoke) + DB_NAME="${PSCALE_DB:-d1-import-test}" + EXPORT="${D1_EXPORT:-/tmp/import-test-export.sql}" + ;; + 1gb) + DB_NAME="${PSCALE_DB:-cf-d1-import-1gb}" + EXPORT="${D1_EXPORT:-/tmp/import-test-1gb-export.sql}" + ;; + 5gb) + DB_NAME="${PSCALE_DB:-cf-d1-import-5gb}" + EXPORT="${D1_EXPORT:-/tmp/import-test-5gb-export.sql}" + ;; + 9gb) + DB_NAME="${PSCALE_DB:-cf-d1-import-9gb}" + EXPORT="${D1_EXPORT:-/tmp/import-test-9gb-export.sql}" + ;; + *) + echo "ERROR: unknown IMPORT_PROFILE=${IMPORT_PROFILE:-} (use smoke, 1gb, 5gb, or 9gb)" >&2 + exit 1 + ;; +esac + +METHOD="${IMPORT_METHOD:-pgloader}" +RUN_DIR="${IMPORT_RUN_DIR:-/tmp/d1-cli-import-$(date +%Y%m%d-%H%M%S)}" +mkdir -p "$RUN_DIR" + +export PSCALE_DISABLE_DEV_WARNING=true +export PSCALE_TEST_MODE=1 + +PSCALE="${CLI}/pscale-test" +if [[ ! -x "$PSCALE" ]]; then + echo "==> Building pscale-test" + (cd "$CLI" && go build -o pscale-test ./cmd/pscale) +fi + +pscale_cmd() { + "$PSCALE" --api-url "$API_URL" "$@" +} + +pscale_org_cmd() { + "$PSCALE" --api-url "$API_URL" "$@" --org "$ORG" +} + +json_field() { + local file="$1" path="$2" + python3 - "$file" "$path" <<'PY' +import json, sys +doc = json.load(open(sys.argv[1])) +path = sys.argv[2].split(".") +cur = doc +for part in path: + if not isinstance(cur, dict) or part not in cur: + cur = None + break + cur = cur[part] +if cur is None and path[0] in doc: + cur = doc[path[0]] +if cur is None: + sys.exit(1) +if isinstance(cur, bool): + print("true" if cur else "false") +else: + print(cur) +PY +} + +require_json_ok() { + local file="$1" label="$2" + local status + status="$(json_field "$file" status)" + if [[ "$status" != "ok" && "$status" != "dry_run" ]]; then + echo "ERROR: $label failed (status=$status). See $file" >&2 + exit 1 + fi +} + +require_json_ok_or_warning() { + local file="$1" label="$2" + local status + status="$(json_field "$file" status)" + if [[ "$status" != "ok" && "$status" != "dry_run" && "$status" != "warning" ]]; then + echo "ERROR: $label failed (status=$status). See $file" >&2 + exit 1 + fi +} + +require_verify_matched() { + local file="$1" + local matched + matched="$(json_field "$file" data.matched)" + if [[ "$matched" != "true" ]]; then + echo "ERROR: verify did not match. See $file" >&2 + exit 1 + fi +} + +print_timings() { + local file="$1" + python3 - "$file" <<'PY' +import json, sys +doc = json.load(open(sys.argv[1])) +data = doc.get("data") or {} +timings = data.get("timings") +if not timings: + print(" (no timings in CLI response — rebuild pscale-test)") + sys.exit(0) +total = timings.get("total_ms", 0) / 1000 +print(f" total: {total:.1f}s") +for key, label in [ + ("sqlite_staging_ms", "sqlite staging"), + ("schema_ms", "schema apply"), + ("pgloader_ms", "pgloader"), + ("index_build_ms", "index build"), + ("sequence_reset_ms", "sequence reset"), +]: + ms = timings.get(key) + if ms: + print(f" {label + ':':16} {ms/1000:.1f}s") +loads = timings.get("table_loads") or [] +if loads: + slow = sorted(loads, key=lambda x: x.get("ms", 0), reverse=True)[:5] + print(" slowest tables:") + for row in slow: + print(f" {row['table']}: {row['ms']/1000:.1f}s") +PY +} + +if [[ ! -f "$EXPORT" ]]; then + echo "ERROR: export not found: $EXPORT" >&2 + echo "Export first, e.g. wrangler d1 export import-test --remote --output $EXPORT" >&2 + exit 1 +fi + +if ! "$PSCALE" --api-url "$API_URL" auth check >/dev/null 2>&1; then + echo "ERROR: pscale auth required. Run: pscale auth login --api-url ${PSCALE_AUTH_URL:-http://auth.pscaledev.com:3000}" >&2 + exit 1 +fi + +echo "==> CLI import test" +echo " org: $ORG" +echo " database: $DB_NAME" +echo " branch: $BRANCH" +echo " export: $EXPORT ($(python3 -c "import os; print(round(os.path.getsize('$EXPORT')/(1024**3),2))") GB)" +echo " method: $METHOD" +echo " run dir: $RUN_DIR" + +IMPORT_START=$(date +%s) + +echo "==> import d1 doctor" +pscale_cmd import d1 doctor --format json | tee "$RUN_DIR/doctor.json" +require_json_ok_or_warning "$RUN_DIR/doctor.json" "doctor" + +echo "==> import d1 lint" +pscale_cmd import d1 lint --input "$EXPORT" --format json | tee "$RUN_DIR/lint.json" +require_json_ok_or_warning "$RUN_DIR/lint.json" "lint" + +echo "==> import d1 start --dry-run (preview)" +pscale_org_cmd import d1 start \ + --input "$EXPORT" \ + --database "$DB_NAME" \ + --branch "$BRANCH" \ + --dry-run \ + --force \ + --format json | tee "$RUN_DIR/preview.json" +require_json_ok "$RUN_DIR/preview.json" "preview" + +MIGRATION_ID="$(python3 -c " +import json +d = json.load(open('$RUN_DIR/preview.json')) +print(d.get('migration_id') or (d.get('data') or {}).get('migration_id', '')) +")" +if [[ -z "$MIGRATION_ID" ]]; then + echo "ERROR: could not read migration_id from $RUN_DIR/preview.json" >&2 + exit 1 +fi +echo " migration_id: $MIGRATION_ID" + +echo "==> import d1 start" +START_WALL=$(date +%s) +pscale_org_cmd import d1 start \ + --database "$DB_NAME" \ + --branch "$BRANCH" \ + --input "$EXPORT" \ + --migration-id "$MIGRATION_ID" \ + --method "$METHOD" \ + --force \ + --format json | tee "$RUN_DIR/start.json" +START_WALL_END=$(date +%s) +require_json_ok "$RUN_DIR/start.json" "start" + +echo "==> import d1 verify" +pscale_org_cmd import d1 verify \ + --database "$DB_NAME" \ + --branch "$BRANCH" \ + --migration-id "$MIGRATION_ID" \ + --input "$EXPORT" \ + --format json | tee "$RUN_DIR/verify.json" +require_json_ok "$RUN_DIR/verify.json" "verify" +require_verify_matched "$RUN_DIR/verify.json" + +IMPORT_END=$(date +%s) +WALL_SEC=$((IMPORT_END - IMPORT_START)) +START_SEC=$((START_WALL_END - START_WALL)) + +echo "" +echo "==============================================" +echo "CLI import passed" +echo " migration_id: $MIGRATION_ID" +echo " database: $ORG/$DB_NAME/$BRANCH" +echo " wall clock: ${WALL_SEC}s ($(python3 -c "print(round($WALL_SEC/60,1))") min)" +echo " start phase: ${START_SEC}s ($(python3 -c "print(round($START_SEC/60,1))") min)" +echo " artifacts: $RUN_DIR/" +echo " CLI timings:" +print_timings "$RUN_DIR/start.json" +echo "==============================================" diff --git a/script/d1-import-test/run-local-import.sh b/script/d1-import-test/run-local-import.sh new file mode 100755 index 000000000..6b2bcaa60 --- /dev/null +++ b/script/d1-import-test/run-local-import.sh @@ -0,0 +1,187 @@ +#!/usr/bin/env bash +# Provision a fresh PlanetScale Postgres DB, then run the CLI import test. +# +# By default each run targets a new empty database (FRESH_DB=new). +# For CLI-only testing against an already-provisioned DB, use run-cli-import.sh. +# +# Usage: +# IMPORT_PROFILE=9gb ./script/d1-import-test/run-local-import.sh +# FRESH_DB=recreate PSCALE_DB=cf-d1-import-9gb ./script/d1-import-test/run-local-import.sh +# FRESH_DB=reuse SKIP_DB_CREATE=true ./script/d1-import-test/run-cli-import.sh +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLI="$(cd "$ROOT/../.." && pwd)" +MONOREPO_ROOT="$(cd "$CLI/../.." && pwd)" + +ORG="${PSCALE_ORG:-bb}" +BRANCH="${PSCALE_BRANCH:-main}" +REGION="${PSCALE_REGION:-dev-aws-us-east-1-1}" +CLUSTER_SIZE="${PSCALE_CLUSTER_SIZE:-PS_10}" +PG_MAJOR_VERSION="${PSCALE_PG_MAJOR_VERSION:-17}" +REPLICAS="${PSCALE_REPLICAS:-0}" +API_URL="${PSCALE_API_URL:-http://api.pscaledev.com:3000/v1}" +AUTH_URL="${PSCALE_AUTH_URL:-http://auth.pscaledev.com:3000}" +FRESH_DB="${FRESH_DB:-new}" +AUTO_EXPORT="${AUTO_EXPORT:-false}" + +case "${IMPORT_PROFILE:-smoke}" in + smoke) + BASE_DB="${PSCALE_DB:-d1-import-test}" + EXPORT="${D1_EXPORT:-/tmp/import-test-export.sql}" + ;; + 1gb) + BASE_DB="${PSCALE_DB:-cf-d1-import-1gb}" + EXPORT="${D1_EXPORT:-/tmp/import-test-1gb-export.sql}" + ;; + 5gb) + BASE_DB="${PSCALE_DB:-cf-d1-import-5gb}" + EXPORT="${D1_EXPORT:-/tmp/import-test-5gb-export.sql}" + ;; + 9gb) + BASE_DB="${PSCALE_DB:-cf-d1-import-9gb}" + EXPORT="${D1_EXPORT:-/tmp/import-test-9gb-export.sql}" + ;; + *) + echo "ERROR: unknown IMPORT_PROFILE=${IMPORT_PROFILE:-} (use smoke, 1gb, 5gb, or 9gb)" >&2 + exit 1 + ;; +esac + +export PSCALE_DISABLE_DEV_WARNING=true +export PSCALE_TEST_MODE=1 +export PSCALE_ORG="$ORG" +export PSCALE_BRANCH="$BRANCH" +export D1_EXPORT="$EXPORT" + +PSCALE="${CLI}/pscale-test" +if [[ ! -x "$PSCALE" ]]; then + (cd "$CLI" && go build -o pscale-test ./cmd/pscale) +fi + +pscale_cmd() { + "$PSCALE" --api-url "$API_URL" "$@" --org "$ORG" +} + +db_cmd() { + "$PSCALE" --api-url "$API_URL" database "$@" --org "$ORG" +} + +db_exists() { + db_cmd show "$1" --format json >/dev/null 2>&1 +} + +region_pskube_alias() { + case "$1" in + dev-aws-us-east-1-1) echo dev-aws-fatih-useast1 ;; + dev-aws-us-east-1-2) echo dev-aws-noonan-useast1 ;; + dev-aws-us-east-1-3) echo dev-aws-shared-useast1 ;; + dev-aws-us-east-2-1) echo dev-aws-mdlayher-useast2 ;; + dev-aws-us-east-2-2) echo dev-aws-orch1-useast2 ;; + dev-aws-us-east-2-3) echo dev-aws-orch2-useast2 ;; + dev-aws-us-west-2-1) echo dev-aws-fatih-uswest2 ;; + dev-aws-us-west-2-4) echo dev-aws-shared-uswest2 ;; + dev-aws-us-west-2-5) echo dev-aws-rcrowley-uswest2 ;; + dev-aws-us-west-2-6) echo dev-aws-orch1-uswest2 ;; + dev-aws-eu-central-1-1) echo dev-aws-fatih-eucentral1 ;; + dev-aws-eu-central-1-2) echo dev-aws-amir-eucentral1 ;; + dev-aws-eu-west-2-1) echo dev-aws-shared-euwest2 ;; + dev-aws-eu-west-2-2) echo dev-aws-shared2-euwest2 ;; + dev-gcp-us-east4-1) echo dev-gcp-mdlayher-useast4 ;; + *) echo "$1" ;; + esac +} + +check_branch_pskube() { + local branch_id="$1" + local alias + alias="$(region_pskube_alias "$REGION")" + echo "==> pskube branch status (alias: $alias, branch: $branch_id)" + if ! command -v pskube >/dev/null 2>&1; then + echo "pskube not installed; skip k8s status" + return 0 + fi + pskube "$alias" get horizonclusters.hzdb.co -n hz-data "hzc-${branch_id}" -o wide 2>/dev/null || true + pskube "$alias" get horizoninstances.hzdb.co -n hz-data -l "hzdb.co/branch=${branch_id}" -o wide 2>/dev/null || true +} + +provision_database() { + local name="$1" + echo "==> Creating Postgres database $name (org: $ORG, region: $REGION, size: $CLUSTER_SIZE)" + db_cmd create "$name" \ + --engine postgresql \ + --region "$REGION" \ + --cluster-size "$CLUSTER_SIZE" \ + --replicas "$REPLICAS" \ + --major-version "$PG_MAJOR_VERSION" \ + --wait + + local branch_id + branch_id="$(pscale_cmd branch show "$name" "$BRANCH" --format json | python3 -c "import json,sys; print(json.load(sys.stdin)['id'])")" + check_branch_pskube "$branch_id" +} + +echo "==> Checking pscale auth" +if ! "$PSCALE" --api-url "$API_URL" auth check >/dev/null 2>&1; then + echo "Run interactively first:" + echo " pscale auth login --api-url $AUTH_URL" + exit 1 +fi + +echo "==> Checking singularity" +if ! curl -sS --connect-timeout 2 -o /dev/null http://127.0.0.1:8080/ 2>/dev/null; then + echo "Singularity not responding on :8080. Restart with:" + echo " cd $MONOREPO_ROOT && nix develop -c process-compose process restart singularity -p 8181 --address localhost" + exit 1 +fi + +if [[ ! -f "$EXPORT" ]]; then + if [[ "$AUTO_EXPORT" == "true" ]]; then + echo "==> Exporting D1 import-test -> $EXPORT" + wrangler d1 export import-test --remote --output "$EXPORT" + else + echo "ERROR: export not found: $EXPORT" >&2 + exit 1 + fi +fi + +DB_NAME="$BASE_DB" +if [[ -n "${PSCALE_DB:-}" ]]; then + DB_NAME="$PSCALE_DB" + echo "==> Using pre-provisioned database: $DB_NAME" +elif [[ "${SKIP_DB_CREATE:-false}" == "true" ]]; then + echo "ERROR: SKIP_DB_CREATE=true but PSCALE_DB is not set" >&2 + exit 1 +else +case "$FRESH_DB" in + new) + DB_NAME="${BASE_DB}-$(date +%Y%m%d-%H%M%S)" + echo "==> Fresh database (FRESH_DB=new): $DB_NAME" + provision_database "$DB_NAME" + ;; + recreate) + DB_NAME="$BASE_DB" + echo "==> Recreating database (FRESH_DB=recreate): $DB_NAME" + if db_exists "$DB_NAME"; then + echo " deleting existing database" + db_cmd delete "$DB_NAME" --force + fi + provision_database "$DB_NAME" + ;; + reuse) + DB_NAME="$BASE_DB" + echo "==> Reusing database (FRESH_DB=reuse): $DB_NAME" + echo " import will DROP SCHEMA public CASCADE before applying DDL" + if ! db_exists "$DB_NAME"; then + provision_database "$DB_NAME" + fi + ;; + *) + echo "ERROR: unknown FRESH_DB=$FRESH_DB (use new, recreate, or reuse)" >&2 + exit 1 + ;; +esac +fi + +export PSCALE_DB="$DB_NAME" +exec "$ROOT/run-cli-import.sh" diff --git a/script/d1-import-test/run-size-benchmark.sh b/script/d1-import-test/run-size-benchmark.sh new file mode 100755 index 000000000..649ada872 --- /dev/null +++ b/script/d1-import-test/run-size-benchmark.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash +# Import benchmark: 1 GB and 5 GB from local SQL exports, 9 GB from existing export. +# Parallelizes export generation and DB provisioning; imports run when both are ready. +# +# Usage: +# ./script/d1-import-test/run-size-benchmark.sh +# BUILD_LOCAL=false ./script/d1-import-test/run-size-benchmark.sh +# SIZES="1 9" PARALLEL_IMPORTS=false ./script/d1-import-test/run-size-benchmark.sh +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLI="$(cd "$ROOT/../.." && pwd)" +MONOREPO_ROOT="$(cd "$CLI/../.." && pwd)" +LOG="${D1_BENCHMARK_LOG:-/tmp/d1-size-benchmark.log}" +STATE_DIR="${BENCH_STATE_DIR:-/tmp/d1-bench-$(date +%Y%m%d-%H%M%S)}" +BUILD_LOCAL="${BUILD_LOCAL:-true}" +PARALLEL_IMPORTS="${PARALLEL_IMPORTS:-true}" +SIZES=(${SIZES:-1 5 9}) + +mkdir -p "$STATE_DIR" +exec > >(tee -a "$LOG") 2>&1 + +echo "==============================================" +echo "D1 import size benchmark $(date -u +%Y-%m-%dT%H:%M:%SZ)" +echo "Log: $LOG" +echo "State dir: $STATE_DIR" +echo "Sizes: ${SIZES[*]}" +echo "==============================================" + +if [[ ! -x "$CLI/pscale-test" ]]; then + echo "==> Building pscale-test" + (cd "$CLI" && go build -o pscale-test ./cmd/pscale) +fi + +export PSCALE_DISABLE_DEV_WARNING=true +export PSCALE_TEST_MODE=1 +export PSCALE_ORG="${PSCALE_ORG:-bb}" + +PSCALE="${CLI}/pscale-test" +API_URL="${PSCALE_API_URL:-http://api.pscaledev.com:3000/v1}" +if ! "$PSCALE" --api-url "$API_URL" auth check >/dev/null 2>&1; then + echo "ERROR: pscale auth required" >&2 + exit 1 +fi +if ! curl -sS --connect-timeout 2 -o /dev/null http://127.0.0.1:8080/ 2>/dev/null; then + echo "ERROR: singularity not responding on :8080" >&2 + exit 1 +fi + +bench_ts="$(basename "$STATE_DIR" | sed 's/^d1-bench-//')" +declare -A EXPORT_FILE DB_NAME BASE_DB + +for size in "${SIZES[@]}"; do + BASE_DB["$size"]="cf-d1-import-${size}gb" + DB_NAME["$size"]="${BASE_DB[$size]}-${bench_ts}" + echo "${DB_NAME[$size]}" > "$STATE_DIR/db-${size}gb.name" + + if [[ "$size" == "9" ]]; then + EXPORT_FILE["$size"]="${D1_EXPORT_9GB:-/tmp/import-test-9gb-export.sql}" + if [[ ! -f "${EXPORT_FILE[$size]}" ]]; then + echo "ERROR: 9 GB export not found: ${EXPORT_FILE[$size]}" >&2 + exit 1 + fi + touch "$STATE_DIR/export-9gb.ready" + else + EXPORT_FILE["$size"]="/tmp/import-test-${size}gb-export.sql" + fi +done + +echo "" +echo "==> Phase 1: parallel export generation + DB provisioning" +pids=() + +for size in "${SIZES[@]}"; do + if [[ "$size" != "9" && "$BUILD_LOCAL" == "true" ]]; then + if [[ -f "${EXPORT_FILE[$size]}" && "${REGENERATE_EXPORTS:-true}" != "true" ]]; then + echo "==> [${size}gb] Reusing export ${EXPORT_FILE[$size]}" + touch "$STATE_DIR/export-${size}gb.ready" + else + ( + export SEED_DIR="$ROOT/seed/local-${size}gb" + export D1_EXPORT="${EXPORT_FILE[$size]}" + export EXPORT_READY_FILE="$STATE_DIR/export-${size}gb.ready" + "$ROOT/build-local-export.sh" "$size" + ) > "$STATE_DIR/export-${size}gb.log" 2>&1 & + pids+=("$!") + echo "==> [${size}gb] Export build started (pid $!)" + fi + elif [[ "$size" != "9" && -f "${EXPORT_FILE[$size]}" ]]; then + touch "$STATE_DIR/export-${size}gb.ready" + fi + + ( + PROVISION_READY_FILE="$STATE_DIR/provision-${size}gb.ready" \ + "$ROOT/provision-database.sh" "${DB_NAME[$size]}" + ) > "$STATE_DIR/provision-${size}gb.log" 2>&1 & + pids+=("$!") + echo "==> [${size}gb] DB provision started: ${DB_NAME[$size]} (pid $!)" +done + +echo "==> Phase 1 jobs launched; starting import watcher (exports + branch readiness)" +export BENCH_STATE_DIR="$STATE_DIR" +export BENCH_WATCH_LOG="$STATE_DIR/watch.log" +export SIZES="${SIZES[*]}" +chmod +x "$ROOT/bench-watch-imports.sh" +"$ROOT/bench-watch-imports.sh" +exit $? diff --git a/script/d1-import-test/run-storage-benchmark.sh b/script/d1-import-test/run-storage-benchmark.sh new file mode 100755 index 000000000..01202be52 --- /dev/null +++ b/script/d1-import-test/run-storage-benchmark.sh @@ -0,0 +1,118 @@ +#!/usr/bin/env bash +# Full 1/5/9 GB storage benchmark: build exports, provision DBs, import, collect results. +# SEED_TARGET_GB = Postgres logical blob storage (sum of attachment payload bytes). +# +# Usage: +# ./script/d1-import-test/run-storage-benchmark.sh +# BENCH_STATE_DIR=/tmp/d1-bench-manual ./script/d1-import-test/run-storage-benchmark.sh +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CLI="$(cd "$ROOT/../.." && pwd)" +STATE_DIR="${BENCH_STATE_DIR:-/tmp/d1-bench-$(date +%Y%m%d-%H%M%S)}" +LOG="${D1_BENCHMARK_LOG:-/tmp/d1-storage-benchmark.log}" +SIZES=(${SIZES:-1 5 9}) +bench_ts="$(basename "$STATE_DIR" | sed 's/^d1-bench-//')" + +mkdir -p "$STATE_DIR" +exec >> "$LOG" 2>&1 + +echo "==============================================" +echo "Storage benchmark start $(date -u +%Y-%m-%dT%H:%M:%SZ)" +echo "State: $STATE_DIR" +echo "Sizes: ${SIZES[*]}" +echo "==============================================" + +export PSCALE_DISABLE_DEV_WARNING=true +export PSCALE_TEST_MODE=1 +export PSCALE_ORG="${PSCALE_ORG:-bb}" + +if [[ ! -x "$CLI/pscale-test" ]]; then + (cd "$CLI" && go build -o pscale-test ./cmd/pscale) +fi + +PSCALE="${CLI}/pscale-test" +API_URL="${PSCALE_API_URL:-http://api.pscaledev.com:3000/v1}" +if ! "$PSCALE" --api-url "$API_URL" auth check >/dev/null 2>&1; then + echo "ERROR: pscale auth required" >&2 + exit 1 +fi + +declare -A EXPORT_FILE DB_NAME +for size in "${SIZES[@]}"; do + DB_NAME["$size"]="cf-d1-import-${size}gb-${bench_ts}" + echo "${DB_NAME[$size]}" > "$STATE_DIR/db-${size}gb.name" + if [[ "$size" == "9" ]]; then + EXPORT_FILE["$size"]="${D1_EXPORT_9GB:-/tmp/import-test-9gb-export.sql}" + if [[ ! -f "${EXPORT_FILE[$size]}" ]]; then + echo "ERROR: 9 GB export missing: ${EXPORT_FILE[$size]}" >&2 + exit 1 + fi + touch "$STATE_DIR/export-9gb.ready" + else + EXPORT_FILE["$size"]="/tmp/import-test-${size}gb-export.sql" + fi +done + +echo "" +echo "==> Phase 1: build 1GB and 5GB exports (Postgres storage target)" +pids=() +for size in "${SIZES[@]}"; do + [[ "$size" == "9" ]] && continue + ( + export SEED_DIR="$ROOT/seed/local-${size}gb" + export D1_EXPORT="${EXPORT_FILE[$size]}" + export EXPORT_READY_FILE="$STATE_DIR/export-${size}gb.ready" + export REGENERATE_SEED=true + "$ROOT/build-local-export.sh" "$size" + ) > "$STATE_DIR/export-${size}gb.log" 2>&1 & + pids+=("$!") + echo "export ${size}gb pid $!" +done + +fail=0 +for pid in "${pids[@]}"; do + wait "$pid" || fail=1 +done +if [[ "$fail" -ne 0 ]]; then + echo "ERROR: export build failed; see $STATE_DIR/export-*.log" >&2 + exit 1 +fi + +echo "" +echo "==> Phase 2: provision Postgres databases" +pids=() +for size in "${SIZES[@]}"; do + ( + PROVISION_READY_FILE="$STATE_DIR/provision-${size}gb.ready" \ + "$ROOT/provision-database.sh" "${DB_NAME[$size]}" + ) > "$STATE_DIR/provision-${size}gb.log" 2>&1 & + pids+=("$!") + echo "provision ${size}gb: ${DB_NAME[$size]} pid $!" +done +fail=0 +for pid in "${pids[@]}"; do + wait "$pid" || fail=1 +done +if [[ "$fail" -ne 0 ]]; then + echo "ERROR: provision failed; see $STATE_DIR/provision-*.log" >&2 + exit 1 +fi + +echo "" +echo "==> Phase 3: imports" +export BENCH_STATE_DIR="$STATE_DIR" +export BENCH_WATCH_LOG="$STATE_DIR/watch.log" +export SIZES="${SIZES[*]}" +export POLL_SEC="${POLL_SEC:-15}" +"$ROOT/bench-watch-imports.sh" + +echo "" +echo "==> Phase 4: collect results" +"$ROOT/collect-benchmark-results.sh" "$STATE_DIR" + +echo "" +echo "==============================================" +echo "Storage benchmark complete $(date -u +%Y-%m-%dT%H:%M:%SZ)" +echo "Report: $STATE_DIR/report.txt" +echo "==============================================" diff --git a/script/d1-import-test/schema.sql b/script/d1-import-test/schema.sql new file mode 100644 index 000000000..f2caed1f1 --- /dev/null +++ b/script/d1-import-test/schema.sql @@ -0,0 +1,362 @@ +-- Stress schema for pscale import d1 testing (import-test D1 database). +-- Exercises: autoincrement PKs, 0/1 booleans, TEXT timestamps, JSON-in-TEXT, +-- multi-level FKs, junction tables, table-level FKs, self-referential FKs, +-- REAL columns, BLOB columns, CHECK constraints, indexes, __drizzle_migrations. +-- Note: FTS5 virtual tables are omitted — they break ParseDump on export. + +PRAGMA foreign_keys = ON; + +CREATE TABLE __drizzle_migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + hash TEXT NOT NULL, + created_at INTEGER NOT NULL +); + +CREATE TABLE _prisma_migrations ( + id TEXT PRIMARY KEY, + checksum TEXT NOT NULL, + finished_at TEXT, + migration_name TEXT NOT NULL, + logs TEXT, + rolled_back_at TEXT, + started_at TEXT NOT NULL, + applied_steps_count INTEGER NOT NULL +); + +CREATE TABLE knex_migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + batch INTEGER, + migration_time INTEGER +); + +CREATE TABLE knex_migrations_lock ( + "index" INTEGER PRIMARY KEY, + is_locked INTEGER +); + +CREATE TABLE sequelizemeta ( + name TEXT PRIMARY KEY +); + +CREATE TABLE schema_migrations ( + version TEXT PRIMARY KEY +); + +CREATE TABLE ar_internal_metadata ( + key TEXT PRIMARY KEY, + value TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +CREATE TABLE flyway_schema_history ( + installed_rank INTEGER PRIMARY KEY, + version TEXT, + description TEXT NOT NULL, + type TEXT NOT NULL, + script TEXT NOT NULL, + checksum INTEGER, + installed_by TEXT NOT NULL, + installed_on TEXT NOT NULL DEFAULT (datetime('now')), + execution_time INTEGER NOT NULL, + success INTEGER NOT NULL +); + +CREATE TABLE databasechangelog ( + id TEXT PRIMARY KEY, + author TEXT NOT NULL, + filename TEXT NOT NULL, + dateexecuted TEXT NOT NULL, + orderexecuted INTEGER NOT NULL, + exectype TEXT NOT NULL, + md5sum TEXT, + description TEXT, + comments TEXT, + tag TEXT, + liquibase TEXT, + contexts TEXT, + labels TEXT, + deployment_id TEXT +); + +CREATE TABLE databasechangeloglock ( + id INTEGER PRIMARY KEY, + locked INTEGER NOT NULL, + lockgranted TEXT, + lockedby TEXT +); + +CREATE TABLE django_migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + app TEXT NOT NULL, + name TEXT NOT NULL, + applied TEXT NOT NULL +); + +CREATE TABLE alembic_version ( + version_num TEXT PRIMARY KEY +); + +CREATE TABLE typeorm_metadata ( + type TEXT NOT NULL, + "database" TEXT, + "schema" TEXT, + "table" TEXT, + name TEXT, + value TEXT +); + +CREATE TABLE goose_db_version ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version_id INTEGER NOT NULL, + is_applied INTEGER NOT NULL, + tstamp TEXT DEFAULT (datetime('now')) +); + +CREATE TABLE organizations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + slug TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, + plan TEXT NOT NULL DEFAULT 'free', + is_active INTEGER NOT NULL DEFAULT 1, + settings TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT +); + +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + org_id INTEGER NOT NULL, + email TEXT NOT NULL, + display_name TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'member', + is_active INTEGER NOT NULL DEFAULT 1, + is_admin INTEGER NOT NULL DEFAULT 0, + profile_json TEXT, + last_login_at TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (org_id) REFERENCES organizations(id) +); + +CREATE TABLE teams ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + org_id INTEGER NOT NULL, + name TEXT NOT NULL, + is_archived INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + FOREIGN KEY (org_id) REFERENCES organizations(id) +); + +CREATE TABLE team_members ( + team_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + role TEXT NOT NULL DEFAULT 'member', + joined_at TEXT NOT NULL, + PRIMARY KEY (team_id, user_id), + FOREIGN KEY (team_id) REFERENCES teams(id), + FOREIGN KEY (user_id) REFERENCES users(id) +); + +CREATE TABLE projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + org_id INTEGER NOT NULL, + owner_user_id INTEGER NOT NULL, + team_id INTEGER, + slug TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT, + is_public INTEGER NOT NULL DEFAULT 0, + is_archived INTEGER NOT NULL DEFAULT 0, + metadata TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (org_id) REFERENCES organizations(id), + FOREIGN KEY (owner_user_id) REFERENCES users(id), + FOREIGN KEY (team_id) REFERENCES teams(id), + UNIQUE (org_id, slug) +); + +CREATE TABLE tags ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + org_id INTEGER NOT NULL, + label TEXT NOT NULL, + color TEXT NOT NULL DEFAULT '#6366f1', + FOREIGN KEY (org_id) REFERENCES organizations(id), + UNIQUE (org_id, label) +); + +CREATE TABLE project_tags ( + project_id INTEGER NOT NULL, + tag_id INTEGER NOT NULL, + PRIMARY KEY (project_id, tag_id), + FOREIGN KEY (project_id) REFERENCES projects(id), + FOREIGN KEY (tag_id) REFERENCES tags(id) +); + +CREATE TABLE tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL, + assignee_user_id INTEGER, + parent_task_id INTEGER, + title TEXT NOT NULL, + body TEXT, + status TEXT NOT NULL DEFAULT 'open' CHECK (status IN ('open', 'in_progress', 'done', 'cancelled')), + priority INTEGER NOT NULL DEFAULT 2 CHECK (priority BETWEEN 1 AND 5), + estimate_hours REAL, + is_blocked INTEGER NOT NULL DEFAULT 0, + labels_json TEXT, + due_at TEXT, + completed_at TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id), + FOREIGN KEY (assignee_user_id) REFERENCES users(id), + FOREIGN KEY (parent_task_id) REFERENCES tasks(id) +); + +CREATE TABLE task_dependencies ( + task_id INTEGER NOT NULL, + depends_on_task_id INTEGER NOT NULL, + PRIMARY KEY (task_id, depends_on_task_id), + FOREIGN KEY (task_id) REFERENCES tasks(id), + FOREIGN KEY (depends_on_task_id) REFERENCES tasks(id) +); + +CREATE TABLE comments ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + author_user_id INTEGER NOT NULL, + parent_comment_id INTEGER, + body TEXT NOT NULL, + is_edited INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (task_id) REFERENCES tasks(id), + FOREIGN KEY (author_user_id) REFERENCES users(id), + FOREIGN KEY (parent_comment_id) REFERENCES comments(id) +); + +CREATE TABLE attachments ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + comment_id INTEGER NOT NULL, + filename TEXT NOT NULL, + content_type TEXT NOT NULL, + byte_size INTEGER NOT NULL, + checksum TEXT NOT NULL, + payload BLOB, + uploaded_at TEXT NOT NULL, + FOREIGN KEY (comment_id) REFERENCES comments(id) +); + +CREATE TABLE audit_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + org_id INTEGER NOT NULL, + actor_user_id INTEGER, + action TEXT NOT NULL, + entity_type TEXT NOT NULL, + entity_id INTEGER NOT NULL, + payload TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (org_id) REFERENCES organizations(id), + FOREIGN KEY (actor_user_id) REFERENCES users(id) +); + +CREATE TABLE api_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + name TEXT NOT NULL, + key_prefix TEXT NOT NULL, + is_active INTEGER NOT NULL DEFAULT 1, + scopes_json TEXT NOT NULL, + expires_at TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) +); + +CREATE TABLE sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + token_hash TEXT NOT NULL UNIQUE, + ip_address TEXT, + user_agent TEXT, + last_seen_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + created_at TEXT NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) +); + +CREATE TABLE notifications ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + kind TEXT NOT NULL, + title TEXT NOT NULL, + body TEXT, + payload TEXT, + is_read INTEGER NOT NULL DEFAULT 0, + read_at TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) +); + +CREATE TABLE invoices ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + org_id INTEGER NOT NULL, + invoice_number TEXT NOT NULL UNIQUE, + subtotal REAL NOT NULL, + tax REAL NOT NULL DEFAULT 0, + total REAL NOT NULL, + status TEXT NOT NULL DEFAULT 'draft', + issued_at TEXT, + due_at TEXT, + paid_at TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (org_id) REFERENCES organizations(id) +); + +CREATE TABLE invoice_line_items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + invoice_id INTEGER NOT NULL, + task_id INTEGER, + description TEXT NOT NULL, + quantity REAL NOT NULL DEFAULT 1, + unit_price REAL NOT NULL, + amount REAL NOT NULL, + FOREIGN KEY (invoice_id) REFERENCES invoices(id), + FOREIGN KEY (task_id) REFERENCES tasks(id) +); + +CREATE TABLE external_entities ( + id TEXT PRIMARY KEY, + org_id INTEGER NOT NULL, + name TEXT NOT NULL, + kind TEXT NOT NULL DEFAULT 'webhook', + metadata TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (org_id) REFERENCES organizations(id) +); + +CREATE TABLE entity_links ( + entity_id TEXT NOT NULL, + task_id INTEGER NOT NULL, + linked_at TEXT NOT NULL, + PRIMARY KEY (entity_id, task_id), + FOREIGN KEY (entity_id) REFERENCES external_entities(id), + FOREIGN KEY (task_id) REFERENCES tasks(id) +); + +CREATE INDEX idx_users_org_id ON users(org_id); +CREATE INDEX idx_users_email ON users(email); +CREATE INDEX idx_teams_org_id ON teams(org_id); +CREATE INDEX idx_projects_org_id ON projects(org_id); +CREATE INDEX idx_projects_owner ON projects(owner_user_id); +CREATE INDEX idx_tasks_project_id ON tasks(project_id); +CREATE INDEX idx_tasks_assignee ON tasks(assignee_user_id); +CREATE INDEX idx_tasks_status ON tasks(status); +CREATE INDEX idx_comments_task_id ON comments(task_id); +CREATE INDEX idx_audit_log_org_created ON audit_log(org_id, created_at); +CREATE INDEX idx_notifications_user_unread ON notifications(user_id, is_read); +CREATE INDEX idx_external_entities_org_id ON external_entities(org_id); +CREATE UNIQUE INDEX idx_external_entities_org_name ON external_entities(org_id, name); diff --git a/script/d1-import-test/time-export.sh b/script/d1-import-test/time-export.sh new file mode 100755 index 000000000..ab4cfebe8 --- /dev/null +++ b/script/d1-import-test/time-export.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# Time wrangler d1 export from remote D1 to a local SQL file. +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DB="${D1_DATABASE:-import-test}" +OUTPUT="${D1_EXPORT_OUTPUT:-/tmp/import-test-export.sql}" +REMOTE="${D1_REMOTE:-true}" + +remote_flag=() +if [[ "$REMOTE" == "true" ]]; then + remote_flag=(--remote) +fi + +echo "==> D1 database: $DB" +echo "==> Output file: $OUTPUT" +echo "==> Remote size (wrangler d1 list):" +wrangler d1 list 2>/dev/null | grep -E "$DB|file_size" || wrangler d1 list + +echo "" +echo "==> Starting export at $(date -u +%Y-%m-%dT%H:%M:%SZ)" +start_ns=$(python3 -c "import time; print(time.time_ns())") + +wrangler d1 export "$DB" "${remote_flag[@]}" --output "$OUTPUT" + +end_ns=$(python3 -c "import time; print(time.time_ns())") +elapsed_sec=$(python3 -c "print(round(($end_ns - $start_ns) / 1e9, 2))") + +bytes=$(stat -f%z "$OUTPUT" 2>/dev/null || stat -c%s "$OUTPUT") +mb=$(python3 -c "print(round($bytes / (1024*1024), 2))") +gb=$(python3 -c "print(round($bytes / (1024**3), 3))") +throughput=$(python3 -c "print(round($bytes / (1024*1024) / $elapsed_sec, 2)) if $elapsed_sec > 0 else 0") + +echo "" +echo "==> Finished at $(date -u +%Y-%m-%dT%H:%M:%SZ)" +echo " Elapsed: ${elapsed_sec}s" +echo " File size: ${mb} MB (${gb} GB)" +echo " Throughput: ${throughput} MB/s" +echo " Path: $OUTPUT" diff --git a/test_import_d1.sh b/test_import_d1.sh new file mode 100755 index 000000000..fdf59115a --- /dev/null +++ b/test_import_d1.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT" + +export PSCALE_TEST_MODE=1 + +FIXTURE="$ROOT/internal/migrate/d1/testdata/sample_d1_export.sql" + +echo "==> Building pscale CLI..." +go build -o ./pscale-test ./cmd/pscale + +PSCALE="./pscale-test" + +echo "==> import d1 doctor" +$PSCALE import d1 doctor --format json | tee /tmp/d1-doctor.json +grep -q '"status"' /tmp/d1-doctor.json + +echo "==> import d1 lint" +$PSCALE import d1 lint --input "$FIXTURE" --format json | tee /tmp/d1-lint.json +grep -q '"phase": "lint"' /tmp/d1-lint.json +grep -q 'AUTOINCREMENT' /tmp/d1-lint.json + +echo "==> import d1 start --dry-run (preview)" +$PSCALE import d1 start \ + --input "$FIXTURE" \ + --org testorg \ + --database testdb \ + --branch main \ + --dry-run \ + --force \ + --format json | tee /tmp/d1-preview.json +MIGRATION_ID=$(grep -o '"migration_id": "[^"]*"' /tmp/d1-preview.json | head -1 | cut -d'"' -f4) +test -n "$MIGRATION_ID" +grep -q '"dry_run": true' /tmp/d1-preview.json + +echo "==> import d1 convert-schema" +SCHEMA_OUT="$(mktemp -t d1-schema.XXXXXX.sql)" +$PSCALE import d1 convert-schema --input "$FIXTURE" --output "$SCHEMA_OUT" --format json +grep -q 'GENERATED BY DEFAULT AS IDENTITY' "$SCHEMA_OUT" + +echo "==> import d1 status" +$PSCALE import d1 status \ + --org testorg \ + --database testdb \ + --branch main \ + --migration-id "$MIGRATION_ID" \ + --format json | tee /tmp/d1-status.json +grep -q '"phase": "planned"' /tmp/d1-status.json + +echo "==> import d1 verify (source-only, no dest URI)" +$PSCALE import d1 verify \ + --input "$FIXTURE" \ + --migration-id "$MIGRATION_ID" \ + --format json | tee /tmp/d1-verify.json || true + +echo "==> import d1 start dry-run" +$PSCALE import d1 start \ + --org testorg \ + --database testdb \ + --branch main \ + --input "$FIXTURE" \ + --migration-id "$MIGRATION_ID" \ + --dry-run \ + --force \ + --format json | tee /tmp/d1-start-dryrun.json +grep -q '"dry_run": true' /tmp/d1-start-dryrun.json + +echo "==> import d1 complete" +$PSCALE import d1 complete \ + --org testorg \ + --database testdb \ + --branch main \ + --migration-id "$MIGRATION_ID" \ + --force \ + --format json + +echo "==> All import d1 smoke tests passed" From e5fc41ce7a30758f7a0eee52c7aab98c27a9f734 Mon Sep 17 00:00:00 2001 From: Elom Gomez Date: Fri, 26 Jun 2026 14:44:32 -0500 Subject: [PATCH 2/6] Fix golangci-lint string builder warnings in D1 migrate code. Use fmt.Fprintf instead of WriteString with Sprintf or concatenation to satisfy QF1012 and writestring checks. Co-authored-by: Cursor --- internal/migrate/d1/convert.go | 6 +++--- internal/migrate/d1/import.go | 2 +- internal/migrate/d1/pgloader.go | 26 +++++++++++++------------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/internal/migrate/d1/convert.go b/internal/migrate/d1/convert.go index 51dd1ffbe..42ef44d09 100644 --- a/internal/migrate/d1/convert.go +++ b/internal/migrate/d1/convert.go @@ -32,7 +32,7 @@ func ConvertSchemaParts(inputPath string) (SchemaParts, int, error) { var tableBuf strings.Builder tableBuf.WriteString("-- Generated by pscale import d1 convert-schema (tables)\n") - tableBuf.WriteString("-- Source: " + inputPath + "\n\n") + fmt.Fprintf(&tableBuf, "-- Source: %s\n\n", inputPath) converted := 0 tableByName := make(map[string]TableSchema, len(tables)) @@ -55,7 +55,7 @@ func ConvertSchemaParts(inputPath string) (SchemaParts, int, error) { var indexBuf strings.Builder if len(indexes) > 0 { indexBuf.WriteString("-- Generated by pscale import d1 convert-schema (indexes)\n") - indexBuf.WriteString("-- Source: " + inputPath + "\n\n") + fmt.Fprintf(&indexBuf, "-- Source: %s\n\n", inputPath) indexBuf.WriteString("-- Indexes\n") for _, idx := range indexes { if IsORMMetadataTable(idx.Table) { @@ -95,7 +95,7 @@ func ConvertSchema(inputPath, outputPath string) (int, error) { func convertTableDDL(table TableSchema, all []TableSchema) string { var b strings.Builder - b.WriteString(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (\n", quoteIdent(table.Name))) + fmt.Fprintf(&b, "CREATE TABLE IF NOT EXISTS %s (\n", quoteIdent(table.Name)) var lines []string for _, col := range table.Columns { diff --git a/internal/migrate/d1/import.go b/internal/migrate/d1/import.go index dc3991616..ee4c47ccb 100644 --- a/internal/migrate/d1/import.go +++ b/internal/migrate/d1/import.go @@ -357,7 +357,7 @@ func applyPostgresIndexes(ctx context.Context, opts ImportOptions, destURI strin var b strings.Builder b.WriteString("-- Generated by pscale import d1 (post-load indexes)\n") - b.WriteString(fmt.Sprintf("SET maintenance_work_mem TO '%s';\n", pgloaderIndexMaintenanceWorkMem)) + fmt.Fprintf(&b, "SET maintenance_work_mem TO '%s';\n", pgloaderIndexMaintenanceWorkMem) b.WriteString(parts.Indexes) indexPath := filepath.Join(workDir, fmt.Sprintf("postgres-indexes-%s.sql", opts.MigrationID)) diff --git a/internal/migrate/d1/pgloader.go b/internal/migrate/d1/pgloader.go index 7f1d2f74e..84ff0cc62 100644 --- a/internal/migrate/d1/pgloader.go +++ b/internal/migrate/d1/pgloader.go @@ -225,8 +225,8 @@ func buildPgloaderScript(sqlitePath, destURI string, cfg pgloaderScriptConfig, c var b strings.Builder b.WriteString("LOAD DATABASE\n") - b.WriteString(" FROM " + src + "\n") - b.WriteString(" INTO " + target + "\n") + fmt.Fprintf(&b, " FROM %s\n", src) + fmt.Fprintf(&b, " INTO %s\n", target) b.WriteString("\n") if cfg.dataOnly { @@ -236,28 +236,28 @@ func buildPgloaderScript(sqlitePath, destURI string, cfg pgloaderScriptConfig, c } else { b.WriteString(" reset no sequences,\n") } - b.WriteString(fmt.Sprintf(" workers = %d, concurrency = %d,\n", profile.workers, profile.concurrency)) - b.WriteString(fmt.Sprintf(" batch rows = %d,\n", profile.batchRows)) - b.WriteString(" batch size = " + pgloaderBatchSize + ",\n") - b.WriteString(fmt.Sprintf(" prefetch rows = %d\n", profile.prefetchRows)) + fmt.Fprintf(&b, " workers = %d, concurrency = %d,\n", profile.workers, profile.concurrency) + fmt.Fprintf(&b, " batch rows = %d,\n", profile.batchRows) + fmt.Fprintf(&b, " batch size = %s,\n", pgloaderBatchSize) + fmt.Fprintf(&b, " prefetch rows = %d\n", profile.prefetchRows) } else { b.WriteString(" WITH include drop, create tables, create indexes, reset sequences,\n") - b.WriteString(fmt.Sprintf(" workers = %d, concurrency = %d,\n", profile.workers, profile.concurrency)) - b.WriteString(fmt.Sprintf(" batch rows = %d,\n", profile.batchRows)) - b.WriteString(" batch size = " + pgloaderBatchSize + ",\n") - b.WriteString(fmt.Sprintf(" prefetch rows = %d\n", profile.prefetchRows)) + fmt.Fprintf(&b, " workers = %d, concurrency = %d,\n", profile.workers, profile.concurrency) + fmt.Fprintf(&b, " batch rows = %d,\n", profile.batchRows) + fmt.Fprintf(&b, " batch size = %s,\n", pgloaderBatchSize) + fmt.Fprintf(&b, " prefetch rows = %d\n", profile.prefetchRows) } if cfg.tableName != "" { b.WriteString("\n") - b.WriteString(" INCLUDING ONLY TABLE NAMES LIKE " + pgloaderQuotePattern(cfg.tableName) + "\n") + fmt.Fprintf(&b, " INCLUDING ONLY TABLE NAMES LIKE %s\n", pgloaderQuotePattern(cfg.tableName)) } appendPgloaderCasts(&b, castTables, allTables) b.WriteString("\n") - b.WriteString(fmt.Sprintf(" SET work_mem to '%s', maintenance_work_mem to '%s', synchronous_commit to 'off';\n", - pgloaderLoadWorkMem, pgloaderLoadMaintenanceWorkMem)) + fmt.Fprintf(&b, " SET work_mem to '%s', maintenance_work_mem to '%s', synchronous_commit to 'off';\n", + pgloaderLoadWorkMem, pgloaderLoadMaintenanceWorkMem) return b.String() } From a2c54bcce7662033c8e287f1e4a3fff8faf69929 Mon Sep 17 00:00:00 2001 From: Elom Gomez Date: Fri, 26 Jun 2026 14:50:01 -0500 Subject: [PATCH 3/6] Fix bugbot issues --- internal/migrate/d1/prepare.go | 10 +----- internal/migrate/d1/prepare_test.go | 40 +++++++++++++++++++++++ internal/migrate/d1/verify_checks.go | 9 ++++- internal/migrate/d1/verify_checks_test.go | 24 ++++++++++++++ 4 files changed, 73 insertions(+), 10 deletions(-) diff --git a/internal/migrate/d1/prepare.go b/internal/migrate/d1/prepare.go index 691e49ee5..0591f3fa4 100644 --- a/internal/migrate/d1/prepare.go +++ b/internal/migrate/d1/prepare.go @@ -76,15 +76,7 @@ func resolvePlan(opts ImportOptions, method string, lint *LintResult) (*PlanResu state, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID) if err != nil { - return createAndSavePlan(PlanOptions{ - InputPath: opts.InputPath, - Org: opts.Org, - Database: opts.Database, - Branch: opts.Branch, - Method: method, - MigrationID: opts.MigrationID, - Lint: lint, - }) + return nil, err } if opts.InputPath != "" && state.InputPath != "" && state.InputPath != opts.InputPath { diff --git a/internal/migrate/d1/prepare_test.go b/internal/migrate/d1/prepare_test.go index 5ed22f08a..1550b55a5 100644 --- a/internal/migrate/d1/prepare_test.go +++ b/internal/migrate/d1/prepare_test.go @@ -3,6 +3,7 @@ package d1 import ( "bytes" "context" + "os" "strings" "testing" @@ -61,6 +62,45 @@ func TestImport_BlocksOnLintErrors(t *testing.T) { } } +func TestPrepareImportRejectsMissingMigrationState(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + _, err := PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + Branch: "main", + MigrationID: "missing-migration-id", + }) + requireMigrationErr(t, err, ErrCodeNotFound) +} + +func TestPrepareImportRejectsCorruptMigrationState(t *testing.T) { + t.Setenv("PSCALE_TEST_MODE", "1") + + store, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + migrationID := "corrupt-migration-id" + path := store.statePath("acme", "mydb", "main", migrationID) + if err := os.WriteFile(path, []byte("{not-json"), 0o600); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = os.Remove(path) }) + + _, err = PrepareImport(ImportOptions{ + InputPath: testFixture(t), + Org: "acme", + Database: "mydb", + Branch: "main", + MigrationID: migrationID, + }) + if err == nil { + t.Fatal("expected corrupt migration state to fail") + } +} + func TestPrintStartPreview(t *testing.T) { prepared, err := PrepareImport(ImportOptions{ InputPath: testFixture(t), diff --git a/internal/migrate/d1/verify_checks.go b/internal/migrate/d1/verify_checks.go index 59ce6a8cc..f7ced4d28 100644 --- a/internal/migrate/d1/verify_checks.go +++ b/internal/migrate/d1/verify_checks.go @@ -434,6 +434,9 @@ func sqliteSignatureColumnExpr(col ColumnSchema) string { if isJSONText(col) { return fmt.Sprintf(`COALESCE(json(%q), CAST(%q AS TEXT), '')`, col.Name, col.Name) } + if isBlobColumn(col) { + return fmt.Sprintf(`COALESCE(hex(%q), '')`, col.Name) + } return fmt.Sprintf(`COALESCE(CAST(%q AS TEXT), '')`, col.Name) } @@ -451,7 +454,7 @@ func postgresSignatureColumnExpr(col ColumnSchema, table TableSchema, all []Tabl return fmt.Sprintf(`COALESCE(%s::jsonb::text, '')`, name) case "BYTEA": name := quoteIdent(col.Name) - return fmt.Sprintf(`COALESCE(convert_from(%s, 'UTF8'), '')`, name) + return fmt.Sprintf(`COALESCE(encode(%s, 'hex'), '')`, name) default: return fmt.Sprintf(`COALESCE(%s::text, '')`, quoteIdent(col.Name)) } @@ -520,6 +523,10 @@ func byteaValuesEqual(sqliteText, pgText string) bool { return false } +func isBlobColumn(col ColumnSchema) bool { + return strings.Contains(strings.ToUpper(col.Type), "BLOB") +} + func sqliteRowSignature(ctx context.Context, sqlitePath string, table TableSchema, pkCol, pkVal string) (string, error) { cols := make([]string, 0, len(table.Columns)) for _, col := range table.Columns { diff --git a/internal/migrate/d1/verify_checks_test.go b/internal/migrate/d1/verify_checks_test.go index 6b694bc24..6d74601dc 100644 --- a/internal/migrate/d1/verify_checks_test.go +++ b/internal/migrate/d1/verify_checks_test.go @@ -2,6 +2,7 @@ package d1 import ( "encoding/hex" + "strings" "testing" ) @@ -109,3 +110,26 @@ func TestByteaValuesEqual(t *testing.T) { t.Fatalf("expected bytea hex %q to match text %q", hex, text) } } + +func TestByteaSignatureExprsUseHex(t *testing.T) { + col := ColumnSchema{Name: "payload", Type: "BLOB"} + table := TableSchema{Name: "attachments", Columns: []ColumnSchema{col}} + + sqliteExpr := sqliteSignatureColumnExpr(col) + if !strings.Contains(sqliteExpr, "hex(") { + t.Fatalf("sqlite blob signature should use hex(), got %q", sqliteExpr) + } + + pgExpr := postgresSignatureColumnExpr(col, table, nil) + if !strings.Contains(pgExpr, "encode(") || !strings.Contains(pgExpr, "'hex'") { + t.Fatalf("postgres bytea signature should use encode(..., 'hex'), got %q", pgExpr) + } +} + +func TestByteaValuesEqualBinaryHex(t *testing.T) { + raw := string([]byte{0x00, 0xff, 0xfe, 0x01}) + hexSig := hex.EncodeToString([]byte(raw)) + if !byteaValuesEqual(hexSig, hexSig) { + t.Fatalf("expected matching hex signatures for binary blob") + } +} From 270bceb52eb9fb9664096f34cfacea9b147b9476 Mon Sep 17 00:00:00 2001 From: Elom Gomez Date: Fri, 26 Jun 2026 14:57:17 -0500 Subject: [PATCH 4/6] Fix D1 error output consistency and apply go fmt. Unify JSON/human exit codes for status:error responses, populate lint error envelopes, and format affected Go sources with go fmt. Co-authored-by: Cursor --- internal/cmd/importcmd/d1.go | 40 +++++++-------- internal/cmd/importcmd/d1_test.go | 55 ++++++++++++++++++++ internal/cmd/mcp/import_d1_handlers.go | 9 +--- internal/migrate/d1/constraints.go | 6 +-- internal/migrate/d1/lint.go | 2 +- internal/migrate/d1/orm_metadata.go | 54 ++++++++++---------- internal/migrate/d1/output.go | 70 ++++++++++++++++++++------ internal/migrate/d1/output_test.go | 70 ++++++++++++++++++++++++++ internal/migrate/d1/parse.go | 6 +-- internal/migrate/d1/pgloader.go | 4 +- internal/migrate/d1/plan.go | 2 +- internal/migrate/d1/types.go | 52 +++++++++---------- internal/postgres/postgres_test.go | 16 +++--- 13 files changed, 269 insertions(+), 117 deletions(-) create mode 100644 internal/cmd/importcmd/d1_test.go create mode 100644 internal/migrate/d1/output_test.go diff --git a/internal/cmd/importcmd/d1.go b/internal/cmd/importcmd/d1.go index 622638567..0b6c9e353 100644 --- a/internal/cmd/importcmd/d1.go +++ b/internal/cmd/importcmd/d1.go @@ -17,15 +17,12 @@ func writeD1(ch *cmdutil.Helper, resp d1.Response) error { if err := ch.Printer.PrintJSON(resp); err != nil { return err } - return &cmdutil.Error{ - ExitCode: cmdutil.ActionRequestedExitCode, - Printed: true, - } case printer.Human: - return humanD1Error(resp) + d1.PrintHumanResponse(ch.Printer, resp) default: return fmt.Errorf(`import d1 does not support output format %q (use human or json)`, ch.Printer.Format()) } + return d1CommandError(resp) } switch ch.Printer.Format() { @@ -39,14 +36,19 @@ func writeD1(ch *cmdutil.Helper, resp d1.Response) error { } } -func humanD1Error(resp d1.Response) error { - if resp.Error == nil { - return fmt.Errorf("import d1 command failed") +func d1CommandError(resp d1.Response) error { + msg := "import d1 command failed" + if resp.Error != nil { + msg = resp.Error.Message + if resp.Error.Remediation != "" { + msg += "\n" + resp.Error.Remediation + } } - if resp.Error.Remediation != "" { - return fmt.Errorf("%s\n%s", resp.Error.Message, resp.Error.Remediation) + return &cmdutil.Error{ + Msg: msg, + ExitCode: cmdutil.ActionRequestedExitCode, + Printed: true, } - return fmt.Errorf("%s", resp.Error.Message) } // D1Cmd returns the import d1 subcommand group. @@ -101,8 +103,8 @@ func d1ExportCmd(ch *cmdutil.Helper) *cobra.Command { } cmd := &cobra.Command{ - Use: "export", - Short: "Export a D1 database using wrangler", + Use: "export", + Short: "Export a D1 database using wrangler", Example: ` pscale import d1 export --d1-database my-app-db --remote --output ./d1-export.sql --format json`, RunE: func(cmd *cobra.Command, args []string) error { result, err := d1.Export(cmd.Context(), d1.ExportOptions{ @@ -137,21 +139,15 @@ func d1LintCmd(ch *cmdutil.Helper) *cobra.Command { } cmd := &cobra.Command{ - Use: "lint", - Short: "Analyze a D1 SQL export for migration issues", + Use: "lint", + Short: "Analyze a D1 SQL export for migration issues", Example: ` pscale import d1 lint --input ./d1-export.sql --format json`, RunE: func(cmd *cobra.Command, args []string) error { result, err := d1.Lint(flags.input) if err != nil { return writeD1(ch, d1.ErrorResponse("lint", err)) } - resp := d1.OKResponse("lint", result, d1.LintNextSteps(result)) - resp.Issues = result.Issues - if result.ErrorCount > 0 { - resp.Status = "error" - } else if result.WarningCount > 0 { - resp.Status = "warning" - } + resp := d1.LintResponse(result) return writeD1(ch, resp) }, } diff --git a/internal/cmd/importcmd/d1_test.go b/internal/cmd/importcmd/d1_test.go new file mode 100644 index 000000000..536f89c15 --- /dev/null +++ b/internal/cmd/importcmd/d1_test.go @@ -0,0 +1,55 @@ +package importcmd + +import ( + "bytes" + "errors" + "testing" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/migrate/d1" + "github.com/planetscale/cli/internal/printer" +) + +func TestWriteD1ErrorUsesConsistentExitCode(t *testing.T) { + resp := d1.LintResponse(&d1.LintResult{ + TableCount: 1, + ErrorCount: 1, + Issues: []d1.Issue{{ + Code: "VIRTUAL_TABLE", + Severity: d1.SeverityError, + Table: "fts", + Remediation: "Virtual tables are not supported", + }}, + }) + + for _, format := range []printer.Format{printer.Human, printer.JSON} { + t.Run(format.String(), func(t *testing.T) { + var buf bytes.Buffer + p := printer.NewPrinter(&format) + if format == printer.Human { + p.SetHumanOutput(&buf) + } else { + p.SetResourceOutput(&buf) + } + + err := writeD1(&cmdutil.Helper{Printer: p}, resp) + if err == nil { + t.Fatal("expected error") + } + + var cmdErr *cmdutil.Error + if !errors.As(err, &cmdErr) { + t.Fatalf("expected *cmdutil.Error, got %T: %v", err, err) + } + if cmdErr.ExitCode != cmdutil.ActionRequestedExitCode { + t.Fatalf("exit code = %d, want %d", cmdErr.ExitCode, cmdutil.ActionRequestedExitCode) + } + if !cmdErr.Printed { + t.Fatal("expected output to be marked printed") + } + if buf.Len() == 0 { + t.Fatal("expected response output") + } + }) + } +} diff --git a/internal/cmd/mcp/import_d1_handlers.go b/internal/cmd/mcp/import_d1_handlers.go index bc4e25f05..2bad17f16 100644 --- a/internal/cmd/mcp/import_d1_handlers.go +++ b/internal/cmd/mcp/import_d1_handlers.go @@ -86,17 +86,10 @@ func handleImportD1Lint(ctx context.Context, request mcp.CallToolRequest, ch *cm if err != nil { return importD1Error("lint", err) } - resp := d1.OKResponse("lint", result, d1.LintNextSteps(result)) - resp.Issues = result.Issues - if result.ErrorCount > 0 { - resp.Status = "error" - } else if result.WarningCount > 0 { - resp.Status = "warning" - } + resp := d1.LintResponse(result) return importD1Result(resp) } - func handleImportD1Start(ctx context.Context, request mcp.CallToolRequest, ch *cmdutil.Helper) (*mcp.CallToolResult, error) { input, err := request.RequireString("input") if err != nil { diff --git a/internal/migrate/d1/constraints.go b/internal/migrate/d1/constraints.go index a91dfc742..ca87b8ba6 100644 --- a/internal/migrate/d1/constraints.go +++ b/internal/migrate/d1/constraints.go @@ -6,11 +6,11 @@ import ( ) var ( - referencesClauseRe = regexp.MustCompile(`(?is)^REFERENCES\s+(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(\s*([^)]+)\)\s*(.*)$`) + referencesClauseRe = regexp.MustCompile(`(?is)^REFERENCES\s+(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(\s*([^)]+)\)\s*(.*)$`) foreignKeyConstraintRe = regexp.MustCompile(`(?is)^FOREIGN\s+KEY\s*\(\s*([^)]+)\)\s*(REFERENCES\s+.+)$`) primaryKeyConstraintRe = regexp.MustCompile(`(?is)^PRIMARY\s+KEY\s*\(\s*([^)]+)\)\s*$`) - uniqueConstraintRe = regexp.MustCompile(`(?is)^UNIQUE\s*\(\s*([^)]+)\)\s*$`) - createIndexRe = regexp.MustCompile(`(?is)^CREATE\s+(UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s+ON\s+(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(\s*([^)]+)\)\s*;?\s*$`) + uniqueConstraintRe = regexp.MustCompile(`(?is)^UNIQUE\s*\(\s*([^)]+)\)\s*$`) + createIndexRe = regexp.MustCompile(`(?is)^CREATE\s+(UNIQUE\s+)?INDEX\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s+ON\s+(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(\s*([^)]+)\)\s*;?\s*$`) ) // IndexSchema holds a parsed CREATE INDEX statement from a dump. diff --git a/internal/migrate/d1/lint.go b/internal/migrate/d1/lint.go index cfd8a1f5e..3c59e2795 100644 --- a/internal/migrate/d1/lint.go +++ b/internal/migrate/d1/lint.go @@ -13,7 +13,7 @@ func Lint(inputPath string) (*LintResult, error) { } result := &LintResult{ - InputPath: inputPath, + InputPath: inputPath, TableCount: len(tables), Issues: []Issue{}, Tables: make([]string, 0, len(tables)), diff --git a/internal/migrate/d1/orm_metadata.go b/internal/migrate/d1/orm_metadata.go index a30b0019b..0c45b3aa2 100644 --- a/internal/migrate/d1/orm_metadata.go +++ b/internal/migrate/d1/orm_metadata.go @@ -29,58 +29,58 @@ var ormMetadataRules = []ormMetadataRule{ match: matchTableName("_prisma_migrations"), }, { - code: "KNEX_MIGRATIONS", - orm: "Knex", + code: "KNEX_MIGRATIONS", + orm: "Knex", remediation: "After import, re-baseline Knex migration history on Postgres; knex_migrations from SQLite is not valid on Postgres", - match: matchAnyTableName("knex_migrations", "knex_migrations_lock"), + match: matchAnyTableName("knex_migrations", "knex_migrations_lock"), }, { - code: "SEQUELIZE_META", - orm: "Sequelize", + code: "SEQUELIZE_META", + orm: "Sequelize", remediation: "After import, re-baseline Sequelize migration history on Postgres; SequelizeMeta from SQLite is not valid on Postgres", - match: matchTableName("sequelizemeta"), + match: matchTableName("sequelizemeta"), }, { - code: "RAILS_MIGRATIONS", - orm: "Rails ActiveRecord", + code: "RAILS_MIGRATIONS", + orm: "Rails ActiveRecord", remediation: "After import, re-baseline Rails schema_migrations on Postgres; SQLite migration versions do not transfer cleanly", - match: matchAnyTableName("schema_migrations", "ar_internal_metadata"), + match: matchAnyTableName("schema_migrations", "ar_internal_metadata"), }, { - code: "FLYWAY_MIGRATIONS", - orm: "Flyway", + code: "FLYWAY_MIGRATIONS", + orm: "Flyway", remediation: "After import, baseline Flyway on Postgres; flyway_schema_history from SQLite must not be reused", - match: matchTableName("flyway_schema_history"), + match: matchTableName("flyway_schema_history"), }, { - code: "LIQUIBASE_MIGRATIONS", - orm: "Liquibase", + code: "LIQUIBASE_MIGRATIONS", + orm: "Liquibase", remediation: "After import, baseline Liquibase on Postgres; databasechangelog tables from SQLite must not be reused", - match: matchAnyTableName("databasechangelog", "databasechangeloglock"), + match: matchAnyTableName("databasechangelog", "databasechangeloglock"), }, { - code: "DJANGO_MIGRATIONS", - orm: "Django", + code: "DJANGO_MIGRATIONS", + orm: "Django", remediation: "After import, run django migrate --fake-initial or otherwise baseline django_migrations on Postgres", - match: matchTableName("django_migrations"), + match: matchTableName("django_migrations"), }, { - code: "ALEMBIC_VERSION", - orm: "Alembic", + code: "ALEMBIC_VERSION", + orm: "Alembic", remediation: "After import, stamp Alembic to the correct Postgres revision; alembic_version from SQLite is not portable", - match: matchTableName("alembic_version"), + match: matchTableName("alembic_version"), }, { - code: "TYPEORM_METADATA", - orm: "TypeORM", + code: "TYPEORM_METADATA", + orm: "TypeORM", remediation: "After import, baseline TypeORM migrations on Postgres; typeorm_metadata from SQLite is not valid on Postgres", - match: matchTableName("typeorm_metadata"), + match: matchTableName("typeorm_metadata"), }, { - code: "GOOSE_MIGRATIONS", - orm: "Goose", + code: "GOOSE_MIGRATIONS", + orm: "Goose", remediation: "After import, re-baseline Goose version table on Postgres; goose_db_version from SQLite is not portable", - match: matchTableName("goose_db_version"), + match: matchTableName("goose_db_version"), }, } diff --git a/internal/migrate/d1/output.go b/internal/migrate/d1/output.go index 54ec13927..cad9d760a 100644 --- a/internal/migrate/d1/output.go +++ b/internal/migrate/d1/output.go @@ -21,6 +21,13 @@ func PrintHumanResponse(p *printer.Printer, resp Response) { printHumanData(p, resp.Phase, resp.Data) + if resp.Error != nil { + p.Printf("\nError [%s]: %s\n", resp.Error.Code, resp.Error.Message) + if resp.Error.Remediation != "" { + p.Printf("%s\n", resp.Error.Remediation) + } + } + if len(resp.Issues) > 0 { p.Printf("\nIssues (%d):\n", len(resp.Issues)) for _, issue := range resp.Issues { @@ -44,6 +51,24 @@ func PrintHumanResponse(p *printer.Printer, resp Response) { } } +func printImportResultHuman(p *printer.Printer, r ImportResult) { + p.Printf("\nMethod: %s", r.Method) + if r.DryRun { + p.Print(" (dry run)") + } + p.Println() + if r.Plan != nil { + sizeMB := float64(r.Plan.EstimatedSizeBytes) / (1024 * 1024) + p.Printf("Plan: %d tables, %.1f MB estimated\n", len(r.Plan.Tables), sizeMB) + } + if r.TablesLoaded > 0 { + p.Printf("Tables loaded: %d\n", r.TablesLoaded) + } + if r.Timings != nil && r.Timings.TotalMs > 0 { + p.Printf("Total time: %.1fs\n", float64(r.Timings.TotalMs)/1000) + } +} + func printHumanData(p *printer.Printer, phase string, data any) { if data == nil { return @@ -67,25 +92,21 @@ func printHumanData(p *printer.Printer, phase string, data any) { p.Printf("\nExported to %s (%d bytes)\n", r.OutputPath, r.SizeBytes) } case "lint": - if r, ok := data.(LintResult); ok { + switch r := data.(type) { + case LintResult: p.Printf("\nTables: %d | Errors: %d | Warnings: %d\n", r.TableCount, r.ErrorCount, r.WarningCount) + case *LintResult: + if r != nil { + p.Printf("\nTables: %d | Errors: %d | Warnings: %d\n", r.TableCount, r.ErrorCount, r.WarningCount) + } } case "start": - if r, ok := data.(ImportResult); ok { - p.Printf("\nMethod: %s", r.Method) - if r.DryRun { - p.Print(" (dry run)") - } - p.Println() - if r.Plan != nil { - sizeMB := float64(r.Plan.EstimatedSizeBytes) / (1024 * 1024) - p.Printf("Plan: %d tables, %.1f MB estimated\n", len(r.Plan.Tables), sizeMB) - } - if r.TablesLoaded > 0 { - p.Printf("Tables loaded: %d\n", r.TablesLoaded) - } - if r.Timings != nil && r.Timings.TotalMs > 0 { - p.Printf("Total time: %.1fs\n", float64(r.Timings.TotalMs)/1000) + switch r := data.(type) { + case ImportResult: + printImportResultHuman(p, r) + case *ImportResult: + if r != nil { + printImportResultHuman(p, *r) } } case "verify": @@ -142,3 +163,20 @@ func ErrorResponse(phase string, err error) Response { } return resp } + +// LintResponse builds the lint command envelope with status derived from issue severity. +func LintResponse(result *LintResult) Response { + resp := OKResponse("lint", result, LintNextSteps(result)) + resp.Issues = result.Issues + if result.ErrorCount > 0 { + resp.Status = "error" + resp.Error = &ErrorInfo{ + Code: ErrCodeLintBlocked, + Message: lintBlockedReason(result.ErrorCount), + Remediation: lintBlockedRemediation, + } + } else if result.WarningCount > 0 { + resp.Status = "warning" + } + return resp +} diff --git a/internal/migrate/d1/output_test.go b/internal/migrate/d1/output_test.go new file mode 100644 index 000000000..0bbc5aa2b --- /dev/null +++ b/internal/migrate/d1/output_test.go @@ -0,0 +1,70 @@ +package d1 + +import ( + "bytes" + "strings" + "testing" + + "github.com/planetscale/cli/internal/printer" +) + +func TestLintResponseSetsErrorEnvelope(t *testing.T) { + result := &LintResult{ + InputPath: "/tmp/export.sql", + TableCount: 1, + ErrorCount: 1, + WarningCount: 2, + Issues: []Issue{{ + Code: "VIRTUAL_TABLE", + Severity: SeverityError, + Table: "fts", + Remediation: "Virtual tables are not supported", + }}, + } + + resp := LintResponse(result) + if resp.Status != "error" { + t.Fatalf("status = %q, want error", resp.Status) + } + if resp.Error == nil { + t.Fatal("expected structured error") + } + if resp.Error.Code != ErrCodeLintBlocked { + t.Fatalf("error code = %q, want %q", resp.Error.Code, ErrCodeLintBlocked) + } + if len(resp.Issues) != 1 { + t.Fatalf("issues = %d, want 1", len(resp.Issues)) + } +} + +func TestPrintHumanResponseIncludesLintIssuesOnError(t *testing.T) { + resp := LintResponse(&LintResult{ + TableCount: 1, + ErrorCount: 1, + Issues: []Issue{{ + Code: "VIRTUAL_TABLE", + Severity: SeverityError, + Table: "fts", + Remediation: "Virtual tables are not supported", + }}, + }) + + var buf bytes.Buffer + format := printer.Human + p := printer.NewPrinter(&format) + p.SetHumanOutput(&buf) + PrintHumanResponse(p, resp) + + out := buf.String() + for _, want := range []string{ + "Status: error", + "Errors: 1", + "[error] VIRTUAL_TABLE", + "Virtual tables are not supported", + ErrCodeLintBlocked, + } { + if !strings.Contains(out, want) { + t.Fatalf("output missing %q:\n%s", want, out) + } + } +} diff --git a/internal/migrate/d1/parse.go b/internal/migrate/d1/parse.go index 3dce7d6ac..3061ce83b 100644 --- a/internal/migrate/d1/parse.go +++ b/internal/migrate/d1/parse.go @@ -9,10 +9,10 @@ import ( ) var ( - createTableRe = regexp.MustCompile(`(?is)^CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(`) - virtualTableRe = regexp.MustCompile(`(?is)^CREATE\s+VIRTUAL\s+TABLE`) + createTableRe = regexp.MustCompile(`(?is)^CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:"([^"]+)"|'([^']+)'|` + "`" + `([^` + "`" + `]+)` + "`" + `|([a-zA-Z_][\w]*))\s*\(`) + virtualTableRe = regexp.MustCompile(`(?is)^CREATE\s+VIRTUAL\s+TABLE`) autoincrementRe = regexp.MustCompile(`(?i)AUTOINCREMENT`) - insertRe = regexp.MustCompile(`(?is)^INSERT\s+INTO\s+(?:` + "`" + `([^` + "`" + `]+)` + "`" + `|"([^"]+)"|'([^']+)'|([a-zA-Z_][\w]*))`) + insertRe = regexp.MustCompile(`(?is)^INSERT\s+INTO\s+(?:` + "`" + `([^` + "`" + `]+)` + "`" + `|"([^"]+)"|'([^']+)'|([a-zA-Z_][\w]*))`) valueTupleSepRe = regexp.MustCompile(`\)\s*,\s*\(`) ) diff --git a/internal/migrate/d1/pgloader.go b/internal/migrate/d1/pgloader.go index 84ff0cc62..e5be10999 100644 --- a/internal/migrate/d1/pgloader.go +++ b/internal/migrate/d1/pgloader.go @@ -31,8 +31,8 @@ const ( // Conservative profile: wide rows / large tables (e.g. attachments). pgloaderSlowPrefetchRows = 5000 pgloaderSlowBatchRows = 10000 - pgloaderSlowWorkers = 2 - pgloaderSlowConcurrency = 1 + pgloaderSlowWorkers = 2 + pgloaderSlowConcurrency = 1 pgloaderLoadWorkMem = "256MB" pgloaderLoadMaintenanceWorkMem = "512MB" diff --git a/internal/migrate/d1/plan.go b/internal/migrate/d1/plan.go index 3e1c606bd..05e278243 100644 --- a/internal/migrate/d1/plan.go +++ b/internal/migrate/d1/plan.go @@ -21,7 +21,7 @@ type PlanOptions struct { Database string Branch string Method string - MigrationID string // optional: reuse an existing migration ID from plan/start + MigrationID string // optional: reuse an existing migration ID from plan/start Lint *LintResult // optional: skip re-lint when already computed } diff --git a/internal/migrate/d1/types.go b/internal/migrate/d1/types.go index a2470cc4d..a79149299 100644 --- a/internal/migrate/d1/types.go +++ b/internal/migrate/d1/types.go @@ -21,9 +21,9 @@ type Issue struct { // NextStep guides agents to the next tool or command. type NextStep struct { - Tool string `json:"tool,omitempty"` + Tool string `json:"tool,omitempty"` Command string `json:"command,omitempty"` - Reason string `json:"reason"` + Reason string `json:"reason"` } // Response is the common JSON envelope for migrate d1 commands. @@ -61,34 +61,34 @@ type DoctorCheck struct { // LintResult summarizes lint output. type LintResult struct { - InputPath string `json:"input_path"` - TableCount int `json:"table_count"` - ErrorCount int `json:"error_count"` - WarningCount int `json:"warning_count"` - Issues []Issue `json:"issues"` - Tables []string `json:"tables"` + InputPath string `json:"input_path"` + TableCount int `json:"table_count"` + ErrorCount int `json:"error_count"` + WarningCount int `json:"warning_count"` + Issues []Issue `json:"issues"` + Tables []string `json:"tables"` } // PlanResult is the migration plan JSON. type PlanResult struct { - MigrationID string `json:"migration_id"` - InputPath string `json:"input_path"` - Org string `json:"org"` - Database string `json:"database"` - Branch string `json:"branch"` - RecommendedMethod string `json:"recommended_method"` - EstimatedSizeBytes int64 `json:"estimated_size_bytes,omitempty"` - Tables []TablePlan `json:"tables"` - CastRules []CastRule `json:"cast_rules"` - LoadOrder []string `json:"load_order"` - Issues []Issue `json:"issues"` + MigrationID string `json:"migration_id"` + InputPath string `json:"input_path"` + Org string `json:"org"` + Database string `json:"database"` + Branch string `json:"branch"` + RecommendedMethod string `json:"recommended_method"` + EstimatedSizeBytes int64 `json:"estimated_size_bytes,omitempty"` + Tables []TablePlan `json:"tables"` + CastRules []CastRule `json:"cast_rules"` + LoadOrder []string `json:"load_order"` + Issues []Issue `json:"issues"` } // TablePlan describes a table in the migration plan. type TablePlan struct { - Name string `json:"name"` - RowEstimate int `json:"row_estimate,omitempty"` - HasFK bool `json:"has_foreign_keys"` + Name string `json:"name"` + RowEstimate int `json:"row_estimate,omitempty"` + HasFK bool `json:"has_foreign_keys"` } // CastRule maps SQLite types to Postgres casts for pgloader. @@ -158,10 +158,10 @@ type VerifyResult struct { // TableVerifyResult is per-table verification. type TableVerifyResult struct { - Table string `json:"table"` - SourceRows int64 `json:"source_rows"` - DestRows int64 `json:"dest_rows"` - Match bool `json:"match"` + Table string `json:"table"` + SourceRows int64 `json:"source_rows"` + DestRows int64 `json:"dest_rows"` + Match bool `json:"match"` } // Migration phases persisted in local state. diff --git a/internal/postgres/postgres_test.go b/internal/postgres/postgres_test.go index 8f5803e92..303776345 100644 --- a/internal/postgres/postgres_test.go +++ b/internal/postgres/postgres_test.go @@ -140,24 +140,24 @@ func TestBuildConnectionString(t *testing.T) { func TestRedactPassword(t *testing.T) { tests := []struct { - name string + name string connStr string - want string + want string }{ { - name: "with password", + name: "with password", connStr: "host=localhost port=5432 user=user password=secret dbname=mydb", - want: "host=localhost port=5432 user=user password=**** dbname=mydb", + want: "host=localhost port=5432 user=user password=**** dbname=mydb", }, { - name: "without password", + name: "without password", connStr: "host=localhost port=5432 user=user dbname=mydb", - want: "host=localhost port=5432 user=user dbname=mydb", + want: "host=localhost port=5432 user=user dbname=mydb", }, { - name: "empty string", + name: "empty string", connStr: "", - want: "", + want: "", }, } From 0128af4b71841c06fa9e150233968710a737fc83 Mon Sep 17 00:00:00 2001 From: Elom Gomez Date: Fri, 26 Jun 2026 15:01:18 -0500 Subject: [PATCH 5/6] Flush sqlite dump chunks only at SQL statement boundaries. Defer .read batch splits until a line completes a statement so multi-line CREATE TABLE blocks are never cut mid-statement on large exports. Co-authored-by: Cursor --- internal/migrate/d1/sqlite_load.go | 31 ++++++++++++- internal/migrate/d1/sqlite_load_test.go | 58 +++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/internal/migrate/d1/sqlite_load.go b/internal/migrate/d1/sqlite_load.go index 5e762d841..51c90add8 100644 --- a/internal/migrate/d1/sqlite_load.go +++ b/internal/migrate/d1/sqlite_load.go @@ -154,7 +154,7 @@ func loadSQLiteDumpChunked(ctx context.Context, sqlite3, dumpPath, sqlitePath st return werr } chunkSize += int64(len(line)) - if chunkSize >= chunkBytes { + if chunkSize >= chunkBytes && lineEndsSQLStatement(line) { if err := flushChunk(); err != nil { return err } @@ -184,3 +184,32 @@ func truncateLoadError(msg string, max int) string { } return msg[:max] + "..." } + +// lineEndsSQLStatement reports whether line completes a standalone SQL statement. +// Chunk flushes use this so multi-line CREATE TABLE blocks are never split. +func lineEndsSQLStatement(line []byte) bool { + s := strings.TrimSpace(string(line)) + if s == "" || strings.HasPrefix(s, "--") { + return false + } + return sqlEndsWithSemicolon(s) +} + +func sqlEndsWithSemicolon(s string) bool { + inSingle := false + for i := 0; i < len(s); i++ { + c := s[i] + if c == '\'' { + if inSingle && i+1 < len(s) && s[i+1] == '\'' { + i++ + continue + } + inSingle = !inSingle + continue + } + if c == ';' && !inSingle && strings.TrimSpace(s[i+1:]) == "" { + return true + } + } + return false +} diff --git a/internal/migrate/d1/sqlite_load_test.go b/internal/migrate/d1/sqlite_load_test.go index ce0552be5..fc74f4ccb 100644 --- a/internal/migrate/d1/sqlite_load_test.go +++ b/internal/migrate/d1/sqlite_load_test.go @@ -101,6 +101,64 @@ func TestEnsureSQLiteFromDumpReusesExisting(t *testing.T) { } } +func TestLoadSQLiteDumpChunkedMultiLineCreate(t *testing.T) { + dir := t.TempDir() + dumpPath := filepath.Join(dir, "create.sql") + sqlitePath := filepath.Join(dir, "create.sqlite") + + var b strings.Builder + b.WriteString("PRAGMA defer_foreign_keys=TRUE;\n") + b.WriteString("CREATE TABLE multi (\n") + for i := 0; i < 40; i++ { + fmt.Fprintf(&b, " col_%d TEXT,\n", i) + } + b.WriteString(" id INTEGER PRIMARY KEY\n") + b.WriteString(");\n") + for i := 0; i < 10; i++ { + fmt.Fprintf(&b, "INSERT INTO multi (id) VALUES(%d);\n", i) + } + if err := os.WriteFile(dumpPath, []byte(b.String()), 0o600); err != nil { + t.Fatal(err) + } + + sqlite3, err := FindSQLite3() + if err != nil { + t.Fatal(err) + } + // Small chunks force splits that would bisect CREATE TABLE without boundary-aware flushing. + if err := loadSQLiteDumpChunked(context.Background(), sqlite3, dumpPath, sqlitePath, 200); err != nil { + t.Fatalf("loadSQLiteDumpChunked: %v", err) + } + + counts, err := CountSQLiteRows(context.Background(), sqlitePath, []string{"multi"}) + if err != nil { + t.Fatal(err) + } + if counts["multi"] != 10 { + t.Fatalf("expected 10 rows, got %d", counts["multi"]) + } +} + +func TestSQLStatementBoundary(t *testing.T) { + tests := []struct { + line string + want bool + }{ + {"CREATE TABLE t (id INTEGER);\n", true}, + {" );\n", true}, + {" payload BLOB\n", false}, + {"INSERT INTO t VALUES('a;b');\n", true}, + {"INSERT INTO t VALUES('a;b\n", false}, + {"-- comment only\n", false}, + {"PRAGMA defer_foreign_keys=TRUE;\n", true}, + } + for _, tc := range tests { + if got := lineEndsSQLStatement([]byte(tc.line)); got != tc.want { + t.Fatalf("lineEndsSQLStatement(%q) = %v, want %v", tc.line, got, tc.want) + } + } +} + func TestLoadSQLiteDumpChunked(t *testing.T) { dir := t.TempDir() dumpPath := filepath.Join(dir, "multi.sql") From e5ad39c9707f23a7b1f7b040eee6391d3377272d Mon Sep 17 00:00:00 2001 From: Elom Gomez Date: Fri, 26 Jun 2026 15:06:19 -0500 Subject: [PATCH 6/6] Install pgloader and sqlite3 in CI for D1 tests. Validate migration state before pgloader when --migration-id is set so state errors surface even when pgloader is missing locally. Co-authored-by: Cursor --- .github/workflows/ci.yml | 3 +++ internal/migrate/d1/prepare.go | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c8680d6e5..f624087e2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,6 +18,9 @@ jobs: with: go-version-file: go.mod + - name: Install D1 test dependencies + run: sudo apt-get update && sudo apt-get install -y pgloader sqlite3 + - run: make verify-licenses: diff --git a/internal/migrate/d1/prepare.go b/internal/migrate/d1/prepare.go index 0591f3fa4..df7ba9e85 100644 --- a/internal/migrate/d1/prepare.go +++ b/internal/migrate/d1/prepare.go @@ -21,6 +21,11 @@ func PrepareImport(opts ImportOptions) (*ImportPrepareResult, error) { if _, err := ValidateInputPath(opts.InputPath); err != nil { return nil, err } + if opts.MigrationID != "" { + if _, err := LoadState(opts.Org, opts.Database, opts.Branch, opts.MigrationID); err != nil { + return nil, err + } + } if _, err := FindPgloader(); err != nil { return nil, err }