aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/onsi/gomega/gstruct/fields.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/onsi/gomega/gstruct/fields.go')
-rw-r--r--vendor/github.com/onsi/gomega/gstruct/fields.go141
1 files changed, 141 insertions, 0 deletions
diff --git a/vendor/github.com/onsi/gomega/gstruct/fields.go b/vendor/github.com/onsi/gomega/gstruct/fields.go
new file mode 100644
index 0000000..f3c1575
--- /dev/null
+++ b/vendor/github.com/onsi/gomega/gstruct/fields.go
@@ -0,0 +1,141 @@
+package gstruct
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "runtime/debug"
+ "strings"
+
+ "github.com/onsi/gomega/format"
+ errorsutil "github.com/onsi/gomega/gstruct/errors"
+ "github.com/onsi/gomega/types"
+)
+
+//MatchAllFields succeeds if every field of a struct matches the field matcher associated with
+//it, and every element matcher is matched.
+// Expect([]string{"a", "b"}).To(MatchAllFields(idFn, gstruct.Fields{
+// "a": BeEqual("a"),
+// "b": BeEqual("b"),
+// })
+func MatchAllFields(fields Fields) types.GomegaMatcher {
+ return &FieldsMatcher{
+ Fields: fields,
+ }
+}
+
+//MatchFields succeeds if each element of a struct matches the field matcher associated with
+//it. It can ignore extra fields and/or missing fields.
+// Expect([]string{"a", "c"}).To(MatchFields(idFn, IgnoreMissing|IgnoreExtra, gstruct.Fields{
+// "a": BeEqual("a")
+// "b": BeEqual("b"),
+// })
+func MatchFields(options Options, fields Fields) types.GomegaMatcher {
+ return &FieldsMatcher{
+ Fields: fields,
+ IgnoreExtras: options&IgnoreExtras != 0,
+ IgnoreMissing: options&IgnoreMissing != 0,
+ }
+}
+
+type FieldsMatcher struct {
+ // Matchers for each field.
+ Fields Fields
+
+ // Whether to ignore extra elements or consider it an error.
+ IgnoreExtras bool
+ // Whether to ignore missing elements or consider it an error.
+ IgnoreMissing bool
+
+ // State.
+ failures []error
+}
+
+// Field name to matcher.
+type Fields map[string]types.GomegaMatcher
+
+func (m *FieldsMatcher) Match(actual interface{}) (success bool, err error) {
+ if reflect.TypeOf(actual).Kind() != reflect.Struct {
+ return false, fmt.Errorf("%v is type %T, expected struct", actual, actual)
+ }
+
+ m.failures = m.matchFields(actual)
+ if len(m.failures) > 0 {
+ return false, nil
+ }
+ return true, nil
+}
+
+func (m *FieldsMatcher) matchFields(actual interface{}) (errs []error) {
+ val := reflect.ValueOf(actual)
+ typ := val.Type()
+ fields := map[string]bool{}
+ for i := 0; i < val.NumField(); i++ {
+ fieldName := typ.Field(i).Name
+ fields[fieldName] = true
+
+ err := func() (err error) {
+ // This test relies heavily on reflect, which tends to panic.
+ // Recover here to provide more useful error messages in that case.
+ defer func() {
+ if r := recover(); r != nil {
+ err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
+ }
+ }()
+
+ matcher, expected := m.Fields[fieldName]
+ if !expected {
+ if !m.IgnoreExtras {
+ return fmt.Errorf("unexpected field %s: %+v", fieldName, actual)
+ }
+ return nil
+ }
+
+ var field interface{}
+ if val.Field(i).IsValid() {
+ field = val.Field(i).Interface()
+ } else {
+ field = reflect.Zero(typ.Field(i).Type)
+ }
+
+ match, err := matcher.Match(field)
+ if err != nil {
+ return err
+ } else if !match {
+ if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
+ return errorsutil.AggregateError(nesting.Failures())
+ }
+ return errors.New(matcher.FailureMessage(field))
+ }
+ return nil
+ }()
+ if err != nil {
+ errs = append(errs, errorsutil.Nest("."+fieldName, err))
+ }
+ }
+
+ for field := range m.Fields {
+ if !fields[field] && !m.IgnoreMissing {
+ errs = append(errs, fmt.Errorf("missing expected field %s", field))
+ }
+ }
+
+ return errs
+}
+
+func (m *FieldsMatcher) FailureMessage(actual interface{}) (message string) {
+ failures := make([]string, len(m.failures))
+ for i := range m.failures {
+ failures[i] = m.failures[i].Error()
+ }
+ return format.Message(reflect.TypeOf(actual).Name(),
+ fmt.Sprintf("to match fields: {\n%v\n}\n", strings.Join(failures, "\n")))
+}
+
+func (m *FieldsMatcher) NegatedFailureMessage(actual interface{}) (message string) {
+ return format.Message(actual, "not to match fields")
+}
+
+func (m *FieldsMatcher) Failures() []error {
+ return m.failures
+}