@@ 439,19 439,10 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) {
}
c.resolving[t] = true
ft := ctor.fn.Type()
- args := make([]reflect.Value, ft.NumIn())
- for i := 0; i < ft.NumIn(); i++ {
- paramT := ft.In(i)
- arg, err := c.resolve(paramT)
- if err != nil {
- delete(c.resolving, t)
- return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctor.name, paramT, err)
- }
- if !isAssignableOrImpl(arg.Type(), paramT) {
- delete(c.resolving, t)
- return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctor.name, arg.Type(), paramT)
- }
- args[i] = arg
+ args, err := c.prepareArgs(ft, ctor.name)
+ if err != nil {
+ delete(c.resolving, t)
+ return reflect.Value{}, err
}
outs := ctor.fn.Call(args)
delete(c.resolving, t)
@@ 521,17 512,9 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) {
impl := candidates[0]
ft := impl.fn.Type()
- args := make([]reflect.Value, ft.NumIn())
- for i := 0; i < ft.NumIn(); i++ {
- paramT := ft.In(i)
- arg, err := c.resolve(paramT)
- if err != nil {
- return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", impl.name, paramT, err)
- }
- if !isAssignableOrImpl(arg.Type(), paramT) {
- return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", impl.name, arg.Type(), paramT)
- }
- args[i] = arg
+ args, err := c.prepareArgs(ft, impl.name)
+ if err != nil {
+ return reflect.Value{}, err
}
outs := impl.fn.Call(args)
if len(outs) == 2 {
@@ 545,10 528,56 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) {
return outs[0], nil
}
-func validateCtorSignature(ft reflect.Type, name string) error {
+func (c *collector) prepareArgs(ft reflect.Type, name string) ([]reflect.Value, error) {
+ regularParamCount := ft.NumIn()
+ args := make([]reflect.Value, 0, regularParamCount)
if ft.IsVariadic() {
- return fmt.Errorf("constructor %q: variadics not supported", name)
+ regularParamCount--
}
+
+ // Handle regular parameters
+ for i := range regularParamCount {
+ arg, err := c.resolveParam(ft.In(i), name)
+ if err != nil {
+ return nil, err
+ }
+ args = append(args, arg)
+ }
+
+ if ft.IsVariadic() {
+ // Handle variadic parameter ...T by resolving []T
+ variadicParamT := ft.In(regularParamCount) // This is []T for ...T
+ sliceArg, err := c.resolve(variadicParamT)
+ if err != nil {
+ return nil, fmt.Errorf("%s depends on variadic %s: %w", name, variadicParamT, err)
+ }
+
+ // Convert []T to individual arguments for variadic call
+ if sliceArg.Kind() != reflect.Slice {
+ return nil, fmt.Errorf("%s: variadic parameter resolved to non-slice type %s", name, sliceArg.Type())
+ }
+
+ // Expand slice elements as individual arguments
+ for i := range sliceArg.Len() {
+ args = append(args, sliceArg.Index(i))
+ }
+ }
+
+ return args, nil
+}
+
+func (c *collector) resolveParam(paramT reflect.Type, ctorName string) (reflect.Value, error) {
+ arg, err := c.resolve(paramT)
+ if err != nil {
+ return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctorName, paramT, err)
+ }
+ if !isAssignableOrImpl(arg.Type(), paramT) {
+ return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctorName, arg.Type(), paramT)
+ }
+ return arg, nil
+}
+
+func validateCtorSignature(ft reflect.Type, name string) error {
nout := ft.NumOut()
if nout == 1 {
return nil
@@ 591,3 591,244 @@ func TestBuildPrimitiveTypes(t *testing.T) {
})
}
}
+
+func TestVariadicSupport(t *testing.T) {
+ type Service struct {
+ ID int
+ Name string
+ }
+
+ type Aggregator struct {
+ Services []Service
+ Total int
+ }
+
+ tests := []struct {
+ name string
+ config any
+ result any
+ verify func(t *testing.T, result any)
+ }{
+ {
+ name: "variadic constructor with multiple services",
+ config: struct {
+ Services []Service
+ Aggregator func(...Service) *Aggregator
+ }{
+ Services: []Service{
+ {ID: 1, Name: "service1"},
+ {ID: 2, Name: "service2"},
+ {ID: 3, Name: "service3"},
+ },
+ Aggregator: func(services ...Service) *Aggregator {
+ return &Aggregator{
+ Services: services,
+ Total: len(services),
+ }
+ },
+ },
+ result: &struct {
+ Aggregator *Aggregator
+ }{},
+ verify: func(t *testing.T, result any) {
+ res := result.(*struct {
+ Aggregator *Aggregator
+ })
+ if res.Aggregator == nil {
+ t.Fatalf("expected Aggregator to be populated")
+ }
+ if res.Aggregator.Total != 3 {
+ t.Fatalf("expected Total=3, got %d", res.Aggregator.Total)
+ }
+ if len(res.Aggregator.Services) != 3 {
+ t.Fatalf("expected 3 services, got %d", len(res.Aggregator.Services))
+ }
+ expectedServices := []Service{
+ {ID: 1, Name: "service1"},
+ {ID: 2, Name: "service2"},
+ {ID: 3, Name: "service3"},
+ }
+ for i, svc := range res.Aggregator.Services {
+ if svc != expectedServices[i] {
+ t.Fatalf("service[%d]: expected %+v, got %+v", i, expectedServices[i], svc)
+ }
+ }
+ },
+ },
+ {
+ name: "variadic constructor with empty slice",
+ config: struct {
+ Services []Service
+ Aggregator func(...Service) *Aggregator
+ }{
+ Services: []Service{}, // explicit empty slice
+ Aggregator: func(services ...Service) *Aggregator {
+ return &Aggregator{
+ Services: services,
+ Total: len(services),
+ }
+ },
+ },
+ result: &struct {
+ Aggregator *Aggregator
+ }{},
+ verify: func(t *testing.T, result any) {
+ res := result.(*struct {
+ Aggregator *Aggregator
+ })
+ if res.Aggregator == nil {
+ t.Fatalf("expected Aggregator to be populated")
+ }
+ if res.Aggregator.Total != 0 {
+ t.Fatalf("expected Total=0, got %d", res.Aggregator.Total)
+ }
+ if len(res.Aggregator.Services) != 0 {
+ t.Fatalf("expected 0 services, got %d", len(res.Aggregator.Services))
+ }
+ },
+ },
+ {
+ name: "variadic constructor with Provide",
+ config: struct {
+ Services []Provide[Service]
+ Aggregator Provide[*Aggregator]
+ }{
+ Services: []Provide[Service]{
+ MustProvide[Service](Service{ID: 10, Name: "provided1"}),
+ MustProvide[Service](func() Service {
+ return Service{ID: 20, Name: "provided2"}
+ }),
+ },
+ Aggregator: MustProvide[*Aggregator](func(services ...Service) *Aggregator {
+ return &Aggregator{
+ Services: services,
+ Total: len(services),
+ }
+ }),
+ },
+ result: &struct {
+ Aggregator *Aggregator
+ }{},
+ verify: func(t *testing.T, result any) {
+ res := result.(*struct {
+ Aggregator *Aggregator
+ })
+ if res.Aggregator == nil {
+ t.Fatalf("expected Aggregator to be populated")
+ }
+ if res.Aggregator.Total != 2 {
+ t.Fatalf("expected Total=2, got %d", res.Aggregator.Total)
+ }
+ if len(res.Aggregator.Services) != 2 {
+ t.Fatalf("expected 2 services, got %d", len(res.Aggregator.Services))
+ }
+ if res.Aggregator.Services[0].ID != 10 || res.Aggregator.Services[0].Name != "provided1" {
+ t.Fatalf("unexpected first service: %+v", res.Aggregator.Services[0])
+ }
+ if res.Aggregator.Services[1].ID != 20 || res.Aggregator.Services[1].Name != "provided2" {
+ t.Fatalf("unexpected second service: %+v", res.Aggregator.Services[1])
+ }
+ },
+ },
+ {
+ name: "variadic constructor with mixed parameters",
+ config: struct {
+ Name Service // Use Service instead of string to avoid primitive type issue
+ Services []Service
+ Aggregator func(Service, ...Service) *Aggregator
+ }{
+ Name: Service{ID: 0, Name: "MyAggregator"},
+ Services: []Service{
+ {ID: 1, Name: "service1"},
+ },
+ Aggregator: func(name Service, services ...Service) *Aggregator {
+ // Verify that the name parameter comes first
+ if name.Name != "MyAggregator" {
+ panic(fmt.Sprintf("expected name 'MyAggregator', got '%s'", name.Name))
+ }
+ return &Aggregator{
+ Services: services,
+ Total: len(services),
+ }
+ },
+ },
+ result: &struct {
+ Aggregator *Aggregator
+ }{},
+ verify: func(t *testing.T, result any) {
+ res := result.(*struct {
+ Aggregator *Aggregator
+ })
+ if res.Aggregator == nil {
+ t.Fatalf("expected Aggregator to be populated")
+ }
+ if res.Aggregator.Total != 1 {
+ t.Fatalf("expected Total=1, got %d", res.Aggregator.Total)
+ }
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := Build(tt.config, tt.result)
+ if err != nil {
+ t.Fatalf("Build failed: %v", err)
+ }
+ tt.verify(t, tt.result)
+ })
+ }
+}
+
+func TestVariadicErrors(t *testing.T) {
+ type Service struct {
+ ID int
+ }
+
+ type Result struct {
+ Name string
+ }
+
+ tests := []struct {
+ name string
+ config any
+ result any
+ expectedErrors []string
+ }{
+ {
+ name: "variadic constructor missing slice dependency",
+ config: struct {
+ MakeResult func(...Service) *Result
+ }{
+ MakeResult: func(services ...Service) *Result {
+ return &Result{Name: "test"}
+ },
+ },
+ result: &struct {
+ Result *Result
+ }{},
+ expectedErrors: []string{"no provider for []di.Service"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := Build(tt.config, tt.result)
+ if err == nil {
+ t.Fatalf("expected error")
+ }
+
+ errorStr := err.Error()
+ var foundError bool
+ for _, expectedErr := range tt.expectedErrors {
+ if strings.Contains(errorStr, expectedErr) {
+ foundError = true
+ break
+ }
+ }
+ if !foundError {
+ t.Fatalf("expected error to contain one of %v, got: %v", tt.expectedErrors, err)
+ }
+ })
+ }
+}