Skip to content

Commit

Permalink
Add FromCtyValueTagged
Browse files Browse the repository at this point in the history
Like FromCtyValue, but you can pass in a different tag to decode with.
Update code using literal "cty" to use a passed in tag string. Put the
insides of FromCtyValue into FromCtyValueTagged, now FromCtyValue
calls FromCtyValueTagged with literal "cty".
  • Loading branch information
gastrodon committed Apr 8, 2024
1 parent 15a9d85 commit 32b3c5d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 34 deletions.
4 changes: 2 additions & 2 deletions cty/gocty/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ var stringType = reflect.TypeOf("")
//
// This function will panic if two fields within the struct are tagged with
// the same cty attribute name.
func structTagIndices(st reflect.Type) map[string]int {
func structTagIndices(st reflect.Type, tag string) map[string]int {
ct := st.NumField()
ret := make(map[string]int, ct)

for i := 0; i < ct; i++ {
field := st.Field(i)
attrName := field.Tag.Get("cty")
attrName := field.Tag.Get(tag)
if attrName != "" {
ret[attrName] = i
}
Expand Down
2 changes: 1 addition & 1 deletion cty/gocty/in.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func toCtyObject(val reflect.Value, attrTypes map[string]cty.Type, path cty.Path
// path to give us a place to put our GetAttr step.
path = append(path, cty.PathStep(nil))

attrFields := structTagIndices(val.Type())
attrFields := structTagIndices(val.Type(), "cty")

vals := make(map[string]cty.Value, len(attrTypes))
for k, at := range attrTypes {
Expand Down
48 changes: 26 additions & 22 deletions cty/gocty/out.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import (
// The function will panic if given a non-pointer as the Go value target,
// since that is considered to be a bug in the calling program.
func FromCtyValue(val cty.Value, target interface{}) error {
return FromCtyValueTagged(val, target, "cty")
}

func FromCtyValueTagged(val cty.Value, target interface{}, tag string) error {
tVal := reflect.ValueOf(target)
if tVal.Kind() != reflect.Ptr {
panic("target value is not a pointer")
Expand All @@ -40,10 +44,10 @@ func FromCtyValue(val cty.Value, target interface{}) error {
// unused capacity on the end of it, depending on how deeply-recursive
// the given cty.Value is.
path := make(cty.Path, 0)
return fromCtyValue(val, tVal, path)
return fromCtyValue(val, tVal, path, tag)
}

func fromCtyValue(val cty.Value, target reflect.Value, path cty.Path) error {
func fromCtyValue(val cty.Value, target reflect.Value, path cty.Path, tag string) error {
ty := val.Type()

deepTarget := fromCtyPopulatePtr(target, false)
Expand Down Expand Up @@ -89,17 +93,17 @@ func fromCtyValue(val cty.Value, target reflect.Value, path cty.Path) error {

switch {
case ty.IsListType():
return fromCtyList(val, target, path)
return fromCtyList(val, target, path, tag)
case ty.IsMapType():
return fromCtyMap(val, target, path)
return fromCtyMap(val, target, path, tag)
case ty.IsSetType():
return fromCtySet(val, target, path)
return fromCtySet(val, target, path, tag)
case ty.IsObjectType():
return fromCtyObject(val, target, path)
return fromCtyObject(val, target, path, tag)
case ty.IsTupleType():
return fromCtyTuple(val, target, path)
return fromCtyTuple(val, target, path, tag)
case ty.IsCapsuleType():
return fromCtyCapsule(val, target, path)
return fromCtyCapsule(val, target, path, tag)
}

// We should never fall out here; reaching here indicates a bug in this
Expand Down Expand Up @@ -251,7 +255,7 @@ func fromCtyString(val cty.Value, target reflect.Value, path cty.Path) error {
}
}

func fromCtyList(val cty.Value, target reflect.Value, path cty.Path) error {
func fromCtyList(val cty.Value, target reflect.Value, path cty.Path, tag string) error {
switch target.Kind() {

case reflect.Slice:
Expand All @@ -273,7 +277,7 @@ func fromCtyList(val cty.Value, target reflect.Value, path cty.Path) error {
}

targetElem := tv.Index(i)
err = fromCtyValue(val, targetElem, path)
err = fromCtyValue(val, targetElem, path, tag)
if err != nil {
return true
}
Expand Down Expand Up @@ -310,7 +314,7 @@ func fromCtyList(val cty.Value, target reflect.Value, path cty.Path) error {
}

targetElem := target.Index(i)
err = fromCtyValue(val, targetElem, path)
err = fromCtyValue(val, targetElem, path, tag)
if err != nil {
return true
}
Expand All @@ -332,7 +336,7 @@ func fromCtyList(val cty.Value, target reflect.Value, path cty.Path) error {
}
}

func fromCtyMap(val cty.Value, target reflect.Value, path cty.Path) error {
func fromCtyMap(val cty.Value, target reflect.Value, path cty.Path, tag string) error {

switch target.Kind() {

Expand All @@ -356,7 +360,7 @@ func fromCtyMap(val cty.Value, target reflect.Value, path cty.Path) error {
ks := key.AsString()

targetElem := reflect.New(et)
err = fromCtyValue(val, targetElem, path)
err = fromCtyValue(val, targetElem, path, tag)

tv.SetMapIndex(reflect.ValueOf(ks), targetElem.Elem())

Expand All @@ -377,7 +381,7 @@ func fromCtyMap(val cty.Value, target reflect.Value, path cty.Path) error {
}
}

func fromCtySet(val cty.Value, target reflect.Value, path cty.Path) error {
func fromCtySet(val cty.Value, target reflect.Value, path cty.Path, tag string) error {
switch target.Kind() {

case reflect.Slice:
Expand All @@ -393,7 +397,7 @@ func fromCtySet(val cty.Value, target reflect.Value, path cty.Path) error {
var err error
val.ForEachElement(func(key cty.Value, val cty.Value) bool {
targetElem := tv.Index(i)
err = fromCtyValue(val, targetElem, path)
err = fromCtyValue(val, targetElem, path, tag)
if err != nil {
return true
}
Expand Down Expand Up @@ -422,7 +426,7 @@ func fromCtySet(val cty.Value, target reflect.Value, path cty.Path) error {
var err error
val.ForEachElement(func(key cty.Value, val cty.Value) bool {
targetElem := target.Index(i)
err = fromCtyValue(val, targetElem, path)
err = fromCtyValue(val, targetElem, path, tag)
if err != nil {
return true
}
Expand All @@ -444,14 +448,14 @@ func fromCtySet(val cty.Value, target reflect.Value, path cty.Path) error {
}
}

func fromCtyObject(val cty.Value, target reflect.Value, path cty.Path) error {
func fromCtyObject(val cty.Value, target reflect.Value, path cty.Path, tag string) error {

switch target.Kind() {

case reflect.Struct:

attrTypes := val.Type().AttributeTypes()
targetFields := structTagIndices(target.Type())
targetFields := structTagIndices(target.Type(), tag)

path = append(path, nil)

Expand Down Expand Up @@ -482,7 +486,7 @@ func fromCtyObject(val cty.Value, target reflect.Value, path cty.Path) error {
ev := val.GetAttr(k)

targetField := target.Field(fieldIdx)
err := fromCtyValue(ev, targetField, path)
err := fromCtyValue(ev, targetField, path, tag)
if err != nil {
return err
}
Expand All @@ -498,7 +502,7 @@ func fromCtyObject(val cty.Value, target reflect.Value, path cty.Path) error {
}
}

func fromCtyTuple(val cty.Value, target reflect.Value, path cty.Path) error {
func fromCtyTuple(val cty.Value, target reflect.Value, path cty.Path, tag string) error {

switch target.Kind() {

Expand All @@ -521,7 +525,7 @@ func fromCtyTuple(val cty.Value, target reflect.Value, path cty.Path) error {
ev := val.Index(cty.NumberIntVal(int64(i)))

targetField := target.Field(i)
err := fromCtyValue(ev, targetField, path)
err := fromCtyValue(ev, targetField, path, tag)
if err != nil {
return err
}
Expand All @@ -537,7 +541,7 @@ func fromCtyTuple(val cty.Value, target reflect.Value, path cty.Path) error {
}
}

func fromCtyCapsule(val cty.Value, target reflect.Value, path cty.Path) error {
func fromCtyCapsule(val cty.Value, target reflect.Value, path cty.Path, tag string) error {

if target.Kind() == reflect.Ptr {
// Walk through indirection until we get to the last pointer,
Expand Down
29 changes: 29 additions & 0 deletions cty/gocty/out_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,31 @@ func TestOut(t *testing.T) {
}
}

func TestOutOtherTags(t *testing.T) {
taggedCty, taggedOther := new(testStructManyTag), new(testStructManyTag)
err := FromCtyValueTagged(cty.ObjectVal(map[string]cty.Value{
"name": cty.StringVal("Eva"),
}), taggedCty, "cty")
if err != nil {
t.Fatalf("FromCtyValueTagged returned error: %s", err)
}

err = FromCtyValueTagged(cty.ObjectVal(map[string]cty.Value{
"another_name": cty.StringVal("Alice"),
}), taggedOther, "other")
if err != nil {
t.Fatalf("FromCtyValueTagged returned error: %s", err)
}

if taggedCty.Name != "Eva" {
t.Fatalf("taggedCty name mismatch: %s != Eva!", taggedCty.Name)
}

if taggedOther.Name != "Alice" {
t.Fatalf("taggedCty name mismatch: %s != Alice!", taggedOther.Name)
}
}

type testOutAssertFunc func(cty.Value, reflect.Type, interface{}, *testing.T)

func testOutAssertPtrVal(want interface{}) testOutAssertFunc {
Expand Down Expand Up @@ -409,6 +434,10 @@ type testStruct struct {
Number *int `cty:"number"`
}

type testStructManyTag struct {
Name string `cty:"name" other:"another_name"`
}

type testTupleStruct struct {
Name string
Number int
Expand Down
22 changes: 13 additions & 9 deletions cty/gocty/type_implied.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@ import (
// type, because it cannot know the capsule types supported by the calling
// program.
func ImpliedType(gv interface{}) (cty.Type, error) {
return ImpliedTypeTagged(gv, "cty")
}

func ImpliedTypeTagged(gv interface{}, tag string) (cty.Type, error) {
rt := reflect.TypeOf(gv)
var path cty.Path
return impliedType(rt, path)
return impliedType(rt, path, tag)
}

func impliedType(rt reflect.Type, path cty.Path) (cty.Type, error) {
func impliedType(rt reflect.Type, path cty.Path, tag string) (cty.Type, error) {
switch rt.Kind() {

case reflect.Ptr:
return impliedType(rt.Elem(), path)
return impliedType(rt.Elem(), path, tag)

// Primitive types
case reflect.Bool:
Expand All @@ -48,7 +52,7 @@ func impliedType(rt reflect.Type, path cty.Path) (cty.Type, error) {
// Collection types
case reflect.Slice:
path := append(path, cty.IndexStep{Key: cty.UnknownVal(cty.Number)})
ety, err := impliedType(rt.Elem(), path)
ety, err := impliedType(rt.Elem(), path, tag)
if err != nil {
return cty.NilType, err
}
Expand All @@ -58,29 +62,29 @@ func impliedType(rt reflect.Type, path cty.Path) (cty.Type, error) {
return cty.NilType, path.NewErrorf("no cty.Type for %s (must have string keys)", rt)
}
path := append(path, cty.IndexStep{Key: cty.UnknownVal(cty.String)})
ety, err := impliedType(rt.Elem(), path)
ety, err := impliedType(rt.Elem(), path, tag)
if err != nil {
return cty.NilType, err
}
return cty.Map(ety), nil

// Structural types
case reflect.Struct:
return impliedStructType(rt, path)
return impliedStructType(rt, path, tag)

default:
return cty.NilType, path.NewErrorf("no cty.Type for %s", rt)
}
}

func impliedStructType(rt reflect.Type, path cty.Path) (cty.Type, error) {
func impliedStructType(rt reflect.Type, path cty.Path, tag string) (cty.Type, error) {
if valueType.AssignableTo(rt) {
// Special case: cty.Value represents cty.DynamicPseudoType, for
// type conformance checking.
return cty.DynamicPseudoType, nil
}

fieldIdxs := structTagIndices(rt)
fieldIdxs := structTagIndices(rt, tag)
if len(fieldIdxs) == 0 {
return cty.NilType, path.NewErrorf("no cty.Type for %s (no cty field tags)", rt)
}
Expand All @@ -95,7 +99,7 @@ func impliedStructType(rt reflect.Type, path cty.Path) (cty.Type, error) {
path[len(path)-1] = cty.GetAttrStep{Name: k}

ft := rt.Field(fi).Type
aty, err := impliedType(ft, path)
aty, err := impliedType(ft, path, tag)
if err != nil {
return cty.NilType, err
}
Expand Down

0 comments on commit 32b3c5d

Please sign in to comment.