From ac21236d4423eadd39afd0304a29c5d0e69cd7b9 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 18 Sep 2025 18:27:31 -0700 Subject: [PATCH] fix some bugs around providing slices []T --- di.go | 101 +++++++++++++++++++++++++++++------------------------ di_test.go | 38 ++++++++++++++++++++ 2 files changed, 94 insertions(+), 45 deletions(-) diff --git a/di.go b/di.go index 2aa1d5b3b8e82ce1427931ca35dfc059a353f2fb..6f78b100f38220ad8520c22d33bd6494bfa979b3 100644 --- a/di.go +++ b/di.go @@ -316,7 +316,11 @@ func (c *collector) collect(v reflect.Value, path string) error { if err := validateCtorSignature(pt, name); err != nil { return err } - c.providers[pt.Out(0)] = append(c.providers[pt.Out(0)], ctor{name: name, fn: reflect.ValueOf(payload), out: pt.Out(0)}) + // Validate that the function's return type matches the declared output type + if !pt.Out(0).AssignableTo(outT) { + return fmt.Errorf("constructor %q: function returns %v but Provide declares %v", name, pt.Out(0), outT) + } + c.providers[outT] = append(c.providers[outT], ctor{name: name, fn: reflect.ValueOf(payload), out: outT}) } else { c.values[pt] = reflect.ValueOf(payload) } @@ -418,59 +422,66 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) { // Slice resolution []T if t.Kind() == reflect.Slice { - elem := t.Elem() - var elems []reflect.Value - found := false - - // Explicit list values (concrete types assignable to elem) - for haveT, vals := range c.listValues { - if isAssignableOrImpl(haveT, elem) { - found = true - elems = append(elems, vals...) - } - } - // From list constructors whose out is assignable to elem - for outT, ctors := range c.listProviders { - if !isAssignableOrImpl(outT, elem) { - continue + // First check if there's a direct provider for the slice type + if ps, ok := c.providers[t]; ok && len(ps) > 0 { + // Use the singular provider logic for the slice + // (fall through to singular provider handling below) + } else { + // Try to build slice from individual elements + elem := t.Elem() + var elems []reflect.Value + found := false + + // Explicit list values (concrete types assignable to elem) + for haveT, vals := range c.listValues { + if isAssignableOrImpl(haveT, elem) { + found = true + elems = append(elems, vals...) + } } - found = true - for _, ctor := range ctors { - if c.resolving[t] { - return reflect.Value{}, fmt.Errorf("dependency cycle detected at %s", t) + // From list constructors whose out is assignable to elem + for outT, ctors := range c.listProviders { + if !isAssignableOrImpl(outT, elem) { + continue } - c.resolving[t] = true - ft := ctor.fn.Type() - args, err := c.prepareArgs(ft, ctor.name) - if err != nil { + found = true + for _, ctor := range ctors { + if c.resolving[t] { + return reflect.Value{}, fmt.Errorf("dependency cycle detected at %s", t) + } + c.resolving[t] = true + ft := ctor.fn.Type() + args, err := c.prepareArgs(ft, ctor.name) + if err != nil { + delete(c.resolving, t) + return reflect.Value{}, err + } + outs := ctor.fn.Call(args) delete(c.resolving, t) - return reflect.Value{}, err - } - outs := ctor.fn.Call(args) - delete(c.resolving, t) - if len(outs) == 2 && !outs[1].IsNil() { - return reflect.Value{}, fmt.Errorf("%s error: %w", ctor.name, outs[1].Interface().(error)) + if len(outs) == 2 && !outs[1].IsNil() { + return reflect.Value{}, fmt.Errorf("%s error: %w", ctor.name, outs[1].Interface().(error)) + } + elems = append(elems, outs[0]) } - elems = append(elems, outs[0]) } - } - // If an explicit provider for elem T exists (e.g., []Provide[T] or []T present but empty), - // we should still succeed with an empty slice. - if !found { - if c.listPresence[elem] { - c.values[t] = reflect.MakeSlice(t, 0, 0) - return c.values[t], nil + // If an explicit provider for elem T exists (e.g., []Provide[T] or []T present but empty), + // we should still succeed with an empty slice. + if !found { + if c.listPresence[elem] { + c.values[t] = reflect.MakeSlice(t, 0, 0) + return c.values[t], nil + } + return reflect.Value{}, fmt.Errorf("no provider for %s", t) } - return reflect.Value{}, fmt.Errorf("no provider for %s", t) - } - slice := reflect.MakeSlice(t, 0, len(elems)) - for _, e := range elems { - slice = reflect.Append(slice, e) + slice := reflect.MakeSlice(t, 0, len(elems)) + for _, e := range elems { + slice = reflect.Append(slice, e) + } + c.values[t] = slice + return slice, nil } - c.values[t] = slice - return slice, nil } // Try existing instances for interface targets (singular) diff --git a/di_test.go b/di_test.go index 1be85359971761a077daed5e17d4442d2483004b..847c8aa1ccc72439796cda680ad10f539bacc980 100644 --- a/di_test.go +++ b/di_test.go @@ -210,6 +210,44 @@ func TestBuildSuccess(t *testing.T) { } }, }, + { + name: "provide a slice of []T", + config: struct { + MakeAs Provide[[]A] + MakeB Provide[*B] + MakeCs []Provide[C] + }{ + MakeAs: MustProvide[[]A](func() ([]A, error) { + return []A{{val: "hello"}}, nil + }), + MakeB: MustProvide[*B](func(a []A, cs []C) *B { + return &B{a: &a[0], cs: cs} + }), + MakeCs: []Provide[C]{ + MustProvide[C](C{val: 1}), + MustProvide[C](func() (C, error) { + return C{val: 2}, nil + }), + }, + }, + result: &struct { + B *B + }{}, + verify: func(t *testing.T, result any) { + res := result.(*struct { + B *B + }) + if res.B == nil { + t.Fatalf("expected res.B to be populated") + } + if len(res.B.cs) != 2 { + t.Fatalf("wrong count. Saw %d", len(res.B.cs)) + } + if res.B.a.val != "hello" { + t.Fatalf("unexpected A value: %s", res.B.a.val) + } + }, + }, { name: "simple function constructors", config: struct {