123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562 |
- package inject
- import (
- "bytes"
- "fmt"
- "math/rand"
- "reflect"
- )
- type Logger interface {
- Debugf(format string, v ...interface{})
- }
- func Populate(values ...interface{}) error {
- var g Graph
- for _, v := range values {
- if err := g.Provide(&Object{Value: v}); err != nil {
- return err
- }
- }
- return g.Populate()
- }
- type Object struct {
- Value interface{}
- Name string
- Complete bool
- Fields map[string]*Object
- reflectType reflect.Type
- reflectValue reflect.Value
- private bool
- created bool
- embedded bool
- }
- func (o *Object) String() string {
- var buf bytes.Buffer
- fmt.Fprint(&buf, o.reflectType)
- if o.Name != "" {
- fmt.Fprintf(&buf, " named %s", o.Name)
- }
- return buf.String()
- }
- func (o *Object) addDep(field string, dep *Object) {
- if o.Fields == nil {
- o.Fields = make(map[string]*Object)
- }
- o.Fields[field] = dep
- }
- type Graph struct {
- Logger Logger
- unnamed []*Object
- unnamedType map[reflect.Type]bool
- named map[string]*Object
- }
- func (g *Graph) GetObjectByName(name string) (interface{}, error) {
- existing := g.named[name]
- if existing == nil {
- return nil, fmt.Errorf("did not find object named %s", name)
- }
- return existing.Value, nil
- }
- func (g *Graph) Provide(objects ...*Object) error {
- for _, o := range objects {
- o.reflectType = reflect.TypeOf(o.Value)
- o.reflectValue = reflect.ValueOf(o.Value)
- if o.Fields != nil {
- return fmt.Errorf(
- "fields were specified on object %s when it was provided",
- o,
- )
- }
- if o.Name == "" {
- if !isStructPtr(o.reflectType) {
- return fmt.Errorf(
- "expected unnamed object value to be a pointer to a struct but got type %s "+
- "with value %v",
- o.reflectType,
- o.Value,
- )
- }
- if !o.private {
- if g.unnamedType == nil {
- g.unnamedType = make(map[reflect.Type]bool)
- }
- if g.unnamedType[o.reflectType] {
- return fmt.Errorf(
- "provided two unnamed instances of type *%s.%s",
- o.reflectType.Elem().PkgPath(), o.reflectType.Elem().Name(),
- )
- }
- g.unnamedType[o.reflectType] = true
- }
- g.unnamed = append(g.unnamed, o)
- } else {
- if g.named == nil {
- g.named = make(map[string]*Object)
- }
- if g.named[o.Name] != nil {
- return fmt.Errorf("provided two instances named %s", o.Name)
- }
- g.named[o.Name] = o
- }
- if g.Logger != nil {
- if o.created {
- g.Logger.Debugf("created %s", o)
- } else if o.embedded {
- g.Logger.Debugf("provided embedded %s", o)
- } else {
- g.Logger.Debugf("provided %s", o)
- }
- }
- }
- return nil
- }
- func (g *Graph) Populate() error {
- for _, o := range g.named {
- if o.Complete {
- continue
- }
- if err := g.populateExplicit(o); err != nil {
- return err
- }
- }
-
-
- i := 0
- for {
- if i == len(g.unnamed) {
- break
- }
- o := g.unnamed[i]
- i++
- if o.Complete {
- continue
- }
- if err := g.populateExplicit(o); err != nil {
- return err
- }
- }
-
-
- for _, o := range g.unnamed {
- if o.Complete {
- continue
- }
- if err := g.populateUnnamedInterface(o); err != nil {
- return err
- }
- }
- for _, o := range g.named {
- if o.Complete {
- continue
- }
- if err := g.populateUnnamedInterface(o); err != nil {
- return err
- }
- }
- return nil
- }
- func (g *Graph) populateExplicit(o *Object) error {
-
- if o.Name != "" && !isStructPtr(o.reflectType) {
- return nil
- }
- StructLoop:
- for i := 0; i < o.reflectValue.Elem().NumField(); i++ {
- field := o.reflectValue.Elem().Field(i)
- fieldType := field.Type()
- fieldTag := o.reflectType.Elem().Field(i).Tag
- fieldName := o.reflectType.Elem().Field(i).Name
- tag, err := parseTag(string(fieldTag))
- if err != nil {
- return fmt.Errorf(
- "unexpected tag format `%s` for field %s in type %s",
- string(fieldTag),
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
-
- if tag == nil {
- continue
- }
-
- if !field.CanSet() {
- return fmt.Errorf(
- "inject requested on unexported field %s in type %s",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
-
- if tag.Inline && fieldType.Kind() != reflect.Struct {
- return fmt.Errorf(
- "inline requested on non inlined field %s in type %s",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
-
- if !isNilOrZero(field, fieldType) {
- continue
- }
-
- if tag.Name != "" {
- existing := g.named[tag.Name]
- if existing == nil {
- return fmt.Errorf(
- "did not find object named %s required by field %s in type %s",
- tag.Name,
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
- if !existing.reflectType.AssignableTo(fieldType) {
- return fmt.Errorf(
- "object named %s of type %s is not assignable to field %s (%s) in type %s",
- tag.Name,
- fieldType,
- o.reflectType.Elem().Field(i).Name,
- existing.reflectType,
- o.reflectType,
- )
- }
- field.Set(reflect.ValueOf(existing.Value))
- if g.Logger != nil {
- g.Logger.Debugf(
- "assigned %s to field %s in %s",
- existing,
- o.reflectType.Elem().Field(i).Name,
- o,
- )
- }
- o.addDep(fieldName, existing)
- continue StructLoop
- }
-
-
- if fieldType.Kind() == reflect.Struct {
- if tag.Private {
- return fmt.Errorf(
- "cannot use private inject on inline struct on field %s in type %s",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
- if !tag.Inline {
- return fmt.Errorf(
- "inline struct on field %s in type %s requires an explicit \"inline\" tag",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
- err := g.Provide(&Object{
- Value: field.Addr().Interface(),
- private: true,
- embedded: o.reflectType.Elem().Field(i).Anonymous,
- })
- if err != nil {
- return err
- }
- continue
- }
-
- if fieldType.Kind() == reflect.Interface {
- continue
- }
-
- if fieldType.Kind() == reflect.Map {
- if !tag.Private {
- return fmt.Errorf(
- "inject on map field %s in type %s must be named or private",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
- field.Set(reflect.MakeMap(fieldType))
- if g.Logger != nil {
- g.Logger.Debugf(
- "made map for field %s in %s",
- o.reflectType.Elem().Field(i).Name,
- o,
- )
- }
- continue
- }
-
- if !isStructPtr(fieldType) {
- return fmt.Errorf(
- "found inject tag on unsupported field %s in type %s",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
-
-
- if !tag.Private {
- for _, existing := range g.unnamed {
- if existing.private {
- continue
- }
- if existing.reflectType.AssignableTo(fieldType) {
- field.Set(reflect.ValueOf(existing.Value))
- if g.Logger != nil {
- g.Logger.Debugf(
- "assigned existing %s to field %s in %s",
- existing,
- o.reflectType.Elem().Field(i).Name,
- o,
- )
- }
- o.addDep(fieldName, existing)
- continue StructLoop
- }
- }
- }
- newValue := reflect.New(fieldType.Elem())
- newObject := &Object{
- Value: newValue.Interface(),
- private: tag.Private,
- created: true,
- }
-
- err = g.Provide(newObject)
- if err != nil {
- return err
- }
-
- field.Set(newValue)
- if g.Logger != nil {
- g.Logger.Debugf(
- "assigned newly created %s to field %s in %s",
- newObject,
- o.reflectType.Elem().Field(i).Name,
- o,
- )
- }
- o.addDep(fieldName, newObject)
- }
- return nil
- }
- func (g *Graph) populateUnnamedInterface(o *Object) error {
-
- if o.Name != "" && !isStructPtr(o.reflectType) {
- return nil
- }
- for i := 0; i < o.reflectValue.Elem().NumField(); i++ {
- field := o.reflectValue.Elem().Field(i)
- fieldType := field.Type()
- fieldTag := o.reflectType.Elem().Field(i).Tag
- fieldName := o.reflectType.Elem().Field(i).Name
- tag, err := parseTag(string(fieldTag))
- if err != nil {
- return fmt.Errorf(
- "unexpected tag format `%s` for field %s in type %s",
- string(fieldTag),
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
-
- if tag == nil {
- continue
- }
-
-
- if fieldType.Kind() != reflect.Interface {
- continue
- }
-
-
- if tag.Private {
- return fmt.Errorf(
- "found private inject tag on interface field %s in type %s",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
-
- if !isNilOrZero(field, fieldType) {
- continue
- }
-
- if tag.Name != "" {
- panic(fmt.Sprintf("unhandled named instance with name %s", tag.Name))
- }
-
- var found *Object
- for _, existing := range g.unnamed {
- if existing.private {
- continue
- }
- if existing.reflectType.AssignableTo(fieldType) {
- if found != nil {
- return fmt.Errorf(
- "found two assignable values for field %s in type %s. one type "+
- "%s with value %v and another type %s with value %v",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- found.reflectType,
- found.Value,
- existing.reflectType,
- existing.reflectValue,
- )
- }
- found = existing
- field.Set(reflect.ValueOf(existing.Value))
- if g.Logger != nil {
- g.Logger.Debugf(
- "assigned existing %s to interface field %s in %s",
- existing,
- o.reflectType.Elem().Field(i).Name,
- o,
- )
- }
- o.addDep(fieldName, existing)
- }
- }
-
- if found == nil {
- return fmt.Errorf(
- "found no assignable value for field %s in type %s",
- o.reflectType.Elem().Field(i).Name,
- o.reflectType,
- )
- }
- }
- return nil
- }
- func (g *Graph) Objects() []*Object {
- objects := make([]*Object, 0, len(g.unnamed)+len(g.named))
- for _, o := range g.unnamed {
- if !o.embedded {
- objects = append(objects, o)
- }
- }
- for _, o := range g.named {
- if !o.embedded {
- objects = append(objects, o)
- }
- }
-
- for i := 0; i < len(objects); i++ {
- j := rand.Intn(i + 1)
- objects[i], objects[j] = objects[j], objects[i]
- }
- return objects
- }
- var (
- injectOnly = &tag{}
- injectPrivate = &tag{Private: true}
- injectInline = &tag{Inline: true}
- )
- type tag struct {
- Name string
- Inline bool
- Private bool
- }
- func parseTag(t string) (*tag, error) {
- found, value, err := Extract("inject", t)
- if err != nil {
- return nil, err
- }
- if !found {
- return nil, nil
- }
- if value == "" {
- return injectOnly, nil
- }
- if value == "inline" {
- return injectInline, nil
- }
- if value == "private" {
- return injectPrivate, nil
- }
- return &tag{Name: value}, nil
- }
- func isStructPtr(t reflect.Type) bool {
- return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
- }
- func isNilOrZero(v reflect.Value, t reflect.Type) bool {
- switch v.Kind() {
- default:
- return reflect.DeepEqual(v.Interface(), reflect.Zero(t).Interface())
- case reflect.Interface, reflect.Ptr:
- return v.IsNil()
- }
- }
|