diff --git a/runtime/drivers/duckdb/duckdb.go b/runtime/drivers/duckdb/duckdb.go index 78f57909285..046b61519d9 100644 --- a/runtime/drivers/duckdb/duckdb.go +++ b/runtime/drivers/duckdb/duckdb.go @@ -49,6 +49,14 @@ var spec = drivers.Spec{ Description: "DuckDB SQL query.", Placeholder: "select * from read_csv('data/file.csv', header=true);", }, + { + Key: "db", + Type: drivers.StringPropertyType, + Required: true, + DisplayName: "DB", + Description: "Path to external DuckDB database. Use md: for motherduckb.", + Placeholder: "/path/to/main.db or md:main.db(for motherduck)", + }, }, ConfigProperties: []drivers.PropertySchema{ { @@ -58,18 +66,6 @@ var spec = drivers.Spec{ } var motherduckSpec = drivers.Spec{ - DisplayName: "MotherDuck", - Description: "Import data from MotherDuck.", - SourceProperties: []drivers.PropertySchema{ - { - Key: "sql", - Type: drivers.StringPropertyType, - Required: true, - DisplayName: "SQL", - Description: "Query to extract data from MotherDuck.", - Placeholder: "select * from my_db.my_table;", - }, - }, ConfigProperties: []drivers.PropertySchema{ { Key: "token", diff --git a/runtime/drivers/duckdb/transporter_duckDB_to_duckDB.go b/runtime/drivers/duckdb/transporter_duckDB_to_duckDB.go index 50d5b2db51c..1e8e1331cd0 100644 --- a/runtime/drivers/duckdb/transporter_duckDB_to_duckDB.go +++ b/runtime/drivers/duckdb/transporter_duckDB_to_duckDB.go @@ -2,9 +2,12 @@ package duckdb import ( "context" + "database/sql" "errors" "fmt" "net/url" + "path/filepath" + "strings" "github.com/rilldata/rill/runtime/drivers" "github.com/rilldata/rill/runtime/pkg/duckdbsql" @@ -37,6 +40,10 @@ func (t *duckDBToDuckDB) Transfer(ctx context.Context, srcProps, sinkProps map[s return err } + if srcCfg.Database != "" { // query to be run against an external DB + return t.transferFromExternalDB(ctx, srcCfg, sinkCfg, opts) + } + // We can't just pass the SQL statement to DuckDB outright. // We need to do some rewriting for certain table references (currently object stores and local files). @@ -90,16 +97,66 @@ func (t *duckDBToDuckDB) Transfer(ctx context.Context, srcProps, sinkProps map[s // If the path is a local file reference, rewrite to a safe and repo-relative path. if uri.Scheme == "" && uri.Host == "" { - sql, err := rewriteLocalPaths(ast, opts.RepoRoot, opts.AllowHostAccess) + rewrittenSQL, err := rewriteLocalPaths(ast, opts.RepoRoot, opts.AllowHostAccess) if err != nil { return fmt.Errorf("invalid local path: %w", err) } - srcCfg.SQL = sql + srcCfg.SQL = rewrittenSQL } return t.to.CreateTableAsSelect(ctx, sinkCfg.Table, false, srcCfg.SQL) } +func (t *duckDBToDuckDB) transferFromExternalDB(ctx context.Context, srcProps *dbSourceProperties, sinkProps *sinkProperties, opts *drivers.TransferOptions) error { + return t.to.WithConnection(ctx, 1, true, false, func(ctx, ensuredCtx context.Context, _ *sql.Conn) error { + res, err := t.to.Execute(ctx, &drivers.Statement{Query: "SELECT current_database(),current_schema();"}) + if err != nil { + return err + } + + var localDB, localSchema string + for res.Next() { + if err := res.Scan(&localDB, &localSchema); err != nil { + _ = res.Close() + return err + } + } + _ = res.Close() + + // duckdb considers everything before first . as db name + // alternative solution can be to query `show databases()` before and after to identify db name + dbName, _, _ := strings.Cut(filepath.Base(srcProps.Database), ".") + if dbName == "main" { + return fmt.Errorf("`main` is a reserved db name") + } + + if err = t.to.Exec(ctx, &drivers.Statement{Query: fmt.Sprintf("ATTACH %s AS %s", safeSQLString(srcProps.Database), safeSQLName(dbName))}); err != nil { + return fmt.Errorf("failed to attach db %q: %w", srcProps.Database, err) + } + + defer func() { + if err = t.to.Exec(ensuredCtx, &drivers.Statement{Query: fmt.Sprintf("DETACH %s;", safeSQLName(dbName))}); err != nil { + t.logger.Error("failed to detach db", zap.Error(err)) + } + }() + + if err := t.to.Exec(ctx, &drivers.Statement{Query: fmt.Sprintf("USE %s;", safeName(dbName))}); err != nil { + return err + } + + defer func() { // revert back to localdb + if err = t.to.Exec(ensuredCtx, &drivers.Statement{Query: fmt.Sprintf("USE %s.%s;", safeName(localDB), safeName(localSchema))}); err != nil { + t.logger.Error("failed to switch to local database", zap.Error(err)) + } + }() + + userQuery := strings.TrimSpace(srcProps.SQL) + userQuery, _ = strings.CutSuffix(userQuery, ";") // trim trailing semi colon + query := fmt.Sprintf("CREATE OR REPLACE TABLE %s.%s.%s AS (%s\n);", safeName(localDB), safeName(localSchema), safeName(sinkProps.Table), userQuery) + return t.to.Exec(ctx, &drivers.Statement{Query: query}) + }) +} + // rewriteLocalPaths rewrites a DuckDB SQL statement such that relative paths become absolute paths relative to the basePath, // and if allowHostAccess is false, returns an error if any of the paths resolve to a path outside of the basePath. func rewriteLocalPaths(ast *duckdbsql.AST, basePath string, allowHostAccess bool) (string, error) { diff --git a/runtime/drivers/duckdb/transporter_duckDB_to_duckDB_test.go b/runtime/drivers/duckdb/transporter_duckDB_to_duckDB_test.go new file mode 100644 index 00000000000..a750ec637c9 --- /dev/null +++ b/runtime/drivers/duckdb/transporter_duckDB_to_duckDB_test.go @@ -0,0 +1,65 @@ +package duckdb + +import ( + "context" + "fmt" + "path/filepath" + "testing" + + "github.com/rilldata/rill/runtime/drivers" + activity "github.com/rilldata/rill/runtime/pkg/activity" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestDuckDBToDuckDBTransfer(t *testing.T) { + tempDir := t.TempDir() + conn, err := Driver{}.Open(map[string]any{"dsn": fmt.Sprintf("%s.db?access_mode=read_write", filepath.Join(tempDir, "tranfser"))}, false, activity.NewNoopClient(), zap.NewNop()) + require.NoError(t, err) + + olap, ok := conn.AsOLAP("") + require.True(t, ok) + + err = olap.Exec(context.Background(), &drivers.Statement{ + Query: "CREATE TABLE foo(bar VARCHAR, baz INTEGER)", + }) + require.NoError(t, err) + + err = olap.Exec(context.Background(), &drivers.Statement{ + Query: "INSERT INTO foo VALUES ('a', 1), ('a', 2), ('b', 3), ('c', 4)", + }) + require.NoError(t, err) + require.NoError(t, conn.Close()) + + to, err := Driver{}.Open(map[string]any{"dsn": ""}, false, activity.NewNoopClient(), zap.NewNop()) + require.NoError(t, err) + + olap, _ = to.AsOLAP("") + + tr := NewDuckDBToDuckDB(olap, zap.NewNop()) + + // transfer once + err = tr.Transfer(context.Background(), map[string]any{"sql": "SELECT * FROM foo", "db": filepath.Join(tempDir, "tranfser.db")}, map[string]any{"table": "test"}, &drivers.TransferOptions{Progress: drivers.NoOpProgress{}}) + require.NoError(t, err) + + rows, err := olap.Execute(context.Background(), &drivers.Statement{Query: "SELECT COUNT(*) FROM test"}) + require.NoError(t, err) + + var count int + rows.Next() + require.NoError(t, rows.Scan(&count)) + require.Equal(t, 4, count) + require.NoError(t, rows.Close()) + + // transfer again + err = tr.Transfer(context.Background(), map[string]any{"sql": "SELECT * FROM foo", "db": filepath.Join(tempDir, "tranfser.db")}, map[string]any{"table": "test"}, &drivers.TransferOptions{Progress: drivers.NoOpProgress{}}) + require.NoError(t, err) + + rows, err = olap.Execute(context.Background(), &drivers.Statement{Query: "SELECT COUNT(*) FROM test"}) + require.NoError(t, err) + + rows.Next() + require.NoError(t, rows.Scan(&count)) + require.Equal(t, 4, count) + require.NoError(t, rows.Close()) +} diff --git a/runtime/drivers/duckdb/transporter_motherduck_to_duckDB.go b/runtime/drivers/duckdb/transporter_motherduck_to_duckDB.go index b97198fe127..20d32d4cbc7 100644 --- a/runtime/drivers/duckdb/transporter_motherduck_to_duckDB.go +++ b/runtime/drivers/duckdb/transporter_motherduck_to_duckDB.go @@ -61,16 +61,18 @@ func (t *motherduckToDuckDB) Transfer(ctx context.Context, srcProps, sinkProps m if err != nil { return err } - defer res.Close() - res.Next() var localDB, localSchema string - if err := res.Scan(&localDB, &localSchema); err != nil { - return err + for res.Next() { + if err := res.Scan(&localDB, &localSchema); err != nil { + _ = res.Close() + return err + } } + _ = res.Close() // get token - token := config["token"] + token, _ := config["token"].(string) if token == "" && config["allow_host_access"].(bool) { token = os.Getenv("motherduck_token") } diff --git a/web-common/src/features/sources/modal/AddSourceModal.svelte b/web-common/src/features/sources/modal/AddSourceModal.svelte index 5c599d6d805..bc6f96c6d9f 100644 --- a/web-common/src/features/sources/modal/AddSourceModal.svelte +++ b/web-common/src/features/sources/modal/AddSourceModal.svelte @@ -12,7 +12,7 @@ import Https from "../../../components/icons/connectors/HTTPS.svelte"; import LocalFile from "../../../components/icons/connectors/LocalFile.svelte"; import MicrosoftAzureBlobStorage from "../../../components/icons/connectors/MicrosoftAzureBlobStorage.svelte"; - import MotherDuck from "../../../components/icons/connectors/MotherDuck.svelte"; + import DuckDB from "../../../components/icons/connectors/DuckDB.svelte"; import Postgres from "../../../components/icons/connectors/Postgres.svelte"; import Snowflake from "../../../components/icons/connectors/Snowflake.svelte"; import SQLite from "../../../components/icons/connectors/SQLite.svelte"; @@ -40,7 +40,7 @@ // duckdb "bigquery", "athena", - "motherduck", + "duckdb", "postgres", "sqlite", "snowflake", @@ -55,7 +55,7 @@ // duckdb: DuckDB, bigquery: GoogleBigQuery, athena: AmazonAthena, - motherduck: MotherDuck, + duckdb: DuckDB, postgres: Postgres, sqlite: SQLite, snowflake: Snowflake, diff --git a/web-common/src/features/sources/modal/yupSchemas.ts b/web-common/src/features/sources/modal/yupSchemas.ts index 108f23d7b97..6412c5c3fb5 100644 --- a/web-common/src/features/sources/modal/yupSchemas.ts +++ b/web-common/src/features/sources/modal/yupSchemas.ts @@ -46,9 +46,10 @@ export function getYupSchema(connector: V1ConnectorSpec) { ) .required("Source name is required"), }); - case "motherduck": + case "duckdb": return yup.object().shape({ sql: yup.string().required("sql is required"), + db: yup.string().required("db is required"), sourceName: yup .string() .matches( diff --git a/web-common/src/features/sources/sourceUtils.ts b/web-common/src/features/sources/sourceUtils.ts index 7fec9dd0350..d96091fe266 100644 --- a/web-common/src/features/sources/sourceUtils.ts +++ b/web-common/src/features/sources/sourceUtils.ts @@ -39,6 +39,14 @@ export function compileCreateSourceYAML( delete values.db; delete values.table; break; + case "duckdb": { + const db = values.db as string; + if (db.startsWith("md:")) { + connectorName = "motherduck"; + values.db = db.replace("md:", ""); + } + break; + } } const compiledKeyValues = Object.entries(values)