From 9a4fa0450032ce807ceb6b514c8508d4fe34d520 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 18 Sep 2025 17:26:35 -0700 Subject: [PATCH] add support for variadics --- di.go | 81 ++++++++++++------ di_test.go | 241 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 296 insertions(+), 26 deletions(-) diff --git a/di.go b/di.go index 5ef16c2e3f172b6ce6a6ca9aedf239c65d692723..43607ff817d2568005e70334df48b07f644d5f27 100644 --- a/di.go +++ b/di.go @@ -439,19 +439,10 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) { } c.resolving[t] = true ft := ctor.fn.Type() - args := make([]reflect.Value, ft.NumIn()) - for i := 0; i < ft.NumIn(); i++ { - paramT := ft.In(i) - arg, err := c.resolve(paramT) - if err != nil { - delete(c.resolving, t) - return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctor.name, paramT, err) - } - if !isAssignableOrImpl(arg.Type(), paramT) { - delete(c.resolving, t) - return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctor.name, arg.Type(), paramT) - } - args[i] = arg + args, err := c.prepareArgs(ft, ctor.name) + if err != nil { + delete(c.resolving, t) + return reflect.Value{}, err } outs := ctor.fn.Call(args) delete(c.resolving, t) @@ -521,17 +512,9 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) { impl := candidates[0] ft := impl.fn.Type() - args := make([]reflect.Value, ft.NumIn()) - for i := 0; i < ft.NumIn(); i++ { - paramT := ft.In(i) - arg, err := c.resolve(paramT) - if err != nil { - return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", impl.name, paramT, err) - } - if !isAssignableOrImpl(arg.Type(), paramT) { - return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", impl.name, arg.Type(), paramT) - } - args[i] = arg + args, err := c.prepareArgs(ft, impl.name) + if err != nil { + return reflect.Value{}, err } outs := impl.fn.Call(args) if len(outs) == 2 { @@ -545,10 +528,56 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) { return outs[0], nil } -func validateCtorSignature(ft reflect.Type, name string) error { +func (c *collector) prepareArgs(ft reflect.Type, name string) ([]reflect.Value, error) { + regularParamCount := ft.NumIn() + args := make([]reflect.Value, 0, regularParamCount) if ft.IsVariadic() { - return fmt.Errorf("constructor %q: variadics not supported", name) + regularParamCount-- } + + // Handle regular parameters + for i := range regularParamCount { + arg, err := c.resolveParam(ft.In(i), name) + if err != nil { + return nil, err + } + args = append(args, arg) + } + + if ft.IsVariadic() { + // Handle variadic parameter ...T by resolving []T + variadicParamT := ft.In(regularParamCount) // This is []T for ...T + sliceArg, err := c.resolve(variadicParamT) + if err != nil { + return nil, fmt.Errorf("%s depends on variadic %s: %w", name, variadicParamT, err) + } + + // Convert []T to individual arguments for variadic call + if sliceArg.Kind() != reflect.Slice { + return nil, fmt.Errorf("%s: variadic parameter resolved to non-slice type %s", name, sliceArg.Type()) + } + + // Expand slice elements as individual arguments + for i := range sliceArg.Len() { + args = append(args, sliceArg.Index(i)) + } + } + + return args, nil +} + +func (c *collector) resolveParam(paramT reflect.Type, ctorName string) (reflect.Value, error) { + arg, err := c.resolve(paramT) + if err != nil { + return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctorName, paramT, err) + } + if !isAssignableOrImpl(arg.Type(), paramT) { + return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctorName, arg.Type(), paramT) + } + return arg, nil +} + +func validateCtorSignature(ft reflect.Type, name string) error { nout := ft.NumOut() if nout == 1 { return nil diff --git a/di_test.go b/di_test.go index 8db93312cd131e77b2bdf26dda54b2659326be37..1be85359971761a077daed5e17d4442d2483004b 100644 --- a/di_test.go +++ b/di_test.go @@ -591,3 +591,244 @@ func TestBuildPrimitiveTypes(t *testing.T) { }) } } + +func TestVariadicSupport(t *testing.T) { + type Service struct { + ID int + Name string + } + + type Aggregator struct { + Services []Service + Total int + } + + tests := []struct { + name string + config any + result any + verify func(t *testing.T, result any) + }{ + { + name: "variadic constructor with multiple services", + config: struct { + Services []Service + Aggregator func(...Service) *Aggregator + }{ + Services: []Service{ + {ID: 1, Name: "service1"}, + {ID: 2, Name: "service2"}, + {ID: 3, Name: "service3"}, + }, + Aggregator: func(services ...Service) *Aggregator { + return &Aggregator{ + Services: services, + Total: len(services), + } + }, + }, + result: &struct { + Aggregator *Aggregator + }{}, + verify: func(t *testing.T, result any) { + res := result.(*struct { + Aggregator *Aggregator + }) + if res.Aggregator == nil { + t.Fatalf("expected Aggregator to be populated") + } + if res.Aggregator.Total != 3 { + t.Fatalf("expected Total=3, got %d", res.Aggregator.Total) + } + if len(res.Aggregator.Services) != 3 { + t.Fatalf("expected 3 services, got %d", len(res.Aggregator.Services)) + } + expectedServices := []Service{ + {ID: 1, Name: "service1"}, + {ID: 2, Name: "service2"}, + {ID: 3, Name: "service3"}, + } + for i, svc := range res.Aggregator.Services { + if svc != expectedServices[i] { + t.Fatalf("service[%d]: expected %+v, got %+v", i, expectedServices[i], svc) + } + } + }, + }, + { + name: "variadic constructor with empty slice", + config: struct { + Services []Service + Aggregator func(...Service) *Aggregator + }{ + Services: []Service{}, // explicit empty slice + Aggregator: func(services ...Service) *Aggregator { + return &Aggregator{ + Services: services, + Total: len(services), + } + }, + }, + result: &struct { + Aggregator *Aggregator + }{}, + verify: func(t *testing.T, result any) { + res := result.(*struct { + Aggregator *Aggregator + }) + if res.Aggregator == nil { + t.Fatalf("expected Aggregator to be populated") + } + if res.Aggregator.Total != 0 { + t.Fatalf("expected Total=0, got %d", res.Aggregator.Total) + } + if len(res.Aggregator.Services) != 0 { + t.Fatalf("expected 0 services, got %d", len(res.Aggregator.Services)) + } + }, + }, + { + name: "variadic constructor with Provide", + config: struct { + Services []Provide[Service] + Aggregator Provide[*Aggregator] + }{ + Services: []Provide[Service]{ + MustProvide[Service](Service{ID: 10, Name: "provided1"}), + MustProvide[Service](func() Service { + return Service{ID: 20, Name: "provided2"} + }), + }, + Aggregator: MustProvide[*Aggregator](func(services ...Service) *Aggregator { + return &Aggregator{ + Services: services, + Total: len(services), + } + }), + }, + result: &struct { + Aggregator *Aggregator + }{}, + verify: func(t *testing.T, result any) { + res := result.(*struct { + Aggregator *Aggregator + }) + if res.Aggregator == nil { + t.Fatalf("expected Aggregator to be populated") + } + if res.Aggregator.Total != 2 { + t.Fatalf("expected Total=2, got %d", res.Aggregator.Total) + } + if len(res.Aggregator.Services) != 2 { + t.Fatalf("expected 2 services, got %d", len(res.Aggregator.Services)) + } + if res.Aggregator.Services[0].ID != 10 || res.Aggregator.Services[0].Name != "provided1" { + t.Fatalf("unexpected first service: %+v", res.Aggregator.Services[0]) + } + if res.Aggregator.Services[1].ID != 20 || res.Aggregator.Services[1].Name != "provided2" { + t.Fatalf("unexpected second service: %+v", res.Aggregator.Services[1]) + } + }, + }, + { + name: "variadic constructor with mixed parameters", + config: struct { + Name Service // Use Service instead of string to avoid primitive type issue + Services []Service + Aggregator func(Service, ...Service) *Aggregator + }{ + Name: Service{ID: 0, Name: "MyAggregator"}, + Services: []Service{ + {ID: 1, Name: "service1"}, + }, + Aggregator: func(name Service, services ...Service) *Aggregator { + // Verify that the name parameter comes first + if name.Name != "MyAggregator" { + panic(fmt.Sprintf("expected name 'MyAggregator', got '%s'", name.Name)) + } + return &Aggregator{ + Services: services, + Total: len(services), + } + }, + }, + result: &struct { + Aggregator *Aggregator + }{}, + verify: func(t *testing.T, result any) { + res := result.(*struct { + Aggregator *Aggregator + }) + if res.Aggregator == nil { + t.Fatalf("expected Aggregator to be populated") + } + if res.Aggregator.Total != 1 { + t.Fatalf("expected Total=1, got %d", res.Aggregator.Total) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Build(tt.config, tt.result) + if err != nil { + t.Fatalf("Build failed: %v", err) + } + tt.verify(t, tt.result) + }) + } +} + +func TestVariadicErrors(t *testing.T) { + type Service struct { + ID int + } + + type Result struct { + Name string + } + + tests := []struct { + name string + config any + result any + expectedErrors []string + }{ + { + name: "variadic constructor missing slice dependency", + config: struct { + MakeResult func(...Service) *Result + }{ + MakeResult: func(services ...Service) *Result { + return &Result{Name: "test"} + }, + }, + result: &struct { + Result *Result + }{}, + expectedErrors: []string{"no provider for []di.Service"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Build(tt.config, tt.result) + if err == nil { + t.Fatalf("expected error") + } + + errorStr := err.Error() + var foundError bool + for _, expectedErr := range tt.expectedErrors { + if strings.Contains(errorStr, expectedErr) { + foundError = true + break + } + } + if !foundError { + t.Fatalf("expected error to contain one of %v, got: %v", tt.expectedErrors, err) + } + }) + } +}