From d14490d6c90f8ca208266e1c3cc59962ca7d99ff Mon Sep 17 00:00:00 2001 From: Stefan Negele Date: Mon, 2 Oct 2023 13:12:14 +0200 Subject: [PATCH] Resolve references --- dataContract.go | 61 +++++++++++++++++++++++++++++++++++++-- dataContract_test.go | 24 +++++++++++++++ main_test.go | 8 ++++- test_resources/model.yaml | 10 +++++++ 4 files changed, 99 insertions(+), 4 deletions(-) create mode 100644 test_resources/model.yaml diff --git a/dataContract.go b/dataContract.go index 888488b9..c727a0c1 100644 --- a/dataContract.go +++ b/dataContract.go @@ -5,9 +5,13 @@ import ( "gopkg.in/yaml.v3" "io" "net/http" + "net/url" "os" + "strings" ) +const referencePrefix = "$ref:" + type DataContract = map[string]interface{} func ReadLocalDataContract(dataContractFileName string) (dataContractFile []byte, err error) { @@ -34,7 +38,7 @@ func GetValue(contract DataContract, path []string) (value interface{}, err erro } if len(path) == 1 { - return contract[fieldName], nil + return resolveValue(contract, fieldName) } next, ok := contract[fieldName].(map[string]interface{}) @@ -45,6 +49,57 @@ func GetValue(contract DataContract, path []string) (value interface{}, err erro return GetValue(next, path[1:]) } +func resolveValue(object map[string]interface{}, fieldName string) (value interface{}, err error) { + value = object[fieldName] + + if stringValue, isString := value.(string); isString && strings.HasPrefix(stringValue, referencePrefix) { + reference := strings.Trim(strings.TrimPrefix(stringValue, referencePrefix), " ") + + value, err = resolveReference(reference) + if err != nil { + return nil, err + } + } + + return value, nil +} + +func resolveReference(reference string) (_ string, err error) { + var bytes []byte + + if isURI(reference) { + bytes, err = resolveReferenceFromRemote(reference) + } else { + bytes, err = resolveReferenceLocally(reference) + } + + if err != nil { + return "", fmt.Errorf("can't resolve reference '%v': %w", reference, err) + } + + return string(bytes), nil +} + +func resolveReferenceLocally(reference string) ([]byte, error) { + return os.ReadFile(reference) +} + +func resolveReferenceFromRemote(reference string) ([]byte, error) { + response, err := http.Get(reference) + defer response.Body.Close() + + if err != nil { + return nil, err + } + + return io.ReadAll(response.Body) +} + +func isURI(reference string) bool { + _, err := url.ParseRequestURI(reference) + return err == nil +} + func FetchDataContract(url string) (result []byte, err error) { response, err := http.Get(url) if err != nil { @@ -53,9 +108,9 @@ func FetchDataContract(url string) (result []byte, err error) { defer response.Body.Close() - if otherContractData, err := io.ReadAll(response.Body); err != nil { + if contractData, err := io.ReadAll(response.Body); err != nil { return nil, fmt.Errorf("failed to read data contract to compare with: %w", err) } else { - return otherContractData, nil + return contractData, nil } } diff --git a/dataContract_test.go b/dataContract_test.go index bf93646c..d4be273f 100644 --- a/dataContract_test.go +++ b/dataContract_test.go @@ -1,11 +1,15 @@ package main import ( + "fmt" + "os" "reflect" "testing" ) func TestGetValue(t *testing.T) { + model, _ := os.ReadFile("./test_resources/model.yaml") + type args struct { contract DataContract path []string @@ -45,6 +49,26 @@ func TestGetValue(t *testing.T) { path: []string{"schema", "type"}}, wantErr: true, }, + { + name: "local reference", + args: args{ + contract: DataContract{"schema": map[string]interface{}{ + "specification": "$ref: test_resources/model.yaml", + }}, + path: []string{"schema", "specification"}}, + wantValue: string(model), + wantErr: false, + }, + { + name: "remote reference", + args: args{ + contract: DataContract{"schema": map[string]interface{}{ + "specification": fmt.Sprintf("$ref: %v/model.yaml", TestResourcesServer.URL), + }}, + path: []string{"schema", "specification"}}, + wantValue: string(model), + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/main_test.go b/main_test.go index b154cfe4..d0bb0a48 100644 --- a/main_test.go +++ b/main_test.go @@ -1,9 +1,15 @@ package main import ( + "net/http" + "net/http/httptest" "testing" ) -func TestMainMethod(t *testing.T) { +var TestResourcesServer = httptest.NewServer(http.FileServer(http.Dir("./test_resources"))) +func TestMain(m *testing.M) { + defer TestResourcesServer.Close() + + m.Run() } diff --git a/test_resources/model.yaml b/test_resources/model.yaml new file mode 100644 index 00000000..6b81e456 --- /dev/null +++ b/test_resources/model.yaml @@ -0,0 +1,10 @@ +version: 2 +models: + - name: my_table + description: "contains data" + config: + materialized: table + columns: + - name: my_column + data_type: text + description: "contains values"