From baa3947d3bc1bddb0a034f5054121078962a226a Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Fri, 22 Aug 2025 16:09:13 -0700 Subject: [PATCH] Initial Version --- .build.yml | 9 + LICENSE | 11 ++ README.md | 9 + di.go | 553 +++++++++++++++++++++++++++++++++++++++++++++++++++++ di_test.go | 362 +++++++++++++++++++++++++++++++++++ go.mod | 3 + 6 files changed, 947 insertions(+) create mode 100644 .build.yml create mode 100644 LICENSE create mode 100644 README.md create mode 100644 di.go create mode 100644 di_test.go create mode 100644 go.mod diff --git a/.build.yml b/.build.yml new file mode 100644 index 0000000000000000000000000000000000000000..b1ac015063bff09c18ee82638ee79085d4b7978c --- /dev/null +++ b/.build.yml @@ -0,0 +1,9 @@ +image: alpine/edge +packages: + - go +sources: + - https://git.sr.ht/~marcopolo/di +tasks: + - test: | + cd di + go test ./... diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2b993be89b1dc9a1a29faea09aba4477f99d8830 --- /dev/null +++ b/LICENSE @@ -0,0 +1,11 @@ +Copyright 2025 Marco Munizaga + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4c04289d451c1d2f957bc3baca2ea269a840150e --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# di + +A small straightforward DI library for Go. + +## Usage + +## Examples + +## Why another DI library? diff --git a/di.go b/di.go new file mode 100644 index 0000000000000000000000000000000000000000..057138b050f66fd7b7bfee1cfadc1b17290b7d30 --- /dev/null +++ b/di.go @@ -0,0 +1,553 @@ +// Package di is a minimal reflection-based dependency injection helper. +// +// Usage pattern: +// +// type Config struct { +// Logger di.Provide[*slog.Logger] +// Server func(logger *slog.Logger) (*http.Server, error) +// } +// +// type Result struct { +// Server *http.Server +// } +// +// cfg := Config{ /* constructors here ... */ } +// var res Result +// if err := di.Build(cfg, &res); err != nil { ... }\ +// +// Or with New(cfg): +// +// server, err := di.New[*http.Server](cfg) +// +// See the doc comments for Build and New for more details +package di + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +type Provide[Out any] struct { + fOrV any +} + +// SideEffect is a sentinel value representing a constructor used only for side +// effects such as introducing two components together without a circular +// dependency. +// +// It's idiomatic to have your config declare a SideEffects field of type +// []di.MustProvide[di.SideEffect], and have a "_" field in your result struct +// of type []di.SideEffect. +type SideEffect struct{} + +func NewSideEffect(f any) (Provide[SideEffect], error) { + return NewProvide[SideEffect](f) +} + +func MustSideEffect(f any) Provide[SideEffect] { + return Must(NewSideEffect(f)) +} + +type provideI interface { + diOutType() reflect.Type + diPayload() any +} + +func (p Provide[Out]) diOutType() reflect.Type { return typeOf[Out]() } +func (p Provide[Out]) diPayload() any { return p.fOrV } + +func Must[T any](t T, err error) T { + if err != nil { + panic(err) + } + return t +} + +func MustProvide[Out any](ctorOrVal any) Provide[Out] { + return Must(NewProvide[Out](ctorOrVal)) +} + +func NewProvide[Out any](ctorOrVal any) (Provide[Out], error) { + outType := typeOf[Out]() + + if ctorOrVal == nil { + if canBeNil(outType) { + return Provide[Out]{fOrV: nil}, nil + } + return Provide[Out]{}, fmt.Errorf("Provide[%v]: nil not valid for non-nilable type", outType) + } + + t := reflect.TypeOf(ctorOrVal) + + // Case 1: function returning Out or (Out, error). Arbitrary args allowed. + if t.Kind() == reflect.Func { + nout := t.NumOut() + switch nout { + case 1: + if !t.Out(0).AssignableTo(outType) { + return Provide[Out]{}, fmt.Errorf("Provide[%v]: function return %v is not assignable to %v", + outType, t.Out(0), outType) + } + return Provide[Out]{fOrV: ctorOrVal}, nil + + case 2: + if !t.Out(0).AssignableTo(outType) { + return Provide[Out]{}, fmt.Errorf("Provide[%v]: first return %v is not assignable to %v", + outType, t.Out(0), outType) + } + if !isErrorType(t.Out(1)) { + return Provide[Out]{}, fmt.Errorf("Provide[%v]: second return must be error, got %v", + outType, t.Out(1)) + } + return Provide[Out]{fOrV: ctorOrVal}, nil + + default: + return Provide[Out]{}, fmt.Errorf("Provide[%v]: function must return Out or (Out, error); got %d returns", + outType, nout) + } + } + + // Case 2: value assignable to Out (covers interface satisfaction) + if !t.AssignableTo(outType) { + return Provide[Out]{}, fmt.Errorf("Provide[%v]: value of type %v is not assignable to %v", outType, t, outType) + } + return Provide[Out]{fOrV: ctorOrVal}, nil +} + +// Helpers +func typeOf[T any]() reflect.Type { + return reflect.TypeOf((*T)(nil)).Elem() +} + +func canBeNil(t reflect.Type) bool { + switch t.Kind() { + case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Slice: + return true + default: + return false + } +} + +type ctor struct { + name string + fn reflect.Value + out reflect.Type +} + +type collector struct { + // Singular instances & providers + values map[reflect.Type]reflect.Value // exact type -> instance + providers map[reflect.Type][]ctor // out type -> ctors + + // List providers for []T + listValues map[reflect.Type][]reflect.Value // elem dynamic type -> instances + listProviders map[reflect.Type][]ctor // elem out type -> ctors + listPresence map[reflect.Type]bool // elem T present explicitly (even if empty) + + // Cycle detection + resolving map[reflect.Type]bool +} + +func New[R any, C any](config C) (R, error) { + var r R + err := Build(config, &r) + return r, err +} + +// Build resolves only what's needed to populate exported fields in result. +// Supports arbitrarily nested provider namespaced structs inside config. +// +// # TODO rewrite this doc +// +// Supported providers in config (at any nesting depth): +// - Provide[T] // single constructor or value for T +// - []Provide[T] // list of constructors/values contributing to []T +// - func(...Deps) T / (T,error) // singular constructor for T +// - value of type T // preprovided singular value +// - []func(...Deps) T/(T,error) // contributes to []T +// - []T // contributes to []T +// +// Non-func, non-zero exported fields remain prebound instances. +// result must be a pointer to a struct or value. If it is a value the config +// must define how to construct it. +func Build[C any, R any](config C, result R) error { + cfgV := reflect.ValueOf(config) + for cfgV.IsValid() && cfgV.Kind() == reflect.Pointer { + if cfgV.IsNil() { + return errors.New("config pointer is nil") + } + cfgV = cfgV.Elem() + } + if !cfgV.IsValid() || cfgV.Kind() != reflect.Struct { + return errors.New("config must be a struct or pointer to struct") + } + + resV := reflect.ValueOf(result) + if !resV.IsValid() || resV.Kind() != reflect.Pointer { + return errors.New("result must be a pointer to struct or a pointer to a value") + } + + c := &collector{ + values: make(map[reflect.Type]reflect.Value), + providers: make(map[reflect.Type][]ctor), + listValues: make(map[reflect.Type][]reflect.Value), + listProviders: make(map[reflect.Type][]ctor), + listPresence: make(map[reflect.Type]bool), + resolving: make(map[reflect.Type]bool), + } + + // Provide access to the Config value itself + c.values[cfgV.Type()] = cfgV + if err := c.collect(cfgV, ""); err != nil { + return err + } + + // Can we just resolve the direct result type? + if v, err := c.resolve(resV.Elem().Type()); err == nil { + resV.Elem().Set(v) + return nil + } + + if resV.Elem().Kind() != reflect.Struct { + return fmt.Errorf("couldn't build result direct, and can not fill result as it is not a struct") + } + + // Populate result fields of the struct + var missing []string + resStruct := resV.Elem() + resT := resStruct.Type() + for i := 0; i < resT.NumField(); i++ { + sf := resT.Field(i) + if sf.Name != "_" && sf.PkgPath != "" { + continue + } + fv := resStruct.Field(i) + if !fv.IsZero() { + continue + } + v, err := c.resolve(sf.Type) + if err != nil { + missing = append(missing, fmt.Sprintf("%s (%s): %v", sf.Name, sf.Type, err)) + continue + } + + // Evaluate, but do not set underscore field names + if sf.Name != "_" { + fv.Set(v) + } + } + if len(missing) > 0 { + return fmt.Errorf("failed to build result fields:\n - %s", strings.Join(missing, "\n - ")) + } + return nil +} + +func (c *collector) collect(v reflect.Value, path string) error { + if !v.IsValid() { + return nil + } + // Deref pointers + for v.Kind() == reflect.Pointer { + if v.IsNil() { + return nil + } + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return nil + } + + t := v.Type() + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + if sf.PkgPath != "" { // unexported + continue + } + fv := v.Field(i) + name := sf.Name + if path != "" { + name = path + "." + sf.Name + } + + // Provide[T] (singular) must be recognized BEFORE treating structs as namespaces. + if pi, ok := asProvide(fv); ok { + outT := pi.diOutType() + payload := pi.diPayload() + if payload == nil { + c.values[outT] = reflect.Zero(outT) + continue + } + pt := reflect.TypeOf(payload) + if pt.Kind() == reflect.Func { + if err := validateCtorSignature(pt, name); err != nil { + return err + } + c.providers[pt.Out(0)] = append(c.providers[pt.Out(0)], ctor{name: name, fn: reflect.ValueOf(payload), out: pt.Out(0)}) + } else { + c.values[pt] = reflect.ValueOf(payload) + } + continue + } + + // Namespace recursion for embedded structs + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + // Provide access to the nested config value itself + c.values[sf.Type] = fv + + if err := c.collect(fv, name); err != nil { + return err + } + continue + } + + // []Provide[T] (list) + if fv.Kind() == reflect.Slice && fv.Type().Elem().Kind() == reflect.Struct { + if provideElem, ok := reflect.New(fv.Type().Elem()).Elem().Interface().(provideI); ok { + if fv.Len() == 0 && !fv.IsNil() { + // Set presence of empty list + outT := provideElem.diOutType() + c.listPresence[outT] = true + } + + for j := 0; j < fv.Len(); j++ { + pi := fv.Index(j).Interface().(provideI) + outT := pi.diOutType() + c.listPresence[outT] = true + payload := pi.diPayload() + if payload == nil { + c.listValues[outT] = append(c.listValues[outT], reflect.Zero(outT)) + continue + } + pt := reflect.TypeOf(payload) + if pt.Kind() == reflect.Func { + if err := validateCtorSignature(pt, fmt.Sprintf("%s[%d]", name, j)); err != nil { + return err + } + c.listProviders[pt.Out(0)] = append(c.listProviders[pt.Out(0)], ctor{ + name: fmt.Sprintf("%s[%d]", name, j), + fn: reflect.ValueOf(payload), + out: pt.Out(0), + }) + } else { + c.listValues[pt] = append(c.listValues[pt], reflect.ValueOf(payload)) + } + } + continue + } + } + + // Fallback to earlier forms (singular func/value; list of funcs/values) + switch sf.Type.Kind() { + case reflect.Func: + ft := fv.Type() + if err := validateCtorSignature(ft, name); err != nil { + return err + } + c.providers[ft.Out(0)] = append(c.providers[ft.Out(0)], ctor{name: name, fn: fv, out: ft.Out(0)}) + + case reflect.Slice: + elemT := sf.Type.Elem() + if elemT.Kind() == reflect.Func { + for j := 0; j < fv.Len(); j++ { + fn := fv.Index(j) + ft := fn.Type() + if err := validateCtorSignature(ft, fmt.Sprintf("%s[%d]", name, j)); err != nil { + return err + } + c.listProviders[ft.Out(0)] = append(c.listProviders[ft.Out(0)], ctor{ + name: fmt.Sprintf("%s[%d]", name, j), + fn: fn, out: ft.Out(0), + }) + } + } else { + if fv.Len() == 0 { + c.listPresence[elemT] = true // explicit empty list present + } + for j := 0; j < fv.Len(); j++ { + vj := fv.Index(j) + c.listValues[vj.Type()] = append(c.listValues[vj.Type()], vj) + } + } + + default: + // preprovided singular instance (non-zero only) + if !fv.IsZero() { + c.values[fv.Type()] = fv + } + } + } + return nil +} + +func (c *collector) resolve(t reflect.Type) (reflect.Value, error) { + // Cached exact? + if v, ok := c.values[t]; ok { + return v, nil + } + + // Slice resolution []T + if t.Kind() == reflect.Slice { + elem := t.Elem() + var elems []reflect.Value + found := false + + // Explicit list values (concrete types assignable to elem) + for haveT, vals := range c.listValues { + if isAssignableOrImpl(haveT, elem) { + found = true + elems = append(elems, vals...) + } + } + // From list constructors whose out is assignable to elem + for outT, ctors := range c.listProviders { + if !isAssignableOrImpl(outT, elem) { + continue + } + found = true + for _, ctor := range ctors { + if c.resolving[t] { + return reflect.Value{}, fmt.Errorf("dependency cycle detected at %s", t) + } + c.resolving[t] = true + ft := ctor.fn.Type() + args := make([]reflect.Value, ft.NumIn()) + for i := 0; i < ft.NumIn(); i++ { + paramT := ft.In(i) + arg, err := c.resolve(paramT) + if err != nil { + delete(c.resolving, t) + return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctor.name, paramT, err) + } + if !isAssignableOrImpl(arg.Type(), paramT) { + delete(c.resolving, t) + return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctor.name, arg.Type(), paramT) + } + args[i] = arg + } + outs := ctor.fn.Call(args) + delete(c.resolving, t) + if len(outs) == 2 && !outs[1].IsNil() { + return reflect.Value{}, fmt.Errorf("%s error: %w", ctor.name, outs[1].Interface().(error)) + } + elems = append(elems, outs[0]) + } + } + + // If an explicit provider for elem T exists (e.g., []Provide[T] or []T present but empty), + // we should still succeed with an empty slice. + if !found { + if c.listPresence[elem] { + c.values[t] = reflect.MakeSlice(t, 0, 0) + return c.values[t], nil + } + return reflect.Value{}, fmt.Errorf("no provider for %s", t) + } + + slice := reflect.MakeSlice(t, 0, len(elems)) + for _, e := range elems { + slice = reflect.Append(slice, e) + } + c.values[t] = slice + return slice, nil + } + + // Try existing instances for interface targets (singular) + if t.Kind() == reflect.Interface { + for haveT, v := range c.values { + if haveT.Implements(t) { + return v, nil + } + } + } + + // Cycle detection (singular) + if c.resolving[t] { + return reflect.Value{}, fmt.Errorf("dependency cycle detected at %s", t) + } + c.resolving[t] = true + defer delete(c.resolving, t) + + // Pick singular provider(s) + var candidates []ctor + if ps, ok := c.providers[t]; ok { + candidates = append(candidates, ps...) + } else if t.Kind() == reflect.Interface { + for outT, ps := range c.providers { + if outT.Implements(t) { + candidates = append(candidates, ps...) + } + } + } + + if len(candidates) == 0 { + return reflect.Value{}, fmt.Errorf("no provider for %s", t) + } + if len(candidates) > 1 { + var names []string + for _, candidate := range candidates { + names = append(names, candidate.name+" -> "+candidate.out.String()) + } + return reflect.Value{}, fmt.Errorf("ambiguous providers for %s: %s", t, strings.Join(names, ", ")) + } + + impl := candidates[0] + ft := impl.fn.Type() + args := make([]reflect.Value, ft.NumIn()) + for i := 0; i < ft.NumIn(); i++ { + paramT := ft.In(i) + arg, err := c.resolve(paramT) + if err != nil { + return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", impl.name, paramT, err) + } + if !isAssignableOrImpl(arg.Type(), paramT) { + return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", impl.name, arg.Type(), paramT) + } + args[i] = arg + } + outs := impl.fn.Call(args) + if len(outs) == 2 { + if !outs[1].IsNil() { + return reflect.Value{}, fmt.Errorf("%s error: %w", impl.name, outs[1].Interface().(error)) + } + c.values[impl.out] = outs[0] + return outs[0], nil + } + c.values[impl.out] = outs[0] + return outs[0], nil +} + +// --- helpers --- + +func validateCtorSignature(ft reflect.Type, name string) error { + if ft.IsVariadic() { + return fmt.Errorf("constructor %q: variadics not supported", name) + } + nout := ft.NumOut() + if nout == 1 { + return nil + } + if nout == 2 && isErrorType(ft.Out(1)) { + return nil + } + return fmt.Errorf("constructor %q: must return (T) or (T, error); got %d returns", name, nout) +} + +func isAssignableOrImpl(have, want reflect.Type) bool { + return have.AssignableTo(want) || (want.Kind() == reflect.Interface && have.Implements(want)) +} + +func isErrorType(t reflect.Type) bool { + return t == reflect.TypeOf((*error)(nil)).Elem() +} + +// asProvide tries to view v as a Provide[*]. Returns (iface, true) if so. +func asProvide(v reflect.Value) (provideI, bool) { + if !v.IsValid() { + return nil, false + } + x := v.Interface() + pi, ok := x.(provideI) + return pi, ok +} diff --git a/di_test.go b/di_test.go new file mode 100644 index 0000000000000000000000000000000000000000..41104c5441a07d4b14c2820793396c1fa5822381 --- /dev/null +++ b/di_test.go @@ -0,0 +1,362 @@ +package di + +import ( + "errors" + "strings" + "testing" +) + +func TestBuildSuccess(t *testing.T) { + type A struct { + val string + } + type C struct { + val int + } + type B struct { + a *A + cs []C + } + + type Config struct { + MakeA Provide[*A] + MakeB Provide[*B] + MakeCs []Provide[C] + } + + cfg := Config{ + MakeA: MustProvide[*A](func() (*A, error) { + return &A{val: "hello"}, nil + }), + MakeB: MustProvide[*B](func(a *A, cs []C) *B { + return &B{a: a, cs: cs} + }), + MakeCs: []Provide[C]{ + MustProvide[C](C{val: 1}), + MustProvide[C](func() (C, error) { + return C{val: 2}, nil + })}, + } + + type Result struct { + A *A + B *B + } + var res Result + err := Build(cfg, &res) + if err != nil { + t.Fatalf("Build failed: %v", err) + } + if res.A == nil { + t.Fatalf("expected res.A to be populated") + } + if res.B == nil { + t.Fatalf("expected res.B to be populated") + } + if res.B.a != res.A { + t.Fatalf("expected B.a to reference A instance") + } + if len(res.B.cs) != 2 { + t.Fatalf("wrong count. Saw %d", len(res.B.cs)) + } + if res.B.cs[0].val != 1 { + t.Fatalf("wrong value") + } + if res.B.cs[1].val != 2 { + t.Fatalf("wrong value") + } + if res.A.val != "hello" { + t.Fatalf("unexpected A value: %s", res.A.val) + } +} + +func TestBuildSuccess2(t *testing.T) { + type A struct { + val string + } + type B struct { + a *A + } + + type Config struct { + MakeA func() (*A, error) + MakeB func(*A) (*B, error) + } + + type Result struct { + A *A + } + + cfg := Config{ + MakeA: func() (*A, error) { + return &A{val: "hello"}, nil + }, + MakeB: func(a *A) (*B, error) { + panic("Unexpected call to MakeB") + // (removed unreachable code after panic) + }, + } + + var res Result + err := Build(cfg, &res) + if err != nil { + t.Fatalf("Build failed: %v", err) + } + if res.A == nil { + t.Fatalf("expected res.A to be populated") + } + if res.A.val != "hello" { + t.Fatalf("unexpected A value: %s", res.A.val) + } +} + +// Test that constructor error is propagated. +func TestBuildConstructorError(t *testing.T) { + type A struct{} + sentinel := errors.New("boom") + + type Config struct { + MakeA func() (*A, error) + } + type Result struct { + A *A + } + + cfg := Config{ + MakeA: func() (*A, error) { + return nil, sentinel + }, + } + var res Result + err := Build(cfg, &res) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "MakeA") { + t.Fatalf("expected error to mention constructor name, got: %v", err) + } + if !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected original error message, got: %v", err) + } + if res.A != nil { + t.Fatalf("result A should not be populated on constructor failure") + } +} + +// Test missing dependency (constructor requires *A but *A not provided). +func TestBuildMissingDependency(t *testing.T) { + type A struct{} + type B struct { + a *A + } + + type Config struct { + MakeB func(*A) (*B, error) + } + type Result struct { + B *B + } + + cfg := Config{ + MakeB: func(a *A) (*B, error) { + return &B{a: a}, nil + }, + } + var res Result + err := Build(cfg, &res) + if err == nil { + t.Fatalf("expected missing dependency error") + } + // Parameter type string should appear (may be *di.A). + if !strings.Contains(err.Error(), "*di.A") && !strings.Contains(err.Error(), "di.A") { + t.Fatalf("expected error to mention missing type *di.A, got: %v", err) + } + if res.B != nil { + t.Fatalf("result B should not be populated") + } +} + +// Test cycle detection between X and Y. +func TestBuildCycleDetection(t *testing.T) { + type X struct{} + type Y struct{} + + type Config struct { + MakeX func(*Y) *X + MakeY func(*X) *Y + } + type Result struct { + X *X + Y *Y + } + + cfg := Config{ + MakeX: func(y *Y) *X { + return &X{} + }, + MakeY: func(x *X) *Y { + return &Y{} + }, + } + + var res Result + err := Build(cfg, &res) + if err == nil { + t.Fatalf("expected cycle detection error") + } + // Both constructors should still be listed as remaining. + if !strings.Contains(err.Error(), "MakeX") || !strings.Contains(err.Error(), "MakeY") { + t.Fatalf("expected error to list remaining constructors MakeX and MakeY, got: %v", err) + } + if res.X != nil || res.Y != nil { + t.Fatalf("cycle should prevent any construction; got X=%v Y=%v", res.X, res.Y) + } +} + +// Ensure that providing pre-supplied value satisfies dependency without constructor for it. +func TestBuildWithPreSuppliedValue(t *testing.T) { + type A struct { + v int + } + type B struct { + a *A + } + + type Config struct { + // Only constructor for B; A provided directly in config. + A *A + MB func(*A) (*B, error) + } + type Result struct { + A *A + B *B + } + + cfg := Config{ + A: &A{v: 42}, + MB: func(a *A) (*B, error) { + return &B{a: a}, nil + }, + } + + var res Result + err := Build(cfg, &res) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.A == nil || res.A.v != 42 { + t.Fatalf("expected pre-supplied A (42), got %+v", res.A) + } + if res.B == nil || res.B.a != res.A { + t.Fatalf("expected B referencing A, got %+v", res.B) + } +} + +func TestTypeAlias(t *testing.T) { + type ANum int + type Config struct { + A ANum + B int + } + + type Result struct { + A ANum + } + var res Result + err := Build(Config{3, 4}, &res) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.A != 3 { + t.Fatalf("expected A=3, got %v", res.A) + } +} + +func TestReferenceConfig(t *testing.T) { + type NestedConfig struct { + OtherSetting bool + NestedDecision func(c NestedConfig) uint + } + + type Config struct { + NestedConfig + SomeSetting bool + Inner func(c Config) int + } + + type Result struct { + A int + B uint + } + var res Result + err := Build(Config{ + SomeSetting: true, + Inner: func(c Config) int { + if c.SomeSetting { + return 1 + } + return 0 + }, + NestedConfig: NestedConfig{ + OtherSetting: true, + NestedDecision: func(c NestedConfig) uint { + if c.OtherSetting { + return 1 + } + return 0 + }, + }, + }, &res) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.A != 1 { + t.Fatalf("expected A=1, got %v", res.A) + } + if res.B != 1 { + t.Fatalf("expected B=1, got %v", res.A) + } +} + +func TestNew(t *testing.T) { + type Config struct { + A int + } + + type Result struct { + A int + } + + cfg := Config{A: 42} + res, err := New[Result](cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.A != 42 { + t.Fatalf("expected A=42, got %v", res.A) + } +} + +func TestSpecificTypes(t *testing.T) { + type Config struct { + A int + } + + cfg := Config{A: 42} + res, err := New[int](cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res != 42 { + t.Fatalf("expected A=42, got %v", res) + } + + var res2 int + err = Build(Config{A: 42}, &res2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res2 != 42 { + t.Fatalf("expected A=42, got %v", res2) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000000000000000000000000000000000000..b2f14743c93c9d15817b806b02de22ebbd266457 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.sr.ht/~marcopolo/di + +go 1.24.0