diff --git a/cmd/templ/lspcmd/lsp_test.go b/cmd/templ/lspcmd/lsp_test.go index 041605a55..a1f22f144 100644 --- a/cmd/templ/lspcmd/lsp_test.go +++ b/cmd/templ/lspcmd/lsp_test.go @@ -319,6 +319,126 @@ func TestHover(t *testing.T) { } } +func TestDefinitions(t *testing.T) { + if testing.Short() { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + log, _ := zap.NewProduction() + + ctx, appDir, _, server, teardown, err := Setup(ctx, log) + if err != nil { + t.Fatalf("failed to setup test: %v", err) + return + } + defer teardown(t) + defer cancel() + + log.Info("Calling Definitions") + + tests := []struct { + line int + character int + filename string + assert func(t *testing.T, l []protocol.Location) (msg string, ok bool) + }{ + { + line: 4, + character: 7, + filename: "/remoteParent.templ", + assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) { + expectedReference := []protocol.Location{ + { + // This is the useage of the templ function in the main.go file. + URI: uri.URI("file://" + appDir + "/remoteChild.templ"), + Range: protocol.Range{ + Start: protocol.Position{ + Line: uint32(2), + Character: uint32(6), + }, + End: protocol.Position{ + Line: uint32(2), + Character: uint32(12), + }, + }, + }, + } + if diff := lspdiff.Definitions(expectedReference, actual); diff != "" { + return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false + } + return "", true + }, + }, + { + // This is the usage of the templ function in the main.go file. + line: 25, + character: 10, + filename: "/main.go", + assert: func(t *testing.T, actual []protocol.Location) (msg string, ok bool) { + expectedReference := []protocol.Location{ + { + // Creation of the templ component + URI: uri.URI("file://" + appDir + "/templates.templ"), + Range: protocol.Range{ + Start: protocol.Position{ + Line: uint32(4), + Character: uint32(6), + }, + End: protocol.Position{ + Line: uint32(4), + Character: uint32(10), + }, + }, + }, + } + if diff := lspdiff.Definitions(expectedReference, actual); diff != "" { + return fmt.Sprintf("Expected: %+v\nActual: %+v", expectedReference, actual), false + } + return "", true + }, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { + // Give CI/CD pipeline executors some time because they're often quite slow. + var ok bool + var msg string + for i := 0; i < 3; i++ { + if err != nil { + t.Error(err) + return + } + actual, err := server.Definition(ctx, &protocol.DefinitionParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: uri.URI("file://" + appDir + test.filename), + }, + // Positions are zero indexed. + Position: protocol.Position{ + Line: uint32(test.line - 1), + Character: uint32(test.character - 1), + }, + }, + }) + if err != nil { + t.Errorf("failed to get references: %v", err) + return + } + msg, ok = test.assert(t, actual) + if !ok { + break + } + time.Sleep(time.Millisecond * 500) + } + if !ok { + t.Error(msg) + } + }) + } +} + func TestReferences(t *testing.T) { if testing.Short() { return diff --git a/cmd/templ/lspcmd/lspdiff/lspdiff.go b/cmd/templ/lspcmd/lspdiff/lspdiff.go index 653ee9671..536acb2ce 100644 --- a/cmd/templ/lspcmd/lspdiff/lspdiff.go +++ b/cmd/templ/lspcmd/lspdiff/lspdiff.go @@ -29,6 +29,10 @@ func References(expected, actual []protocol.Location) string { return cmp.Diff(expected, actual) } +func Definitions(expected, actual []protocol.Location) string { + return cmp.Diff(expected, actual) +} + func CompletionListContainsText(cl *protocol.CompletionList, text string) bool { if cl == nil { return false diff --git a/cmd/templ/lspcmd/proxy/server.go b/cmd/templ/lspcmd/proxy/server.go index 02fc83571..3302178d7 100644 --- a/cmd/templ/lspcmd/proxy/server.go +++ b/cmd/templ/lspcmd/proxy/server.go @@ -613,11 +613,14 @@ func (p *Server) Definition(ctx context.Context, params *lsp.DefinitionParams) ( p.Log.Info("client -> server: Definition") defer p.Log.Info("client -> server: Definition end") // Rewrite the request. + originalRequestFromTempl, _ := convertTemplToGoURI(params.TextDocument.URI) templURI := params.TextDocument.URI var ok bool - ok, params.TextDocument.URI, params.Position = p.updatePosition(templURI, params.Position) - if !ok { - return result, nil + if originalRequestFromTempl { + ok, params.TextDocument.URI, params.Position = p.updatePosition(templURI, params.Position) + if !ok { + return result, nil + } } // Call gopls and get the result. result, err = p.Target.Definition(ctx, params) @@ -633,6 +636,17 @@ func (p *Server) Definition(ctx context.Context, params *lsp.DefinitionParams) ( result[i].Range = p.convertGoRangeToTemplRange(templURI, result[i].Range) } } + if !originalRequestFromTempl { + // If the requst came from outside of a templ file, we only care about the templ file references + // The attached gopls will return the others to the IDE + filteredResult := []lsp.Location{} + for i := 0; i < len(result); i++ { + if isTemplFile, _ := convertTemplToGoURI(result[i].URI); isTemplFile { + filteredResult = append(filteredResult, result[i]) + } + } + result = filteredResult + } return }