~marcopolo/di

9a4fa0450032ce807ceb6b514c8508d4fe34d520 — Marco Munizaga 9 months ago c1c6d41
add support for variadics
2 files changed, 296 insertions(+), 26 deletions(-)

M di.go
M di_test.go
M di.go => di.go +55 -26
@@ 439,19 439,10 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) {
				}
				c.resolving[t] = true
				ft := ctor.fn.Type()
				args := make([]reflect.Value, ft.NumIn())
				for i := 0; i < ft.NumIn(); i++ {
					paramT := ft.In(i)
					arg, err := c.resolve(paramT)
					if err != nil {
						delete(c.resolving, t)
						return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctor.name, paramT, err)
					}
					if !isAssignableOrImpl(arg.Type(), paramT) {
						delete(c.resolving, t)
						return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctor.name, arg.Type(), paramT)
					}
					args[i] = arg
				args, err := c.prepareArgs(ft, ctor.name)
				if err != nil {
					delete(c.resolving, t)
					return reflect.Value{}, err
				}
				outs := ctor.fn.Call(args)
				delete(c.resolving, t)


@@ 521,17 512,9 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) {

	impl := candidates[0]
	ft := impl.fn.Type()
	args := make([]reflect.Value, ft.NumIn())
	for i := 0; i < ft.NumIn(); i++ {
		paramT := ft.In(i)
		arg, err := c.resolve(paramT)
		if err != nil {
			return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", impl.name, paramT, err)
		}
		if !isAssignableOrImpl(arg.Type(), paramT) {
			return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", impl.name, arg.Type(), paramT)
		}
		args[i] = arg
	args, err := c.prepareArgs(ft, impl.name)
	if err != nil {
		return reflect.Value{}, err
	}
	outs := impl.fn.Call(args)
	if len(outs) == 2 {


@@ 545,10 528,56 @@ func (c *collector) resolve(t reflect.Type) (reflect.Value, error) {
	return outs[0], nil
}

func validateCtorSignature(ft reflect.Type, name string) error {
func (c *collector) prepareArgs(ft reflect.Type, name string) ([]reflect.Value, error) {
	regularParamCount := ft.NumIn()
	args := make([]reflect.Value, 0, regularParamCount)
	if ft.IsVariadic() {
		return fmt.Errorf("constructor %q: variadics not supported", name)
		regularParamCount--
	}

	// Handle regular parameters
	for i := range regularParamCount {
		arg, err := c.resolveParam(ft.In(i), name)
		if err != nil {
			return nil, err
		}
		args = append(args, arg)
	}

	if ft.IsVariadic() {
		// Handle variadic parameter ...T by resolving []T
		variadicParamT := ft.In(regularParamCount) // This is []T for ...T
		sliceArg, err := c.resolve(variadicParamT)
		if err != nil {
			return nil, fmt.Errorf("%s depends on variadic %s: %w", name, variadicParamT, err)
		}

		// Convert []T to individual arguments for variadic call
		if sliceArg.Kind() != reflect.Slice {
			return nil, fmt.Errorf("%s: variadic parameter resolved to non-slice type %s", name, sliceArg.Type())
		}

		// Expand slice elements as individual arguments
		for i := range sliceArg.Len() {
			args = append(args, sliceArg.Index(i))
		}
	}

	return args, nil
}

func (c *collector) resolveParam(paramT reflect.Type, ctorName string) (reflect.Value, error) {
	arg, err := c.resolve(paramT)
	if err != nil {
		return reflect.Value{}, fmt.Errorf("%s depends on %s: %w", ctorName, paramT, err)
	}
	if !isAssignableOrImpl(arg.Type(), paramT) {
		return reflect.Value{}, fmt.Errorf("%s: cannot use %s as %s", ctorName, arg.Type(), paramT)
	}
	return arg, nil
}

func validateCtorSignature(ft reflect.Type, name string) error {
	nout := ft.NumOut()
	if nout == 1 {
		return nil

M di_test.go => di_test.go +241 -0
@@ 591,3 591,244 @@ func TestBuildPrimitiveTypes(t *testing.T) {
		})
	}
}

func TestVariadicSupport(t *testing.T) {
	type Service struct {
		ID   int
		Name string
	}

	type Aggregator struct {
		Services []Service
		Total    int
	}

	tests := []struct {
		name   string
		config any
		result any
		verify func(t *testing.T, result any)
	}{
		{
			name: "variadic constructor with multiple services",
			config: struct {
				Services   []Service
				Aggregator func(...Service) *Aggregator
			}{
				Services: []Service{
					{ID: 1, Name: "service1"},
					{ID: 2, Name: "service2"},
					{ID: 3, Name: "service3"},
				},
				Aggregator: func(services ...Service) *Aggregator {
					return &Aggregator{
						Services: services,
						Total:    len(services),
					}
				},
			},
			result: &struct {
				Aggregator *Aggregator
			}{},
			verify: func(t *testing.T, result any) {
				res := result.(*struct {
					Aggregator *Aggregator
				})
				if res.Aggregator == nil {
					t.Fatalf("expected Aggregator to be populated")
				}
				if res.Aggregator.Total != 3 {
					t.Fatalf("expected Total=3, got %d", res.Aggregator.Total)
				}
				if len(res.Aggregator.Services) != 3 {
					t.Fatalf("expected 3 services, got %d", len(res.Aggregator.Services))
				}
				expectedServices := []Service{
					{ID: 1, Name: "service1"},
					{ID: 2, Name: "service2"},
					{ID: 3, Name: "service3"},
				}
				for i, svc := range res.Aggregator.Services {
					if svc != expectedServices[i] {
						t.Fatalf("service[%d]: expected %+v, got %+v", i, expectedServices[i], svc)
					}
				}
			},
		},
		{
			name: "variadic constructor with empty slice",
			config: struct {
				Services   []Service
				Aggregator func(...Service) *Aggregator
			}{
				Services: []Service{}, // explicit empty slice
				Aggregator: func(services ...Service) *Aggregator {
					return &Aggregator{
						Services: services,
						Total:    len(services),
					}
				},
			},
			result: &struct {
				Aggregator *Aggregator
			}{},
			verify: func(t *testing.T, result any) {
				res := result.(*struct {
					Aggregator *Aggregator
				})
				if res.Aggregator == nil {
					t.Fatalf("expected Aggregator to be populated")
				}
				if res.Aggregator.Total != 0 {
					t.Fatalf("expected Total=0, got %d", res.Aggregator.Total)
				}
				if len(res.Aggregator.Services) != 0 {
					t.Fatalf("expected 0 services, got %d", len(res.Aggregator.Services))
				}
			},
		},
		{
			name: "variadic constructor with Provide",
			config: struct {
				Services   []Provide[Service]
				Aggregator Provide[*Aggregator]
			}{
				Services: []Provide[Service]{
					MustProvide[Service](Service{ID: 10, Name: "provided1"}),
					MustProvide[Service](func() Service {
						return Service{ID: 20, Name: "provided2"}
					}),
				},
				Aggregator: MustProvide[*Aggregator](func(services ...Service) *Aggregator {
					return &Aggregator{
						Services: services,
						Total:    len(services),
					}
				}),
			},
			result: &struct {
				Aggregator *Aggregator
			}{},
			verify: func(t *testing.T, result any) {
				res := result.(*struct {
					Aggregator *Aggregator
				})
				if res.Aggregator == nil {
					t.Fatalf("expected Aggregator to be populated")
				}
				if res.Aggregator.Total != 2 {
					t.Fatalf("expected Total=2, got %d", res.Aggregator.Total)
				}
				if len(res.Aggregator.Services) != 2 {
					t.Fatalf("expected 2 services, got %d", len(res.Aggregator.Services))
				}
				if res.Aggregator.Services[0].ID != 10 || res.Aggregator.Services[0].Name != "provided1" {
					t.Fatalf("unexpected first service: %+v", res.Aggregator.Services[0])
				}
				if res.Aggregator.Services[1].ID != 20 || res.Aggregator.Services[1].Name != "provided2" {
					t.Fatalf("unexpected second service: %+v", res.Aggregator.Services[1])
				}
			},
		},
		{
			name: "variadic constructor with mixed parameters",
			config: struct {
				Name       Service // Use Service instead of string to avoid primitive type issue
				Services   []Service
				Aggregator func(Service, ...Service) *Aggregator
			}{
				Name: Service{ID: 0, Name: "MyAggregator"},
				Services: []Service{
					{ID: 1, Name: "service1"},
				},
				Aggregator: func(name Service, services ...Service) *Aggregator {
					// Verify that the name parameter comes first
					if name.Name != "MyAggregator" {
						panic(fmt.Sprintf("expected name 'MyAggregator', got '%s'", name.Name))
					}
					return &Aggregator{
						Services: services,
						Total:    len(services),
					}
				},
			},
			result: &struct {
				Aggregator *Aggregator
			}{},
			verify: func(t *testing.T, result any) {
				res := result.(*struct {
					Aggregator *Aggregator
				})
				if res.Aggregator == nil {
					t.Fatalf("expected Aggregator to be populated")
				}
				if res.Aggregator.Total != 1 {
					t.Fatalf("expected Total=1, got %d", res.Aggregator.Total)
				}
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			err := Build(tt.config, tt.result)
			if err != nil {
				t.Fatalf("Build failed: %v", err)
			}
			tt.verify(t, tt.result)
		})
	}
}

func TestVariadicErrors(t *testing.T) {
	type Service struct {
		ID int
	}

	type Result struct {
		Name string
	}

	tests := []struct {
		name           string
		config         any
		result         any
		expectedErrors []string
	}{
		{
			name: "variadic constructor missing slice dependency",
			config: struct {
				MakeResult func(...Service) *Result
			}{
				MakeResult: func(services ...Service) *Result {
					return &Result{Name: "test"}
				},
			},
			result: &struct {
				Result *Result
			}{},
			expectedErrors: []string{"no provider for []di.Service"},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			err := Build(tt.config, tt.result)
			if err == nil {
				t.Fatalf("expected error")
			}

			errorStr := err.Error()
			var foundError bool
			for _, expectedErr := range tt.expectedErrors {
				if strings.Contains(errorStr, expectedErr) {
					foundError = true
					break
				}
			}
			if !foundError {
				t.Fatalf("expected error to contain one of %v, got: %v", tt.expectedErrors, err)
			}
		})
	}
}