aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/onsi/gomega/gstruct/fields.go
blob: f3c1575511d0aa2652571d4647f78ab467547d02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package gstruct

import (
	"errors"
	"fmt"
	"reflect"
	"runtime/debug"
	"strings"

	"github.com/onsi/gomega/format"
	errorsutil "github.com/onsi/gomega/gstruct/errors"
	"github.com/onsi/gomega/types"
)

//MatchAllFields succeeds if every field of a struct matches the field matcher associated with
//it, and every element matcher is matched.
//  Expect([]string{"a", "b"}).To(MatchAllFields(idFn, gstruct.Fields{
//      "a": BeEqual("a"),
//      "b": BeEqual("b"),
//  })
func MatchAllFields(fields Fields) types.GomegaMatcher {
	return &FieldsMatcher{
		Fields: fields,
	}
}

//MatchFields succeeds if each element of a struct matches the field matcher associated with
//it. It can ignore extra fields and/or missing fields.
//  Expect([]string{"a", "c"}).To(MatchFields(idFn, IgnoreMissing|IgnoreExtra, gstruct.Fields{
//      "a": BeEqual("a")
//      "b": BeEqual("b"),
//  })
func MatchFields(options Options, fields Fields) types.GomegaMatcher {
	return &FieldsMatcher{
		Fields:        fields,
		IgnoreExtras:  options&IgnoreExtras != 0,
		IgnoreMissing: options&IgnoreMissing != 0,
	}
}

type FieldsMatcher struct {
	// Matchers for each field.
	Fields Fields

	// Whether to ignore extra elements or consider it an error.
	IgnoreExtras bool
	// Whether to ignore missing elements or consider it an error.
	IgnoreMissing bool

	// State.
	failures []error
}

// Field name to matcher.
type Fields map[string]types.GomegaMatcher

func (m *FieldsMatcher) Match(actual interface{}) (success bool, err error) {
	if reflect.TypeOf(actual).Kind() != reflect.Struct {
		return false, fmt.Errorf("%v is type %T, expected struct", actual, actual)
	}

	m.failures = m.matchFields(actual)
	if len(m.failures) > 0 {
		return false, nil
	}
	return true, nil
}

func (m *FieldsMatcher) matchFields(actual interface{}) (errs []error) {
	val := reflect.ValueOf(actual)
	typ := val.Type()
	fields := map[string]bool{}
	for i := 0; i < val.NumField(); i++ {
		fieldName := typ.Field(i).Name
		fields[fieldName] = true

		err := func() (err error) {
			// This test relies heavily on reflect, which tends to panic.
			// Recover here to provide more useful error messages in that case.
			defer func() {
				if r := recover(); r != nil {
					err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
				}
			}()

			matcher, expected := m.Fields[fieldName]
			if !expected {
				if !m.IgnoreExtras {
					return fmt.Errorf("unexpected field %s: %+v", fieldName, actual)
				}
				return nil
			}

			var field interface{}
			if val.Field(i).IsValid() {
				field = val.Field(i).Interface()
			} else {
				field = reflect.Zero(typ.Field(i).Type)
			}

			match, err := matcher.Match(field)
			if err != nil {
				return err
			} else if !match {
				if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
					return errorsutil.AggregateError(nesting.Failures())
				}
				return errors.New(matcher.FailureMessage(field))
			}
			return nil
		}()
		if err != nil {
			errs = append(errs, errorsutil.Nest("."+fieldName, err))
		}
	}

	for field := range m.Fields {
		if !fields[field] && !m.IgnoreMissing {
			errs = append(errs, fmt.Errorf("missing expected field %s", field))
		}
	}

	return errs
}

func (m *FieldsMatcher) FailureMessage(actual interface{}) (message string) {
	failures := make([]string, len(m.failures))
	for i := range m.failures {
		failures[i] = m.failures[i].Error()
	}
	return format.Message(reflect.TypeOf(actual).Name(),
		fmt.Sprintf("to match fields: {\n%v\n}\n", strings.Join(failures, "\n")))
}

func (m *FieldsMatcher) NegatedFailureMessage(actual interface{}) (message string) {
	return format.Message(actual, "not to match fields")
}

func (m *FieldsMatcher) Failures() []error {
	return m.failures
}