diff --git a/httputilx/header/example_test.go b/httputilx/header/example_test.go index 2a0605f..e38e751 100644 --- a/httputilx/header/example_test.go +++ b/httputilx/header/example_test.go @@ -10,11 +10,11 @@ func ExampleSetCSP() { static := "static.example.com" headers := make(http.Header) header.SetCSP(headers, header.CSPArgs{ // nolint - header.CSPDefaultSrc: {header.CSPSourceNone}, - header.CSPScriptSrc: {static}, - header.CSPStyleSrc: {static, header.CSPSourceUnsafeInline}, - header.CSPFormAction: {header.CSPSourceSelf}, - header.CSPReportURI: {"/csp"}, + {header.CSPDefaultSrc, header.CSPSourceNone}, + {header.CSPScriptSrc, static}, + {header.CSPStyleSrc, static, header.CSPSourceUnsafeInline}, + {header.CSPFormAction, header.CSPSourceSelf}, + {header.CSPReportURI, "/csp"}, }) // Output: diff --git a/httputilx/header/set.go b/httputilx/header/set.go index 874b301..59652ad 100644 --- a/httputilx/header/set.go +++ b/httputilx/header/set.go @@ -139,7 +139,7 @@ const ( ) // CSPArgs are arguments for SetCSP(). -type CSPArgs map[string][]string +type CSPArgs [][]string // SetCSP sets a Content-Security-Policy header. // @@ -166,13 +166,17 @@ func SetCSP(header http.Header, args CSPArgs) error { var b strings.Builder i := 1 - for k, v := range args { - b.WriteString(k) + for _, v := range args { + if len(v) < 2 { + return errors.New("expected pair of values") + } + + b.WriteString(v[0]) b.WriteString(" ") - for j := range v { - b.WriteString(v[j]) - if j != len(v)-1 { + for j := range v[1:] { + b.WriteString(v[j+1]) + if j != len(v)-2 { b.WriteString(" ") } } diff --git a/httputilx/header/set_test.go b/httputilx/header/set_test.go index f88e543..07414a3 100644 --- a/httputilx/header/set_test.go +++ b/httputilx/header/set_test.go @@ -57,17 +57,17 @@ func TestCSP(t *testing.T) { }{ {CSPArgs{}, ""}, { - CSPArgs{CSPDefaultSrc: {CSPSourceSelf}}, + CSPArgs{{CSPDefaultSrc, CSPSourceSelf}}, "default-src 'self'", }, { - CSPArgs{CSPDefaultSrc: {CSPSourceSelf, "https://example.com"}}, + CSPArgs{{CSPDefaultSrc, CSPSourceSelf, "https://example.com"}}, "default-src 'self' https://example.com", }, { CSPArgs{ - CSPDefaultSrc: {CSPSourceSelf, "https://example.com"}, - CSPConnectSrc: {"https://a.com", "https://b.com"}, + {CSPDefaultSrc, CSPSourceSelf, "https://example.com"}, + {CSPConnectSrc, "https://a.com", "https://b.com"}, }, "default-src 'self' https://example.com; connect-src https://a.com https://b.com", }, diff --git a/stringutil/stringutil.go b/stringutil/stringutil.go index 37c109b..d7b3fad 100644 --- a/stringutil/stringutil.go +++ b/stringutil/stringutil.go @@ -5,8 +5,22 @@ import ( "regexp" "strings" "unicode" + "unicode/utf8" ) +// Truncate returns the "n" left characters of the string. +func Truncate(s string, n int) string { + if n <= 0 { + return "" + } + + if utf8.RuneCountInString(s) <= n { + return s + } + + return string([]rune(s)[:n]) +} + // Left returns the "n" left characters of the string. // // If the string is shorter than "n" it will return the first "n" characters of diff --git a/stringutil/stringutil_test.go b/stringutil/stringutil_test.go index f2948d3..d297193 100644 --- a/stringutil/stringutil_test.go +++ b/stringutil/stringutil_test.go @@ -6,6 +6,33 @@ import ( "testing" ) +func TestTruncate(t *testing.T) { + cases := []struct { + in string + n int + want string + }{ + {"Hello", 100, "Hello"}, + {"Hello", 1, "H"}, + {"Hello", 5, "Hello"}, + {"Hello", 4, "Hell"}, + {"Hello", 0, ""}, + {"Hello", -2, ""}, + {"汉语漢語", 1, "汉"}, + {"汉语漢語", 3, "汉语漢"}, + {"汉语漢語", 4, "汉语漢語"}, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + out := Truncate(tc.in, tc.n) + if out != tc.want { + t.Errorf("\nout: %#v\nwant: %#v\n", out, tc.want) + } + }) + } +} + func TestLeft(t *testing.T) { cases := []struct { in string