diff --git a/lib/antlr/create_table.go b/lib/antlr/create_table.go index 9628fe18..3de1f93b 100644 --- a/lib/antlr/create_table.go +++ b/lib/antlr/create_table.go @@ -109,12 +109,12 @@ func processCopyTable(ctx *generated.CopyCreateTableContext) (Event, error) { return nil, fmt.Errorf("expected exactly 2 table names, got %d", len(tableNames)) } - tableName, err := getTextFromSingleNodeBranch(tableNames[0]) + tableName, err := getTableNameFromNode(tableNames[0]) if err != nil { return nil, err } - copiedFromTableName, err := getTextFromSingleNodeBranch(tableNames[1]) + copiedFromTableName, err := getTableNameFromNode(tableNames[1]) if err != nil { return nil, err } diff --git a/lib/antlr/create_table_test.go b/lib/antlr/create_table_test.go index 32e4d6e4..e2aa770d 100644 --- a/lib/antlr/create_table_test.go +++ b/lib/antlr/create_table_test.go @@ -9,16 +9,10 @@ import ( func TestCreateTable(t *testing.T) { { - // Create table LIKE - sameQueries := []string{ - "CREATE TABLE table_name LIKE other_table;", - "create table table_name (like other_table);", - } - - for _, query := range sameQueries { - events, err := Parse(query) + { + // Create table LIKE by specifying schema + events, err := Parse("CREATE TABLE db_name.table_name LIKE db_name.other_table;") assert.NoError(t, err) - assert.Len(t, events, 1) createTableEvent, isOk := events[0].(CopyTableEvent) assert.True(t, isOk) @@ -27,7 +21,26 @@ func TestCreateTable(t *testing.T) { assert.Len(t, createTableEvent.GetColumns(), 0) assert.Equal(t, "other_table", createTableEvent.GetCopyFromTableName()) } - + { + // Create table LIKE + sameQueries := []string{ + "CREATE TABLE table_name LIKE other_table;", + "create table table_name (like other_table);", + } + + for _, query := range sameQueries { + events, err := Parse(query) + assert.NoError(t, err) + assert.Len(t, events, 1) + + createTableEvent, isOk := events[0].(CopyTableEvent) + assert.True(t, isOk) + + assert.Equal(t, "table_name", createTableEvent.GetTable()) + assert.Len(t, createTableEvent.GetColumns(), 0) + assert.Equal(t, "other_table", createTableEvent.GetCopyFromTableName()) + } + } } { // Create table with column as CHARACTER SET and collation specified at the column level diff --git a/lib/antlr/rename_table.go b/lib/antlr/rename_table.go index b6f3450f..d82d3005 100644 --- a/lib/antlr/rename_table.go +++ b/lib/antlr/rename_table.go @@ -13,7 +13,12 @@ func processRenameTable(ctx *generated.RenameTableContext) ([]Event, error) { case *generated.RenameTableClauseContext: var allTableNames []string for _, tableName := range castedChild.AllTableName() { - allTableNames = append(allTableNames, tableName.GetText()) + parsedTableName, err := getTableNameFromNode(tableName) + if err != nil { + return nil, fmt.Errorf("failed to get table name: %w", err) + } + + allTableNames = append(allTableNames, parsedTableName) } // Must be at least two table names diff --git a/lib/antlr/rename_table_test.go b/lib/antlr/rename_table_test.go index afae455c..fc36e845 100644 --- a/lib/antlr/rename_table_test.go +++ b/lib/antlr/rename_table_test.go @@ -20,15 +20,15 @@ func TestRenameTable(t *testing.T) { } { // Another one table variant - events, err := Parse(`RENAME TABLE current_db.tbl_name TO other_db.tbl_name;`) + events, err := Parse(`RENAME TABLE current_db.tbl_name TO current_db.tbl_name;`) assert.NoError(t, err) assert.Len(t, events, 1) renameTableEvent, isOk := events[0].(RenameTableEvent) assert.True(t, isOk) - assert.Equal(t, "current_db.tbl_name", renameTableEvent.GetTable()) - assert.Equal(t, "other_db.tbl_name", renameTableEvent.GetNewTableName()) + assert.Equal(t, "tbl_name", renameTableEvent.GetTable()) + assert.Equal(t, "tbl_name", renameTableEvent.GetNewTableName()) } { // Multiple tables diff --git a/lib/antlr/util.go b/lib/antlr/util.go index 8b427f54..e4aec220 100644 --- a/lib/antlr/util.go +++ b/lib/antlr/util.go @@ -41,6 +41,7 @@ func getTextFromSingleNodeBranch(tree antlr.Tree) (string, error) { return getTextFromSingleNodeBranch(tree.GetChild(0)) } +// TODO: Extend this function to return the schema (if present) func getTableNameFromNode(ctx generated.ITableNameContext) (string, error) { children := ctx.GetChildren() if len(children) != 1 {