Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow user to disable HTML escaping when marshalling to JSON. #45

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions json.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ func (om *OrderedMap[K, V]) MarshalJSON() ([]byte, error) { //nolint:funlen
return []byte("null"), nil
}

writer := jwriter.Writer{}
writer := jwriter.Writer{
NoEscapeHTML: om.disableHTMLEscape,
}
writer.RawByte('{')

for pair, firstIteration := om.Oldest(), true; pair != nil; pair = pair.Next() {
Expand Down Expand Up @@ -78,14 +80,26 @@ func (om *OrderedMap[K, V]) MarshalJSON() ([]byte, error) { //nolint:funlen

writer.RawByte(':')
// the error is checked at the end of the function
writer.Raw(json.Marshal(pair.Value))
writer.Raw(jsonMarshal(pair.Value, om.disableHTMLEscape))
}

writer.RawByte('}')

return dumpWriter(&writer)
}

func jsonMarshal(t interface{}, disableHTMLEscape bool) ([]byte, error) {
if disableHTMLEscape {
buffer := &bytes.Buffer{}
encoder := json.NewEncoder(buffer)
encoder.SetEscapeHTML(false)
err := encoder.Encode(t)
// Encode() adds an extra newline, strip it off to guarantee same behavior as json.Marshal
return bytes.TrimRight(buffer.Bytes(), "\n"), err
}
return json.Marshal(t)
}

func dumpWriter(writer *jwriter.Writer) ([]byte, error) {
if writer.Error != nil {
return nil, writer.Error
Expand All @@ -103,7 +117,7 @@ func dumpWriter(writer *jwriter.Writer) ([]byte, error) {
// UnmarshalJSON implements the json.Unmarshaler interface.
func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error {
if om.list == nil {
om.initialize(0)
om.initialize(0, om.disableHTMLEscape)
}

return jsonparser.ObjectEach(
Expand Down
20 changes: 20 additions & 0 deletions json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,26 @@ func TestMarshalJSON(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, `{}`, string(b))
})

t.Run("HTML escaping enabled (default)", func(t *testing.T) {
om := New[marshallable, any]()
om.Set(marshallable(1), "hello <strong>this is bold</strong>")
om.Set(marshallable(28), "<?xml version=\"1.0\"?><catalog><book>some book</book></catalog>")

b, err := jsonMarshal(om, false)
assert.NoError(t, err)
assert.Equal(t, `{"#1#":"hello \u003cstrong\u003ethis is bold\u003c/strong\u003e","#28#":"\u003c?xml version=\"1.0\"?\u003e\u003ccatalog\u003e\u003cbook\u003esome book\u003c/book\u003e\u003c/catalog\u003e"}`, string(b))
})

t.Run("HTML escaping disabled", func(t *testing.T) {
om := New[marshallable, any](WithDisableHTMLEscape[marshallable, any]())
om.Set(marshallable(1), "hello <strong>this is bold</strong>")
om.Set(marshallable(28), "<?xml version=\"1.0\"?><catalog><book>some book</book></catalog>")

b, err := jsonMarshal(om, true /* we need to disable HTML escaping here also */)
assert.NoError(t, err)
assert.Equal(t, `{"#1#":"hello <strong>this is bold</strong>","#28#":"<?xml version=\"1.0\"?><catalog><book>some book</book></catalog>"}`, string(b))
})
}

func TestUnmarshallJSON(t *testing.T) {
Expand Down
27 changes: 21 additions & 6 deletions orderedmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ type Pair[K comparable, V any] struct {
}

type OrderedMap[K comparable, V any] struct {
pairs map[K]*Pair[K, V]
list *list.List[*Pair[K, V]]
pairs map[K]*Pair[K, V]
list *list.List[*Pair[K, V]]
disableHTMLEscape bool
}

type initConfig[K comparable, V any] struct {
capacity int
initialData []Pair[K, V]
capacity int
initialData []Pair[K, V]
disableHTMLEscape bool
}

type InitOption[K comparable, V any] func(config *initConfig[K, V])
Expand All @@ -49,6 +51,13 @@ func WithInitialData[K comparable, V any](initialData ...Pair[K, V]) InitOption[
}
}

// WithDisableHTMLEscape disables HTMl escaping when marshalling to JSON
func WithDisableHTMLEscape[K comparable, V any]() InitOption[K, V] {
return func(c *initConfig[K, V]) {
c.disableHTMLEscape = true
}
}

// New creates a new OrderedMap.
// options can either be one or several InitOption[K, V], or a single integer,
// which is then interpreted as a capacity hint, à la make(map[K]V, capacity).
Expand All @@ -63,6 +72,11 @@ func New[K comparable, V any](options ...any) *OrderedMap[K, V] {
invalidOption()
}
config.capacity = option
case bool:
if len(options) != 1 {
invalidOption()
}
config.disableHTMLEscape = option

case InitOption[K, V]:
option(&config)
Expand All @@ -72,7 +86,7 @@ func New[K comparable, V any](options ...any) *OrderedMap[K, V] {
}
}

orderedMap.initialize(config.capacity)
orderedMap.initialize(config.capacity, config.disableHTMLEscape)
orderedMap.AddPairs(config.initialData...)

return orderedMap
Expand All @@ -82,9 +96,10 @@ const invalidOptionMessage = `when using orderedmap.New[K,V]() with options, eit

func invalidOption() { panic(invalidOptionMessage) }

func (om *OrderedMap[K, V]) initialize(capacity int) {
func (om *OrderedMap[K, V]) initialize(capacity int, disableHTMLEscape bool) {
om.pairs = make(map[K]*Pair[K, V], capacity)
om.list = list.New[*Pair[K, V]]()
om.disableHTMLEscape = disableHTMLEscape
}

// Get looks for the given key, and returns the value associated with it,
Expand Down
2 changes: 1 addition & 1 deletion yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (om *OrderedMap[K, V]) UnmarshalYAML(value *yaml.Node) error {
}

if om.list == nil {
om.initialize(0)
om.initialize(0, om.disableHTMLEscape)
}

for index := 0; index < len(value.Content); index += 2 {
Expand Down