// Package di is a minimal reflection-based dependency injection helper. // // Usage pattern: // // type Config struct { // Logger di.Provide[*slog.Logger] // Server func(logger *slog.Logger) (*http.Server, error) // } // // type Result struct { // Server *http.Server // } // // cfg := Config{ /* constructors here ... */ } // var res Result // if err := di.Build(cfg, &res); err != nil { ... }\ // // Or with New(cfg): // // server, err := di.New[*http.Server](cfg) // // See the doc comments for Build and New for more details package di import ( "errors" "fmt" "reflect" "strings" ) type Provide[Out any] struct { fOrV any } // SideEffect is a sentinel value representing a constructor used only for side // effects such as introducing two components together without a circular // dependency. // // It's idiomatic to have your config declare a SideEffects field of type // []di.MustProvide[di.SideEffect], and have a "_" field in your result struct // of type []di.SideEffect. type SideEffect struct{} func NewSideEffect(f any) (Provide[SideEffect], error) { return NewProvide[SideEffect](f) } func MustSideEffect(f any) Provide[SideEffect] { return Must(NewSideEffect(f)) } type provideI interface { diOutType() reflect.Type diPayload() any } func (p Provide[Out]) diOutType() reflect.Type { return typeOf[Out]() } func (p Provide[Out]) diPayload() any { return p.fOrV } func Must[T any](t T, err error) T { if err != nil { panic(err) } return t } func MustProvide[Out any](ctorOrVal any) Provide[Out] { return Must(NewProvide[Out](ctorOrVal)) } func NewProvide[Out any](ctorOrVal any) (Provide[Out], error) { outType := typeOf[Out]() if ctorOrVal == nil { if canBeNil(outType) { return Provide[Out]{fOrV: nil}, nil } return Provide[Out]{}, fmt.Errorf("Provide[%v]: nil not valid for non-nilable type", outType) } t := reflect.TypeOf(ctorOrVal) // Case 1: function returning Out or (Out, error). Arbitrary args allowed. if t.Kind() == reflect.Func { nout := t.NumOut() switch nout { case 1: if !t.Out(0).AssignableTo(outType) { return Provide[Out]{}, fmt.Errorf("Provide[%v]: function return %v is not assignable to %v", outType, t.Out(0), outType) } return Provide[Out]{fOrV: ctorOrVal}, nil case 2: if !t.Out(0).AssignableTo(outType) { return Provide[Out]{}, fmt.Errorf("Provide[%v]: first return %v is not assignable to %v", outType, t.Out(0), outType) } if !isErrorType(t.Out(1)) { return Provide[Out]{}, fmt.Errorf("Provide[%v]: second return must be error, got %v", outType, t.Out(1)) } return Provide[Out]{fOrV: ctorOrVal}, nil default: return Provide[Out]{}, fmt.Errorf("Provide[%v]: function must return Out or (Out, error); got %d returns", outType, nout) } } // Case 2: value assignable to Out (covers interface satisfaction) if !t.AssignableTo(outType) { return Provide[Out]{}, fmt.Errorf("Provide[%v]: value of type %v is not assignable to %v", outType, t, outType) } return Provide[Out]{fOrV: ctorOrVal}, nil } // Helpers func typeOf[T any]() reflect.Type { return reflect.TypeOf((*T)(nil)).Elem() } func canBeNil(t reflect.Type) bool { switch t.Kind() { case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Slice: return true default: return false } } type ctor struct { name string fn reflect.Value out reflect.Type } type collector struct { // Singular instances & providers values map[reflect.Type]reflect.Value // exact type -> instance providers map[reflect.Type][]ctor // out type -> ctors // List providers for []T listValues map[reflect.Type][]reflect.Value // elem dynamic type -> instances listProviders map[reflect.Type][]ctor // elem out type -> ctors listPresence map[reflect.Type]bool // elem T present explicitly (even if empty) // Cycle detection resolving map[reflect.Type]bool } func New[R any, C any](config C) (R, error) { var r R err := Build(config, &r) return r, err } // Build resolves only what's needed to populate exported fields in result. // Supports arbitrarily nested provider namespaced structs inside config. // // # TODO rewrite this doc // // Supported providers in config (at any nesting depth): // - Provide[T] // single constructor or value for T // - []Provide[T] // list of constructors/values contributing to []T // - func(...Deps) T / (T,error) // singular constructor for T // - value of type T // preprovided singular value // - []func(...Deps) T/(T,error) // contributes to []T // - []T // contributes to []T // // Non-func, non-zero exported fields remain prebound instances. // result must be a pointer to a struct or value. If it is a value the config // must define how to construct it. func Build[C any, R any](config C, result R) error { cfgV := reflect.ValueOf(config) for cfgV.IsValid() && cfgV.Kind() == reflect.Pointer { if cfgV.IsNil() { return errors.New("config pointer is nil") } cfgV = cfgV.Elem() } if !cfgV.IsValid() || cfgV.Kind() != reflect.Struct { return errors.New("config must be a struct or pointer to struct") } resV := reflect.ValueOf(result) if !resV.IsValid() || resV.Kind() != reflect.Pointer { return errors.New("result must be a pointer to struct or a pointer to a value") } c := &collector{ values: make(map[reflect.Type]reflect.Value), providers: make(map[reflect.Type][]ctor), listValues: make(map[reflect.Type][]reflect.Value), listProviders: make(map[reflect.Type][]ctor), listPresence: make(map[reflect.Type]bool), resolving: make(map[reflect.Type]bool), } // Provide access to the Config value itself c.values[cfgV.Type()] = cfgV if err := c.collect(cfgV, ""); err != nil { return err } // Can we just resolve the direct result type? if v, err := c.resolve(resV.Elem().Type()); err == nil { resV.Elem().Set(v) return nil } if resV.Elem().Kind() != reflect.Struct { return fmt.Errorf("couldn't build result direct, and can not fill result as it is not a struct") } // Populate result fields of the struct var missing []string resStruct := resV.Elem() resT := resStruct.Type() for i := 0; i < resT.NumField(); i++ { sf := resT.Field(i) if sf.Name != "_" && sf.PkgPath != "" { continue } fv := resStruct.Field(i) if !fv.IsZero() { continue } v, err := c.resolve(sf.Type) if err != nil { missing = append(missing, fmt.Sprintf("%s (%s): %v", sf.Name, sf.Type, err)) continue } // Evaluate, but do not set underscore field names if sf.Name != "_" { fv.Set(v) } } if len(missing) > 0 { return fmt.Errorf("failed to build result fields:\n - %s", strings.Join(missing, "\n - ")) } return nil } func (c *collector) collect(v reflect.Value, path string) error { if !v.IsValid() { return nil } // Deref pointers for v.Kind() == reflect.Pointer { if v.IsNil() { return nil } v = v.Elem() } if v.Kind() != reflect.Struct { return nil } t := v.Type() for i := 0; i < t.NumField(); i++ { sf := t.Field(i) if sf.PkgPath != "" { // unexported continue } fv := v.Field(i) name := sf.Name if path != "" { name = path + "." + sf.Name } // Provide[T] (singular) must be recognized BEFORE treating structs as namespaces. if pi, ok := asProvide(fv); ok { outT := pi.diOutType() payload := pi.diPayload() if payload == nil { c.values[outT] = reflect.Zero(outT) continue } pt := reflect.TypeOf(payload) if pt.Kind() == reflect.Func { if err := validateCtorSignature(pt, name); err != nil { return err } c.providers[pt.Out(0)] = append(c.providers[pt.Out(0)], ctor{name: name, fn: reflect.ValueOf(payload), out: pt.Out(0)}) } else { c.values[pt] = reflect.ValueOf(payload) } continue } // Namespace recursion for embedded structs if sf.Anonymous && sf.Type.Kind() == reflect.Struct { // Provide access to the nested config value itself c.values[sf.Type] = fv if err := c.collect(fv, name); err != nil { return err } continue } // []Provide[T] (list) if fv.Kind() == reflect.Slice && fv.Type().Elem().Kind() == reflect.Struct { if provideElem, ok := reflect.New(fv.Type().Elem()).Elem().Interface().(provideI); ok { if fv.Len() == 0 && !fv.IsNil() { // Set presence of empty list outT := provideElem.diOutType() c.listPresence[outT] = true } for j := 0; j < fv.Len(); j++ { pi := fv.Index(j).Interface().(provideI) outT := pi.diOutType() c.listPresence[outT] = true payload := pi.diPayload() if payload == nil { c.listValues[outT] = append(c.listValues[outT], reflect.Zero(outT)) continue } pt := reflect.TypeOf(payload) if pt.Kind() == reflect.Func { if err := validateCtorSignature(pt, fmt.Sprintf("%s[%d]", name, j)); err != nil { return err } c.listProviders[pt.Out(0)] = append(c.listProviders[pt.Out(0)], ctor{ name: fmt.Sprintf("%s[%d]", name, j), fn: reflect.ValueOf(payload), out: pt.Out(0), }) } else { c.listValues[pt] = append(c.listValues[pt], reflect.ValueOf(payload)) } } continue } } // Fallback to earlier forms (singular func/value; list of funcs/values) switch sf.Type.Kind() { case reflect.Func: ft := fv.Type() if err := validateCtorSignature(ft, name); err != nil { return err } c.providers[ft.Out(0)] = append(c.providers[ft.Out(0)], ctor{name: name, fn: fv, out: ft.Out(0)}) case reflect.Slice: elemT := sf.Type.Elem() if elemT.Kind() == reflect.Func { for j := 0; j < fv.Len(); j++ { fn := fv.Index(j) ft := fn.Type() if err := validateCtorSignature(ft, fmt.Sprintf("%s[%d]", name, j)); err != nil { return err } c.listProviders[ft.Out(0)] = append(c.listProviders[ft.Out(0)], ctor{ name: fmt.Sprintf("%s[%d]", name, j), fn: fn, out: ft.Out(0), }) } } else { if fv.Len() == 0 { c.listPresence[elemT] = true // explicit empty list present } for j := 0; j < fv.Len(); j++ { vj := fv.Index(j) c.listValues[vj.Type()] = append(c.listValues[vj.Type()], vj) } } default: // preprovided singular instance (non-zero only) if !fv.IsZero() { c.values[fv.Type()] = fv } } } return nil } func (c *collector) resolve(t reflect.Type) (reflect.Value, error) { // Cached exact? if v, ok := c.values[t]; ok { return v, nil } // Slice resolution []T if t.Kind() == reflect.Slice { elem := t.Elem() var elems []reflect.Value found := false // Explicit list values (concrete types assignable to elem) for haveT, vals := range c.listValues { if isAssignableOrImpl(haveT, elem) { found = true elems = append(elems, vals...) } } // From list constructors whose out is assignable to elem for outT, ctors := range c.listProviders { if !isAssignableOrImpl(outT, elem) { continue } found = true for _, ctor := range ctors { if c.resolving[t] { return reflect.Value{}, fmt.Errorf("dependency cycle detected at %s", t) } c.resolving[t] = true ft := ctor.fn.Type() args := make([]reflect.Value, ft.NumIn()) for i := 0; i < ft.NumIn(); i++ { paramT := ft.In(i) arg, err := c.resolve(paramT) if err != nil { delete(c.resolving, t) return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctor.name, paramT, err) } if !isAssignableOrImpl(arg.Type(), paramT) { delete(c.resolving, t) return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctor.name, arg.Type(), paramT) } args[i] = arg } outs := ctor.fn.Call(args) delete(c.resolving, t) if len(outs) == 2 && !outs[1].IsNil() { return reflect.Value{}, fmt.Errorf("%s error: %w", ctor.name, outs[1].Interface().(error)) } elems = append(elems, outs[0]) } } // If an explicit provider for elem T exists (e.g., []Provide[T] or []T present but empty), // we should still succeed with an empty slice. if !found { if c.listPresence[elem] { c.values[t] = reflect.MakeSlice(t, 0, 0) return c.values[t], nil } return reflect.Value{}, fmt.Errorf("no provider for %s", t) } slice := reflect.MakeSlice(t, 0, len(elems)) for _, e := range elems { slice = reflect.Append(slice, e) } c.values[t] = slice return slice, nil } // Try existing instances for interface targets (singular) if t.Kind() == reflect.Interface { for haveT, v := range c.values { if haveT.Implements(t) { return v, nil } } } // Cycle detection (singular) if c.resolving[t] { return reflect.Value{}, fmt.Errorf("dependency cycle detected at %s", t) } c.resolving[t] = true defer delete(c.resolving, t) // Pick singular provider(s) var candidates []ctor if ps, ok := c.providers[t]; ok { candidates = append(candidates, ps...) } else if t.Kind() == reflect.Interface { for outT, ps := range c.providers { if outT.Implements(t) { candidates = append(candidates, ps...) } } } if len(candidates) == 0 { return reflect.Value{}, fmt.Errorf("no provider for %s", t) } if len(candidates) > 1 { var names []string for _, candidate := range candidates { names = append(names, candidate.name+" -> "+candidate.out.String()) } return reflect.Value{}, fmt.Errorf("ambiguous providers for %s: %s", t, strings.Join(names, ", ")) } impl := candidates[0] ft := impl.fn.Type() args := make([]reflect.Value, ft.NumIn()) for i := 0; i < ft.NumIn(); i++ { paramT := ft.In(i) arg, err := c.resolve(paramT) if err != nil { return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", impl.name, paramT, err) } if !isAssignableOrImpl(arg.Type(), paramT) { return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", impl.name, arg.Type(), paramT) } args[i] = arg } outs := impl.fn.Call(args) if len(outs) == 2 { if !outs[1].IsNil() { return reflect.Value{}, fmt.Errorf("%s error: %w", impl.name, outs[1].Interface().(error)) } c.values[impl.out] = outs[0] return outs[0], nil } c.values[impl.out] = outs[0] return outs[0], nil } // --- helpers --- func validateCtorSignature(ft reflect.Type, name string) error { if ft.IsVariadic() { return fmt.Errorf("constructor %q: variadics not supported", name) } nout := ft.NumOut() if nout == 1 { return nil } if nout == 2 && isErrorType(ft.Out(1)) { return nil } return fmt.Errorf("constructor %q: must return (T) or (T, error); got %d returns", name, nout) } func isAssignableOrImpl(have, want reflect.Type) bool { return have.AssignableTo(want) || (want.Kind() == reflect.Interface && have.Implements(want)) } func isErrorType(t reflect.Type) bool { return t == reflect.TypeOf((*error)(nil)).Elem() } // asProvide tries to view v as a Provide[*]. Returns (iface, true) if so. func asProvide(v reflect.Value) (provideI, bool) { if !v.IsValid() { return nil, false } x := v.Interface() pi, ok := x.(provideI) return pi, ok }