From c1c6d416254b8ddd358e685782a25045b97f3abb Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Tue, 26 Aug 2025 12:10:36 -0700 Subject: [PATCH] check if provider is explicitly set --- di.go | 25 +++++++++++++++---------- di_test.go | 19 ++++++++++++++++++- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/di.go b/di.go index 00660c60599fab8424fb75dfe8a6b9de1f4fcf78..5ef16c2e3f172b6ce6a6ca9aedf239c65d692723 100644 --- a/di.go +++ b/di.go @@ -93,13 +93,15 @@ func MustProvide[Out any](ctorOrVal any) Provide[Out] { } func NewProvide[Out any](ctorOrVal any) (Provide[Out], error) { + var zero Provide[Out] + out := Provide[Out]{set: true} outType := typeOf[Out]() if ctorOrVal == nil { if canBeNil(outType) { - return Provide[Out]{fOrV: nil}, nil + return out, nil } - return Provide[Out]{}, fmt.Errorf("Provide[%v]: nil not valid for non-nilable type", outType) + return zero, fmt.Errorf("Provide[%v]: nil not valid for non-nilable type", outType) } t := reflect.TypeOf(ctorOrVal) @@ -110,33 +112,36 @@ func NewProvide[Out any](ctorOrVal any) (Provide[Out], error) { switch nout { case 1: if !t.Out(0).AssignableTo(outType) { - return Provide[Out]{}, fmt.Errorf("Provide[%v]: function return %v is not assignable to %v", + return zero, fmt.Errorf("Provide[%v]: function return %v is not assignable to %v", outType, t.Out(0), outType) } - return Provide[Out]{fOrV: ctorOrVal}, nil + out.fOrV = ctorOrVal + return out, nil case 2: if !t.Out(0).AssignableTo(outType) { - return Provide[Out]{}, fmt.Errorf("Provide[%v]: first return %v is not assignable to %v", + return zero, fmt.Errorf("Provide[%v]: first return %v is not assignable to %v", outType, t.Out(0), outType) } if !isErrorType(t.Out(1)) { - return Provide[Out]{}, fmt.Errorf("Provide[%v]: second return must be error, got %v", + return zero, fmt.Errorf("Provide[%v]: second return must be error, got %v", outType, t.Out(1)) } - return Provide[Out]{fOrV: ctorOrVal}, nil + out.fOrV = ctorOrVal + return out, nil default: - return Provide[Out]{}, fmt.Errorf("Provide[%v]: function must return Out or (Out, error); got %d returns", + return zero, fmt.Errorf("Provide[%v]: function must return Out or (Out, error); got %d returns", outType, nout) } } // Case 2: value assignable to Out (covers interface satisfaction) if !t.AssignableTo(outType) { - return Provide[Out]{}, fmt.Errorf("Provide[%v]: value of type %v is not assignable to %v", outType, t, outType) + return zero, fmt.Errorf("Provide[%v]: value of type %v is not assignable to %v", outType, t, outType) } - return Provide[Out]{fOrV: ctorOrVal}, nil + out.fOrV = ctorOrVal + return out, nil } // Helpers diff --git a/di_test.go b/di_test.go index 1e5d0893541cf969e667cc00763c9fffa6015ebc..8db93312cd131e77b2bdf26dda54b2659326be37 100644 --- a/di_test.go +++ b/di_test.go @@ -474,6 +474,21 @@ func TestBuildErrors(t *testing.T) { } }, }, + { + name: "Zero value provider", + config: struct { + A Provide[*A] + B Provide[*B] + }{ + B: MustProvide[*B](func(*A) *B { + return &B{} + }), + }, + result: &struct { + B *B + }{}, + expectedErrors: []string{"Missing Provide[*di.A]"}, + }, } for _, tt := range tests { @@ -495,7 +510,9 @@ func TestBuildErrors(t *testing.T) { t.Fatalf("expected error to contain one of %v, got: %v", tt.expectedErrors, err) } - tt.verify(t, tt.result) + if tt.verify != nil { + tt.verify(t, tt.result) + } }) } }